def main():

    # get log

    args = get_args()
    args.save = '{}/eval-{}-{}'.format(args.save,args.note,time.strftime("%Y%m%d-%H%M%S"))
    # if not os.path.exists(args.save):
    #     os.path.mkdir(args.save)
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save,"args.yaml"), "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    if args.cifar100:
        CIFAR_CLASSES = 100
        data_folder = 'cifar-100-python'
    else:
        CIFAR_CLASSES = 10
        data_folder = 'cifar-10-batches-py'
        
    # load model and loss func 
    genotype = eval("binary_genotypes.%s" % args.arch)
    logger.info('---------Genotype---------')
    logger.info(genotype)
    logger.info('--------------------------')
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype,args.group)
    model = model.to(args.device)
    logging.info("param size = %fMB", tools.count_parameters_in_MB(model))
Exemple #2
0
def main():

    # get log

    args = get_args()
    args.save = '{}/search-{}-{}'.format(args.save, args.note,
                                         time.strftime("%Y%m%d-%H%M%S"))
    # if not os.path.exists(args.save):
    #     os.path.mkdir(args.save)
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    if args.cifar100:
        CIFAR_CLASSES = 100
        data_folder = 'cifar-100-python'
    else:
        CIFAR_CLASSES = 10
        data_folder = 'cifar-10-batches-py'

    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = tools._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = tools._data_transforms_cifar10(args)

    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=False,
                                   transform=train_transform)
        vaild_ata = dset.CIFAR100(root=args.tmp_data_dir,
                                  train=False,
                                  download=False,
                                  transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=False,
                                  transform=train_transform)
        vaild_ata = dset.CIFAR10(root=args.tmp_data_dir,
                                 train=False,
                                 download=False,
                                 transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)

    valLoader = torch.utils.data.DataLoader(vaild_ata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(args.device)
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) == 3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.1, 0.3, 0.4]

    eps_no_archs = [10, 10, 10]
    state_epochs = 0
    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion,
                        steps=args.nodes,
                        multiplier=args.multiplier,
                        stem_multiplier=args.stem_multiplier,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        group=args.group,
                        p=float(drop_rate[sp]))
        model = model.to(args.device)
        logger.info("stage:{} param size:{}MB".format(
            sp, tools.count_parameters_in_MB(model)))

        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                                       lr=args.arch_learning_rate,
                                       betas=(0.5, 0.999),
                                       weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)

        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            lr = scheduler.get_lr()[0]
            logger.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training and logging
            if epoch < eps_no_arch:
                p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                logger.info("drop rate:{}".format(p))
                model.p = p
                model.update_p()
                t_top1, t_top5, t_loss = train(state_epochs + epoch,
                                               train_queue,
                                               valid_queue,
                                               model,
                                               criterion,
                                               optimizer,
                                               optimizer_a,
                                               args,
                                               monitors,
                                               logger,
                                               train_arch=False)
            else:
                p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                logger.info("drop rate:{}".format(p))
                model.p = p
                model.update_p()
                t_top1, t_top5, t_loss = train(state_epochs + epoch,
                                               train_queue,
                                               valid_queue,
                                               model,
                                               criterion,
                                               optimizer,
                                               optimizer_a,
                                               args,
                                               monitors,
                                               logger,
                                               train_arch=True)

            v_top1, v_top5, v_loss = infer(
                state_epochs + epoch,
                valLoader,
                model,
                criterion,
                args,
                monitors,
                logger,
            )

            if epoch >= eps_no_arch:
                # 将本epoch的解析结果保存
                arch_param = model.arch_parameters()
                normal_prob = F.softmax(arch_param[0],
                                        dim=-1).data.cpu().numpy()
                reduce_prob = F.softmax(arch_param[1],
                                        dim=-1).data.cpu().numpy()
                logger.info('Genotypev: {}'.format(
                    parse_genotype(switches_normal.copy(),
                                   switches_reduce.copy(), normal_prob.copy(),
                                   reduce_prob.copy())))
            scheduler.step()

        tools.save(model,
                   os.path.join(args.save, 'state{}_weights.pt'.format(sp)))
        state_epochs += args.epochs
        # Save switches info for s-c refinement.
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=-1).data.cpu().numpy()
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()

        logger.info('------Stage %d end!------' % sp)
        logger.info("normal: \n{}".format(normal_prob))
        logger.info("reduce: \n{}".format(reduce_prob))
        logger.info('Genotypev: {}'.format(
            parse_genotype(switches_normal.copy(), switches_reduce.copy(),
                           normal_prob.copy(), reduce_prob.copy())))

        # 根据最新的结构权重,旧的搜索空间,需要抛弃的数量,当前状态 来进行空间正则化
        switches_normal = update_switches(normal_prob.copy(),
                                          switches_normal, num_to_drop[sp], sp,
                                          len(num_to_keep))
        switches_reduce = update_switches(reduce_prob.copy(),
                                          switches_reduce, num_to_drop[sp], sp,
                                          len(num_to_keep))

        logger.info('------Dropping %d paths------' % num_to_drop[sp])
        logger.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal, logger)
        logger.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce, logger)

        if sp == len(num_to_keep) - 1:
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])
            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logger.info(genotype)
            ## restrict skipconnect (normal cell only)
            logger.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = check_sk_number(switches_normal)
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal,
                                                     switches_normal_2,
                                                     normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal,
                                                      normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logger.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logger.info(genotype)
