def main(num_epoch): system_init() # load data dataset = TrafficDataset(path=cfg.data.path, train_prop=cfg.data.train_prop, valid_prop=cfg.data.valid_prop, num_sensors=cfg.data.num_sensors, in_length=cfg.data.in_length, out_length=cfg.data.out_length, batch_size_per_gpu=cfg.data.batch_size_per_gpu, num_gpus=1) net = AutoSTG(in_length=cfg.data.in_length, out_length=cfg.data.out_length, node_hiddens=[ dataset.node_fts.shape[1], ] + cfg.model.node_hiddens, edge_hiddens=[ dataset.adj_mats.shape[2], ] + cfg.model.edge_hiddens, in_channels=cfg.data.in_channels, out_channels=cfg.data.out_channels, hidden_channels=cfg.model.hidden_channels, skip_channels=cfg.model.skip_channels, end_channels=cfg.model.end_channels, layer_names=cfg.model.layer_names, num_mixed_ops=cfg.model.num_mixed_ops, candidate_op_profiles=cfg.model.candidate_op_profiles) run_manager = RunManager( name=cfg.model.name, net=net, dataset=dataset, arch_lr=cfg.trainer.arch_lr, arch_lr_decay_milestones=cfg.trainer.arch_lr_decay_milestones, arch_lr_decay_ratio=cfg.trainer.arch_lr_decay_ratio, arch_decay=cfg.trainer.arch_decay, arch_clip_gradient=cfg.trainer.arch_clip_gradient, weight_lr=cfg.trainer.weight_lr, weight_lr_decay_milestones=[ 20, 40, 60, 80 ], # cfg.trainer.weight_lr_decay_milestones, weight_lr_decay_ratio=cfg.trainer.weight_lr_decay_ratio, weight_decay=cfg.trainer.weight_decay, weight_clip_gradient=cfg.trainer.weight_clip_gradient, num_search_iterations=cfg.trainer.num_search_iterations, num_search_arch_samples=cfg.trainer.num_search_arch_samples, num_train_iterations=cfg.trainer.num_train_iterations, criterion=cfg.trainer.criterion, metric_names=cfg.trainer.metric_names, metric_indexes=cfg.trainer.metric_indexes, print_frequency=cfg.trainer.print_frequency, device_ids=[0]) run_manager.load(mode='train') run_manager.clear_records() run_manager.initialize() print('# of params', run_manager._net.num_weight_parameters()) run_manager.train(num_epoch)