示例#1
0
文件: print_geno.py 项目: 2BH/NAS_K49
def main():
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model = Network(args.init_channels, args.input_channels, num_classes,
                    args.layers, criterion)

    model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    model.load_state_dict(torch.load(log_path + '/weights.pt'))

    print(model.genotype())
示例#2
0
def start(args):
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.benchmark = True
    cudnn.enabled = True
    logging.info("args = %s", args)

    dataset = LoadData(args.data_name)
    if args.data_name == 'SBM_PATTERN':
        in_dim = 3
        num_classes = 2
    elif args.data_name == 'SBM_CLUSTER':
        in_dim = 7
        num_classes = 6
    print(f"input dimension: {in_dim}, number classes: {num_classes}")

    criterion = MyCriterion(num_classes)
    criterion = criterion.cuda()

    model = Network(args.layers, args.nodes, in_dim, args.feature_dim, num_classes, criterion, args.data_type, args.readout)
    model = model.cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    train_data, val_data, test_data = dataset.train, dataset.val, dataset.test

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    print(f"train set full size : {num_train}; split train set size : {split}")
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size = args.batch_size,
        sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory = True,
        num_workers=args.workers,
        collate_fn = dataset.collate)

    true_valid_queue = torch.utils.data.DataLoader(
        val_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size,
        pin_memory=True,
        num_workers=args.workers,
        collate_fn=dataset.collate)

    optimizer = torch.optim.SGD(model.parameters(),args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, float(args.epochs), eta_min=args.learning_rate_min)
    architect = Architect(model, args)

    # viz = Visdom(env = '{} {}'.format(args.data_name,  time.asctime(time.localtime(time.time()))  ))
    viz = None
    save_file = open(args.save_result, "w")
    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_lr()[0]
        logging.info('[LR]\t%f', lr)

        if epoch % args.save_freq == 0:
            print(model.show_genotypes())
            save_file.write(f"Epoch : {epoch}\n{model.show_genotypes()}\n")
            for i in range(args.layers):
                logging.info('layer = %d', i)
                genotype = model.show_genotype(i)
                logging.info('genotype = %s', genotype)
            '''
            w1, w2, w3 = model.show_weights(0)
            print('[1] weights in first cell\n',w1)
            print('[2] weights in middle cell\n', w2)
            print('[3] weights in last cell\n', w3)
            '''
        # training
        macro_acc, micro_acc, loss = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr, epoch, viz)
        # true validation
        macro_acc, micro_acc, loss = infer(true_valid_queue, model, criterion, stage = 'validating')
        # testing
        macro_acc, micro_acc, loss = infer(test_queue, model, criterion, stage = ' testing  ')
