Ejemplo n.º 1
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype_path = os.path.join(utils.get_dir(), os.path.split(args.model_path)[0], 'genotype.txt')
    if os.path.isfile(genotype_path):
        with open(genotype_path, "r") as f:
            geno_raw = f.read()
            genotype = eval(geno_raw)
    else:
        genoname = os.path.join(utils.get_dir(), os.path.split(args.model_path)[0], 'genoname.txt')
        if os.path.isfile(genoname):
            with open(genoname, "r") as f:
                args.arch = f.read()
            genotype = eval("genotypes.%s" % args.arch)
        else:
            genotype = eval("genotypes.BATH")
    model = Network(args.init_channels, 1, args.layers, args.auxiliary, genotype, input_channels=4)
    model = model.cuda()
    print(os.path.join(utils.get_dir(), args.model_path))
    utils.load(model, os.path.join(utils.get_dir(), args.model_path))

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.MSELoss()
    criterion = criterion.cuda()

    test_data_tne = utils.BathymetryDataset(args, "../29TNE.csv", root_dir="dataset/bathymetry/29TNE/dataset_29TNE",
                                            to_trim="/tmp/pbs.6233542.admin01/tmp_portugal/", to_filter=False)

    test_queue_tne = torch.utils.data.DataLoader(
        test_data_tne, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    model.drop_path_prob = args.drop_path_prob
    test_obj, targets, preds = infer(test_queue_tne, model, criterion, args.depth_normalization)
    logging.info('test_obj tne %f', test_obj)

    test_data_tne.write_results(targets, preds, os.path.join(args.save, 'tne_results.csv'))

    test_data_smd = utils.BathymetryDataset(args, "../29SMD.csv", root_dir="dataset/bathymetry/29SMD/dataset_29SMD",
                                            to_trim="/tmp/pbs.6233565.admin01/tmp_portugal/", to_filter=False)

    test_queue_smd = torch.utils.data.DataLoader(
        test_data_smd, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    test_obj, targets, preds = infer(test_queue_smd, model, criterion, args.depth_normalization)
    logging.info('test_obj smd %f', test_obj)

    test_data_smd.write_results(targets, preds, os.path.join(args.save, 'smd_results.csv'))
Ejemplo n.º 2
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU found!')
        sys.exit(1)
    
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    cudnn.enabled = True
    cudnn.benchmark = False
    cudnn.deterministic = True
    
    args.steps = int(np.ceil(50000 / args.batch_size)) * args.epochs
    logging.info("Args = %s", args)
    
    _, model_state_dict, epoch, step, optimizer_state_dict, best_acc_top1 = utils.load(args.output_dir)
    build_fn = get_builder(args.dataset)
    train_queue, valid_queue, model, train_criterion, eval_criterion, optimizer, scheduler = build_fn(model_state_dict, optimizer_state_dict, epoch=epoch-1)

    while epoch < args.epochs:
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        train_acc, train_obj, step = train(train_queue, model, optimizer, step, train_criterion)
        logging.info('train_acc %f', train_acc)
        valid_acc_top1, valid_obj = valid(valid_queue, model, eval_criterion)
        logging.info('valid_acc %f', valid_acc_top1)
        epoch += 1
        is_best = False
        if valid_acc_top1 > best_acc_top1:
            best_acc_top1 = valid_acc_top1
            is_best = True
        utils.save(args.output_dir, args, model, epoch, step, optimizer, best_acc_top1, is_best)
Ejemplo n.º 3
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                    args.auxiliary, genotype)
    model.drop_path_prob = args.drop_path_prob * 0 / args.epochs
    flops, params = profile(model,
                            inputs=(torch.randn(1, 3, 32, 32), ),
                            verbose=False)
    logging.info('flops = %fM', flops / 1e6)
    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    model = model.cuda()
    utils.load(model, args.model_path)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    _, test_transform = utils._data_transforms_cifar10(args)
    test_data = dset.CIFAR10(root=args.data,
                             train=False,
                             download=True,
                             transform=test_transform)

    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)

    model.drop_path_prob = args.drop_path_prob
    with torch.no_grad():
        test_acc, test_obj = infer(test_queue, model, criterion)
    logging.info('test_acc %f', test_acc)
