コード例 #1
0
def main(num_epoch):
    system_init()

    # load data
    dataset = TrafficDataset(path=cfg.data.path,
                             train_prop=cfg.data.train_prop,
                             valid_prop=cfg.data.valid_prop,
                             num_sensors=cfg.data.num_sensors,
                             in_length=cfg.data.in_length,
                             out_length=cfg.data.out_length,
                             batch_size_per_gpu=cfg.data.batch_size_per_gpu,
                             num_gpus=1)

    net = AutoSTG(in_length=cfg.data.in_length,
                  out_length=cfg.data.out_length,
                  node_hiddens=[
                      dataset.node_fts.shape[1],
                  ] + cfg.model.node_hiddens,
                  edge_hiddens=[
                      dataset.adj_mats.shape[2],
                  ] + cfg.model.edge_hiddens,
                  in_channels=cfg.data.in_channels,
                  out_channels=cfg.data.out_channels,
                  hidden_channels=cfg.model.hidden_channels,
                  skip_channels=cfg.model.skip_channels,
                  end_channels=cfg.model.end_channels,
                  layer_names=cfg.model.layer_names,
                  num_mixed_ops=cfg.model.num_mixed_ops,
                  candidate_op_profiles=cfg.model.candidate_op_profiles)

    run_manager = RunManager(
        name=cfg.model.name,
        net=net,
        dataset=dataset,
        arch_lr=cfg.trainer.arch_lr,
        arch_lr_decay_milestones=cfg.trainer.arch_lr_decay_milestones,
        arch_lr_decay_ratio=cfg.trainer.arch_lr_decay_ratio,
        arch_decay=cfg.trainer.arch_decay,
        arch_clip_gradient=cfg.trainer.arch_clip_gradient,
        weight_lr=cfg.trainer.weight_lr,
        weight_lr_decay_milestones=[
            20, 40, 60, 80
        ],  # cfg.trainer.weight_lr_decay_milestones,
        weight_lr_decay_ratio=cfg.trainer.weight_lr_decay_ratio,
        weight_decay=cfg.trainer.weight_decay,
        weight_clip_gradient=cfg.trainer.weight_clip_gradient,
        num_search_iterations=cfg.trainer.num_search_iterations,
        num_search_arch_samples=cfg.trainer.num_search_arch_samples,
        num_train_iterations=cfg.trainer.num_train_iterations,
        criterion=cfg.trainer.criterion,
        metric_names=cfg.trainer.metric_names,
        metric_indexes=cfg.trainer.metric_indexes,
        print_frequency=cfg.trainer.print_frequency,
        device_ids=[0])

    run_manager.load(mode='train')
    run_manager.clear_records()
    run_manager.initialize()
    print('# of params', run_manager._net.num_weight_parameters())
    run_manager.train(num_epoch)