コード例 #1
0
def main(num_epoch):
    system_init()

    # load data
    dataset = TrafficDataset(path=cfg.data.path,
                             train_prop=cfg.data.train_prop,
                             valid_prop=cfg.data.valid_prop,
                             num_sensors=cfg.data.num_sensors,
                             in_length=cfg.data.in_length,
                             out_length=cfg.data.out_length,
                             batch_size_per_gpu=cfg.data.batch_size_per_gpu,
                             num_gpus=1)

    net = AutoSTG(in_length=cfg.data.in_length,
                  out_length=cfg.data.out_length,
                  node_hiddens=[
                      dataset.node_fts.shape[1],
                  ] + cfg.model.node_hiddens,
                  edge_hiddens=[
                      dataset.adj_mats.shape[2],
                  ] + cfg.model.edge_hiddens,
                  in_channels=cfg.data.in_channels,
                  out_channels=cfg.data.out_channels,
                  hidden_channels=cfg.model.hidden_channels,
                  skip_channels=cfg.model.skip_channels,
                  end_channels=cfg.model.end_channels,
                  layer_names=cfg.model.layer_names,
                  num_mixed_ops=cfg.model.num_mixed_ops,
                  candidate_op_profiles=cfg.model.candidate_op_profiles)

    run_manager = RunManager(
        name=cfg.model.name,
        net=net,
        dataset=dataset,
        arch_lr=cfg.trainer.arch_lr,
        arch_lr_decay_milestones=cfg.trainer.arch_lr_decay_milestones,
        arch_lr_decay_ratio=cfg.trainer.arch_lr_decay_ratio,
        arch_decay=cfg.trainer.arch_decay,
        arch_clip_gradient=cfg.trainer.arch_clip_gradient,
        weight_lr=cfg.trainer.weight_lr,
        weight_lr_decay_milestones=[
            20, 40, 60, 80
        ],  # cfg.trainer.weight_lr_decay_milestones,
        weight_lr_decay_ratio=cfg.trainer.weight_lr_decay_ratio,
        weight_decay=cfg.trainer.weight_decay,
        weight_clip_gradient=cfg.trainer.weight_clip_gradient,
        num_search_iterations=cfg.trainer.num_search_iterations,
        num_search_arch_samples=cfg.trainer.num_search_arch_samples,
        num_train_iterations=cfg.trainer.num_train_iterations,
        criterion=cfg.trainer.criterion,
        metric_names=cfg.trainer.metric_names,
        metric_indexes=cfg.trainer.metric_indexes,
        print_frequency=cfg.trainer.print_frequency,
        device_ids=[0])

    run_manager.load(mode='train')
    run_manager.clear_records()
    run_manager.initialize()
    print('# of params', run_manager._net.num_weight_parameters())
    run_manager.train(num_epoch)
コード例 #2
0
 # else:
 # print (run_config.init_lr)
 # build net from args
 from search.models.normal_nets import *
 net_config_url = "https://hanlab.mit.edu/files/proxylessNAS/proxyless_cifar.config"
 # net_config_url2 = 'https://hanlab.mit.edu/files/proxylessNAS/proxyless_cpu.config'
 # print(net_config_url)
 net = pyramid_base(net_config=net_config_url,
                    n_classes=run_config.data_provider.n_classes,
                    bn_param=(args.bn_momentum, args.bn_eps),
                    dropout_rate=args.dropout)
 # net2 = proxyless_base(net_config=net_config_url2, n_classes=run_config.data_provider.n_classes,
 #     bn_param=(args.bn_momentum, args.bn_eps), dropout_rate=args.dropout)
 # # build run manager
 run_manager = RunManager(args.path,
                          net,
                          run_config,
                          measure_latency=args.latency)
 # run_manager.save_config(print_info=True)
 # load checkpoints
 init_path = '%s/init' % args.path
 if args.resume:
     run_manager.load_model()
     if args.train and run_manager.best_acc == 0:
         loss, acc1, acc5 = run_manager.validate(ist=False,
                                                 return_top5=True)
         run_manager.best_acc = acc1
 elif os.path.isfile(init_path):
     if torch.cuda.is_available():
         checkpoint = torch.load(init_path)
     else:
         checkpoint = torch.load(init_path, map_location='cpu')
