Example #1
0
    def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(self.args)
            train_data = dset.CIFAR10(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(self.args)
            train_data = dset.CIFAR100(root=self.args.data, train=True, download=True, transform=train_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(self.args)
            train_data = dset.SVHN(root=self.args.data, split='train', download=True, transform=train_transform)

        num_train = len(train_data)
        indices = list(range(num_train))
        split = int(np.floor(self.args.train_portion * num_train))

        train_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
            pin_memory=True, num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data, batch_size=self.args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
            pin_memory=True, num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform
Example #2
0
def get_train_validation_loader(args):
    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))

    # train[0:split] as training data
    train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[:split])
    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.train_batch_size,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              sampler=train_sampler)

    # train[split:] as validation data
    valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(
        indices[split:])
    valid_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.valid_batch_size,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              sampler=valid_sampler)
    valid_queue.name = 'valid'

    return train_queue, train_sampler, valid_queue
Example #3
0
def get_test_loader(args):
    _, test_transform = utils._data_transforms_cifar10(args)
    test_data = dset.CIFAR10(root=args.data,
                             train=False,
                             download=True,
                             transform=test_transform)
    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.valid_batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=args.workers)
    test_queue.name = 'test'
    return test_queue
Example #4
0
    def get_train_val_loaders(self):
        if self.args.dataset == 'cifar10':
            train_transform, valid_transform = utils._data_transforms_cifar10(
                self.args)
            train_data = dset.CIFAR10(root=self.args.data,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
            valid_data = dset.CIFAR10(root=self.args.data,
                                      train=False,
                                      download=True,
                                      transform=valid_transform)
        elif self.args.dataset == 'cifar100':
            train_transform, valid_transform = utils._data_transforms_cifar100(
                self.args)
            train_data = dset.CIFAR100(root=self.args.data,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
            valid_data = dset.CIFAR100(root=self.args.data,
                                       train=False,
                                       download=True,
                                       transform=valid_transform)
        elif self.args.dataset == 'svhn':
            train_transform, valid_transform = utils._data_transforms_svhn(
                self.args)
            train_data = dset.SVHN(root=self.args.data,
                                   split='train',
                                   download=True,
                                   transform=train_transform)
            valid_data = dset.SVHN(root=self.args.data,
                                   split='test',
                                   download=True,
                                   transform=valid_transform)

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=self.args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            valid_data,
            batch_size=self.args.batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=2)

        return train_queue, valid_queue, train_transform, valid_transform
Example #5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                    args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob * 0 / args.epochs
    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 32, 32), ),
                            verbose=False)
    logging.info('flops = %fM', flops / 1e6)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    model = model.cuda()
    utils.load(model, args.model_path)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    _, test_transform = utils._data_transforms_cifar10(args)
    test_data = dset.CIFAR10(root=args.data,
                             train=False,
                             download=True,
                             transform=test_transform)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    model.drop_path_prob = args.drop_path_prob
    with torch.no_grad():
        test_acc, test_obj = infer(test_queue, model, criterion)
    logging.info('test_acc %f', test_acc)
Example #6
0
def main():
    wandb.init(project="automl-gradient-based-nas",
               name="GDAS-" + "Opt: " + str(args.optimization) + "Search: " +
               str(args.arch_search_method),
               config=args,
               entity="automl")

    wandb.config.update(args)  # adds all of the arguments as config variables

    global is_multi_gpu

    gpus = [int(i) for i in args.gpu.split(',')]
    logging.info('gpus = %s' % gpus)
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %s' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    # default: args.init_channels = 16, CIFAR_CLASSES = 10, args.layers = 8
    if args.arch_search_method == "DARTS":
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        criterion)
    elif args.arch_search_method == "GDAS":
        model = Network_GumbelSoftmax(args.init_channels, CIFAR_CLASSES,
                                      args.layers, criterion)
    else:
        raise Exception("search space does not exist!")

    if len(gpus) > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
        is_multi_gpu = True

    model.cuda()

    wandb.watch(model)

    arch_parameters = model.module.arch_parameters(
    ) if is_multi_gpu else model.arch_parameters()
    arch_params = list(map(id, arch_parameters))

    parameters = model.module.parameters(
    ) if is_multi_gpu else model.parameters()
    weight_params = filter(lambda p: id(p) not in arch_params, parameters)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        weight_params,  # model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # will cost time to download the data
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))  # split index

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, criterion, args)

    best_accuracy = 0

    table = wandb.Table(columns=["Epoch", "Searched Architecture"])

    for epoch in range(args.epochs):
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)
        genotype = model.module.genotype() if is_multi_gpu else model.genotype(
        )
        logging.info('genotype = %s', genotype)
        wandb.log({"genotype": str(genotype)}, step=epoch)

        table.add_data(str(epoch), str(genotype))
        wandb.log({"Searched Architecture": table})

        print(
            F.softmax(model.module.alphas_normal
                      if is_multi_gpu else model.alphas_normal,
                      dim=-1))
        print(
            F.softmax(model.module.alphas_reduce
                      if is_multi_gpu else model.alphas_reduce,
                      dim=-1))

        # training
        train_acc, train_obj = train(epoch, train_queue, valid_queue, model,
                                     architect, criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)
        wandb.log({"searching_train_acc": train_acc, "epoch": epoch})

        # validation
        with torch.no_grad():
            valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)
        wandb.log({"searching_valid_acc": valid_acc, "epoch": epoch})

        scheduler.step()

        if valid_acc > best_accuracy:
            wandb.run.summary["best_valid_accuracy"] = valid_acc
            best_accuracy = valid_acc

        # utils.save(model, os.path.join(args.save, 'weights.pt'))
        utils.save(model, os.path.join(wandb.run.dir, 'weights.pt'))