示例#3
0
def main():
    args = parse_args()
    update_config(cfg, args)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    # Set the random seed manually for reproducibility.
    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)
    torch.cuda.manual_seed_all(cfg.SEED)

    # Loss
    criterion = CrossEntropyLoss(cfg.MODEL.NUM_CLASSES).cuda()

    # model and optimizer
    print(f"Definining network with {cfg.MODEL.LAYERS} layers...")
    model = Network(cfg.MODEL.INIT_CHANNELS, cfg.MODEL.NUM_CLASSES, cfg.MODEL.LAYERS, criterion, primitives_2,
                    drop_path_prob=cfg.TRAIN.DROPPATH_PROB)
    model = model.cuda()

    # weight params
    arch_params = list(map(id, model.arch_parameters()))
    weight_params = filter(lambda p: id(p) not in arch_params,
                           model.parameters())

    # Optimizer
    optimizer = optim.Adam(
        weight_params,
        lr=cfg.TRAIN.LR
    )

    # resume && make log dir and logger
    if args.load_path and os.path.exists(args.load_path):
        checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint_best.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)

        # load checkpoint
        begin_epoch = checkpoint['epoch']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        best_acc1 = checkpoint['best_acc1']
        optimizer.load_state_dict(checkpoint['optimizer'])
        args.path_helper = checkpoint['path_helper']

        logger = create_logger(args.path_helper['log_path'])
        logger.info("=> loaded checkpoint '{}'".format(checkpoint_file))
    else:
        exp_name = args.cfg.split('/')[-1].split('.')[0]
        args.path_helper = set_path('logs_search', exp_name)
        logger = create_logger(args.path_helper['log_path'])
        begin_epoch = cfg.TRAIN.BEGIN_EPOCH
        best_acc1 = 0.0
        last_epoch = -1

    logger.info(args)
    logger.info(cfg)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, 'models', cfg.MODEL.NAME + '.py'),
        args.path_helper['ckpt_path'])

    # Datasets and dataloaders

    # The toy dataset is downloaded with 10 items for each partition. Remove the sample_size parameters to use the full toy dataset
    asv_train, asv_dev, asv_eval = asv_toys(sample_size=10)


    train_dataset = asv_train #MNIST('mydata', transform=totensor, train=True, download=True)
    val_dataset = asv_dev #MNIST('mydata', transform=totensor, train=False, download=True)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    print(f'search.py: Train loader of {len(train_loader)} batches')
    print(f'Tot train set: {len(train_dataset)}')
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )
    print(f'search.py: Val loader of {len(val_loader)} batches')
    print(f'Tot val set {len(val_dataset)}')
    test_dataset = asv_eval #MNIST('mydata', transform=totensor, train=False, download=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE,
        num_workers=cfg.DATASET.NUM_WORKERS,
        pin_memory=True,
        shuffle=True,
        drop_last=True,
    )

    # training setting
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': begin_epoch * len(train_loader),
        'valid_global_steps': begin_epoch // cfg.VAL_FREQ,
    }

    # training loop
    architect = Architect(model, cfg)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, cfg.TRAIN.END_EPOCH, cfg.TRAIN.LR_MIN,
        last_epoch=last_epoch
    )

    for epoch in tqdm(range(begin_epoch, cfg.TRAIN.END_EPOCH), desc='search progress'):
        model.train()

        genotype = model.genotype()
        logger.info('genotype = %s', genotype)

        if cfg.TRAIN.DROPPATH_PROB != 0:
            model.drop_path_prob = cfg.TRAIN.DROPPATH_PROB * epoch / (cfg.TRAIN.END_EPOCH - 1)

        train(cfg, model, optimizer, train_loader, val_loader, criterion, architect, epoch, writer_dict)

        if epoch % cfg.VAL_FREQ == 0:
            # get threshold and evaluate on validation set
            acc = validate_identification(cfg, model, test_loader, criterion)

            # remember best acc@1 and save checkpoint
            is_best = acc > best_acc1
            best_acc1 = max(acc, best_acc1)

            # save
            logger.info('=> saving checkpoint to {}'.format(args.path_helper['ckpt_path']))
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer': optimizer.state_dict(),
                'arch': model.arch_parameters(),
                'genotype': genotype,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'], 'checkpoint_{}.pth'.format(epoch))

        lr_scheduler.step(epoch)