コード例 #3
0
                     network=['vae64'],
                     shuffle=[True],
                     num_workers=[5],
                     loss='vae',
                     model='vae64_1')

train_set = SLFDatasetUnsampled(root_dir=os.path.join(ROOT, 'slf_mat'),
                                csv_file=os.path.join(ROOT, 'details.csv'),
                                total_data=500000)

validation_set = SLFDatasetUnsampled(
    root_dir=os.path.join(VALIDATION_SET_PATH, 'slf_mat'),
    csv_file=os.path.join(VALIDATION_SET_PATH, 'details.csv'),
    total_data=5000)

m = RunManager()
vae_loss = VAELoss()
mse_loss = torch.nn.MSELoss()
for run in RunBuilder.get_runs(params):
    device = torch.device(run.device)
    network = networks[run.network]().to(run.device)
    loader = torch.utils.data.DataLoader(train_set,
                                         batch_size=run.batch_size,
                                         shuffle=run.shuffle,
                                         num_workers=run.num_workers)
    validation_loader = torch.utils.data.DataLoader(
        validation_set,
        batch_size=run.batch_size,
        shuffle=run.shuffle,
        num_workers=run.num_workers)
コード例 #4
0
ファイル: train.py プロジェクト: paperscodes/CAP
            net_origin = nn.DataParallel(ResNet_ImageNet(depth=50, num_classes=run_config.data_provider.n_classes))
    elif args.model=="mobilenetv2":
        assert args.dataset=='imagenet', 'mobilenetv2 only supports imagenet dataset'
        net = MobileNetV2(num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
        if args.base_path!=None:
            weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
            net_origin = nn.DataParallel(MobileNetV2(num_classes=run_config.data_provider.n_classes))
    elif args.model=="mobilenet":
        assert args.dataset=='imagenet', 'mobilenet only supports imagenet dataset'
        net = MobileNet(num_classes=run_config.data_provider.n_classes, cfg=eval(args.cfg))
        if args.base_path!=None:
            weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
            net_origin = nn.DataParallel(MobileNet(num_classes=run_config.data_provider.n_classes))

    # build run manager
    run_manager = RunManager(args.path, net, run_config)
    if args.local_rank == 0:
        run_manager.save_config(print_info=True)

    # load checkpoints
    if args.base_path!=None:
        weight_path = args.base_path+'/checkpoint/model_best.pth.tar'
    if args.resume:
        run_manager.load_model()
        if args.train and run_manager.best_acc == 0:
            loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
            run_manager.best_acc = acc1
    elif weight_path!=None and os.path.isfile(weight_path):
        assert net_origin != None, "original network is None"
        net_origin.load_state_dict(torch.load(weight_path)['state_dict'])
        net_origin = net_origin.module
コード例 #5
0
ファイル: trainer.py プロジェクト: AIMedLab/TransICD
def train(model, train_set, dev_set, test_set, hyper_params, batch_size,
          device):
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    m = RunManager()
    optimizer = optim.AdamW(model.parameters(), lr=hyper_params.learning_rate)

    logging.info("Training Started...")
    m.begin_run(hyper_params, model, train_loader)
    for epoch in range(hyper_params.num_epoch):
        m.begin_epoch(epoch + 1)
        model.train()
        for batch in train_loader:
            texts = batch['text']
            lens = batch['length']
            targets = batch['codes']

            texts = texts.to(device)
            targets = targets.to(device)
            outputs, ldam_outputs, _ = model(texts, targets)

            if ldam_outputs is not None:
                loss = F.binary_cross_entropy_with_logits(
                    ldam_outputs, targets)
            else:
                loss = F.binary_cross_entropy_with_logits(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            m.track_loss(loss)
            # m.track_num_correct(preds, affinities)

        m.end_epoch()
    m.end_run()
    hype = '_'.join([f'{k}_{v}' for k, v in hyper_params._asdict().items()])
    m.save(f'../results/train_results_{hype}')
    logging.info("Training finished.\n")

    # Training
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)
    probabs, targets, _, _ = evaluate(model,
                                      train_loader,
                                      device,
                                      dtset='train')
    compute_scores(probabs, targets, hyper_params, dtset='train')

    # Validation
    dev_loader = DataLoader(dev_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)
    probabs, targets, _, _ = evaluate(model, dev_loader, device, dtset='dev')
    compute_scores(probabs, targets, hyper_params, dtset='dev')

    # test_dataset
    test_loader = DataLoader(test_set,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=1)
    probabs, targets, full_hadm_ids, full_attn_weights = evaluate(model,
                                                                  test_loader,
                                                                  device,
                                                                  dtset='test')
    compute_scores(probabs,
                   targets,
                   hyper_params,
                   dtset='test',
                   full_hadm_ids=full_hadm_ids,
                   full_attn_weights=full_attn_weights)
コード例 #6
0
ファイル: game.py プロジェクト: TheoTomalty/colossus
import flags
from run_manager import RunManager
from character import Character
import level

#################################
flags.run_manager = RunManager()#
flags.character = Character()   #
#################################

level.load_level(flags.run_manager.display_manager.map)
flags.run_manager.start_game()
コード例 #7
0
def main(args):
    assert torch.cuda.is_available(), 'CUDA is not available'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

    if args.retrain_resume:
        config_file_path = os.path.join(args.resume_file, 'retrain.config')
        assert os.path.exists(
            config_file_path
        ), 'cannot find config_file {:} from the last retrain phase'.format(
            config_file_path)
        f = open(config_file_path, 'r')
        config_dict = json.load(f)
        f.close()
        configs_resume(args, config_dict, 'retrain')
        # get EXP_time in last_retrain for flag
        EXP_time_last_retrain = config_dict['path'].split('/')[-1]
        EXP_time = time_for_file()
        args.path = os.path.join(
            args.path, args.exp_name,
            EXP_time + '-resume-{:}'.format(EXP_time_last_retrain))
        torch.set_num_threads(args.workers)
        set_manual_seed(
            args.random_seed)  # from the last retrain phase or search phase.
        os.makedirs(args.path, exist_ok=True)
        create_exp_dir(args.path, scripts_to_save=glob.glob('./*/*.py'))
        save_configs(args.__dict__, args.path, 'retrain')
        logger = prepare_logger(args)
        logger.log(
            '=> loading configs {:} from the last retrain phase.'.format(
                config_file_path),
            mode='info')
        if args.search_space == 'autodeeplab':
            conv_candidates = autodeeplab
        elif args.search_space == 'proxyless':
            conv_candidates = proxyless
        elif args.search_space == 'my_search_space':
            conv_candidates = my_search_space
        else:
            raise ValueError('search space {:} is not supported'.format(
                args.search_space))
    else:
        # resume partial configs setting and arch_checkpoint from the search phase by default.
        config_file_path = os.path.join(args.checkpoint_file, 'search.config')
        assert os.path.exists(
            config_file_path
        ), 'cannot find config_file {:} from the search phase'.format(
            config_file_path)
        f = open(config_file_path, 'r')
        config_dict = json.load(f)
        f.close()
        args.random_seed = config_dict['random_seed']
        # get EXP_time in search phase, for flag
        EXP_time_search = config_dict['path'].split('/')[-1]
        EXP_time = time_for_file()
        args.path = os.path.join(
            args.path, args.exp_name,
            EXP_time + '-resume-{:}'.format(EXP_time_search))
        torch.set_num_threads(args.workers)
        set_manual_seed(
            args.random_seed)  # from the last retrain phase or search phase.
        os.makedirs(args.path, exist_ok=True)
        create_exp_dir(args.path, scripts_to_save=glob.glob('./*/*.py'))
        save_configs(args.__dict__, args.path, 'retrain')
        logger = prepare_logger(args)
        logger.log(
            '=> starting retrain from the search phase config {:}.'.format(
                config_file_path),
            mode='info')

        # optimizer params
        if args.weight_optimizer_type == 'SGD':
            weight_optimizer_params = {
                'momentum': args.momentum,
                'nesterov': args.nesterov,
                'weight_decay': args.weight_decay,
            }
        elif args.weight_optimizer_type == 'RMSprop':
            weight_optimizer_params = {
                'momentum': args.momentum,
                'weight_decay': args.weight_decay,
            }
        else:
            weight_optimizer_params = None
        # scheduler params
        if args.scheduler == 'cosine':
            scheduler_params = {'T_max': args.T_max, 'eta_min': args.eta_min}
        elif args.scheduler == 'multistep':
            scheduler_params = {
                'milestones': args.milestones,
                'gammas': args.gammas
            }
        elif args.scheduler == 'exponential':
            scheduler_params = {'gamma': args.gamma}
        elif args.scheduler == 'linear':
            scheduler_params = {'min_lr': args.min_lr}
        else:
            scheduler_params = None
        # criterion params
        if args.criterion == 'SmoothSoftmax':
            criterion_params = {'label_smooth': args.label_smoothing}
        else:
            criterion_params = None

        args.optimizer_config = {
            'optimizer_type': args.weight_optimizer_type,
            'optimizer_params': weight_optimizer_params,
            'scheduler': args.scheduler,
            'scheduler_params': scheduler_params,
            'criterion': args.criterion,
            'criterion_params': criterion_params,
            'init_lr': args.init_lr,
            'epochs': args.epochs,
            'class_num': args.nb_classes,
        }
        if args.search_space == 'autodeeplab':
            conv_candidates = autodeeplab
        elif args.search_space == 'proxyless':
            conv_candidates = proxyless
        elif args.search_space == 'counter':
            conv_candidates = counter
        elif args.search_space == 'my_search_space':
            conv_candidates = my_search_space
        else:
            raise ValueError('search_space : {:} is not supported'.format(
                args.search_space))

        # related to entropy constraint loss
        if args.reg_loss_type == 'add#linear':
            args.reg_loss_params = {'lambda': args.reg_loss_lambda}
        elif args.reg_loss_type == 'mul#log':
            args.reg_loss_params = {
                'alpha': args.reg_loss_alpha,
                'beta': args.reg_loss_beta
            }
        else:
            args.reg_loss_params = None

    # create run_config
    run_config = RunConfig(**args.__dict__)

    #if args.open_test == False: # retrain and validate
    if args.open_vis:  # only open_vis in re-train phase, rather than both re-train and test.
        vis = visdomer(args.port,
                       args.server,
                       args.exp_name,
                       args.compare_phase,
                       args.elements,
                       init_params=None)
    else:
        vis = None
    if args.retrain_resume:
        logger.log(
            '=> Loading checkpoint from {:} of the last retrain phase'.format(
                args.resume_file),
            mode='info')
        # checkpoint_file from the last retrain phase.
        checkpoint_path = os.path.join(
            args.resume_file, 'checkpoints',
            'seed-{:}-retrain.pth'.format(args.random_seed))
        assert os.path.exists(
            checkpoint_path), 'cannot find retrain checkpoint file {:}'.format(
                checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        actual_path, cell_genotypes = checkpoint['actual_path'], checkpoint[
            'cell_genotypes']
        args.actual_path = actual_path
        args.cell_genotypes = cell_genotypes
        normal_network = NewGumbelAutoDeeplab(args.nb_layers,
                                              args.filter_multiplier,
                                              args.block_multiplier,
                                              args.steps,
                                              args.nb_classes,
                                              actual_path,
                                              cell_genotypes,
                                              args.search_space,
                                              affine=True)
        retrain_run_manager = RunManager(args.path,
                                         normal_network,
                                         logger,
                                         run_config,
                                         vis,
                                         out_log=True)
        normal_network.load_state_dict(checkpoint['state_dict'])
        display_all_families_information(args, 'retrain', retrain_run_manager,
                                         logger)
        retrain_run_manager.optimizer.load_state_dict(
            checkpoint['weight_optimizer'])
        retrain_run_manager.scheduler.load_state_dict(checkpoint['scheduler'])
        retrain_run_manager.monitor_metric = checkpoint['best_monitor'][0]
        retrain_run_manager.best_monitor = checkpoint['best_monitor'][1]
        retrain_run_manager.start_epoch = checkpoint['start_epoch']
        logger.log(
            '=> loaded checkpoint file {:} from the last retrain phase, starts with {:}-th epoch'
            .format(checkpoint_path, checkpoint['start_epoch']),
            mode='info')
    else:
        # todo from the search phase, read the last arch_checkpoint, rather than the best one.
        arch_checkpoint_path = os.path.join(
            args.checkpoint_file, 'checkpoints',
            'seed-{:}-arch.pth'.format(args.random_seed))
        assert os.path.exists(
            arch_checkpoint_path
        ), 'cannot find arch_checkpoint file {:} from search phase'.format(
            arch_checkpoint_path)
        checkpoint = torch.load(arch_checkpoint_path)
        actual_path, cell_genotypes = checkpoint['actual_path'], checkpoint[
            'cell_genotypes']
        new_genotypes = []
        for _index, genotype in cell_genotypes:
            xlist = []
            for edge_genotype in genotype:
                for (node_str, select_index) in edge_genotype:
                    xlist.append((node_str, conv_candidates[select_index]))
            new_genotypes.append((_index, xlist))
        log_str = 'Obtained actual_path and cell_genotypes:\n' \
                  'Actual_path: {:}\n' \
                  'Genotype:\n'.format(actual_path)
        for _index, genotype in new_genotypes:
            log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
        logger.log(log_str, mode='info')
        args.actual_path = actual_path
        args.cell_genotypes = cell_genotypes
        normal_network = NewGumbelAutoDeeplab(args.nb_layers,
                                              args.filter_multiplier,
                                              args.block_multiplier,
                                              args.steps,
                                              args.nb_classes,
                                              actual_path,
                                              cell_genotypes,
                                              args.search_space,
                                              affine=True)
        retrain_run_manager = RunManager(args.path,
                                         normal_network,
                                         logger,
                                         run_config,
                                         vis,
                                         out_log=True)
        #normal_network.load_state_dict(checkpoint['state_dict'])
        display_all_families_information(args, 'retrain', retrain_run_manager,
                                         logger)
        logger.log(
            '=> Construct NewGumbelAutoDeeplab according to the last-arch obtained from search phase',
            mode='info')
    # perform train and validation in train() method
    retrain_run_manager.train()
    '''
    else: # test phase
        checkpoint_path = os.path.join(args.resume_file, 'checkpoints', 'seed-{:}-retrain-best.pth'.format(args.random_seed))
        assert os.path.exists(checkpoint_path), 'cannot find best checkpoint {:} from the retrain phase'.format(checkpoint_path)
        checkpoint = torch.load(checkpoint_path)
        actual_path, cell_genotypes = checkpoint['actual_path'], checkpoint['cell_genotypes']
        normal_network = NewGumbelAutoDeeplab(args.nb_layers, args.filter_multiplier, args.block_multiplier,
                                              args.steps, args.nb_classes, actual_path, cell_genotypes, args.search_space, affine=True)
        normal_network.load_state_dict(checkpoint['state_dict'])
        test_manager = RunManager(args.path, normal_network, logger, run_config, vis=None, out_log=True)
        display_all_families_information(args, 'retrain', test_manager, logger)

        # save testing configs
        save_configs(args.__dict__, args.path, 'test')
        test_manager.validate(epoch=None, is_test=    True, use_train_mode = False)
    '''
    logger.close()
コード例 #8
0
def main(args, myargs):

    torch.manual_seed(args.manual_seed)
    torch.cuda.manual_seed_all(args.manual_seed)
    np.random.seed(args.manual_seed)

    # os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    os.makedirs(args.path, exist_ok=True)

    # prepare run config
    run_config_path = '%s/run.config' % args.path
    if os.path.isfile(run_config_path):
        # load run config from file
        run_config = json.load(open(run_config_path, 'r'))
        run_config = ImagenetRunConfig(**run_config)
        if args.valid_size:
            run_config.valid_size = args.valid_size
    else:
        # build run config from args
        args.lr_schedule_param = None
        args.opt_param = {
            'momentum': args.momentum,
            'nesterov': not args.no_nesterov,
        }
        if args.no_decay_keys == 'None':
            args.no_decay_keys = None
        run_config = ImagenetRunConfig(**args.__dict__)
    print('Run config:')
    for k, v in run_config.config.items():
        print('\t%s: %s' % (k, v))

    # prepare network
    net_config_path = '%s/net.config' % args.path
    if os.path.isfile(net_config_path):
        # load net from file
        from models import get_net_by_name
        net_config = json.load(open(net_config_path, 'r'))
        net = get_net_by_name(net_config['name']).build_from_config(net_config)
    else:
        # build net from args
        if 'proxyless' in args.net:
            from models.normal_nets.proxyless_nets import proxyless_base
            net_config_url = 'https://hanlab.mit.edu/files/proxylessNAS/%s.config' % args.net
            net = proxyless_base(
                net_config=net_config_url,
                n_classes=run_config.data_provider.n_classes,
                bn_param=(args.bn_momentum, args.bn_eps),
                dropout_rate=args.dropout,
            )
        else:
            raise ValueError('do not support: %s' % args.net)

    # build run manager
    run_manager = RunManager(args.path,
                             net,
                             run_config,
                             measure_latency=args.latency)
    run_manager.save_config(print_info=True)

    # load checkpoints
    init_path = '%s/init' % args.path
    if args.resume:
        run_manager.load_model()
        if args.train and run_manager.best_acc == 0:
            loss, acc1, acc5 = run_manager.validate(is_test=False,
                                                    return_top5=True)
            run_manager.best_acc = acc1
    elif os.path.isfile(init_path):
        if torch.cuda.is_available():
            checkpoint = torch.load(init_path)
        else:
            checkpoint = torch.load(init_path, map_location='cpu')
        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']
        run_manager.net.module.load_state_dict(checkpoint)
    elif 'proxyless' in args.net and not args.train:
        from utils.latency_estimator import download_url
        pretrained_weight_url = 'https://hanlab.mit.edu/files/proxylessNAS/%s.pth' % args.net
        print('Load pretrained weights from %s' % pretrained_weight_url)
        init_path = download_url(pretrained_weight_url)
        init = torch.load(init_path, map_location='cpu')
        net.load_state_dict(init['state_dict'])
    else:
        print('Random initialization')

    # train
    if args.train:
        print('Start training')
        run_manager.train(print_top5=True)
        run_manager.save_model()

    output_dict = {}
    # validate
    if run_config.valid_size:
        print('Test on validation set')
        loss, acc1, acc5 = run_manager.validate(is_test=False,
                                                return_top5=True)
        log = 'valid_loss: %f\t valid_acc1: %f\t valid_acc5: %f' % (loss, acc1,
                                                                    acc5)
        run_manager.write_log(log, prefix='valid')
        output_dict = {
            **output_dict, 'valid_loss': ' % f' % loss,
            'valid_acc1': ' % f' % acc1,
            'valid_acc5': ' % f' % acc5,
            'valid_size': run_config.valid_size
        }

    # test
    print('Test on test set')
    loss, acc1, acc5 = run_manager.validate(is_test=True, return_top5=True)
    log = 'test_loss: %f\t test_acc1: %f\t test_acc5: %f' % (loss, acc1, acc5)
    run_manager.write_log(log, prefix='test')
    output_dict = {
        **output_dict, 'test_loss': '%f' % loss,
        'test_acc1': '%f' % acc1,
        'test_acc5': '%f' % acc5
    }
    json.dump(output_dict, open('%s/output' % args.path, 'w'), indent=4)
コード例 #9
0
def main(args):

    assert torch.cuda.is_available(), 'CUDA is not available'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(args.workers)
    set_manual_seed(args.random_seed)
    #print_experiment_environment()
    EXP_time = time_for_file()
    args.path = os.path.join(args.path, args.exp_name, EXP_time)
    os.makedirs(args.path, exist_ok=True)
    create_exp_dir(args.path, scripts_to_save='../Efficient_AutoDeeplab')

    # weight optimizer config, related to network_weight_optimizer, scheduler, and criterion
    if args.weight_optimizer_type == 'SGD':
        weight_optimizer_params = {
            'momentum': args.momentum,
            'nesterov': args.nesterov,
            'weight_decay': args.weight_decay,
        }
    elif args.weight_optimizer_type == 'RMSprop':
        weight_optimizer_params = {
            'momentum': args.momentum,
            'weight_decay': args.weight_decay,
        }
    else:
        weight_optimizer_params = None
    if args.scheduler == 'cosine':
        scheduler_params = {'T_max': args.T_max, 'eta_min': args.eta_min}
    elif args.scheduler == 'multistep':
        scheduler_params = {
            'milestones': args.milestones,
            'gammas': args.gammas
        }
    elif args.scheduler == 'exponential':
        scheduler_params = {'gamma': args.gamma}
    elif args.scheduler == 'linear':
        scheduler_params = {'min_lr': args.min_lr}
    else:
        scheduler_params = None
    if args.criterion == 'SmoothSoftmax':
        criterion_params = {'label_smooth': args.label_smoothing}
    else:
        criterion_params = None
    # weight_optimizer_config, used in run_manager to get weight_optimizer, scheduler, and criterion.
    args.optimizer_config = {
        'optimizer_type': args.weight_optimizer_type,
        'optimizer_params': weight_optimizer_params,
        'scheduler': args.scheduler,
        'scheduler_params': scheduler_params,
        'criterion': args.criterion,
        'criterion_params': criterion_params,
        'init_lr': args.init_lr,
        'warmup_epoch': args.warmup_epochs,
        'epochs': args.epochs,
        'class_num': args.nb_classes,
    }
    # TODO need modification, not need in counter_network
    args.conv_candidates = [
        '3x3_MBConv3',
        '3x3_MBConv6',
        '5x5_MBConv3',
        '5x5_MBConv6',
        '7x7_MBConv3',
        '7x7_MBConv6',
        'Zero',  #'Identity'
    ]
    run_config = RunConfig(**args.__dict__)
    # arch_optimizer_config
    if args.arch_optimizer_type == 'adam':
        args.arch_optimizer_params = {
            'betas': (args.arch_adam_beta1, args.arch_adam_beta2),
            'eps': args.arch_adam_eps
        }
    else:
        args.arch_optimizer_params = None

    # related to hardware constraint
    # TODO: get rid of
    if args.reg_loss_type == 'add#linear':
        args.reg_loss_params = {'lambda': args.reg_loss_lambda}
    elif args.reg_loss_type == 'mul#log':
        args.reg_loss_params = {
            'alpha': args.reg_loss_alpha,
            'beta': args.reg_loss_beta
        }
    else:
        args.reg_loss_params = None

    arch_search_config = ArchSearchConfig(**args.__dict__)
    # perform config save, for run_configs and arch_search_configs
    save_configs(run_config.config, arch_search_config.config, args.path,
                 'search')
    logger = prepare_logger(args)
    if args.open_vis:
        vis = visdomer(args.port,
                       args.server,
                       args.exp_name,
                       args.compare_phase,
                       args.elements,
                       init_params=None)
    else:
        vis = None
    '''
    super_network = GumbelAutoDeepLab(
        args.filter_multiplier, args.block_multiplier, args.steps,
        args.nb_classes, args.nb_layers, args.bn_momentum, args.bn_eps, args.conv_candidates, logger
    )
    '''
    super_network = CounterMBConvNet(2, search_space=args.search_space)
    train_manager = RunManager(args.path,
                               super_network,
                               logger,
                               run_config,
                               vis=vis,
                               out_log=True)
    # train search phase
    train_manager.train()
    logger.close()