Example #7
0
def main():
    if not torch.cuda.is_available():
        print('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print("args = %s", args)
    print("unparsed args = %s", unparsed)

    # prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=False, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=False, transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=False, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=False, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)


    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(
        args.init_channels, CIFAR_CLASSES, args.layers
    )
    supernet.cuda()

    ckpt = torch.load(args.load_at)
    print(args.load_at)
    supernet.load_state_dict(ckpt)
    supernet.generate_share_alphas()

    alphas = supernet.cells[0].ops_alphas
    print(alphas)
    out_dir = args.save + '{}/eval_out/{}'.format(args.load_at.split('/')[2], args.seed)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    torch.save(alphas, os.path.join(out_dir, 'alphas.pt'))
    with open(os.path.join(out_dir, 'alphas.txt'), 'w') as f:
        for i in alphas.cpu().detach().numpy():
            for j in i:
                f.write('{:d}'.format(int(j)))
            f.write('\n')

    # Getting subnet according to sample alpha
    subnet = supernet.get_sub_net(alphas)

    init_valid_acc, _ = infer(valid_queue, subnet, criterion)
    print('Initial Valid Acc {:.2f}'.format(init_valid_acc))

    if args.fine_tune:
        if args.cifar100:
            weight_decay = 5e-4
        else:
            weight_decay = 3e-4

        # Fine tuning whole network:
        subnet = supernet.get_sub_net(alphas)
        optimizer = torch.optim.SGD(
            subnet.parameters(),
            args.finetune_lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
        )

        for epoch in range(args.epochs):
            # scheduler.step()
            print('epoch {} lr {:.4f}'.format(epoch, args.finetune_lr))

            train_acc, _ = train(train_queue, subnet, criterion, optimizer)
            print('train_acc {:.2f}'.format(train_acc))

            whole_valid_acc, _ = infer(valid_queue, subnet, criterion)
            print('valid_acc after whole fine-tune {:.2f}'.format(whole_valid_acc))

            fly_whole_valid_acc, _ = infer(valid_queue, subnet, criterion, use_fly_bn=False)
            print('valid_acc after whole fine-tune {:.2f}'.format(fly_whole_valid_acc))

        # Fine-tuning only classifier:
        subnet = supernet.get_sub_net(alphas)
            # Freezing other weights except classifier:
        for name, param in subnet.named_parameters():
            if not 'classifier' in name:
                param.requires_grad_(requires_grad=False)

        optimizer = torch.optim.SGD(
            subnet.classifier.parameters(),
            args.finetune_lr,
            momentum=args.momentum,
            weight_decay=weight_decay,
        )

        for epoch in range(args.epochs):
            # scheduler.step()
            print('epoch {} lr {:.4f}'.format(epoch, args.finetune_lr))

            train_acc, _ = train(train_queue, subnet, criterion, optimizer)
            print('train_acc {:.2f}'.format(train_acc))

            part_valid_acc, _ = infer(valid_queue, subnet, criterion)
            print('valid_acc after fine-tuning classifier {:.2f}'.format(part_valid_acc))

            fly_part_valid_acc, _ = infer(valid_queue, subnet, criterion, use_fly_bn=False)
            print('valid_acc after fine-tuning classifier {:.2f}'.format(fly_part_valid_acc))

        with open(os.path.join(out_dir, 'results.txt'), 'w') as f:
            f.write('-'.join([str(init_valid_acc), str(whole_valid_acc),
                    str(fly_whole_valid_acc), str(part_valid_acc), str(fly_part_valid_acc)]))

    if not args.fine_tune:
        with open(os.path.join(out_dir, 'results.txt'), 'w') as f:
            f.write(str(init_valid_acc))
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.gpu != -1:
        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)
        torch.cuda.set_device(args.gpu)
        cudnn.benchmark = True
        cudnn.enabled = True
        torch.cuda.manual_seed(args.seed)
        logging.info('gpu device = %d' % args.gpu)
    else:
        logging.info('using cpu')

    if args.dyno_schedule:
        args.threshold_divider = np.exp(-np.log(args.threshold_multiplier) *
                                        args.schedfreq)
        print(
            args.threshold_divider, -np.log(args.threshold_multiplier) /
            np.log(args.threshold_divider))
    if args.dyno_split:
        args.train_portion = 1 - 1 / (1 + args.schedfreq)

    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    if args.gpu != -1:
        criterion = criterion.cuda()
    model = Network(args.init_channels,
                    CIFAR_CLASSES,
                    args.layers,
                    criterion,
                    args.rho,
                    args.crb,
                    args.epochs,
                    args.gpu,
                    ewma=args.ewma,
                    reg=args.reg)
    if args.gpu != -1:
        model = model.cuda()
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    datapath = os.path.join(utils.get_dir(), args.data)
    if args.task == "CIFAR100cf":
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
        train_data = utils.CIFAR100C2F(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        num_train = len(train_data)
        indices = list(range(num_train))

        split = int(np.floor(args.train_portion * len(indices)))

        orig_num_train = len(indices[:split])
        orig_num_valid = len(indices[split:num_train])

        train_indices = train_data.filter_by_fine(args.train_filter,
                                                  indices[:split])
        valid_indices = train_data.filter_by_fine(args.valid_filter,
                                                  indices[split:num_train])

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=utils.FillingSubsetRandomSampler(train_indices,
                                                     orig_num_train,
                                                     reshuffle=True),
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=utils.FillingSubsetRandomSampler(valid_indices,
                                                     orig_num_valid,
                                                     reshuffle=True),
            pin_memory=True,
            num_workers=2)
        # TODO: extend each epoch or multiply number of epochs by 20%*args.class_filter
    elif args.task == "CIFAR100split":
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
        train_data = utils.CIFAR100C2F(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        if not args.evensplit:
            train_indices, valid_indices = train_data.split(args.train_portion)
        else:
            num_train = len(train_data)
            indices = list(range(num_train))

            split = int(np.floor(args.train_portion * num_train))

            train_indices = indices[:split]
            valid_indices = indices[split:num_train]

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_indices),
            pin_memory=True,
            num_workers=2)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                valid_indices),
            pin_memory=True,
            num_workers=2)
    else:
        if args.task == "CIFAR100":
            train_transform, valid_transform = utils._data_transforms_cifar100(
                args)
            train_data = dset.CIFAR100(root=datapath,
                                       train=True,
                                       download=True,
                                       transform=train_transform)
        else:
            train_transform, valid_transform = utils._data_transforms_cifar10(
                args)
            train_data = dset.CIFAR10(root=datapath,
                                      train=True,
                                      download=True,
                                      transform=train_transform)
        num_train = len(train_data)
        indices = list(range(num_train))

        split = int(np.floor(args.train_portion * num_train))

        train_indices = indices[:split]
        valid_indices = indices[split:num_train]

        train_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                train_indices),
            pin_memory=True,
            num_workers=4)

        valid_queue = torch.utils.data.DataLoader(
            train_data,
            batch_size=args.batch_size,
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
                valid_indices),
            pin_memory=True,
            num_workers=4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, int(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, args)

    loggers = {
        "train": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "val": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "infer": {
            "loss": [],
            "acc": [],
            "step": []
        },
        "ath": {
            "threshold": [],
            "step": []
        },
        "astep": [],
        "zustep": []
    }

    alpha_threshold = args.init_alpha_threshold
    alpha_counter = 0
    ewma = -1

    for epoch in range(args.epochs):
        scheduler.step()
        lr = scheduler.get_last_lr()[0]

        logging.info('epoch %d lr %e', epoch, lr)

        genotype = model.genotype()
        logging.info('genotype = %s', genotype)
        if args.ckpt_interval > 0 and epoch > 0 and (
                epoch) % args.ckpt_interval == 0:
            logging.info('checkpointing genotype')
            os.mkdir(os.path.join(args.save, 'genotypes', str(epoch)))
            with open(
                    os.path.join(args.save, 'genotypes', str(epoch),
                                 'genotype.txt'), "w") as f:
                f.write(str(genotype))

        print(model.activate(model.alphas_normal))
        print(model.activate(model.alphas_reduce))

        # training
        train_acc, train_obj, alpha_threshold, alpha_counter, ewma = train(
            train_queue, valid_queue, model, architect, criterion, optimizer,
            loggers, alpha_threshold, alpha_counter, ewma, args)
        logging.info('train_acc %f', train_acc)

        # validation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        utils.log_loss(loggers["infer"], valid_obj, valid_acc, model.clock)
        logging.info('valid_acc %f', valid_acc)

        utils.plot_loss_acc(loggers, args.save)

        utils.save_file(recoder=model.alphas_normal_history,
                        path=os.path.join(args.save, 'Normalalpha'),
                        steps=loggers["train"]["step"])
        utils.save_file(recoder=model.alphas_reduce_history,
                        path=os.path.join(args.save, 'Reducealpha'),
                        steps=loggers["train"]["step"])

        utils.plot_FI(loggers["train"]["step"], model.FI_history, args.save,
                      "FI", loggers["ath"], loggers['astep'])
        utils.plot_FI(loggers["train"]["step"], model.FI_ewma_history,
                      args.save, "FI_ewma", loggers["ath"], loggers['astep'])

        utils.save(model, os.path.join(args.save, 'weights.pt'))

    genotype = model.genotype()
    logging.info('genotype = %s', genotype)

    f = open(os.path.join(args.save, 'genotype.txt'), "w")
    f.write(str(genotype))
    f.close()