示例#4
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = True
    logging.info("args = %s", args)

    with open(args.lookup_path, 'rb') as f:
        lat_lookup = pickle.load(f)

    mc_maxnum_dddict = get_mc_num_dddict(mc_mask_dddict, is_max=True)
    model = Network(args.num_classes, mc_maxnum_dddict, lat_lookup)
    model = torch.nn.DataParallel(model).cuda()
    logging.info("param size = %fMB", count_parameters_in_MB(model))

    # save initial model
    model_path = os.path.join(args.save, 'searched_model_00.pth.tar')
    torch.save(
        {
            'state_dict': model.state_dict(),
            'mc_mask_dddict': mc_mask_dddict,
        }, model_path)

    # get lr list
    lr_list = []
    optimizer_w = torch.optim.SGD(model.module.weight_parameters(),
                                  lr=args.w_lr,
                                  momentum=args.w_mom,
                                  weight_decay=args.w_wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_w, float(args.epochs))
    for _ in range(args.epochs):
        lr = scheduler.get_lr()[0]
        lr_list.append(lr)
        scheduler.step()
    del model
    del optimizer_w
    del scheduler

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    normalize = transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4,
                               contrast=0.4,
                               saturation=0.4,
                               hue=0.2),
        transforms.ToTensor(),
        normalize,
    ])
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

    train_queue = torch.utils.data.DataLoader(ImageList(
        root=args.img_root,
        list_path=args.train_list,
        transform=train_transform),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    val_queue = torch.utils.data.DataLoader(ImageList(root=args.img_root,
                                                      list_path=args.val_list,
                                                      transform=val_transform),
                                            batch_size=args.batch_size,
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=args.workers)

    for epoch in range(args.epochs):
        mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
        model = Network(args.num_classes, mc_num_dddict, lat_lookup)
        model = torch.nn.DataParallel(model).cuda()
        model.module.set_temperature(args.T)

        # load model
        model_path = os.path.join(args.save,
                                  'searched_model_{:02}.pth.tar'.format(epoch))
        state_dict = torch.load(model_path)['state_dict']
        for key in state_dict:
            if 'm_ops' not in key:
                exec('model.{}.data = state_dict[key].data'.format(key))
        for stage in mc_mask_dddict:
            for block in mc_mask_dddict[stage]:
                for op_idx in mc_mask_dddict[stage][block]:
                    index = torch.nonzero(
                        mc_mask_dddict[stage][block][op_idx]).view(-1)
                    index = index.cuda()
                    iw = 'model.module.{}.{}.m_ops[{}].inverted_bottleneck.conv.weight.data'.format(
                        stage, block, op_idx)
                    iw_key = 'module.{}.{}.m_ops.{}.inverted_bottleneck.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        iw +
                        ' = torch.index_select(state_dict[iw_key], 0, index).data'
                    )
                    dw = 'model.module.{}.{}.m_ops[{}].depth_conv.conv.weight.data'.format(
                        stage, block, op_idx)
                    dw_key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        dw +
                        ' = torch.index_select(state_dict[dw_key], 0, index).data'
                    )
                    pw = 'model.module.{}.{}.m_ops[{}].point_linear.conv.weight.data'.format(
                        stage, block, op_idx)
                    pw_key = 'module.{}.{}.m_ops.{}.point_linear.conv.weight'.format(
                        stage, block, op_idx)
                    exec(
                        pw +
                        ' = torch.index_select(state_dict[pw_key], 1, index).data'
                    )
                    if op_idx >= 4:
                        se_cr_w = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_reduce.weight.data'.format(
                            stage, block, op_idx)
                        se_cr_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.weight'.format(
                            stage, block, op_idx)
                        exec(
                            se_cr_w +
                            ' = torch.index_select(state_dict[se_cr_w_key], 1, index).data'
                        )
                        se_cr_b = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_reduce.bias.data'.format(
                            stage, block, op_idx)
                        se_cr_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.bias'.format(
                            stage, block, op_idx)
                        exec(se_cr_b + ' = state_dict[se_cr_b_key].data')
                        se_ce_w = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_expand.weight.data'.format(
                            stage, block, op_idx)
                        se_ce_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.weight'.format(
                            stage, block, op_idx)
                        exec(
                            se_ce_w +
                            ' = torch.index_select(state_dict[se_ce_w_key], 0, index).data'
                        )
                        se_ce_b = 'model.module.{}.{}.m_ops[{}].squeeze_excite.conv_expand.bias.data'.format(
                            stage, block, op_idx)
                        se_ce_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.bias'.format(
                            stage, block, op_idx)
                        exec(
                            se_ce_b +
                            ' = torch.index_select(state_dict[se_ce_b_key], 0, index).data'
                        )
        del index

        lr = lr_list[epoch]
        optimizer_w = torch.optim.SGD(model.module.weight_parameters(),
                                      lr=lr,
                                      momentum=args.w_mom,
                                      weight_decay=args.w_wd)
        optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
                                       lr=args.a_lr,
                                       betas=(args.a_beta1, args.a_beta2),
                                       weight_decay=args.a_wd)
        logging.info('Epoch: %d lr: %e T: %e', epoch, lr, args.T)

        # training
        epoch_start = time.time()
        if epoch < 10:
            train_acc = train_wo_arch(train_queue, model, criterion,
                                      optimizer_w)
        else:
            train_acc = train_w_arch(train_queue, val_queue, model, criterion,
                                     optimizer_w, optimizer_a)
            args.T *= args.T_decay
        # logging arch parameters
        logging.info('The current arch parameters are:')
        for param in model.module.log_alphas_parameters():
            param = np.exp(param.detach().cpu().numpy())
            logging.info(' '.join(['{:.6f}'.format(p) for p in param]))
        for param in model.module.betas_parameters():
            param = F.softmax(param.detach().cpu(), dim=-1)
            param = param.numpy()
            logging.info(' '.join(['{:.6f}'.format(p) for p in param]))
        logging.info('Train_acc %f', train_acc)
        epoch_duration = time.time() - epoch_start
        logging.info('Epoch time: %ds', epoch_duration)

        # validation for last 5 epochs
        if args.epochs - epoch < 5:
            val_acc = validate(val_queue, model, criterion)
            logging.info('Val_acc %f', val_acc)

        # update state_dict
        state_dict_from_model = model.state_dict()
        for key in state_dict:
            if 'm_ops' not in key:
                state_dict[key].data = state_dict_from_model[key].data
        for stage in mc_mask_dddict:
            for block in mc_mask_dddict[stage]:
                for op_idx in mc_mask_dddict[stage][block]:
                    index = torch.nonzero(
                        mc_mask_dddict[stage][block][op_idx]).view(-1)
                    index = index.cuda()
                    iw_key = 'module.{}.{}.m_ops.{}.inverted_bottleneck.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[iw_key].data[
                        index, :, :, :] = state_dict_from_model[iw_key]
                    dw_key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[dw_key].data[
                        index, :, :, :] = state_dict_from_model[dw_key]
                    pw_key = 'module.{}.{}.m_ops.{}.point_linear.conv.weight'.format(
                        stage, block, op_idx)
                    state_dict[
                        pw_key].data[:, index, :, :] = state_dict_from_model[
                            pw_key]
                    if op_idx >= 4:
                        se_cr_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.weight'.format(
                            stage, block, op_idx)
                        state_dict[
                            se_cr_w_key].data[:,
                                              index, :, :] = state_dict_from_model[
                                                  se_cr_w_key]
                        se_cr_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_reduce.bias'.format(
                            stage, block, op_idx)
                        state_dict[
                            se_cr_b_key].data[:] = state_dict_from_model[
                                se_cr_b_key]
                        se_ce_w_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.weight'.format(
                            stage, block, op_idx)
                        state_dict[se_ce_w_key].data[
                            index, :, :, :] = state_dict_from_model[
                                se_ce_w_key]
                        se_ce_b_key = 'module.{}.{}.m_ops.{}.squeeze_excite.conv_expand.bias'.format(
                            stage, block, op_idx)
                        state_dict[se_ce_b_key].data[
                            index] = state_dict_from_model[se_ce_b_key]
        del state_dict_from_model, index

        # shrink and expand
        if epoch >= 10:
            logging.info('Now shrinking or expanding the arch')
            op_weights, depth_weights = get_op_and_depth_weights(model)
            parsed_arch = parse_architecture(op_weights, depth_weights)
            mc_num_dddict = get_mc_num_dddict(mc_mask_dddict)
            before_lat = get_lookup_latency(parsed_arch, mc_num_dddict,
                                            lat_lookup_key_dddict, lat_lookup)
            logging.info(
                'Before, the current lat: {:.4f}, the target lat: {:.4f}'.
                format(before_lat, args.target_lat))

            if before_lat > args.target_lat:
                logging.info('Shrinking......')
                stages = ['stage{}'.format(x) for x in range(1, 7)]
                mc_num_dddict, after_lat = fit_mc_num_by_latency(
                    parsed_arch,
                    mc_num_dddict,
                    mc_maxnum_dddict,
                    lat_lookup_key_dddict,
                    lat_lookup,
                    args.target_lat,
                    stages,
                    sign=-1)
                for start in range(2, 7):
                    stages = ['stage{}'.format(x) for x in range(start, 7)]
                    mc_num_dddict, after_lat = fit_mc_num_by_latency(
                        parsed_arch,
                        mc_num_dddict,
                        mc_maxnum_dddict,
                        lat_lookup_key_dddict,
                        lat_lookup,
                        args.target_lat,
                        stages,
                        sign=1)
            elif before_lat < args.target_lat:
                logging.info('Expanding......')
                stages = ['stage{}'.format(x) for x in range(1, 7)]
                mc_num_dddict, after_lat = fit_mc_num_by_latency(
                    parsed_arch,
                    mc_num_dddict,
                    mc_maxnum_dddict,
                    lat_lookup_key_dddict,
                    lat_lookup,
                    args.target_lat,
                    stages,
                    sign=1)
                for start in range(2, 7):
                    stages = ['stage{}'.format(x) for x in range(start, 7)]
                    mc_num_dddict, after_lat = fit_mc_num_by_latency(
                        parsed_arch,
                        mc_num_dddict,
                        mc_maxnum_dddict,
                        lat_lookup_key_dddict,
                        lat_lookup,
                        args.target_lat,
                        stages,
                        sign=1)
            else:
                logging.info('No opeartion')
                after_lat = before_lat

            # change mc_mask_dddict based on mc_num_dddict
            for stage in parsed_arch:
                for block in parsed_arch[stage]:
                    op_idx = parsed_arch[stage][block]
                    if mc_num_dddict[stage][block][op_idx] != int(
                            sum(mc_mask_dddict[stage][block][op_idx]).item()):
                        mc_num = mc_num_dddict[stage][block][op_idx]
                        max_mc_num = mc_mask_dddict[stage][block][op_idx].size(
                            0)
                        mc_mask_dddict[stage][block][op_idx].data[
                            [True] * max_mc_num] = 0.0
                        key = 'module.{}.{}.m_ops.{}.depth_conv.conv.weight'.format(
                            stage, block, op_idx)
                        weight_copy = state_dict[key].clone().abs().cpu(
                        ).numpy()
                        weight_l1_norm = np.sum(weight_copy, axis=(1, 2, 3))
                        weight_l1_order = np.argsort(weight_l1_norm)
                        weight_l1_order_rev = weight_l1_order[::-1][:mc_num]
                        mc_mask_dddict[stage][block][op_idx].data[
                            weight_l1_order_rev.tolist()] = 1.0

            logging.info(
                'After, the current lat: {:.4f}, the target lat: {:.4f}'.
                format(after_lat, args.target_lat))

        # save model
        model_path = os.path.join(
            args.save, 'searched_model_{:02}.pth.tar'.format(epoch + 1))
        torch.save(
            {
                'state_dict': state_dict,
                'mc_mask_dddict': mc_mask_dddict,
            }, model_path)