Exemple #3
0
def main():
    args = get_args()

    # get log
    args.save = 'search-{}-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)
    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)
    trainLoader = torch.utils.data.DataLoader(traindata,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              shuffle=True,
                                              num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # load pretrained model
    model_t = Network(C=args.init_channels,
                      num_classes=args.class_num,
                      layers=args.layers,
                      steps=args.nodes,
                      multiplier=args.nodes,
                      stem_multiplier=args.stem_multiplier,
                      group=args.group)
    model_t, _, _ = loadCheckpoint(args.model_path, model_t, args)
    model_t.freeze_arch_parameters()
    # 冻结教师网络
    for para in list(model_t.parameters())[:-2]:
        para.requires_grad = False

    model_s = Network(C=args.init_channels,
                      num_classes=args.class_num,
                      layers=args.layers,
                      steps=args.nodes,
                      multiplier=args.nodes,
                      stem_multiplier=args.stem_multiplier,
                      group=args.group)
    model_s, _, _ = loadCheckpoint(args.model_path, model_s, args)
    model_s._initialize_alphas()

    criterion = nn.CrossEntropyLoss().to(args.device)
    model_d = Discriminator().to(args.device)
    model_s = model_s.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model_s))

    optimizer_d = optim.SGD(model_d.parameters(),
                            lr=args.learning_rate,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_s = optim.SGD(model_s.weight_parameters(),
                            lr=args.learning_rate,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
    optimizer_m = FISTA(model_s.arch_parameters(),
                        lr=args.learning_rate,
                        gamma=args.sparse_lambda)

    scheduler_d = StepLR(optimizer_d, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_s = StepLR(optimizer_s, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_m = StepLR(optimizer_m, step_size=args.lr_decay_step, gamma=0.1)

    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    if args.resume:
        logger.info('=> Resuming from ckpt {}'.format(args.resume_path))
        ckpt = torch.load(args.resume_path, map_location=args.device)
        start_epoch = ckpt['epoch']
        model_s.load_state_dict(ckpt['state_dict_s'])
        model_d.load_state_dict(ckpt['state_dict_d'])
        optimizer_d.load_state_dict(ckpt['optimizer_d'])
        optimizer_s.load_state_dict(ckpt['optimizer_s'])
        optimizer_m.load_state_dict(ckpt['optimizer_m'])
        scheduler_d.load_state_dict(ckpt['scheduler_d'])
        scheduler_s.load_state_dict(ckpt['scheduler_s'])
        scheduler_m.load_state_dict(ckpt['scheduler_m'])
        perf_scoreboard = ckpt['perf_scoreboard']
        logger.info('=> Continue from epoch {}...'.format(start_epoch))

    models = [model_t, model_s, model_d]
    optimizers = [optimizer_d, optimizer_s, optimizer_m]
    schedulers = [scheduler_d, scheduler_s, scheduler_m]

    for epoch in range(start_epoch, args.num_epochs):
        for s in schedulers:
            logger.info('epoch %d lr %e ', epoch, s.get_lr()[0])

        _, _, _ = train(trainLoader, models, epoch, optimizers, monitors, args,
                        logger)
        v_top1, v_top5, v_loss = validate(valLoader, model_s, criterion, epoch,
                                          monitors, args, logger)

        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

        logger.info("normal: \n{}".format(
            model_s.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model_s.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev1: {}'.format(model_s.genotypev1()))
        logger.info('Genotypev2: {}'.format(model_s.genotypev2()))
        logger.info('Genotypev3: {}'.format(model_s.genotypev3()))
        mask = []
        pruned = 0
        num = 0
        for param in model_s.arch_parameters():
            weight_copy = param.clone()
            param_array = np.array(weight_copy.detach().cpu())
            pruned += sum(w == 0 for w in param_array)
            num += len(param_array)
        logger.info("Epoch:{} Pruned {} / {}".format(epoch, pruned, num))

        if epoch % args.save_freq == 0:
            model_state_dict = model_s.module.state_dict() if len(
                args.gpus) > 1 else model_s.state_dict()
            state = {
                'state_dict_s': model_state_dict,
                'state_dict_d': model_d.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'optimizer_s': optimizer_s.state_dict(),
                'optimizer_m': optimizer_m.state_dict(),
                'scheduler_d': scheduler_d.state_dict(),
                'scheduler_s': scheduler_s.state_dict(),
                'scheduler_m': scheduler_m.state_dict(),
                "perf_scoreboard": perf_scoreboard,
                'epoch': epoch + 1
            }
            tools.save_model(state,
                             epoch + 1,
                             is_best,
                             path=os.path.join(args.save, "ckpt"))
        # update learning rate
        for s in schedulers:
            s.step(epoch)
def main():
    args = get_args()

    # get log
    args.save = 'search-{}-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    # load pretrained model
    criterion = nn.CrossEntropyLoss()
    model = Network(C=args.init_channels,
                    num_classes=args.class_num,
                    layers=args.layers,
                    steps=args.nodes,
                    multiplier=args.nodes,
                    stem_multiplier=args.stem_multiplier,
                    group=args.group)
    model, _, _ = loadCheckpoint(args.model_path, model, args)

    if args.multi_gpu:
        logger.info('use: %d gpus', args.gpus)
        model = nn.DataParallel(model)
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model))

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)
    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)

    trainLoader = torch.utils.data.DataLoader(traindata,
                                              batch_size=args.batch_size,
                                              pin_memory=True,
                                              shuffle=True,
                                              num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # weight optimizer and struct parameters /mask optimizer
    optimizer_w = torch.optim.SGD(model.weight_parameters(),
                                  args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    # scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(
    #   optimizer_w, float(args.epochs), eta_min=args.learning_rate_min)
    optimizer_alpha = FISTA(model.arch_parameters(),
                            lr=args.arch_learning_rate,
                            gamma=args.sparse_lambda)
    # scheduler_alpha = torch.optim.lr_scheduler.CosineAnnealingLR(
    #   optimizer_alpha, float(args.epochs))
    scheduler_w = StepLR(optimizer_w, step_size=args.lr_decay_step, gamma=0.1)
    scheduler_alpha = StepLR(optimizer_alpha,
                             step_size=args.lr_decay_step,
                             gamma=0.1)
    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    # resume
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume_path):
            model, extras, start_epoch = loadCheckpoint(
                args.resume_path, model, args)
            scheduler_w = extras["scheduler_w"]
            scheduler_alpha = extras["scheduler_alpha"]
            optimizer_w = extras["optimizer_w"]
            optimizer_alpha = extras["optimizer_alpha"]
            perf_scoreboard = extras["perf_scoreboard"]
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))
    for epoch in range(args.epochs):
        weight_lr = scheduler_w.get_lr()[0]
        arch_lr = scheduler_alpha.get_lr()[0]
        logging.info('epoch %d weight lr %e   arch lr %e', epoch, weight_lr,
                     arch_lr)

        t_top1, t_top5, t_loss = train(trainLoader, valLoader, model,
                                       criterion, epoch, optimizer_w,
                                       optimizer_alpha, monitors, args, logger)
        v_top1, v_top5, v_loss = validate(valLoader, model, criterion, epoch,
                                          monitors, args, logger)

        tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {
            'train': t_loss,
            'val': v_loss
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {
            'train': t_top1,
            'val': v_top1
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {
            'train': t_top5,
            'val': v_top5
        }, epoch)

        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

        logger.info("normal: \n{}".format(
            model.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev1: {}'.format(model.genotypev1()))
        logger.info('Genotypev2: {}'.format(model.genotypev2()))
        logger.info('Genotypev3: {}'.format(model.genotypev3()))
        mask = []
        pruned = 0
        num = 0
        for param in model.arch_parameters():
            weight_copy = param.clone()
            param_array = np.array(weight_copy.detach().cpu())
            pruned += sum(w == 0 for w in param_array)
            num += len(param_array)
        logger.info("Epoch:{} Pruned {} / {}".format(epoch, pruned, num))

        is_best = perf_scoreboard.is_best(epoch)
        # save model
        if epoch % args.save_freq == 0:
            saveCheckpoint(
                epoch, args.model, model, {
                    'scheduler_w': scheduler_w,
                    "scheduler_alpha": scheduler_alpha,
                    "optimizer_w": optimizer_w,
                    'optimizer_alpha': optimizer_alpha,
                    'perf_scoreboard': perf_scoreboard
                }, is_best, os.path.join(args.save, "ckpts"))
        # update lr
        scheduler_w.step()
        scheduler_alpha.step()
Exemple #5
0
def main():

  args = get_args()
  args.save = '{}/search-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
  tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

  log_format = '%(asctime)s %(message)s'
  logging.basicConfig(stream=sys.stdout, level=logging.INFO,
      format=log_format, datefmt='%m/%d %I:%M:%S %p')
  fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
  fh.setFormatter(logging.Formatter(log_format))
  logging.getLogger().addHandler(fh)


  CIFAR_CLASSES = 10

  if not torch.cuda.is_available():
    logging.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logging.info('gpu device = %d' % args.gpu)
  logging.info("args = %s", args)

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
  model = model.cuda()
  logging.info("param size = %fMB", tools.count_parameters_in_MB(model))

  optimizer = torch.optim.SGD(
      model.weight_parameters(),
      args.learning_rate,
      momentum=args.momentum,
      weight_decay=args.weight_decay)

  train_transform, valid_transform = tools._data_transforms_cifar10(args)
  train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)

  num_train = len(train_data)
  indices = list(range(num_train))
  split = int(np.floor(args.train_portion * num_train))

  train_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
      pin_memory=True, num_workers=2)

  valid_queue = torch.utils.data.DataLoader(
      train_data, batch_size=args.batch_size,
      sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
      pin_memory=True, num_workers=2)

  scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

  architect = Architect(model, args)

  for epoch in range(args.epochs):
    scheduler.step()
    lr = scheduler.get_lr()[0]
    logging.info('epoch %d lr %e', epoch, lr)

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    print(F.softmax(model.alphas_normal, dim=-1))
    print(F.softmax(model.alphas_reduce, dim=-1))

    # training
    train_acc, train_obj = train(train_queue, valid_queue, model, architect, criterion, optimizer, lr,args)
    logging.info('train_acc %f', train_acc)

    # validation
    valid_acc, valid_obj = infer(valid_queue, model, criterion,args)
    logging.info('valid_acc %f', valid_acc)

    tools.save(model, os.path.join(args.save, 'weights.pt'))
def main():
    # get log
    args = get_args()
    args.save = '{}/eval-{}-{}'.format(args.save,args.note,time.strftime("%Y%m%d-%H%M%S"))
    # if not os.path.exists(args.save):
    #     os.path.mkdir(args.save)
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout, level=logging.INFO,
        format=log_format, datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save,"args.yaml"), "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    if args.cifar100:
        CIFAR_CLASSES = 100
        data_folder = 'cifar-100-python'
    else:
        CIFAR_CLASSES = 10
        data_folder = 'cifar-10-batches-py'

    # load model and loss func 
    genotype = eval("binary_genotypes.%s" % args.arch)
    logger.info('---------Genotype---------')
    logger.info(genotype)
    logger.info('--------------------------')
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype,args.group)
    model = model.to(args.device)
    logging.info("param size = %fMB", tools.count_parameters_in_MB(model))
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(args.device)


    if args.cifar100:
        train_transform, valid_transform = tools._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)

    trainLoader = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

    valLoader = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)
    


    if args.optimizer.lower() == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    elif args.optimizer.lower() == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate,
                                    betas=(0.9, 0.999), 
                                    weight_decay=args.weight_decay)

    elif args.optimizer.lower() == 'radam':
        optimizer = RAdam(model.parameters(), lr=args.learning_rate,
                                    betas=(0.9, 0.999), 
                                    weight_decay=args.weight_decay)
    else:
        NotImplementedError()


    if args.scheduler.lower() == 'warm_up_cos':
        warm_up_epochs = 5
        warm_up_with_adam = lambda epoch: (epoch+1) / warm_up_epochs if epoch < warm_up_epochs \
         else 0.5 * (1 + math.cos(math.pi * epoch / args.epochs))
        scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=warm_up_with_adam)

    elif args.scheduler.lower() == "cos":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs,eta_min=0,last_epoch=-1)

    elif args.scheduler.lower() == "step":
        scheduler=torch.optim.lr_scheduler.StepLR(optimizer=optimizer,step_size=args.step_size,gamma=0.1,last_epoch=-1)
    
    elif  args.scheduler.lower() == "mstep":
        scheduler=torch.optim.lr_scheduler.MultiStepLR(optimizer, args.steplist, gamma=args.gamma)
    else:
        NotImplementedError()


    # best recoder
    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    # resume 
    start_epoch=0
    if args.resume:
        if os.path.isfile(args.resume_path):
            model,extras,start_epoch=loadCheckpoint(args.resume_path,model,args)
            optimizer,perf_scoreboard=extras["optimizer"],extras["perf_scoreboard"]
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(args.resume))

    # just eval model
    if args.eval:
        validate(valLoader, model, criterion, -1, monitors, args,logger)
    else:
        # resume training or pretrained model, we should eval model firstly.
        if args.resume:
            logger.info('>>>>>>>> Epoch -1 (pre-trained model evaluation)')
            top1, top5, _ = validate(valLoader, model, criterion,
                                             start_epoch - 1, monitors, args,logger)
            l,board=perf_scoreboard.update(top1, top5, start_epoch - 1)
            for idx in range(l):
                score = board[idx]
                logger.info('Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                                idx + 1, score['epoch'], score['top1'], score['top5'])

        # start training
        for _ in range(start_epoch):
            scheduler.step()

        for epoch in range(start_epoch, args.epochs):
            drop_prob = args.drop_path_prob * epoch / args.epochs
            model.drop_path_prob = drop_prob
            logger.info('>>>> Epoch {} Lr {} Drop:{} '.format(epoch,
                                                            optimizer.param_groups[0]['lr'],
                                                            drop_prob))

            t_top1, t_top5, t_loss = train(trainLoader, model, criterion, optimizer,
                                                   scheduler, epoch, monitors, args,logger)
            v_top1, v_top5, v_loss = validate(valLoader, model, criterion, epoch, monitors, args,logger)

            tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {'train': t_loss, 'val': v_loss}, epoch)
            tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {'train': t_top1, 'val': v_top1}, epoch)
            tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {'train': t_top5, 'val': v_top5}, epoch)

            l,board=perf_scoreboard.update(v_top1, v_top5, epoch)
            for idx in range(l):
                score = board[idx]
                logger.info('Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                                idx + 1, score['epoch'], score['top1'], score['top5'])


            is_best = perf_scoreboard.is_best(epoch)
            # save model
            if (epoch+1)% 5==0:
                saveCheckpoint(epoch, "search", model,
                                {
                                # 'scheduler': scheduler,
                                 'optimizer': optimizer,
                                 'perf_scoreboard' : perf_scoreboard
                                 }, 
                                is_best,os.path.join(args.save,"ckpts"))
            # update lr
            scheduler.step()

        logger.info('>>>>>>>> Epoch -1 (final model evaluation)')
        validate(valLoader, model, criterion, -1, monitors, args,logger)

    tbmonitor.writer.close()  # close the TensorBoard
    logger.info('Program completed successfully ... exiting ...')
    logger.info('If you have any questions or suggestions, please visit: github.com/lswzjuer/nas-pruning-quantize')
Exemple #7
0
def main():

    args = get_args()
    # get log
    args.save = '{}/search-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(args.device)
    model = Network(C=args.init_channels,
                    num_classes=args.class_num,
                    layers=args.layers,
                    criterion=criterion,
                    steps=args.nodes,
                    multiplier=args.nodes,
                    stem_multiplier=args.stem_multiplier,
                    group=args.group)
    model = model.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model))

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)

    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)

    num_train = len(traindata)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    train_queue = torch.utils.data.DataLoader(
        traindata,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)
    valid_queue = torch.utils.data.DataLoader(
        traindata,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # weight optimizer and struct parameters /mask optimizer
    optimizer_w = torch.optim.SGD(model.weight_parameters(),
                                  args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_w, float(args.epochs), eta_min=args.learning_rate_min)

    optimizer_a = torch.optim.Adam(model.arch_parameters(),
                                   lr=args.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=args.arch_weight_decay)

    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    # resume
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resume_path):
            model, extras, start_epoch = loadCheckpoint(
                args.resume_path, model, args)
            optimizer_w, optimizer_a, perf_scoreboard = extras[
                "optimizer_w"], extras["optimizer_a"], extras[
                    "perf_scoreboard"]
        else:
            raise FileNotFoundError("No checkpoint found at '{}'".format(
                args.resume))

    if args.resume:
        logger.info('>>>>>>>> Epoch -1 (pre-trained model evaluation)')
        top1, top5, _ = validate(valLoader, model, criterion, start_epoch - 1,
                                 monitors, args, logger)
        l, board = perf_scoreboard.update(top1, top5, start_epoch - 1)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

    # start training
    for _ in range(start_epoch):
        scheduler_w.step()

    for epoch in range(args.epochs):
        lr = scheduler_w.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        if epoch < args.arch_after:
            model.p = float(
                args.drop_rate) * (args.epochs - epoch - 1) / args.epochs
            model.update_p()
            t_top1, t_top5, t_loss = train(train_queue,
                                           valid_queue,
                                           model,
                                           criterion,
                                           epoch,
                                           optimizer_w,
                                           optimizer_a,
                                           args,
                                           monitors,
                                           logger,
                                           train_arch=False)
        else:
            model.p = float(args.drop_rate) * np.exp(
                -(epoch - args.arch_after) * 0.2)
            model.update_p()
            t_top1, t_top5, t_loss = train(train_queue,
                                           valid_queue,
                                           model,
                                           criterion,
                                           epoch,
                                           optimizer_w,
                                           optimizer_a,
                                           args,
                                           monitors,
                                           logger,
                                           train_arch=True)

        v_top1, v_top5, v_loss = validate(valLoader, model, criterion, epoch,
                                          args, monitors, logger)

        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        logger.info("normal: \n{}".format(
            model.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev: {}'.format(model.genotype()))
        is_best = perf_scoreboard.is_best(epoch)
        # save model
        if epoch % args.save_freq == 0:
            saveCheckpoint(
                epoch, args.model, model, {
                    'optimizer_w': optimizer_w,
                    'optimizer_a': optimizer_a,
                    'perf_scoreboard': perf_scoreboard
                }, is_best, os.path.join(args.save, "ckpts"))
        # update lr
        scheduler_w.step()
Exemple #8
0
def main():
    args = get_args()

    # get log
    args.save = '{}/search-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    # set random seed
    if args.seed is None:
        args.seed = random.randint(1, 10000)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.multi_gpu = args.gpus > 1 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    criterion = nn.CrossEntropyLoss()
    model = Network(C=args.init_channels,
                    num_classes=args.class_num,
                    layers=args.layers,
                    steps=args.nodes,
                    multiplier=args.nodes,
                    stem_multiplier=args.stem_multiplier,
                    group=args.group)
    if args.multi_gpu:
        logger.info('use: %d gpus', args.gpus)
        model = nn.DataParallel(model)
    #model = model.freeze_arch_parameters()
    model = model.to(args.device)
    criterion = criterion.to(args.device)
    logger.info("param size = %fMB", tools.count_parameters_in_MB(model))

    # get dataloader
    if args.dataset_name == "cifar10":
        train_transform, valid_transform = tools._data_transforms_cifar10(args)
        traindata = dset.CIFAR10(root=args.dataset,
                                 train=True,
                                 download=False,
                                 transform=train_transform)
        valdata = dset.CIFAR10(root=args.dataset,
                               train=False,
                               download=False,
                               transform=valid_transform)
    else:
        train_transform, valid_transform = tools._data_transforms_mnist(args)
        traindata = dset.MNIST(root=args.dataset,
                               train=True,
                               download=False,
                               transform=train_transform)
        valdata = dset.MNIST(root=args.dataset,
                             train=False,
                             download=False,
                             transform=valid_transform)

    num_train = len(traindata)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    train_queue = torch.utils.data.DataLoader(
        traindata,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)
    valid_queue = torch.utils.data.DataLoader(
        traindata,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)
    valLoader = torch.utils.data.DataLoader(valdata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # weight optimizer and struct parameters /mask optimizer
    optimizer_w = torch.optim.SGD(model.weight_parameters(),
                                  args.learning_rate,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    scheduler_w = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_w, float(args.epochs), eta_min=args.learning_rate_min)

    optimizer_a = torch.optim.Adam(model.arch_parameters(),
                                   lr=args.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=0)

    perf_scoreboard = PerformanceScoreboard(args.num_best_scores)

    for epoch in range(args.epochs):
        lr = scheduler_w.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        flag = epoch >= args.arch_after
        t_top1, t_top5, t_loss = train(train_queue,
                                       valid_queue,
                                       model,
                                       criterion,
                                       epoch,
                                       optimizer_w,
                                       optimizer_a,
                                       monitors,
                                       args,
                                       logger,
                                       train_arch=flag)
        v_top1, v_top5, v_loss = validate(valLoader, model, criterion, epoch,
                                          monitors, args, logger)

        tbmonitor.writer.add_scalars('Train_vs_Validation/Loss', {
            'train': t_loss,
            'val': v_loss
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top1', {
            'train': t_top1,
            'val': v_top1
        }, epoch)
        tbmonitor.writer.add_scalars('Train_vs_Validation/Top5', {
            'train': t_top5,
            'val': v_top5
        }, epoch)

        # for name,param in model.named_parameters():
        #     if ("conv" in name.lower() or "fc" in name.lower()) and "weight" in name.lower():
        #         tbmonitor.writer.add_histogram(name,param.data.cpu(),epoch)
        l, board = perf_scoreboard.update(v_top1, v_top5, epoch)
        for idx in range(l):
            score = board[idx]
            logger.info(
                'Scoreboard best %d ==> Epoch [%d][Top1: %.3f   Top5: %.3f]',
                idx + 1, score['epoch'], score['top1'], score['top5'])

        logger.info("normal: \n{}".format(
            model.alphas_normal.data.cpu().numpy()))
        logger.info("reduce: \n{}".format(
            model.alphas_reduce.data.cpu().numpy()))
        logger.info('Genotypev1: {}'.format(model.genotypev1()))
        logger.info('Genotypev2: {}'.format(model.genotypev2()))
        mask = []
        pruned = 0
        num = 0
        for param in model.arch_parameters():
            weight_copy = param.clone()
            param_array = np.array(weight_copy.detach().cpu())
            pruned += sum(w == 0 for w in param_array)
            num += len(param_array)
        logger.info("Epoch:{} Pruned {} / {}".format(epoch, pruned, num))

        is_best = perf_scoreboard.is_best(epoch)
        # save model
        if epoch % args.save_freq == 0:
            saveCheckpoint(
                epoch, args.model, model, {
                    'scheduler': scheduler_w,
                    'optimizer': optimizer,
                    'perf_scoreboard': perf_scoreboard
                }, is_best, os.path.join(args.save, "ckpts"))
        # update lr
        scheduler_w.step()