def main():
    if is_wandb_used:
        wandb.init(project="automl-gradient-based-nas",
                   name="r" + str(args.run_id) + "-e" + str(args.epochs) +
                   "-lr" + str(args.learning_rate) + "-l(" +
                   str(args.lambda_train_regularizer) + "," +
                   str(args.lambda_valid_regularizer) + ")",
                   config=args,
                   entity="automl")

    global is_multi_gpu

    gpus = [int(i) for i in args.gpu.split(',')]
    logging.info('gpus = %s' % gpus)
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)

    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %s' % args.gpu)
    logging.info("args = %s", args)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    # default: args.init_channels = 16, CIFAR_CLASSES = 10, args.layers = 8
    if args.arch_search_method == "DARTS":
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        criterion)
    elif args.arch_search_method == "GDAS":
        model = Network_GumbelSoftmax(args.init_channels, CIFAR_CLASSES,
                                      args.layers, criterion)
    else:
        model = Network(args.init_channels, CIFAR_CLASSES, args.layers,
                        criterion)

    if len(gpus) > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
        is_multi_gpu = True

    model.cuda()
    if args.model_path != "saved_models":
        utils.load(model, args.model_path)

    arch_parameters = model.module.arch_parameters(
    ) if is_multi_gpu else model.arch_parameters()
    arch_params = list(map(id, arch_parameters))

    parameters = model.module.parameters(
    ) if is_multi_gpu else model.parameters()
    weight_params = filter(lambda p: id(p) not in arch_params, parameters)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(
        weight_params,  # model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)

    # will cost time to download the data
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)

    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(args.train_portion * num_train))  # split index

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size * len(gpus),
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True,
        num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size * len(gpus),
        sampler=torch.utils.data.sampler.SubsetRandomSampler(
            indices[split:num_train]),
        pin_memory=True,
        num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    architect = Architect(model, criterion, args)

    best_accuracy = 0
    best_accuracy_different_cnn_counts = dict()

    if is_wandb_used:
        table = wandb.Table(columns=["Epoch", "Searched Architecture"])

    for epoch in range(args.epochs):
        lr = scheduler.get_lr()[0]
        logging.info('epoch %d lr %e', epoch, lr)

        # training
        train_acc, train_obj, train_loss = train(epoch, train_queue,
                                                 valid_queue, model, architect,
                                                 criterion, optimizer, lr)
        logging.info('train_acc %f', train_acc)
        if is_wandb_used:
            wandb.log({"searching_train_acc": train_acc, "epoch": epoch})
            wandb.log({"searching_train_loss": train_loss, "epoch": epoch})

        # validation
        with torch.no_grad():
            valid_acc, valid_obj, valid_loss = infer(valid_queue, model,
                                                     criterion)
        logging.info('valid_acc %f', valid_acc)

        scheduler.step()

        if is_wandb_used:
            wandb.log({"searching_valid_acc": valid_acc, "epoch": epoch})
            wandb.log({"searching_valid_loss": valid_loss, "epoch": epoch})
            wandb.log({
                "search_train_valid_acc_gap": train_acc - valid_acc,
                "epoch": epoch
            })
            wandb.log({
                "search_train_valid_loss_gap": train_loss - valid_loss,
                "epoch": epoch
            })

        # save the structure
        genotype, normal_cnn_count, reduce_cnn_count = model.module.genotype(
        ) if is_multi_gpu else model.genotype()

        # early stopping
        if args.early_stopping == 1:
            if normal_cnn_count == 6 and reduce_cnn_count == 0:
                break

        print("(n:%d,r:%d)" % (normal_cnn_count, reduce_cnn_count))
        print(
            F.softmax(model.module.alphas_normal
                      if is_multi_gpu else model.alphas_normal,
                      dim=-1))
        print(
            F.softmax(model.module.alphas_reduce
                      if is_multi_gpu else model.alphas_reduce,
                      dim=-1))
        logging.info('genotype = %s', genotype)
        if is_wandb_used:
            wandb.log({"genotype": str(genotype)}, step=epoch - 1)
            table.add_data(str(epoch), str(genotype))
            wandb.log({"Searched Architecture": table})

            # save the cnn architecture according to the CNN count
            cnn_count = normal_cnn_count * 10 + reduce_cnn_count
            wandb.log({
                "searching_cnn_count(%s)" % cnn_count: valid_acc,
                "epoch": epoch
            })
            if cnn_count not in best_accuracy_different_cnn_counts.keys():
                best_accuracy_different_cnn_counts[cnn_count] = valid_acc
                summary_key_cnn_structure = "best_acc_for_cnn_structure(n:%d,r:%d)" % (
                    normal_cnn_count, reduce_cnn_count)
                wandb.run.summary[summary_key_cnn_structure] = valid_acc

                summary_key_best_cnn_structure = "epoch_of_best_acc_for_cnn_structure(n:%d,r:%d)" % (
                    normal_cnn_count, reduce_cnn_count)
                wandb.run.summary[summary_key_best_cnn_structure] = epoch
            else:
                if valid_acc > best_accuracy_different_cnn_counts[cnn_count]:
                    best_accuracy_different_cnn_counts[cnn_count] = valid_acc
                    summary_key_cnn_structure = "best_acc_for_cnn_structure(n:%d,r:%d)" % (
                        normal_cnn_count, reduce_cnn_count)
                    wandb.run.summary[summary_key_cnn_structure] = valid_acc

                    summary_key_best_cnn_structure = "epoch_of_best_acc_for_cnn_structure(n:%d,r:%d)" % (
                        normal_cnn_count, reduce_cnn_count)
                    wandb.run.summary[summary_key_best_cnn_structure] = epoch

            if valid_acc > best_accuracy:
                best_accuracy = valid_acc
                wandb.run.summary["best_valid_accuracy"] = valid_acc
                wandb.run.summary["epoch_of_best_accuracy"] = epoch
                utils.save(model, os.path.join(wandb.run.dir, 'weights.pt'))
