Ejemplo n.º 1
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(args.seed)
    logging.info('GPU device = %d' % args.gpu)
    logging.info("args = %s", args)
    #  prepare dataset
    train_transform, valid_transform = utils.data_transforms(args.dataset,args.cutout,args.cutout_length)
    if args.dataset == "CIFAR100":
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == "CIFAR10":
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
    elif args.dataset == 'mit67':
        dset_cls = dset.ImageFolder
        data_path = '%s/MIT67/train' % args.tmp_data_dir  # 'data/MIT67/train'
        val_path = '%s/MIT67/test' % args.tmp_data_dir  # 'data/MIT67/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)
    elif args.dataset == 'sport8':
        dset_cls = dset.ImageFolder
        data_path = '%s/Sport8/train' % args.tmp_data_dir  # 'data/Sport8/train'
        val_path = '%s/Sport8/test' % args.tmp_data_dir  # 'data/Sport8/val'
        train_data = dset_cls(root=data_path, transform=train_transform)
        valid_data = dset_cls(root=val_path, transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))
    random.shuffle(indices)
    
    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=args.workers)
    
    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 3, 6]
    if len(args.dropout_rate) ==3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.0, 0.0, 0.0]
    eps_no_archs = [10, 10, 10]
    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]), CLASSES, args.layers + int(add_layers[sp]), criterion, switches_normal=switches_normal, switches_reduce=switches_reduce, p=float(drop_rate[sp]), largemode=args.dataset in utils.LARGE_DATASETS)
        
        model = model.cuda()
        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
        network_params = []
        for k, v in model.named_parameters():
            if not (k.endswith('alphas_normal') or k.endswith('alphas_reduce')):
                network_params.append(v)       
        optimizer = torch.optim.SGD(
                network_params,
                args.learning_rate,
                momentum=args.momentum,
                weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                    lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, float(args.epochs), eta_min=args.learning_rate_min)
        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            scheduler.step()
            lr = scheduler.get_lr()[0]
            logging.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.update_p()
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=False)
            else:
                model.p = float(drop_rate[sp]) * np.exp(-(epoch - eps_no_arch) * scale_factor) 
                model.update_p()                
                train_acc, train_obj = train(train_queue, valid_queue, model, network_params, criterion, optimizer, optimizer_a, lr, train_arch=True)
            logging.info('Train_acc %f', train_acc)
            epoch_duration = time.time() - epoch_start
            logging.info('Epoch time: %ds', epoch_duration)
            # validation
            if epochs - epoch < 5:
                valid_acc, valid_obj = infer(valid_queue, model, criterion)
                logging.info('Valid_acc %f', valid_acc)
        utils.save(model, os.path.join(args.save, 'weights.pt'))
        print('------Dropping %d paths------' % num_to_drop[sp])
        # Save switches info for s-c refinement. 
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        # drop operations with low architecture weights
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()        
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_normal[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(normal_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(normal_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_normal[i][idxs[idx]] = False
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()
        for i in range(14):
            idxs = []
            for j in range(len(PRIMITIVES)):
                if switches_reduce[i][j]:
                    idxs.append(j)
            if sp == len(num_to_keep) - 1:
                drop = get_min_k_no_zero(reduce_prob[i, :], idxs, num_to_drop[sp])
            else:
                drop = get_min_k(reduce_prob[i, :], num_to_drop[sp])
            for idx in drop:
                switches_reduce[i][idxs[idx]] = False
        logging.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal)
        logging.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce)
        
        if sp == len(num_to_keep) - 1:
            arch_param = model.arch_parameters()
            normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])                
            # Generate Architecture
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logging.info(genotype)
            ## restrict skipconnect (normal cell only)
            logging.info('Restricting skipconnect...')
            for sks in range(0, len(PRIMITIVES)+1):
                max_sk = len(PRIMITIVES) - sks
                num_sk = check_sk_number(switches_normal)
                if num_sk < max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal, switches_normal_2, normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal, normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logging.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logging.info(genotype)
    with open(args.save + "/best_genotype.txt", "w") as f:
        f.write(str(genotype))
