Beispiel #1
0
def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    logging.info("args = %s", args)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader, valid_loader = utils.search_dataloader(args, kwargs)

    criterion = nn.CrossEntropyLoss().to(device)
    model = Network(device, nodes=2).to(device)

    logging.info(
        "param size = %fMB",
        np.sum(np.prod(v.size())
               for name, v in model.named_parameters()) / 1e6)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=0.9,
                                weight_decay=args.weight_decay)

    architect = Architect(model)

    for epoch in range(args.epochs):
        logging.info("Starting epoch %d/%d", epoch + 1, args.epochs)

        # training
        train_acc, train_obj = train(train_loader, valid_loader, model,
                                     architect, criterion, optimizer, device)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_loader, model, criterion, device)
        logging.info('valid_acc %f', valid_acc)

        # compute the discrete architecture from the current alphas
        genotype = model.genotype()
        logging.info('genotype = %s', genotype)

        print(F.softmax(model.alphas_normal, dim=-1))
        print(F.softmax(model.alphas_reduce, dim=-1))

        with open(args.save + '/architecture', 'w') as f:
            f.write(str(genotype))
Beispiel #2
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('GPU device = %d' % args.gpu)
    logging.info("args = %s", args)
    #  prepare dataset
    train_transform, valid_transform = utils.data_transforms(args.dataset,args.cutout,args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == 'mit67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.tmp_data_dir  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.tmp_data_dir  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    random.shuffle(indices)
    
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers)
    
    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 3, 6]
    if len(args.dropout_rate) ==3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]), CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]), largemode=args.dataset in utils.LARGE_DATASETS)
        
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
                network_params.append(v)       
        optimizer = torch.optim.SGD(
                network_params,
                args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                    lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.update_p()
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)
            else:
                model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) 
                model.update_p()                
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement. 
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()        
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)
        
        if sp == len(num_to_keep) - 1:
            arch_param = model.arch_parameters()
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])                
            # Generate Architecture
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            for sks in range(0, len(PRIMITIVES)+1):
                max_sk = len(PRIMITIVES) - sks
                num_sk = check_sk_number(switches_normal)
                if num_sk < max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal, normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)
    with open(args.save + "/best_genotype.txt", "w") as f:
        f.write(str(genotype))
Beispiel #3
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)

    # eps_no_archs = [10, 10, 10]
    eps_no_archs = [2, 2, 2]
    for sp in range(len(num_to_keep)):
        # if sp < 1:
        #     continue
        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        p=float(drop_rate[sp]))
        model = nn.DataParallel(model)
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal')
                    or k.endswith('alphas_reduce')):
                network_params.append(v)
        optimizer = torch.optim.SGD(network_params,
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        # optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
        #             lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
                                       lr=args.arch_learning_rate,
                                       betas=(0, 0.999),
                                       weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        # cur_sub_model = get_cur_model(model,switches_normal,switches_reduce,num_to_keep,num_to_drop,sp)
        for epoch in range(epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                # if 0:
                model.module.p = float(
                    drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=False)
            else:
                model.module.p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement.
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = model.module.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                # for the last stage, drop all Zero operations
                drop = get_min_k_no_zero(normal_prob[i, :], idxs,
                                         num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs,
                                         num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)

        if sp == len(num_to_keep) - 1:
            arch_param = model.module.arch_parameters()
            normal_prob = F.softmax(arch_param[0],
                                    dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1],
                                    dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])
            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):  # 选出最大的两个前序节点
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = check_sk_number(switches_normal)
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal,
                                                     switches_normal_2,
                                                     normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal,
                                                      normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)