Ejemplo n.º 5
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    if TORCH_VERSION.startswith('1'):
        torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    # load search configuration file holding the found architectures
    configuration = '_'.join([args.space, args.dataset])
    settings = '_'.join([str(args.search_dp), str(args.search_wd)])
    with open(args.archs_config_file, 'r') as f:
        cfg = yaml.load(f)
        arch = dict(cfg)[configuration][settings][args.search_task_id]

    print(arch)
    genotype = eval(arch)
    model = Network(args.init_channels, args.n_classes, args.layers,
                    args.auxiliary, genotype)
    if TORCH_VERSION.startswith('1'):
        model = model.to(device)
    else:
        model = model.cuda()

    if args.model_path is not None:
        utils.load(model, args.model_path, genotype)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    if TORCH_VERSION.startswith('1'):
        criterion = criterion.to(device)
    else:
        criterion = criterion.cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = CosineAnnealingLR(optimizer, float(args.epochs))

    train_queue, valid_queue, _, _ = helper.get_train_val_loaders()

    errors_dict = {
        'train_acc': [],
        'train_loss': [],
        'valid_acc': [],
        'valid_loss': []
    }

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        # training
        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)

        # evaluation
        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)

        # update the errors dictionary
        errors_dict['train_acc'].append(100 - train_acc)
        errors_dict['train_loss'].append(train_obj)
        errors_dict['valid_acc'].append(100 - valid_acc)
        errors_dict['valid_loss'].append(valid_obj)

    with codecs.open(os.path.join(
            args.save, 'errors_{}_{}.json'.format(args.search_task_id,
                                                  args.task_id)),
                     'w',
                     encoding='utf-8') as file:
        json.dump(errors_dict, file, separators=(',', ':'))

    utils.write_yaml_results_eval(args, args.results_test, 100 - valid_acc)