示例#5
0
def main():

    args = parse_args()
    reset_config(config, args)
    #device = torch.device("cuda")
    # tensorboard
    if not os.path.exists(config.SEARCH.PATH):
        os.makedirs(config.SEARCH.PATH)
    writer = SummaryWriter(log_dir=os.path.join(config.SEARCH.PATH, "log"))
    logger = utils.get_logger(os.path.join(config.SEARCH.PATH, "{}.log".format(config.SEARCH.NAME)))
    logger.info("Logger is set - training start")
    
    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    # set seed
    #np.random.seed(config.SEARCH.SEED)
    #torch.manual_seed(config.SEARCH.SEED)
    #torch.cuda.manual_seed_all(config.SEARCH.SEED)

    torch.backends.cudnn.benchmark = True

    gpus = [int(i) for i in config.GPUS.split(',')]
    criterion = JointsMSELoss(use_target_weight = config.LOSS.USE_TARGET_WEIGHT).to(device)
    model = Network(config)
    if len(gpus)>1:
        model = nn.DataParallel(model)
    model = model.cuda()
    #for name,p in model.module.named_parameters():
    #    logger.info(name)
    
    mb_params = utils.param_size(model)
    logger.info("Model size = {:.3f} MB".format(mb_params))
    
    # weights optimizer
    params = model.parameters()
    #arch_params = list(map(id, model.module.arch_parameters()))
    #weight_params = filter(lambda p: id(p) not in arch_params, model.parameters())
    #params = [{'params': weight_params},
    #          {'params': model.module.arch_parameters(), 'lr': 0.0004}]

    optimizer = torch.optim.Adam(params, config.SEARCH.W_LR)
                               
    # split data to train/validation
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    train_data = MPIIDataset(config,
                             config.DATASET.ROOT,
                             config.SEARCH.TRAIN_SET,
                             True,
                             transforms.Compose([
                                transforms.ToTensor(),
                                normalize,
                             ]))
    valid_data = MPIIDataset(config,
                             config.DATASET.ROOT,
                             config.SEARCH.TEST_SET,
                             False,
                             transforms.Compose([
                                transforms.ToTensor(),
                                normalize,
                             ]))
                           

    print(len(train_data),len(valid_data))
  
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=config.SEARCH.BATCH_SIZE,
                                               shuffle=True,
                                               num_workers=config.WORKERS,
                                               pin_memory=True)
                                               
    valid_loader = torch.utils.data.DataLoader(valid_data,
                                               batch_size=config.SEARCH.BATCH_SIZE,
                                               shuffle=False,
                                               num_workers=config.WORKERS,
                                               pin_memory=True)
                                             

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, config.SEARCH.LR_STEP, config.SEARCH.LR_FACTOR)

    # training loop
    best_top1 = 0.
    for epoch in range(config.SEARCH.EPOCHS):
    
        lr_scheduler.step()


        # training
        train(config, train_loader, model, criterion, optimizer, epoch, logger, writer)

        # validation
        cur_step = (epoch+1) * len(train_loader)
        top1 = validate(config, valid_loader, valid_data, epoch+1, model, criterion, logger, writer)

        # log
        # genotype
        genotype = model.module.genotype()
        logger.info(F.softmax(model.module.alphas_normal, dim=-1))
        logger.info(F.softmax(model.module.alphas_reduce, dim=-1))
        logger.info("genotype = {}".format(genotype))

        # save
        state = {'state_dict':model.state_dict(),
                 'schedule':lr_scheduler.state_dict(),
                 'epoch':epoch+1}
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
            is_best = True
        else:
            is_best = False
        utils.save_checkpoint(state, config.SEARCH.PATH, is_best)

    logger.info("Final best Accuracy = {:.3f}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))