Ejemplo n.º 2
0
def main():

    args = get_args()
    # get log
    args.save = '{}/search-{}'.format(args.save,
                                      time.strftime("%Y%m%d-%H%M%S"))
    tools.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logger = logging.getLogger('Train Search')
    logger.addHandler(fh)

    # monitor
    pymonitor = ProgressMonitor(logger)
    tbmonitor = TensorBoardMonitor(logger, args.save)
    monitors = [pymonitor, tbmonitor]

    if not torch.cuda.is_available():
        logger.info('no gpu device available')
        sys.exit(1)
    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    args.use_cuda = args.gpus > 0 and torch.cuda.is_available()
    args.device = torch.device('cuda:0' if args.use_cuda else 'cpu')
    if args.use_cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    setting = {k: v for k, v in args._get_kwargs()}
    logger.info(setting)
    with open(os.path.join(args.save, "args.yaml"),
              "w") as yaml_file:  # dump experiment config
        yaml.dump(args, yaml_file)

    if args.cifar100:
        CIFAR_CLASSES = 100
        data_folder = 'cifar-100-python'
    else:
        CIFAR_CLASSES = 10
        data_folder = 'cifar-10-batches-py'

    #  prepare dataset
    if args.cifar100:
        train_transform, valid_transform = tools._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = tools._data_transforms_cifar10(args)

    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=True,
                                   transform=train_transform)
        vaild_ata = dset.CIFAR100(root=args.tmp_data_dir,
                                  train=False,
                                  download=False,
                                  transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=True,
                                  transform=train_transform)
        vaild_ata = dset.CIFAR10(root=args.tmp_data_dir,
                                 train=False,
                                 download=False,
                                 transform=valid_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=args.workers)

    valLoader = torch.utils.data.DataLoader(vaild_ata,
                                            batch_size=args.batch_size,
                                            pin_memory=True,
                                            num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(args.device)
    switches = []
    for i in range(14):
        switches.append([True for j in range(len(PRIMITIVES))])
    switches_normal = copy.deepcopy(switches)
    switches_reduce = copy.deepcopy(switches)
    # To be moved to args
    num_to_keep = [5, 3, 1]
    num_to_drop = [3, 2, 2]
    if len(args.add_width) == 3:
        add_width = args.add_width
    else:
        add_width = [0, 0, 0]
    if len(args.add_layers) == 3:
        add_layers = args.add_layers
    else:
        add_layers = [0, 6, 12]
    if len(args.dropout_rate) == 3:
        drop_rate = args.dropout_rate
    else:
        drop_rate = [0.1, 0.4, 0.7]
    eps_no_archs = [10, 10, 10]
    state_epochs = 0
    for sp in range(len(num_to_keep)):
        model = Network(args.init_channels + int(add_width[sp]),
                        CIFAR_CLASSES,
                        args.layers + int(add_layers[sp]),
                        criterion,
                        steps=args.nodes,
                        multiplier=args.multiplier,
                        stem_multiplier=args.stem_multiplier,
                        switches_normal=switches_normal,
                        switches_reduce=switches_reduce,
                        p=float(drop_rate[sp]))

        model = model.to(args.device)
        logger.info("stage:{} param size:{}MB".format(
            sp, tools.count_parameters_in_MB(model)))

        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.learning_rate,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
        optimizer_a = torch.optim.Adam(model.arch_parameters(),
                                       lr=args.arch_learning_rate,
                                       betas=(0.5, 0.999),
                                       weight_decay=args.arch_weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, float(args.epochs), eta_min=args.learning_rate_min)

        sm_dim = -1
        epochs = args.epochs
        eps_no_arch = eps_no_archs[sp]
        scale_factor = 0.2
        for epoch in range(epochs):
            lr = scheduler.get_lr()[0]
            logger.info('Epoch: %d lr: %e', epoch, lr)
            epoch_start = time.time()
            # training
            if epoch < eps_no_arch:
                model.p = float(drop_rate[sp]) * (epochs - epoch - 1) / epochs
                model.update_p()
                train_acc, train_obj = train(state_epochs + epoch,
                                             train_queue,
                                             valid_queue,
                                             model,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             args,
                                             monitors,
                                             logger,
                                             train_arch=False)
            else:
                model.p = float(drop_rate[sp]) * np.exp(
                    -(epoch - eps_no_arch) * scale_factor)
                model.update_p()
                train_acc, train_obj = train(state_epochs + epoch,
                                             train_queue,
                                             valid_queue,
                                             model,
                                             criterion,
                                             optimizer,
                                             optimizer_a,
                                             args,
                                             monitors,
                                             logger,
                                             train_arch=True)

            # validation
            valid_acc, valid_obj = infer(state_epochs + epoch, valLoader,
                                         model, criterion, args, monitors,
                                         logger)

            if epoch >= eps_no_arch:
                # 将本epoch的解析结果保存
                arch_param = model.arch_parameters()
                normal_prob = F.softmax(arch_param[0],
                                        dim=-1).data.cpu().numpy()
                reduce_prob = F.softmax(arch_param[1],
                                        dim=-1).data.cpu().numpy()
                logger.info('Genotypev: {}'.format(
                    parse_genotype(switches_normal.copy(),
                                   switches_reduce.copy(), normal_prob.copy(),
                                   reduce_prob.copy())))

            scheduler.step()

        tools.save(model,
                   os.path.join(args.save, 'state{}_weights.pt'.format(sp)))
        state_epochs += args.epochs

        # Save switches info for s-c refinement.
        if sp == len(num_to_keep) - 1:
            switches_normal_2 = copy.deepcopy(switches_normal)
            switches_reduce_2 = copy.deepcopy(switches_reduce)
        arch_param = model.arch_parameters()
        normal_prob = F.softmax(arch_param[0], dim=-1).data.cpu().numpy()
        reduce_prob = F.softmax(arch_param[1], dim=-1).data.cpu().numpy()

        logger.info('------Stage %d end!------' % sp)
        logger.info("normal: \n{}".format(normal_prob))
        logger.info("reduce: \n{}".format(reduce_prob))
        logger.info('Genotypev: {}'.format(
            parse_genotype(switches_normal.copy(), switches_reduce.copy(),
                           normal_prob.copy(), reduce_prob.copy())))

        # 根据最新的结构权重,旧的搜索空间,需要抛弃的数量,当前状态 来进行空间正则化
        switches_normal = update_switches(normal_prob.copy(),
                                          switches_normal, num_to_drop[sp], sp,
                                          len(num_to_keep))
        switches_reduce = update_switches(reduce_prob.copy(),
                                          switches_reduce, num_to_drop[sp], sp,
                                          len(num_to_keep))

        logger.info('------Dropping %d paths------' % num_to_drop[sp])
        logger.info('switches_normal = %s', switches_normal)
        logging_switches(switches_normal, logger)
        logger.info('switches_reduce = %s', switches_reduce)
        logging_switches(switches_reduce, logger)

        if sp == len(num_to_keep) - 1:
            # arch_param = model.arch_parameters()
            # normal_prob = F.softmax(arch_param[0], dim=sm_dim).data.cpu().numpy()
            # reduce_prob = F.softmax(arch_param[1], dim=sm_dim).data.cpu().numpy()
            normal_final = [0 for idx in range(14)]
            reduce_final = [0 for idx in range(14)]
            # remove all Zero operations
            for i in range(14):
                if switches_normal_2[i][0] == True:
                    normal_prob[i][0] = 0
                normal_final[i] = max(normal_prob[i])
                if switches_reduce_2[i][0] == True:
                    reduce_prob[i][0] = 0
                reduce_final[i] = max(reduce_prob[i])
            # Generate Architecture, similar to DARTS
            keep_normal = [0, 1]
            keep_reduce = [0, 1]
            n = 3
            start = 2
            for i in range(3):
                end = start + n
                tbsn = normal_final[start:end]
                tbsr = reduce_final[start:end]
                edge_n = sorted(range(n), key=lambda x: tbsn[x])
                keep_normal.append(edge_n[-1] + start)
                keep_normal.append(edge_n[-2] + start)
                edge_r = sorted(range(n), key=lambda x: tbsr[x])
                keep_reduce.append(edge_r[-1] + start)
                keep_reduce.append(edge_r[-2] + start)
                start = end
                n = n + 1
            # set switches according the ranking of arch parameters
            for i in range(14):
                if not i in keep_normal:
                    for j in range(len(PRIMITIVES)):
                        switches_normal[i][j] = False
                if not i in keep_reduce:
                    for j in range(len(PRIMITIVES)):
                        switches_reduce[i][j] = False
            # translate switches into genotype
            genotype = parse_network(switches_normal, switches_reduce)
            logger.info(genotype)
            ## restrict skipconnect (normal cell only)
            logger.info('Restricting skipconnect...')
            # generating genotypes with different numbers of skip-connect operations
            for sks in range(0, 9):
                max_sk = 8 - sks
                num_sk = check_sk_number(switches_normal)
                if not num_sk > max_sk:
                    continue
                while num_sk > max_sk:
                    normal_prob = delete_min_sk_prob(switches_normal,
                                                     switches_normal_2,
                                                     normal_prob)
                    switches_normal = keep_1_on(switches_normal_2, normal_prob)
                    switches_normal = keep_2_branches(switches_normal,
                                                      normal_prob)
                    num_sk = check_sk_number(switches_normal)
                logger.info('Number of skip-connect: %d', max_sk)
                genotype = parse_network(switches_normal, switches_reduce)
                logger.info(genotype)