Ejemplo n.º 6
0
def main():
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    torch.cuda.set_device(args.gpu)
    torch.cuda.empty_cache()
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %d' % args.gpu)
    logging.info("args = %s", args)

    genotype_path = os.path.join(utils.get_dir(), os.path.split(args.model_path)[0], 'genotype.txt')
    print(genotype_path)
    if os.path.isfile(genotype_path):
        with open(genotype_path, "r") as f:
            geno_raw = f.read()
            genotype = eval(geno_raw)
    else:
        genoname = os.path.join(utils.get_dir(), os.path.split(args.model_path)[0], 'genoname.txt')
        if os.path.isfile(genoname):
            with open(genoname, "r") as f:
                args.arch = f.read()
            genotype = eval("genotypes.%s" % args.arch)
        else:
            genotype = eval("genotypes.ADMM")
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    model = model.cuda()
    utils.load(model, os.path.join(utils.get_dir(), args.model_path))

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()

    _, test_transform = utils._data_transforms_cifar10(args)
    datapath = os.path.join(utils.get_dir(), args.data)
    test_data = dset.CIFAR10(root=datapath, train=False, download=True, transform=test_transform)

    test_queue = torch.utils.data.DataLoader(
        test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    if args.task == "CIFAR100cf":
        _, test_transform = utils._data_transforms_cifar100(args)

        test_data = utils.CIFAR100C2F(root=datapath, train=False, download=True, transform=test_transform)

        test_indices = test_data.filter_by_fine(args.test_filter)

        test_queue = torch.utils.data.DataLoader(
            torch.utils.data.Subset(test_data, test_indices), batch_size=args.batch_size,
            shuffle=False, pin_memory=True, num_workers=2)

        # TODO: extend each epoch or multiply number of epochs by 20%*args.class_filter

    else:
        if args.task == "CIFAR100":
            _, test_transform = utils._data_transforms_cifar100(args)
            test_data = dset.CIFAR100(root=datapath, train=False, download=True, transform=test_transform)
        else:
            _, test_transform = utils._data_transforms_cifar10(args)
            test_data = dset.CIFAR10(root=datapath, train=False, download=True, transform=test_transform)

        test_queue = torch.utils.data.DataLoader(
            test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    model.drop_path_prob = args.drop_path_prob
    test_acc, test_obj = infer(test_queue, model, criterion)
    logging.info('test_acc %f', test_acc)
Ejemplo n.º 7
0
def segmentation(path):
    parser = argparse.ArgumentParser("BGB_dataset")
    parser.add_argument('--data',
                        type=str,
                        default='/home/zhangmingwei/srbrain/media',
                        help='location of the data for model')
    parser.add_argument('--output',
                        type=str,
                        default='/home/zhangmingwei/srbrain/media/output',
                        help='location of the output')
    parser.add_argument('runserver',
                        type=str,
                        default='0.0.0.0:8001',
                        help='batch size')
    parser.add_argument('0.0.0.0:8001',
                        type=str,
                        default='0.0.0.0:8001',
                        help='batch size')
    parser.add_argument('--data_folder_name',
                        type=str,
                        default='image',
                        help='data_folder_name')
    parser.add_argument('--target_folder_name',
                        type=str,
                        default='label',
                        help='target_folder_name')
    parser.add_argument('--input_size',
                        type=int,
                        default=512,
                        help='the size of the dataset')
    parser.add_argument('--nb_classes',
                        type=int,
                        default=2,
                        help='the classes of the dataset')
    parser.add_argument('--batch_size', type=int, default=1, help='batch size')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.025,
                        help='init learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=3e-4,
                        help='weight decay')
    parser.add_argument('--report_freq',
                        type=float,
                        default=50,
                        help='report frequency')
    parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='num of training epochs')
    parser.add_argument('--init_channels',
                        type=int,
                        default=12,
                        help='num of init channels')
    parser.add_argument('--layers',
                        type=int,
                        default=8,
                        help='total number of layers')
    parser.add_argument(
        '--model_path',
        type=str,
        default=
        '/home/zhangmingwei/NAS/NAS-RSI1/train-WHU_train-20200119-114304/weights.pt',
        help='path of pretrained model')
    parser.add_argument('--auxiliary',
                        action='store_true',
                        default=False,
                        help='use auxiliary tower')
    parser.add_argument('--drop_path_prob',
                        type=float,
                        default=0.2,
                        help='drop path probability')
    parser.add_argument('--save',
                        type=str,
                        default='DARTS',
                        help='experiment name')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument('--arch',
                        type=str,
                        default='DARTS_WHU',
                        help='which architecture to use')
    parser.add_argument('--grad_clip',
                        type=float,
                        default=5,
                        help='gradient clipping')
    args = parser.parse_args()
    if args.gpu == -1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(args.gpu))
    args.data = path
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, args.nb_classes, args.layers,
                    args.auxiliary, genotype)
    #model = Network(args)
    model = model.to(device)
    utils.load(model, args.model_path)

    test_data = MyDataset(args=args, subset='predict')
    test_queue = torch.utils.data.DataLoader(test_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=2)
    model.drop_path_prob = args.drop_path_prob
    result = predict(args, test_queue, model)

    return result