Beispiel #4
0
class neural_architecture_search():
    def __init__(self, args):
        self.args = args

        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        if self.args.distributed:
            # Init distributed environment
            self.rank, self.world_size, self.device = init_dist(
                port=self.args.port)
            self.seed = self.rank * self.args.seed
        else:
            torch.cuda.set_device(self.args.gpu)
            self.device = torch.device("cuda")
            self.rank = 0
            self.seed = self.args.seed
            self.world_size = 1

        if self.args.fix_seedcudnn:
            random.seed(self.seed)
            torch.backends.cudnn.deterministic = True
            np.random.seed(self.seed)
            cudnn.benchmark = False
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
        else:
            np.random.seed(self.seed)
            cudnn.benchmark = True
            torch.manual_seed(self.seed)
            cudnn.enabled = True
            torch.cuda.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)

        self.path = os.path.join(generate_date, self.args.save)
        if self.rank == 0:
            utils.create_exp_dir(generate_date,
                                 self.path,
                                 scripts_to_save=glob.glob('*.py'))
            logging.basicConfig(stream=sys.stdout,
                                level=logging.INFO,
                                format=log_format,
                                datefmt='%m/%d %I:%M:%S %p')
            fh = logging.FileHandler(os.path.join(self.path, 'log.txt'))
            fh.setFormatter(logging.Formatter(log_format))
            logging.getLogger().addHandler(fh)
            logging.info("self.args = %s", self.args)
            self.logger = tensorboardX.SummaryWriter(
                './runs/' + generate_date + '/nas_{}'.format(self.args.remark))
        else:
            self.logger = None

        # set default resource_lambda for different methods
        if self.args.resource_efficient:
            if self.args.method == 'policy_gradient':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-4
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'reparametrization':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-5
            if self.args.method == 'discrete':
                if self.args.log_penalty:
                    default_resource_lambda = 1e-2
                else:
                    default_resource_lambda = 1e-4
            if self.args.resource_lambda == default_lambda:
                self.args.resource_lambda = default_resource_lambda

        #initialize loss function
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        #initialize model
        self.init_model()

        #calculate model param size
        if self.rank == 0:
            logging.info("param size = %fMB",
                         utils.count_parameters_in_MB(self.model))
            self.model._logger = self.logger
            self.model._logging = logging

        #initialize optimizer
        self.init_optimizer()

        #iniatilize dataset loader
        self.init_loaddata()

        self.update_theta = True
        self.update_alpha = True

    def init_model(self):

        self.model = Network(self.args.init_channels, CIFAR_CLASSES,
                             self.args.layers, self.criterion, self.args,
                             self.rank, self.world_size)
        self.model.to(self.device)
        if self.args.distributed:
            broadcast_params(self.model)
        for v in self.model.parameters():
            if v.requires_grad:
                if v.grad is None:
                    v.grad = torch.zeros_like(v)
        self.model.normal_log_alpha.grad = torch.zeros_like(
            self.model.normal_log_alpha)
        self.model.reduce_log_alpha.grad = torch.zeros_like(
            self.model.reduce_log_alpha)

    def init_optimizer(self):

        if args.distributed:
            self.optimizer = torch.optim.SGD(
                [
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ],
                self.args.learning_rate,
                momentum=self.args.momentum,
                weight_decay=self.args.weight_decay)
            self.arch_optimizer = torch.optim.Adam(
                [
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ],
                lr=self.args.arch_learning_rate,
                betas=(0.5, 0.999),
                weight_decay=self.args.arch_weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             self.args.learning_rate,
                                             momentum=self.args.momentum,
                                             weight_decay=args.weight_decay)

            self.arch_optimizer = torch.optim.SGD(
                self.model.arch_parameters(), lr=self.args.arch_learning_rate)

    def init_loaddata(self):

        train_transform, valid_transform = utils._data_transforms_cifar10(
            self.args)
        train_data = dset.CIFAR10(root=self.args.data,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=self.args.data,
                                  train=False,
                                  download=True,
                                  transform=valid_transform)

        if self.args.seed:

            def worker_init_fn():
                seed = self.seed
                np.random.seed(seed)
                random.seed(seed)
                torch.manual_seed(seed)
                return
        else:
            worker_init_fn = None

        if self.args.distributed:
            train_sampler = DistributedSampler(train_data)
            valid_sampler = DistributedSampler(valid_data)

            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=train_sampler)
            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size // self.world_size,
                shuffle=False,
                num_workers=0,
                pin_memory=False,
                sampler=valid_sampler)

        else:
            self.train_queue = torch.utils.data.DataLoader(
                train_data,
                batch_size=self.args.batch_size,
                shuffle=True,
                pin_memory=False,
                num_workers=2)

            self.valid_queue = torch.utils.data.DataLoader(
                valid_data,
                batch_size=self.args.batch_size,
                shuffle=False,
                pin_memory=False,
                num_workers=2)

    def main(self):
        # lr scheduler: cosine annealing
        # temp scheduler: linear annealing (self-defined in utils)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            float(self.args.epochs),
            eta_min=self.args.learning_rate_min)

        self.temp_scheduler = utils.Temp_Scheduler(self.args.epochs,
                                                   self.model._temp,
                                                   self.args.temp,
                                                   temp_min=self.args.temp_min)

        for epoch in range(self.args.epochs):
            if self.args.random_sample_pretrain:
                if epoch < self.args.random_sample_pretrain_epoch:
                    self.args.random_sample = True
                else:
                    self.args.random_sample = False

            self.scheduler.step()
            if self.args.temp_annealing:
                self.model._temp = self.temp_scheduler.step()
            self.lr = self.scheduler.get_lr()[0]

            if self.rank == 0:
                logging.info('epoch %d lr %e temp %e', epoch, self.lr,
                             self.model._temp)
                self.logger.add_scalar('epoch_temp', self.model._temp, epoch)
                logging.info(self.model.normal_log_alpha)
                logging.info(self.model.reduce_log_alpha)
                logging.info(
                    self.model._get_weights(self.model.normal_log_alpha[0]))
                logging.info(
                    self.model._get_weights(self.model.reduce_log_alpha[0]))

            genotype_edge_all = self.model.genotype_edge_all()

            if self.rank == 0:
                logging.info('genotype_edge_all = %s', genotype_edge_all)
                # create genotypes.txt file
                txt_name = self.args.remark + '_genotype_edge_all_epoch' + str(
                    epoch)
                utils.txt('genotype', self.args.save, txt_name,
                          str(genotype_edge_all), generate_date)

            self.model.train()
            train_acc, loss, error_loss, loss_alpha = self.train(
                epoch, logging)
            if self.rank == 0:
                logging.info('train_acc %f', train_acc)
                self.logger.add_scalar("epoch_train_acc", train_acc, epoch)
                self.logger.add_scalar("epoch_train_error_loss", error_loss,
                                       epoch)
                if self.args.dsnas:
                    self.logger.add_scalar("epoch_train_alpha_loss",
                                           loss_alpha, epoch)

            # validation
            self.model.eval()
            valid_acc, valid_obj = self.infer(epoch)
            if self.args.gen_max_child:
                self.args.gen_max_child_flag = True
                valid_acc_max_child, valid_obj_max_child = self.infer(epoch)
                self.args.gen_max_child_flag = False

            if self.rank == 0:
                logging.info('valid_acc %f', valid_acc)
                self.logger.add_scalar("epoch_valid_acc", valid_acc, epoch)
                if self.args.gen_max_child:
                    logging.info('valid_acc_argmax_alpha %f',
                                 valid_acc_max_child)
                    self.logger.add_scalar("epoch_valid_acc_argmax_alpha",
                                           valid_acc_max_child, epoch)

                utils.save(self.model, os.path.join(self.path, 'weights.pt'))

        if self.rank == 0:
            logging.info(self.model.normal_log_alpha)
            logging.info(self.model.reduce_log_alpha)
            genotype_edge_all = self.model.genotype_edge_all()
            logging.info('genotype_edge_all = %s', genotype_edge_all)

    def train(self, epoch, logging):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()
        grad = utils.AvgrageMeter()

        normal_resource_gradient = 0
        reduce_resource_gradient = 0
        normal_loss_gradient = 0
        reduce_loss_gradient = 0
        normal_total_gradient = 0
        reduce_total_gradient = 0

        loss_alpha = None

        count = 0
        for step, (input, target) in enumerate(self.train_queue):
            if self.args.alternate_update:
                if step % 2 == 0:
                    self.update_theta = True
                    self.update_alpha = False
                else:
                    self.update_theta = False
                    self.update_alpha = True

            n = input.size(0)
            input = input.to(self.device)
            target = target.to(self.device, non_blocking=True)
            if self.args.snas:
                logits, logits_aux, penalty, op_normal, op_reduce = self.model(
                    input)
                error_loss = self.criterion(logits, target)
                if self.args.auxiliary:
                    loss_aux = self.criterion(logits_aux, target)
                    error_loss += self.args.auxiliary_weight * loss_aux

            if self.args.dsnas:
                logits, error_loss, loss_alpha, penalty = self.model(
                    input, target, self.criterion)

            num_normal = self.model.num_normal
            num_reduce = self.model.num_reduce
            normal_arch_entropy = self.model._arch_entropy(
                self.model.normal_log_alpha)
            reduce_arch_entropy = self.model._arch_entropy(
                self.model.reduce_log_alpha)

            if self.args.resource_efficient:
                if self.args.method == 'policy_gradient':
                    resource_penalty = (penalty[2]) / 6 + self.args.ratio * (
                        penalty[7]) / 2
                    log_resource_penalty = (
                        penalty[35]) / 6 + self.args.ratio * (penalty[36]) / 2
                elif self.args.method == 'reparametrization':
                    resource_penalty = (penalty[26]) / 6 + self.args.ratio * (
                        penalty[25]) / 2
                    log_resource_penalty = (
                        penalty[37]) / 6 + self.args.ratio * (penalty[38]) / 2
                elif self.args.method == 'discrete':
                    resource_penalty = (penalty[28]) / 6 + self.args.ratio * (
                        penalty[27]) / 2
                    log_resource_penalty = (
                        penalty[39]) / 6 + self.args.ratio * (penalty[40]) / 2
                elif self.args.method == 'none':
                    # TODo
                    resource_penalty = torch.zeros(1).cuda()
                    log_resource_penalty = torch.zeros(1).cuda()
                else:
                    logging.info(
                        "wrongly input of method, please re-enter --method from 'policy_gradient', 'discrete', "
                        "'reparametrization', 'none'")
                    sys.exit(1)
            else:
                resource_penalty = torch.zeros(1).cuda()
                log_resource_penalty = torch.zeros(1).cuda()

            if self.args.log_penalty:
                resource_loss = self.model._resource_lambda * log_resource_penalty
            else:
                resource_loss = self.model._resource_lambda * resource_penalty

            if self.args.loss:
                if self.args.snas:
                    loss = resource_loss.clone() + error_loss.clone()
                elif self.args.dsnas:
                    loss = resource_loss.clone()
                else:
                    loss = resource_loss.clone() + -child_coef * (
                        torch.log(normal_one_hot_prob) +
                        torch.log(reduce_one_hot_prob)).sum()
            else:
                if self.args.snas or self.args.dsnas:
                    loss = error_loss.clone()

            if self.args.distributed:
                loss.div_(self.world_size)
                error_loss.div_(self.world_size)
                resource_loss.div_(self.world_size)
                if self.args.dsnas:
                    loss_alpha.div_(self.world_size)

            # logging gradient
            count += 1
            if self.args.resource_efficient:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                resource_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_resource_gradient += self.model.normal_log_alpha.grad
                    reduce_resource_gradient += self.model.reduce_log_alpha.grad
            if self.args.snas:
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()
                error_loss.backward(retain_graph=True)
                if not self.args.random_sample:
                    normal_loss_gradient += self.model.normal_log_alpha.grad
                    reduce_loss_gradient += self.model.reduce_log_alpha.grad
                self.optimizer.zero_grad()
                self.arch_optimizer.zero_grad()

            if self.args.snas or not self.args.random_sample and not self.args.dsnas:
                loss.backward()
            if not self.args.random_sample:
                normal_total_gradient += self.model.normal_log_alpha.grad
                reduce_total_gradient += self.model.reduce_log_alpha.grad

            if self.args.distributed:
                reduce_tensorgradients(self.model.parameters(), sync=True)
                nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters() if
                    name != 'normal_log_alpha' and name != 'reduce_log_alpha'
                ], self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_([
                    param for name, param in self.model.named_parameters()
                    if name == 'normal_log_alpha' or name == 'reduce_log_alpha'
                ], 10.)
            else:
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.grad_clip)
                arch_grad_norm = nn.utils.clip_grad_norm_(
                    self.model.arch_parameters(), 10.)

            grad.update(arch_grad_norm)
            if not self.args.fix_weight and self.update_theta:
                self.optimizer.step()
            self.optimizer.zero_grad()
            if not self.args.random_sample and self.update_alpha:
                self.arch_optimizer.step()
            self.arch_optimizer.zero_grad()

            if self.rank == 0:
                self.logger.add_scalar(
                    "iter_train_loss", error_loss,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "normal_arch_entropy", normal_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "reduce_arch_entropy", reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "total_arch_entropy",
                    normal_arch_entropy + reduce_arch_entropy,
                    step + len(self.train_queue.dataset) * epoch)
                if self.args.dsnas:
                    #reward_normal_edge
                    self.logger.add_scalar(
                        "reward_normal_edge_0",
                        self.model.normal_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_1",
                        self.model.normal_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_2",
                        self.model.normal_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_3",
                        self.model.normal_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_4",
                        self.model.normal_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_5",
                        self.model.normal_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_6",
                        self.model.normal_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_7",
                        self.model.normal_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_8",
                        self.model.normal_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_9",
                        self.model.normal_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_10",
                        self.model.normal_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_11",
                        self.model.normal_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_12",
                        self.model.normal_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_normal_edge_13",
                        self.model.normal_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                    #reward_reduce_edge
                    self.logger.add_scalar(
                        "reward_reduce_edge_0",
                        self.model.reduce_edge_reward[0],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_1",
                        self.model.reduce_edge_reward[1],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_2",
                        self.model.reduce_edge_reward[2],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_3",
                        self.model.reduce_edge_reward[3],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_4",
                        self.model.reduce_edge_reward[4],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_5",
                        self.model.reduce_edge_reward[5],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_6",
                        self.model.reduce_edge_reward[6],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_7",
                        self.model.reduce_edge_reward[7],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_8",
                        self.model.reduce_edge_reward[8],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_9",
                        self.model.reduce_edge_reward[9],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_10",
                        self.model.reduce_edge_reward[10],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_11",
                        self.model.reduce_edge_reward[11],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_12",
                        self.model.reduce_edge_reward[12],
                        step + len(self.train_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "reward_reduce_edge_13",
                        self.model.reduce_edge_reward[13],
                        step + len(self.train_queue.dataset) * epoch)
                #policy size
                self.logger.add_scalar(
                    "iter_normal_size_policy", penalty[2] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_policy", penalty[7] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # baseline: discrete_probability
                self.logger.add_scalar(
                    "iter_normal_size_baseline", penalty[3] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_baseline", penalty[5] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_baseline", penalty[6] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_baseline", penalty[8] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_baseline", penalty[9] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_baseline", penalty[10] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # R - median(R)
                self.logger.add_scalar(
                    "iter_normal_size-avg", penalty[60] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops-avg", penalty[61] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac-avg", penalty[62] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size-avg", penalty[63] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops-avg", penalty[64] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac-avg", penalty[65] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # lnR - ln(median)
                self.logger.add_scalar(
                    "iter_normal_ln_size-ln_avg", penalty[66] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_flops-ln_avg", penalty[67] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_ln_mac-ln_avg", penalty[68] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_size-ln_avg", penalty[69] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_flops-ln_avg", penalty[70] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_ln_mac-ln_avg", penalty[71] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                '''
                self.logger.add_scalar("iter_normal_size_normalized", penalty[17] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_flops_normalized", penalty[18] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_mac_normalized", penalty[19] / 6, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_size_normalized", penalty[20] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_flops_normalized", penalty[21] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_mac_normalized", penalty[22] / 2, step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_normal_penalty_normalized", penalty[23] / 6,
                                  step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar("iter_reduce_penalty_normalized", penalty[24] / 2,
                                  step + len(self.train_queue.dataset) * epoch)
                '''
                # Monte_Carlo(R_i)
                self.logger.add_scalar(
                    "iter_normal_size_mc", penalty[29] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_flops_mc", penalty[30] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_mac_mc", penalty[31] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_size_mc", penalty[32] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_flops_mc", penalty[33] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_mac_mc", penalty[34] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(|R_i|)
                self.logger.add_scalar(
                    "iter_normal_log_size", penalty[41] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_flops", penalty[42] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_log_mac", penalty[43] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_size", penalty[44] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_flops", penalty[45] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_log_mac", penalty[46] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)R_i
                self.logger.add_scalar(
                    "iter_normal_logP_size", penalty[47] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_flops", penalty[48] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_mac", penalty[49] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_size", penalty[50] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_flops", penalty[51] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_mac", penalty[52] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                # log(P)log(R_i)
                self.logger.add_scalar(
                    "iter_normal_logP_log_size", penalty[53] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_flops", penalty[54] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_normal_logP_log_mac", penalty[55] / num_normal,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_size", penalty[56] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_flops", penalty[57] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)
                self.logger.add_scalar(
                    "iter_reduce_logP_log_mac", penalty[58] / num_reduce,
                    step + len(self.train_queue.dataset) * epoch)

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

            if self.args.distributed:
                loss = loss.detach()
                dist.all_reduce(error_loss)
                dist.all_reduce(prec1)
                dist.all_reduce(prec5)
                prec1.div_(self.world_size)
                prec5.div_(self.world_size)
                #dist_util.all_reduce([loss, prec1, prec5], 'mean')
            objs.update(error_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.args.report_freq == 0 and self.rank == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)
                self.logger.add_scalar(
                    "iter_train_top1_acc", top1.avg,
                    step + len(self.train_queue.dataset) * epoch)

        if self.rank == 0:
            logging.info('-------resource gradient--------')
            logging.info(normal_resource_gradient / count)
            logging.info(reduce_resource_gradient / count)
            logging.info('-------loss gradient--------')
            logging.info(normal_loss_gradient / count)
            logging.info(reduce_loss_gradient / count)
            logging.info('-------total gradient--------')
            logging.info(normal_total_gradient / count)
            logging.info(reduce_total_gradient / count)

        return top1.avg, loss, error_loss, loss_alpha

    def infer(self, epoch):
        objs = utils.AvgrageMeter()
        top1 = utils.AvgrageMeter()
        top5 = utils.AvgrageMeter()

        self.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.valid_queue):
                input = input.to(self.device)
                target = target.to(self.device)
                if self.args.snas:
                    logits, logits_aux, resource_loss, op_normal, op_reduce = self.model(
                        input)
                    loss = self.criterion(logits, target)
                elif self.args.dsnas:
                    logits, error_loss, loss_alpha, resource_loss = self.model(
                        input, target, self.criterion)
                    loss = error_loss

                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))

                if self.args.distributed:
                    loss.div_(self.world_size)
                    loss = loss.detach()
                    dist.all_reduce(loss)
                    dist.all_reduce(prec1)
                    dist.all_reduce(prec5)
                    prec1.div_(self.world_size)
                    prec5.div_(self.world_size)
                objs.update(loss.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                if step % self.args.report_freq == 0 and self.rank == 0:
                    logging.info('valid %03d %e %f %f', step, objs.avg,
                                 top1.avg, top5.avg)
                    self.logger.add_scalar(
                        "iter_valid_loss", loss,
                        step + len(self.valid_queue.dataset) * epoch)
                    self.logger.add_scalar(
                        "iter_valid_top1_acc", top1.avg,
                        step + len(self.valid_queue.dataset) * epoch)

        return top1.avg, objs.avg
Beispiel #5
0
if args.cuda:
    # Move model to GPU.
    model.cuda()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
                      lr=args.base_lr * hvd.size(),
                      momentum=args.momentum,
                      weight_decay=args.wd)  #, nesterov=True)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(optimizer,
                                     named_parameters=model.named_parameters(),
                                     compression=compression)

arch_optimizer = torch.optim.Adam(model.arch_parameters(),
                                  lr=args.arch_learning_rate,
                                  betas=(0.5, 0.999),
                                  weight_decay=args.arch_weight_decay)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
    filepath = args.checkpoint_format.format(exp=args.save,
                                             epoch=resume_from_epoch)
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
Beispiel #6
0
def model_search(args):
    if os.path.isdir(args.save) == False:
        os.makedirs(args.save)
    save_dir = '{}search-{}-{}'.format(args.save, args.note,
                                       time.strftime("%Y%m%d-%H%M%S"))
    utils.create_exp_dir(save_dir, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(save_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    if args.cifar100:
        CIFAR_CLASSES = 100
        data_folder = 'cifar-100-python'
    else:
        CIFAR_CLASSES = 10
        data_folder = 'cifar-10-batches-py'

    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)

    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.train_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.train_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)

    ###### if dataset is too small #######
    num_train = len(train_data)
    iter_per_one_epoch = num_train // (2 * args.batch_size)
    if iter_per_one_epoch >= 100:
        train_extend_rate = 1
    else:
        train_extend_rate = (100 // iter_per_one_epoch) + 1
    iter_per_one_epoch = iter_per_one_epoch * train_extend_rate
    logging.info('num original train data: %d', num_train)
    logging.info('iter per one epoch: %d', iter_per_one_epoch)
    ######################################

    indices = list(range(num_train))
    random.shuffle(indices)
    split = int(np.floor(args.train_portion * num_train))
    train_set = torch.utils.data.Subset(train_data, indices[:split])
    valid_set = torch.utils.data.Subset(train_data, indices[split:num_train])

    train_set = torch.utils.data.ConcatDataset([train_set] * train_extend_rate)
    # valid_set = torch.utils.data.ConcatDataset([valid_set]*train_extend_rate)

    train_queue = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.RandomSampler(train_set),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_set,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.RandomSampler(valid_set),
        pin_memory=True,
        num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    eps_no_arch = args.eps_no_archs
    epochs = args.epochs

    model = Network(args.init_channels,
                    CIFAR_CLASSES,
                    args.layers,
                    criterion,
                    steps=args.inter_nodes,
                    multiplier=args.inter_nodes,
                    stem_multiplier=args.stem_multiplier,
                    residual_connection=args.residual_connection)
    model = nn.DataParallel(model)
    model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    network_params = []
    for k, v in model.named_parameters():
        if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
            network_params.append(v)

    optimizer = torch.optim.SGD(network_params,
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
                                   lr=args.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=args.arch_weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(epochs), eta_min=args.learning_rate_min)
    scheduler_a = torch.optim.lr_scheduler.StepLR(optimizer_a, 30, gamma=0.2)

    train_epoch_record = -1
    arch_train_count = 0
    prev_geno = ''
    prev_rank = None
    rank_geno = None
    result_geno = None
    arch_stable = 0
    best_arch_stable = 0

    for epoch in range(epochs):

        lr = scheduler.get_lr()[0]
        logging.info('Epoch: %d lr: %e', epoch, lr)
        epoch_start = time.time()
        # training
        if epoch < eps_no_arch:
            train_acc, train_obj = train(train_queue,
                                         valid_queue,
                                         model,
                                         network_params,
                                         criterion,
                                         optimizer,
                                         optimizer_a,
                                         lr,
                                         train_arch=False)
        else:
            ops, probs = parse_ops_without_none(model)

            parsed_ops = []
            for i in range(len(probs)):
                parsed_op = parse_network(model.module._steps, probs[i])
                parsed_ops.append(parsed_op)

            concat = range(2, 2 + model.module._steps)
            genotype = Genotype(
                normal=parsed_ops[0],
                normal_concat=concat,
                reduce=parsed_ops[1],
                reduce_concat=concat,
            )

            if str(prev_geno) != str(genotype):
                prev_geno = genotype
                logging.info(genotype)

            # early stopping

            stable_cond = True
            rank = []
            for i in range(len(probs)):
                rank_tmp = ranking(probs[i])
                rank.append(rank_tmp)

            if prev_rank != rank:
                stable_cond = False
                arch_stable = 0
                prev_rank = rank
                rank_geno = genotype
                logging.info('rank: %s', rank)

            if stable_cond:
                arch_stable += 1

            if arch_stable > best_arch_stable:
                best_arch_stable = arch_stable
                result_geno = rank_geno
                logging.info('arch_stable: %d', arch_stable)
                logging.info('best genotype: %s', rank_geno)

            if arch_stable >= args.stable_arch - 1:
                logging.info('stable genotype: %s', rank_geno)
                result_geno = rank_geno
                break

            train_acc, train_obj = train(train_queue,
                                         valid_queue,
                                         model,
                                         network_params,
                                         criterion,
                                         optimizer,
                                         optimizer_a,
                                         lr,
                                         train_arch=True)

            arch_train_count += 1

            scheduler_a.step()

        scheduler.step()
        logging.info('Train_acc %f, Objs: %e', train_acc, train_obj)
        epoch_duration = time.time() - epoch_start
        logging.info('Epoch time: %ds', epoch_duration)

        # validation
        if epoch >= eps_no_arch:
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('Valid_acc %f, Objs: %e', valid_acc, valid_obj)

        # # early arch training
        # if train_epoch_record == -1:
        #     if train_acc > 70:
        #         arch_train_num = args.epochs - args.eps_no_archs
        #         eps_no_arch = 0
        #         train_epoch_record = epoch
        # else:
        #     if epoch >= train_epoch_record + arch_train_num:
        #         break

        utils.save(model, os.path.join(save_dir, 'weights.pt'))

    # last geno parser
    ops, probs = parse_ops_without_none(model)

    parsed_ops = []
    for i in range(len(probs)):
        parsed_op = parse_network(model.module._steps, probs[i])
        parsed_ops.append(parsed_op)

    concat = range(2, 2 + model.module._steps)
    genotype = Genotype(
        normal=parsed_ops[0],
        normal_concat=concat,
        reduce=parsed_ops[1],
        reduce_concat=concat,
    )
    logging.info('Last geno: %s', genotype)

    if result_geno == None:
        result_geno = genotype

    return result_geno, best_arch_stable
Beispiel #7
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    # torch.autograd.set_detect_anomaly(True)
    logging.info("args = %s", args)
    #  prepare dataset
    if args.cifar100:
        train_transform, test_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, test_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        test_data = dset.CIFAR10(root=args.tmp_data_dir,
                                 train=False,
                                 download=True,
                                 transform=test_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             pin_memory=True,
                                             num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    path_num = sum(1 for i in range(args.nodes) for n in range(2 + i))
    switches = []
    for i in range(path_num):
        switches.append([True for j in range(len(PRIMITIVES))])
        if args.drop_none:
            switches[i][0] = False  # switch off zero operator
        if args.drop_skip:
            switches[i][3] = False  # switch off identity operator

    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [2, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) == 3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
    for sp in range(len(num_to_keep)):

        # if sp == len(num_to_keep)-1: # switch on zero operator in the last stage
        #     for i in range(path_num):
        #         switches_normal[i][0]=True
        #     for i in range(path_num):
        #         switches_reduce[i][0]=True

        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion,
                        steps=args.nodes,
                        multiplier=args.nodes,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        p=float(drop_rate[sp]))
        model = nn.DataParallel(model)
        # print(model)

        # if sp==0:
        #     utils.save(model, os.path.join(args.save, 'cell_weights.pt')) # keep initial weights
        # else:
        #     utils.load(model.module.cells, os.path.join(args.save, 'cell_weights.pt')) # strict=False

        # print('copying weight....')
        # state_dict = torch.load(os.path.join(args.save, 'cell_weights.pt'))
        # for key in state_dict.keys():
        #     print(key)
        # for key in state_dict.keys():
        #     if 'm_ops' in key and 'op0' not in key:
        #         s = re.split('op\d', key)
        #         copy_key = s[0]+'op0'+s[1]
        #         state_dict[key] = state_dict[copy_key]
        #         print(key)
        # model.load_state_dict(state_dict)
        # print('done!')

        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        arch_params = []
        for k, v in model.named_parameters():
            if 'alpha' in k:
                print(k)
                arch_params.append(v)
            else:
                network_params.append(v)
            # if not (k.endswith('alphas_normal_source') or k.endswith('alphas_reduce')):
            #     network_params.append(v)

        optimizer = torch.optim.SGD(network_params,
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        optimizer_a = torch.optim.Adam(arch_params,
                                       lr=args.arch_learning_rate,
                                       betas=(0.5, 0.999),
                                       weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):  #epochs
            scheduler.step()
            lr = scheduler.get_lr()[0]  #args.learning_rate#
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.module.p = float(
                    drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=False,
                                             train_weight=True)
            elif epoch < epochs:
                model.module.p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=True,
                                             train_weight=True)
            else:  # train arch only
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=True,
                                             train_weight=False)

            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            # if epochs - epoch < 5:
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('Valid_acc %f', valid_acc)
            test_acc, test_obj = infer(test_queue, model, criterion)
            logging.info('Test_acc %f', test_acc)

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement.
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = model.module.arch_parameters()

        # n = 3
        # start = 2
        # weightsn2 = F.softmax(arch_param[2][0:2], dim=-1)
        # weightsr2 = F.softmax(arch_param[3][0:2], dim=-1)
        weightsn2 = F.sigmoid(arch_param[2])
        weightsr2 = F.sigmoid(arch_param[3])
        # for i in range(args.nodes-1):
        #     end = start + n
        #     tn2 = F.softmax(arch_param[2][start:end], dim=-1)
        #     tr2 = F.softmax(arch_param[3][start:end], dim=-1)
        #     start = end
        #     n += 1
        #     weightsn2 = torch.cat([weightsn2, tn2],dim=0)
        #     weightsr2 = torch.cat([weightsr2, tr2],dim=0)
        weightsn2 = weightsn2.data.cpu().numpy()
        weightsr2 = weightsr2.data.cpu().numpy()

        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
        for i in range(path_num):
            normal_prob[i] = normal_prob[i] * weightsn2[i]
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                # for the last stage, drop all Zero operations
                drop = get_min_k_no_zero(normal_prob[i, :], idxs,
                                         num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(path_num):
            reduce_prob[i] = reduce_prob[i] * weightsr2[i]
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs,
                                         num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)

        if sp == len(num_to_keep) - 1:

            # n = 3
            # start = 2
            # weightsn2 = F.softmax(arch_param[2][0:2], dim=-1)
            # weightsr2 = F.softmax(arch_param[3][0:2], dim=-1)
            weightsn2 = F.sigmoid(arch_param[2])
            weightsr2 = F.sigmoid(arch_param[3])
            # for i in range(args.nodes-1):
            #     end = start + n
            #     tn2 = F.softmax(arch_param[2][start:end], dim=-1)
            #     tr2 = F.softmax(arch_param[3][start:end], dim=-1)
            #     start = end
            #     n += 1
            #     weightsn2 = torch.cat([weightsn2, tn2],dim=0)
            #     weightsr2 = torch.cat([weightsr2, tr2],dim=0)
            weightsn2 = weightsn2.data.cpu().numpy()
            weightsr2 = weightsr2.data.cpu().numpy()

            arch_param = model.module.arch_parameters()
            normal_prob = F.softmax(arch_param[0],
                                    dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1],
                                    dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(path_num)]
            reduce_final = [0 for idx in range(path_num)]
            # remove all Zero operations
            for i in range(path_num):
                normal_prob[i] = normal_prob[i] * weightsn2[i]
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                reduce_prob[i] = reduce_prob[i] * weightsr2[i]
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])
            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(args.nodes - 1):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(path_num):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal,
                                     switches_reduce,
                                     steps=args.nodes)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = check_sk_number(switches_normal)
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal,
                                                     switches_normal_2,
                                                     normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal,
                                                      normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal,
                                         switches_reduce,
                                         steps=args.nodes)
                logging.info(genotype)
Beispiel #8
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []  #switch开关,标记path中哪个op是开放状态
    for i in range(14):  # 一个Cell中有4个计算Node,2个输入Node,因此总的path为14
        switches.append([True for j in range(len(PRIMITIVES))
                         ])  #每个path上有len(PRIMITIVES)个op,初试状态是全部都用
    switches_normal = copy.deepcopy(switches)  #normal cell 中op的开关状态
    switches_reduce = copy.deepcopy(switches)  #reduce cell 中op的开关状态
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:  #默认参数为0
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:  #传进去两个args.add_layers
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) == 3:  #传进去三个参数
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
    for sp in range(len(num_to_keep)):  #训练分为3个阶段
        # args.init_channels 默认16,即网络的head输出channels为16,通过设置args.add_width可以改变这个输出channels
        # args.layers 默认5,3个normal cell + 2个reduce cell,第二阶段layer为11,第三阶段为17,这通过add_layers调节
        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        p=float(drop_rate[sp]))
        model = nn.DataParallel(model)
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []  # 保存网络权重
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal')
                    or k.endswith('alphas_reduce')):
                network_params.append(v)
        optimizer = torch.optim.SGD(
            network_params,  # 负责更新网络权重
            args.learning_rate,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(
            model.module.arch_parameters(),  # 负责更新网络架构参数
            lr=args.arch_learning_rate,
            betas=(0.5, 0.999),
            weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):  # 训练epoch,默认25,每个阶段训练25个epoch
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:  #前eps_no_arch(10) 训练网络权重
                model.module.p = float(
                    drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=False)
            else:  # 训练网络架构参数
                model.module.p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:  # epochs=25,即最后5个epoch,在验证集上验证下
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(model, os.path.join(args.save,
                                       'weights.pt'))  #一个阶段的训练结束后,保存下模型
        # 一个阶段训练完后,丢弃一些Op
        print('------Dropping %d paths------' %
              num_to_drop[sp])  # num_to_drop=[3,2,2],共有8个候选op,依次丢弃3,2,2,最后剩下一个
        # Save switches info for s-c refinement.
        if sp == len(num_to_keep
                     ) - 1:  # num_to_keep=[5,3,1], 只有在sp=2时,即最后一个训练阶段,才执行下面的语句
            switches_normal_2 = copy.deepcopy(
                switches_normal)  #此时每条path中还有3个op,后面要删除掉2个,剩下最后一个
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        # 放弃拥有低概率的op
        arch_param = model.module.arch_parameters()  #获取结构参数
        # 处理arch_normal
        normal_prob = F.softmax(
            arch_param[0],
            dim=sm_dim).data.cpu().numpy()  # 计算arch_normal的softmax
        for i in range(14):  #一个Cell中共有14条path
            idxs = []  #记录每条path上选择的op的索引
            for j in range(len(PRIMITIVES)):  # 遍历每条path上的op
                if switches_normal[i][j]:  #如果为True,即选择它
                    idxs.append(j)  #idxs中有3个元素
            if sp == len(num_to_keep) - 1:  # 最后一个训练阶段
                # for the last stage, drop all Zero operations
                # 对于最后一个训练阶段,丢弃所有 Zero 操作
                drop = get_min_k_no_zero(
                    normal_prob[i, :], idxs,
                    num_to_drop[sp])  #最后一个阶段num_to_drop[2]=1
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][
                    idxs[idx]] = False  #将概率最低的k个op关闭,注意此处更新了switches_normal
        # 处理arch_reduce
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs,
                                         num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False  #注意此处更新了switches_reduce
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)

        if sp == len(num_to_keep) - 1:  #最后一个阶段
            arch_param = model.module.arch_parameters()
            normal_prob = F.softmax(arch_param[0],
                                    dim=sm_dim).data.cpu().numpy()  #计算各个op概率
            reduce_prob = F.softmax(arch_param[1],
                                    dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]  #记录每条path上选择的op的索引
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][
                        0] == True:  #如果Zero操作被选择了将其概率置为0,在最后一个阶段训练完成后,每条path还剩3个op
                    normal_prob[i][
                        0] = 0  #如果我们在第3阶段后,还有Zero operations,将其对应的概率置0
                normal_final[i] = max(normal_prob[i])  #记录第i条path上选择的op的概率的最大值

                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])
            # Generate Architecture, similar to DARTS
            # 为每个计算节点选择两条输入path,由于0节点的输入是固定的,因此不用选
            # 1Node候选path数3,2Node候选path数4,3Node候选path数5
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n),
                                key=lambda x: tbsn[x])  #从候选的path中选择两条
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)

                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            # 设置switches,对于没有选择的path,将其上的op全部关掉
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only) 约束跳跃链接
            logging.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            # 生成不同数量skip-connect operations 的基因
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = check_sk_number(switches_normal)
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:  #删除多余的skip-connection
                    normal_prob = delete_min_sk_prob(switches_normal,
                                                     switches_normal_2,
                                                     normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal,
                                                      normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)
Beispiel #9
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True  #找寻特定卷积算法
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=False,
                                   transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=False,
                                  transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)  #pinmemary 锁存固定,高端设备玩主

    #************************************************************************************************#
    #第二阶段:配置网络逻辑层
    # criterion = nn.CrossEntropyLoss()       #可以选择特定的train阶段损失函数
    # criterion = criterion.cuda()

    #L0-1损失函数
    criterion_train = ConvSeparateLoss(
        weight=args.aux_loss_weight
    ) if args.sep_loss == 'l2' else TriSeparateLoss(
        weight=args.aux_loss_weight)
    criterion_val = nn.CrossEntropyLoss()
    criterion_train = criterion_train.cuda()
    criterion_val = nn.CrossEntropyLoss().cuda()
    switches = []  #操作淘汰标志
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [2, 2, 2]  #操作的淘汰数
    if len(args.add_width) == args.stages:
        add_width = args.add_width
    else:
        add_width = [[0, 16], [0, 8, 16]][args.stages - 2]  #add_width初始
    if len(args.add_layers) == args.stages:
        add_layers = args.add_layers
    else:
        add_layers = [[0, 7], [0, 6, 12]][args.stages - 2]  #add_layers初始
    if len(args.dropout_rate) == args.stages:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0] * args.stages  #dropout逻辑

    eps_no_archs = [args.noarc] * args.stages  #前n个epoch只更新逻辑参数

    if len(args.sample) == args.stages:
        sample = args.sample
    else:
        sample = [[4, 8], [4, 4, 4]][args.stages - 2]
    epochs = [25, 25, 25]

    #***************************************************************************************#
    #第三阶段:训练逻辑实现层#

    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion_val,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        p=float(drop_rate[sp]),
                        K=int(sample[sp]),
                        use_baidu=args.use_baidu,
                        use_EN=args.use_EN)
        model = nn.DataParallel(model)  #多GPU并行
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        logging.info("layers=%d", args.layers + int(add_layers[sp]))
        logging.info("channels=%d", args.init_channels + int(add_width[sp]))
        logging.info("K=%d", int(sample[sp]))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')
                    or k.endswith('betas_reduce')
                    or k.endswith('betas_normal')):  #具体是啥
                network_params.append(v)
        optimizer = torch.optim.SGD(network_params,
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.module.arch_parameters(),
                                       lr=args.arch_learning_rate,
                                       betas=(0.5, 0.999),
                                       weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)

        sm_dim = -1
        # epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs[sp]):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.module.p = float(drop_rate[sp]) * (epochs[sp] - epoch -
                                                         1) / epochs[sp]  #!!!
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion_train,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=False)
            else:
                model.module.p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                model.module.update_p()
                train_acc, train_obj = train(train_queue,
                                             valid_queue,
                                             model,
                                             network_params,
                                             criterion_train,
                                             optimizer,
                                             optimizer_a,
                                             lr,
                                             train_arch=True)

            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)

            # print("beats",model.module.arch_parameters()[1])

            # validation
            if epochs[sp] - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion_val)
                logging.info('Valid_acc %f', valid_acc)
            # print("epoch=",epoch,'weights_normal=',model.module.weights_normal,'weights_reduce=',model.module.weights_reduce)
            # print('weights2_normal=',model.module.weights2_normal,'\n','weights2_reduce=',model.module.weights2_reduce)
            #/************************************************************/
            arch_normal = model.module.arch_parameters()[0]
            arch_reduce = model.module.arch_parameters()[1]
            betas_nor = model.module.weights2_normal
            betas_redu = model.module.weights2_reduce
            shengcheng(arch_normal, arch_reduce, switches_normal,
                       switches_reduce, betas_nor, betas_redu)
            #/***********************************************************/

        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])

        #************************************************************************************8

        # Save switches info for s-c refinement.
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)

        # drop operations with low architecture weights
        arch_param = model.module.arch_parameters()

        normal_prob = F.sigmoid(arch_param[0]).data.cpu().numpy()  ##化概率
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)

            # for the last stage, drop all Zero operations
            # drop1 = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp]) ###看函数???看理论
            drop2 = get_min_k(normal_prob[i, :], num_to_drop[sp])
            # if sp == len(num_to_keep) - 1:
            #     for idx in drop1:
            #         switches_normal[i][idxs[idx]] = False
            # else:
            for idx in drop2:
                switches_normal[i][idxs[idx]] = False  #不断地关掉无效操作,正则化方法
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)

        if args.use_baidu == False:
            reduce_prob = F.sigmoid(arch_param[1]).data.cpu().numpy()
            #reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
            for i in range(14):
                idxs = []
                for j in range(len(PRIMITIVES)):
                    if switches_reduce[i][j]:
                        idxs.append(j)
                if sp == len(num_to_keep) - 1:
                    drop = get_min_k_no_zero(reduce_prob[i, :], idxs,
                                             num_to_drop[sp])
                else:
                    drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
                for idx in drop:
                    switches_reduce[i][idxs[idx]] = False
            logging.info('switches_reduce = %s', switches_reduce)
            logging_switches(switches_reduce)
Beispiel #10
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        # train_transform, valid_transform = utils._data_transforms_cifar10(args)
        train_transform = transforms.Compose([transforms.Resize(32), 
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        valid_transform = transforms.Compose([transforms.Resize(32), 
        transforms.ToTensor(), 
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)

    label_dim = 10 
    image_size = 32
    # label preprocess
    onehot = torch.zeros(label_dim, label_dim)
    onehot = onehot.scatter_(1, torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]).view(label_dim, 1), 1).view(label_dim, label_dim, 1, 1)
    fill = torch.zeros([label_dim, label_dim, image_size, image_size])
    for i in range(label_dim):
        fill[i, i, :, :] = 1
        
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers)
    
    adversarial_loss = nn.MSELoss()
    adversarial_loss.cuda()

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []

    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) ==3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
        
    # gen = Generator(100)
    # gen.cuda()
    # gen.apply(weights_init)

    # logging.info("param size gen= %fMB", utils.count_parameters_in_MB(gen))

    # optimizer_gen = torch.optim.Adam(gen.parameters(), lr=args.lr, 
    #                     betas=(args.b1, args.b2))

    # sp = 0
    # disc = Network(args.init_channels + int(add_width[sp]), CIFAR_CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]))
    # disc = nn.DataParallel(disc)
    # disc = disc.cuda()
    # logging.info("param size disc= %fMB", utils.count_parameters_in_MB(disc))
    # network_params = []                                                         

    # for k, v in disc.named_parameters():
    #     if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
    #         network_params.append(v)   
        
    # optimizer_disc = torch.optim.SGD(
    #         network_params,
    #         args.learning_rate,
    #         momentum=args.momentum,
    #         weight_decay=args.weight_decay)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #         optimizer_disc, float(args.epochs), eta_min=args.learning_rate_min)

    # for epoch in range(100):

    #     logging.info('Epoch: %d', epoch)
    #     epoch_start = time.time()
    #     train_acc, train_obj = train_gan(train_queue, valid_queue, gen, disc, network_params, criterion, adversarial_loss, optimizer_gen, optimizer_disc, 0, 0, 0, 0, train_arch=True)
    #     epoch_duration = time.time() - epoch_start
    #     logging.info('Epoch time: %ds', epoch_duration)

    # # utils.save(disc, os.path.join(args.save, 'disc_dump.pt'))
    # utils.save(gen, os.path.join(args.save, 'gen_dump.pt'))
    
    for sp in range(len(num_to_keep)):

        gen = Generator(100)
        gen.cuda()

        model = Resnet18()
        model.cuda()

        logging.info("param size gen= %fMB", utils.count_parameters_in_MB(gen))
        logging.info("param size model= %fMB", utils.count_parameters_in_MB(model))

        optimizer_gen = torch.optim.Adam(gen.parameters(), lr=args.lr, 
                            betas=(args.b1, args.b2))

        optimizer_model = torch.optim.SGD(model.parameters(), lr=0.1,
                            momentum=0.9, weight_decay=5e-4)
        scheduler_model = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_model, T_max=200)

        sp = 0
        disc = Network(args.init_channels + int(add_width[sp]), CIFAR_CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]))
        disc = nn.DataParallel(disc)
        disc = disc.cuda()
        logging.info("param size disc= %fMB", utils.count_parameters_in_MB(disc))
        network_params = []                                                          

        for k, v in disc.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
                network_params.append(v)   
            
        # optimizer_disc = torch.optim.SGD(
        #         network_params,
        #         args.learning_rate,
        #         momentum=args.momentum,
        #         weight_decay=args.weight_decay)
        optimizer_disc = torch.optim.Adam(network_params, lr=args.lr, 
                            betas=(args.b1, args.b2))
        optimizer_a = torch.optim.Adam(disc.module.arch_parameters(),
                    lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer_disc, float(args.epochs), eta_min=args.learning_rate_min)

        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2

        # utils.load(disc, 'disc_dump.pt')
        # utils.load(gen, os.path.join(args.save, 'gen_dump.pt'))

        architect = Architect(gen, disc, model, network_params, criterion, adversarial_loss, CIFAR_CLASSES, args)


        for epoch in range(100):

            logging.info('Epoch: %d', epoch)
            epoch_start = time.time()
            train_acc, train_obj = train_gan(epoch, train_queue, valid_queue, gen, disc, network_params, criterion, adversarial_loss, optimizer_gen, optimizer_disc, 0, 0, 0, 0, train_arch=True)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)

        # for epoch in range(epochs):
        for epoch in range(0):

            scheduler.step()
            scheduler_model.step()
            lr_gen = args.lr
            lr_disc = args.learning_rate
            lr = scheduler.get_lr()[0]
            lr_model = scheduler_model.get_lr()[0]
            logging.info('Epoch: %d lr: %e lr_model: %e', epoch, lr, lr_model)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                disc.module.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                disc.module.update_p()
                train_acc, train_obj = train(train_queue, valid_queue, architect, gen, model, disc, network_params, criterion, adversarial_loss, optimizer_gen, optimizer_disc, optimizer_model, optimizer_a, lr, lr_model, lr_gen, lr_disc, train_arch=False)
            else:
                disc.module.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) 
                disc.module.update_p()                
                train_acc, train_obj = train(train_queue, valid_queue, architect, gen, model, disc, network_params, criterion, adversarial_loss, optimizer_gen, optimizer_disc, optimizer_model, optimizer_a, lr, lr_model, lr_gen, lr_disc, train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(disc, os.path.join(args.save, 'disc.pt'))
        utils.save(gen, os.path.join(args.save, 'gen.pt'))
        utils.save(model, os.path.join(args.save, 'model.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement. 
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = disc.module.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()        
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                # for the last stage, drop all Zero operations
                drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)
        
        if sp == len(num_to_keep) - 1:
            arch_param = disc.module.arch_parameters()
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])                
            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks                
                num_sk = check_sk_number(switches_normal)               
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal, normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)