Ejemplo n.º 8
0
def main():
    wandb.init(
        project="automl-gradient-based-nas",
        name=str(args.arch) + "-lr" + str(args.learning_rate),
        config=args,
        entity="automl"
    )
    wandb.config.update(args)  # adds all of the arguments as config variables

    global is_multi_gpu
    if not torch.cuda.is_available():
        logging.info('no gpu device available')
        sys.exit(1)

    np.random.seed(args.seed)
    gpus = [int(i) for i in args.gpu.split(',')]
    logging.info('gpus = %s' % gpus)

    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info('gpu device = %s' % args.gpu)
    logging.info("args = %s", args)

    genotype = eval("genotypes.%s" % args.arch)
    model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
    if len(gpus) > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = nn.DataParallel(model)
        is_multi_gpu = True

    model.cuda()
    if args.model_path != "saved_models":
        utils.load(model, args.model_path)

    weight_params = model.module.parameters() if is_multi_gpu else model.parameters()

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
    wandb.run.summary["param_size"] = utils.count_parameters_in_MB(model)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    optimizer = torch.optim.SGD(
        weight_params,  # model.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay
    )

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs),
                                                           eta_min=args.learning_rate_min)

    best_accuracy = 0

    for epoch in range(args.epochs):
        scheduler.step()
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, criterion, optimizer)
        logging.info('train_acc %f', train_acc)
        wandb.log({"evaluation_train_acc": train_acc, 'epoch': epoch})

        valid_acc, valid_obj = infer(valid_queue, model, criterion)
        logging.info('valid_acc %f', valid_acc)
        wandb.log({"evaluation_valid_acc": valid_acc, 'epoch': epoch})

        if valid_acc > best_accuracy:
            wandb.run.summary["best_valid_accuracy"] = valid_acc
            wandb.run.summary["epoch_of_best_accuracy"] = epoch
            best_accuracy = valid_acc
            utils.save(model, os.path.join(wandb.run.dir, 'weights-best.pt'))

        utils.save(model, os.path.join(wandb.run.dir, 'weights.pt'))
Ejemplo n.º 9
0
def main():
    if args.load_path:
        args.save = Path(args.load_path) / 'eval-{}-{}'.format(
            args.save, time.strftime("%Y%m%d-%H%M%S"))
    else:
        args.save = Path('logs') / 'eval-{}-{}'.format(
            args.save, time.strftime("%Y%m%d-%H%M%S"))
    utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))

    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(args.save / 'log.txt')
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)

    model = eval(args.model)
    if args.gpu:
        model = model.cuda()

    if args.load_path:
        utils.load(model, os.path.join(args.load_path, 'weights.pt'))
        print("loaded")

    direct_model = model
    if args.gpu:
        model = torch.nn.DataParallel(model)

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    optimizer = torch.optim.SGD(model.parameters(),
                                args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    train_transform, valid_transform = utils._data_transforms_cifar10(args)
    train_data = dset.CIFAR10(root=args.data,
                              train=True,
                              download=True,
                              transform=train_transform)
    valid_data = dset.CIFAR10(root=args.data,
                              train=False,
                              download=True,
                              transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.num_workers)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.num_workers)

    if args.eval:
        direct_model.drop_path_prob = 0
        valid_acc, valid_obj = infer(valid_queue, model, args.gpu)
        logging.info('valid_acc %f', valid_acc)
        return

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs))

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step(epoch)
        logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
        direct_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

        train_acc, train_obj = train(train_queue, model, optimizer, args.gpu)
        logging.info('train_acc %f', train_acc)

        valid_acc, valid_obj = infer(valid_queue, model, args.gpu)
        logging.info('valid_acc %f', valid_acc)

        if epoch >= args.epochs - 50 or epoch % args.save_frequency == 0:
            utils.save(model.module,
                       os.path.join(args.save, f'weights_{epoch}.pt'))