示例#1
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--prefix', type=str, default='default')
    parser.add_argument('--checkpoint', type=str, default=None)
    parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet'])
    parser.add_argument('--learning_rate', type=float, default=1e-4)
    parser.add_argument('--lr_weight_decay', action='store_true', default=False)
    config = parser.parse_args()

    if config.dataset == 'ImageNet':
        import datasets.ImageNet as dataset
    elif config.dataset == 'SVHN':
        import datasets.svhn as dataset
    elif config.dataset == 'CIFAR10':
        import datasets.cifar10 as dataset
    else:
        raise ValueError(config.dataset)

    dataset_train, dataset_test = dataset.create_default_splits()

    image, _, label, _ = dataset_train.get_data(dataset_train.ids[0], dataset_train.ids[0])
    config.data_info = np.concatenate([np.asarray(image.shape), np.asarray(label.shape)])

    trainer = Trainer(config,
                      dataset_train, dataset_test)

    log.warning("dataset: %s, learning_rate: %f",
                config.dataset, config.learning_rate)
    trainer.train(dataset_train)
示例#2
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)
    
    # load the pre-trained weights
    # model = torch.nn.DataParallel(model).cuda()
    model = torch.nn.DataParallel(NNet(regr=True))
    model.load_state_dict(torch.load('models/model_gan.pth.tar', map_location=torch.device('cpu'))['state_dict_G'])
    accuracy = True
    total = 0
    acc = 0 
    test_dataset = ImageNet(args.test_root, paths=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=8, pin_memory=True)
    print("=> Loaded data, length = ", len(test_dataset))
    model.eval()
    for i, (img, target, imgInfo) in enumerate(test_loader):
        imgPath = imgInfo[0]
        dir_name, file_name = imgPath.split('/val/')[1].split('/')
        if img is None:
            continue
        # var = Variable(img.float(), requires_grad=True).cuda()
        var = Variable(img.float(), requires_grad=True)
        output = model(var)
        # decoded_output = utils.decode(output)
        decoded_output = output
        lab = np.zeros((256,256,3))
        # lab[:,:,0] = cv2.resize((img+50.0).squeeze(0).squeeze(0).numpy(), (256,256))
        # lab[:,:,1:] = cv2.resize(decoded_output.squeeze(0).detach().numpy().transpose((1,2,0)),(256,256))
        lab[:,:,0] = resize((img+50.0).squeeze(0).squeeze(0).numpy(),(256,256))
        lab[:,:,1:] = resize(decoded_output.squeeze(0).detach().numpy().transpose((1,2,0)),(256,256))
        rgb = lab2rgb(lab)
        try:
            plt.imsave("img/imagenet-mini/generated-gan/"+ dir_name+ '/'+ file_name, rgb)
            #plt.savefig("img/imagenet-mini/generated/"+ dir_name+ '/'+ file_name)
        except FileNotFoundError:
            os.mkdir("img/imagenet-mini/generated-gan/"+dir_name)
            plt.imsave("img/imagenet-mini/generated-gan/"+ dir_name+ '/'+ file_name, rgb)
            #plt.savefig("img/imagenet-mini/generated/"+ dir_name+ '/'+ file_name)
        print("Forwarded image number: " + str(i+1))
        total += 1
        if accuracy:
            count = 0
            for j in range(56):
                for k in range(56):
                    pixel_acc = (np.linalg.norm(target[0,:,j,k].detach().numpy() - decoded_output[0,:,j,k].detach().numpy()) < range(151))+0
                    count += sum(pixel_acc)
            print('Accuracy is: ', count/(150*56*56))
            acc += count/(150*56*56)

    print('Acc is: ', acc/total)
示例#3
0
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--checkpoint_path', type=str)
    parser.add_argument('--train_dir', type=str)
    parser.add_argument('--dataset', type=str, default='ImageNet', choices=['ImageNet'])
    parser.add_argument('--data_id', nargs='*', default=None)
    config = parser.parse_args()

    if config.dataset == 'ImageNet':
        import datasets.ImageNet as dataset
    else:
        raise ValueError(config.dataset)

    _, dataset = dataset.create_default_splits(ratio=0.999)

    image, _, label, _ = dataset.get_data(dataset.ids[0], dataset.ids[0])
    config.data_info = np.concatenate([np.asarray(image.shape), np.asarray(label.shape)])

    evaler = Evaler(config, dataset)

    log.warning("dataset: %s", dataset)
    evaler.eval_run()
示例#4
0
def main(config):
    tf.reset_default_graph() # for sure

    log_dir = config.log_dir
    learning_rate = config.lr
    va_batch_size = 10

    print('Setup dataset')

    tr_provider = ImageNet(num_threads=config.num_threads)
    va_provider = ImageNet(num_threads=config.num_threads)
    tr_dataset = tr_provider.get_dataset(config.imagenet_dir, phase='train', batch_size=config.batch_size, 
                                is_training=True, shuffle=True)
    va_dataset = va_provider.get_dataset(config.imagenet_dir, phase='val', batch_size=va_batch_size, 
                                is_training=False, shuffle=True, seed=1234)
    tr_num_examples = tr_provider.num_examples
    va_num_examples = min(va_provider.num_examples, 10000)
    print('#examples = {}, {}'.format(tr_num_examples, va_num_examples))

    handle = tf.placeholder(tf.string, shape=[])

    dataset_iter = tf.data.Iterator.from_string_handle(handle, tr_dataset.output_types, tr_dataset.output_shapes) # create mock of iterator
    next_batch = list(dataset_iter.get_next()) #tuple --> list to make it possible to modify each elements

    tr_iter = tr_dataset.make_one_shot_iterator() # infinite loop
    va_iter = va_dataset.make_initializable_iterator() # require initialization in every epoch

    is_training = tf.placeholder(tf.bool, name='is_training')
    global_step = tf.Variable(0, name='global_step', trainable=False)

    print('Build network')
    loss, endpoints = build_network(config, next_batch, is_training, num_classes=tr_provider.NUM_CLASSES)

    if config.lr_decay:
        # copy from official/resnet
        batch_denom = 256
        initial_learning_rate = 0.1 * config.batch_size / batch_denom
        batches_per_epoch = tr_num_examples / config.batch_size
        boundary_epochs = [30, 60, 80, 90]
        decay_rates=[1, 0.1, 0.01, 0.001, 1e-4]
        boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs]
        lr_values = [initial_learning_rate * decay for decay in decay_rates]
        learning_rate = get_piecewise_lr(global_step, boundaries, lr_values, show_summary=True)

        # max_epoch = 50
        # boundaries = list((np.arange(max_epoch, dtype=np.int32)+1) * 5000)
        # lr_values = list(np.logspace(-1, -5, max_epoch))
        # learning_rate = get_piecewise_lr(global_step, boundaries, lr_values, show_summary=True)
        print('Enable adaptive learning. LR will decrease {} when #iter={}'.format(lr_values, boundaries))        

    minimize_op = get_optimizer(config.optim_method, global_step, learning_rate, loss, endpoints['var_list'], show_var_and_grad=config.show_histogram)
    print('Done.')

    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True # almost the same as tf.InteractiveSession
    sess = tf.Session(config=tfconfig)

    summary = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    tr_handle = sess.run(tr_iter.string_handle())
    va_handle = sess.run(va_iter.string_handle())

    if config.clear_logs and tf.gfile.Exists(log_dir):
        print('Clear all files in {}'.format(log_dir))
        try:
            tf.gfile.DeleteRecursively(log_dir) 
        except:
            print('Fail to delete {}. You probably have to kill tensorboard process.'.format(log_dir))

    best_saver = tf.train.Saver(max_to_keep=10, save_relative_paths=True)
    latest_saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)

    latest_checkpoint = tf.train.latest_checkpoint(log_dir)
    best_score_filename = os.path.join(log_dir, 'valid', 'best_score.txt')
    best_score = 0 # larger is better
    if latest_checkpoint is not None:
        from parse import parse
        print('Resume the previous model...')
        latest_saver.restore(sess, latest_checkpoint)
        curr_step = sess.run(global_step)
        if os.path.exists(best_score_filename):
            with open(best_score_filename, 'r') as f:
                dump_res = f.read()
            dump_res = parse('{step:d} {best_score:g}\n', dump_res)
            best_score = dump_res['best_score']
            print('Previous best score = {} @ #step={}'.format(best_score, curr_step))

    train_writer = tf.summary.FileWriter(
        os.path.join(log_dir, 'train'), graph=sess.graph
    )
    valid_writer = tf.summary.FileWriter(
        os.path.join(log_dir, 'valid'), graph=sess.graph
    )    

    if SAVE_MODEL:
        latest_saver.export_meta_graph(os.path.join(log_dir, "models.meta"))
    # Save config
    with open(os.path.join(log_dir, 'config.pkl'), 'wb') as f:
        pickle.dump(config, f)    

    ops = {
        'is_training': is_training,
        'handle': handle,
        'step': global_step,
        'summary': summary,
        'minimize_op': minimize_op,
    }
    for k, v in endpoints.items():
        if isinstance(v, tf.Tensor):
            ops[k] = v

    #----------------------
    # Start Training
    #----------------------
    save_summary_interval = 1000
    save_model_interval = 5000
    valid_interval = 5000

    va_params = {
        'batch_size': va_batch_size,
        'num_examples': va_num_examples,
        'summary_writer': valid_writer,
        'handle': va_handle,
        'ev_init_op': va_iter.initializer,
    }

    def check_counter(counter, interval):
        return (interval > 0 and counter % interval == 0)

    start_itr = sess.run(ops['step'])

    for _ in range(start_itr, config.max_itr):

        feed_dict = {
            ops['is_training']: True,
            ops['handle']: tr_handle,
        }        

        try:
            step, _ = sess.run([ops['step'], ops['minimize_op']], feed_dict=feed_dict)
        except:
            print('Error happens but keep training...')

        if check_counter(step, save_summary_interval):
            feed_dict = {
                ops['is_training']: False,
                ops['handle']: tr_handle,
            }
            fetch_dict = {
                'loss': ops['loss'],
                'top1': ops['top1'],
                'top5': ops['top5'],
                'summary': ops['summary'],
            }
            try:
                outs = sess.run(fetch_dict, feed_dict=feed_dict)
                start_time = time.time()
                outs = sess.run(fetch_dict, feed_dict=feed_dict)
                elapsed_time = time.time() - start_time
                train_writer.add_summary(outs['summary'], step) # save summary
                summaries = [tf.Summary.Value(tag='sec/step', simple_value=elapsed_time)]
                train_writer.add_summary(tf.Summary(value=summaries), global_step=step)
                train_writer.flush()

                print('[Train] {}step Loss: {:g}, Top1: {:g}, Top5: {:g} ({:.1f}sec)'.format(
                            step,
                            outs['loss'], outs['top1'], outs['top5'],
                            elapsed_time))
            except:
                print('Error happens but keep training...')

            if SAVE_MODEL and latest_saver is not None:
                latest_saver.save(sess, os.path.join(log_dir, 'models-latest'), global_step=step, write_meta_graph=False)

        # if SAVE_MODEL and best_saver is not None and check_counter(step, save_model_interval):
        #     # print('#{}step Save latest model'.format(step))
        #     best_saver.save(sess, os.path.join(log_dir, 'models-best'), global_step=step, write_meta_graph=False)

        if check_counter(step, valid_interval):
            eval_one_epoch(sess, ops, va_params)
示例#5
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.run_dir == '':
        writer = SummaryWriter()
    else:
        print("=> Logs can be found in", args.run_dir)
        writer = SummaryWriter(args.run_dir)

    # create model
    print("=> creating model")

    model_G = nn.DataParallel(NNet(regr=False)).cuda().float()
    model_D = nn.DataParallel(DCGAN()).cuda().float()

    weights_init(model_G, args)
    weights_init(model_D, args)
    print("=> model weights initialized")
    print(model_G)
    print(model_D)

    # optionally resume from a checkpoint
    if args.resume:
        for (path, net_G, net_D) in [(args.resume, model_G, model_D)]:
            if os.path.isfile(path):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(path)
                args.start_epoch = checkpoint['epoch']

                net_G.load_state_dict(checkpoint['state_dict_G'])
                net_D.load_state_dict(checkpoint['state_dict_D'])
                print("=> loaded checkpoint '{}' (epoch {})".format(
                    path, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(path))

    # Data loading code
    train_root = args.train_root
    train_dataset = ImageNet(train_root, output_full=True)

    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=args.batch_size,
                                                   shuffle=True,
                                                   num_workers=8,
                                                   pin_memory=True)
        print("=> Loaded data, length = ", len(train_dataset))

    # define loss function (criterion) and optimizer
    criterion_G = loss.classificationLoss
    gann_loss = nn.BCEWithLogitsLoss().cuda()

    def GANLoss(pred, is_real):
        if is_real:
            target = torch.ones_like(pred)
        else:
            target = torch.zeros_like(pred)
        return gann_loss(pred, target)

    criterion_GAN = GANLoss

    optimizer_G = torch.optim.Adam([
        {
            'params': model_G.parameters()
        },
    ],
                                   args.lr,
                                   weight_decay=args.weight_decay,
                                   betas=(0.9, 0.99))
    optimizer_D = torch.optim.Adam([
        {
            'params': model_D.parameters()
        },
    ],
                                   args.lr,
                                   weight_decay=args.weight_decay,
                                   betas=(0.9, 0.999))

    for epoch in range(args.start_epoch, args.epochs):
        print("=> Epoch", epoch, "started.")
        adjust_learning_rate(optimizer_G, optimizer_D, epoch)
        # train for one epoch
        train(train_loader, model_G, model_D, criterion_G, criterion_GAN,
              optimizer_G, optimizer_D, epoch)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict_G': model_G.state_dict(),
                'state_dict_D': model_D.state_dict(),
            }, args.reduced)
        print("=> Epoch", epoch, "finished.")
示例#6
0
def get_datasets(args, input_size, cutout=-1):

    name = args.dataset
    root = args.data_path

    assert len(input_size) in [3, 4]
    if len(input_size) == 4:
        input_size = input_size[1:]
    assert input_size[1] == input_size[2]

    if name == 'cifar10':
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
    elif name == 'cifar100':
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
    elif name.startswith('imagenet-1k'):
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
    elif name.startswith('ImageNet16'):
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
        std = [x / 255 for x in [63.22, 61.26, 65.09]]
    elif name in [
            'MiniImageNet', 'MetaMiniImageNet', 'TieredImageNet',
            'MetaTieredImageNet'
    ]:
        pass
    else:
        raise TypeError("Unknow dataset : {:}".format(name))

    # Data Argumentation
    if name == 'cifar10' or name == 'cifar100':
        lists = [
            transforms.RandomCrop(input_size[1], padding=0),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            RandChannel(input_size[0])
        ]
        if cutout > 0: lists += [CUTOUT(cutout)]
        train_transform = transforms.Compose(lists)
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])
    elif name.startswith('ImageNet16'):
        lists = [
            transforms.RandomCrop(input_size[1], padding=0),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
            RandChannel(input_size[0])
        ]
        if cutout > 0: lists += [CUTOUT(cutout)]
        train_transform = transforms.Compose(lists)
        test_transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize(mean, std)])
    elif name.startswith('imagenet-1k'):
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        if name == 'imagenet-1k':
            xlists = []
            xlists.append(transforms.Resize((32, 32), interpolation=2))
            xlists.append(transforms.RandomCrop(input_size[1], padding=0))
        elif name == 'imagenet-1k-s':
            xlists = [transforms.RandomResizedCrop(32, scale=(0.2, 1.0))]
            xlists = []
        else:
            raise ValueError('invalid name : {:}'.format(name))
        xlists.append(transforms.ToTensor())
        xlists.append(normalize)
        xlists.append(RandChannel(input_size[0]))
        train_transform = transforms.Compose(xlists)
        test_transform = transforms.Compose([
            transforms.Resize(40),
            transforms.CenterCrop(32),
            transforms.ToTensor(), normalize
        ])
    elif name in [
            'MiniImageNet', 'MetaMiniImageNet', 'TieredImageNet',
            'MetaTieredImageNet'
    ]:
        mean = [
            120.39586422 / 255.0, 115.59361427 / 255.0, 104.54012653 / 255.0
        ]
        std = [70.68188272 / 255.0, 68.27635443 / 255.0, 72.54505529 / 255.0]
        normalize = transforms.Normalize(mean=mean, std=std)
        xlists = []
        xlists.append(lambda x: Image.fromarray(x))
        xlists.append(transforms.Resize((32, 32), interpolation=2))
        xlists.append(transforms.RandomCrop(input_size[1], padding=0))
        xlists.append(lambda x: np.asarray(x))
        xlists.append(transforms.ToTensor())
        xlists.append(normalize)
        xlists.append(RandChannel(input_size[0]))
        train_transform = transforms.Compose(xlists)
        test_transform = transforms.Compose([
            lambda x: Image.fromarray(x),
            transforms.Resize(40),
            transforms.CenterCrop(32),
            transforms.ToTensor(), normalize
        ])
    else:
        raise TypeError("Unknow dataset : {:}".format(name))

    if name == 'cifar10':
        train_data = dset.CIFAR10(root,
                                  train=True,
                                  transform=train_transform,
                                  download=True)
        test_data = dset.CIFAR10(root,
                                 train=False,
                                 transform=test_transform,
                                 download=True)
        assert len(train_data) == 50000 and len(test_data) == 10000
    elif name == 'cifar100':
        train_data = dset.CIFAR100(root,
                                   train=True,
                                   transform=train_transform,
                                   download=True)
        test_data = dset.CIFAR100(root,
                                  train=False,
                                  transform=test_transform,
                                  download=True)
        assert len(train_data) == 50000 and len(test_data) == 10000
    elif name.startswith('imagenet-1k'):
        train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
        test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
    elif name == 'ImageNet16':
        train_data = ImageNet16(root, True, train_transform)
        test_data = ImageNet16(root, False, test_transform)
        assert len(train_data) == 1281167 and len(test_data) == 50000
    elif name == 'ImageNet16-120':
        train_data = ImageNet16(root, True, train_transform, 120)
        test_data = ImageNet16(root, False, test_transform, 120)
        assert len(train_data) == 151700 and len(test_data) == 6000
    elif name == 'ImageNet16-150':
        train_data = ImageNet16(root, True, train_transform, 150)
        test_data = ImageNet16(root, False, test_transform, 150)
        assert len(train_data) == 190272 and len(test_data) == 7500
    elif name == 'ImageNet16-200':
        train_data = ImageNet16(root, True, train_transform, 200)
        test_data = ImageNet16(root, False, test_transform, 200)
        assert len(train_data) == 254775 and len(test_data) == 10000
    elif name == 'MiniImageNet':
        train_data = ImageNet(args=args,
                              partition='train',
                              transform=train_transform)
        test_data = ImageNet(args=args,
                             partition='val',
                             transform=test_transform)
    elif name == 'MetaMiniImageNet':
        train_data = MetaImageNet(args=args,
                                  partition='test',
                                  train_transform=train_transform,
                                  test_transform=test_transform)
        test_data = MetaImageNet(args=args,
                                 partition='val',
                                 train_transform=train_transform,
                                 test_transform=test_transform)
    elif name == 'TieredImageNet':
        train_data = TieredImageNet(args=args,
                                    partition='train',
                                    transform=train_transform)
        test_data = TieredImageNet(args=args,
                                   partition='train_phase_val',
                                   transform=test_transform)
    elif name == 'MetaTieredImageNet':
        train_data = MetaTieredImageNet(args=args,
                                        partition='test',
                                        train_transform=train_transform,
                                        test_transform=test_transform)
        test_data = MetaTieredImageNet(args=args,
                                       partition='val',
                                       train_transform=train_transform,
                                       test_transform=test_transform)

    else:
        raise TypeError("Unknow dataset : {:}".format(name))

    class_num = Dataset2Class[name]
    return train_data, test_data, class_num
示例#7
0
def main(args):
    curr_time = time.time()

    args.distributed = args.world_size > 1

    if args.distributed:
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size)

    print("#############  Read in Database   ##############")
    # Data loading code (From PyTorch example https://github.com/pytorch/examples/blob/master/imagenet/main.py)
    traindir = os.path.join(args.data_path, 'train')
    valdir = os.path.join(args.data_path, 'validation')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print("Generating Validation Dataset")
    valid_dataset = ImageNet(
        valdir,
        transforms.Compose([
            transforms.Resize((299, 299)),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    print("Generating Training Dataset")
    train_dataset = ImageNet(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(299),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    print("Generating Data Loaders")
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.workers,
                                               pin_memory=True)

    print("Time taken:  {} seconds".format(time.time() - curr_time))
    curr_time = time.time()

    print("######## Initiate Model and Optimizer   ##############")
    # Model - inception_v3 as specified in the paper
    # Note: This is slightly different to the model used by the paper,
    # however, the differences should be minor in terms of implementation and impact on results
    model = models.inception_v3(pretrained=False)
    # Train on GPU if available
    if not args.distributed:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)

    # if torch.cuda.is_available():
    #     model.cuda()

    # Criterion was not specified by the paper, it was assumed to be cross entropy (as commonly used)
    criterion = torch.nn.CrossEntropyLoss().cuda()  # Loss function
    params = list(model.parameters())  # Parameters to train

    # Optimizer -- the optimizer is not specified in the paper, and was ssumed to
    # be SGD. The parameters of the model were also not specified and were set
    # to commonly values used by pytorch (lr = 0.1, momentum = 0.3, decay = 1e-4)
    optimizer = torch.optim.SGD(params,
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    # The paper does not specify an annealing factor, we set it to 1.0 (no annealing)
    scheduler = MultiStepLR(optimizer,
                            milestones=list(range(0, args.num_epochs, 1)),
                            gamma=args.annealing_factor)

    print("Time taken:  {} seconds".format(time.time() - curr_time))
    curr_time = time.time()

    print("#############  Start Training     ##############")
    total_step = len(train_loader)

    # curr_loss, curr_wacc = eval_step(   model       = model,
    #                                     data_loader = valid_loader,
    #                                     criterion   = criterion,
    #                                     step        = epoch * total_step,
    #                                     datasplit   = "valid")

    for epoch in range(0, args.num_epochs):

        if args.evaluate_only: exit()
        if args.optimizer == 'sgd': scheduler.step()

        logger.add_scalar("Misc/Epoch Number", epoch, epoch * total_step)
        train_step(model=model,
                   train_loader=train_loader,
                   criterion=criterion,
                   optimizer=optimizer,
                   epoch=epoch,
                   step=epoch * total_step,
                   valid_loader=valid_loader)

        curr_loss, curr_wacc = eval_step(model=model,
                                         data_loader=valid_loader,
                                         criterion=criterion,
                                         step=epoch * total_step,
                                         datasplit="valid")

        args = save_checkpoint(model=model,
                               optimizer=optimizer,
                               curr_epoch=epoch,
                               curr_loss=curr_loss,
                               curr_step=(total_step * epoch),
                               args=args,
                               curr_acc=curr_wacc,
                               filename=('model@epoch%d.pkl' % (epoch)))

    # Final save of the model
    args = save_checkpoint(model=model,
                           optimizer=optimizer,
                           curr_epoch=epoch,
                           curr_loss=curr_loss,
                           curr_step=(total_step * epoch),
                           args=args,
                           curr_acc=curr_wacc,
                           filename=('model@epoch%d.pkl' % (epoch)))
示例#8
0
def load_partition_data_ImageNet(dataset, data_dir, 
                            partition_method=None, partition_alpha=None, client_number=100, batch_size=10):

    train_dataset = ImageNet(data_dir=data_dir,
                             dataidxs=None,
                             train=True)

    test_dataset = ImageNet(data_dir=data_dir,
                             dataidxs=None,
                             train=False)

    net_dataidx_map = train_dataset.get_net_dataidx_map()

    class_num = 1000

    # logging.info("traindata_cls_counts = " + str(traindata_cls_counts))
    # train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)])
    train_data_num = len(train_dataset)
    test_data_num = len(test_dataset)
    class_num_dict = train_dataset.get_data_local_num_dict()


    # train_data_global, test_data_global = get_dataloader(dataset, data_dir, batch_size, batch_size)

    train_data_global, test_data_global = get_dataloader_ImageNet_truncated(train_dataset, test_dataset, 
                train_bs=batch_size, test_bs=batch_size, 
                dataidxs=None, net_dataidx_map=None,)

    logging.info("train_dl_global number = " + str(len(train_data_global)))
    logging.info("test_dl_global number = " + str(len(test_data_global)))



    # get local dataset
    data_local_num_dict = dict() 
    train_data_local_dict = dict()
    test_data_local_dict = dict()


    for client_idx in range(client_number):
        if client_number == 1000:
            dataidxs = client_idx
            data_local_num_dict = class_num_dict
        elif client_number == 100:
            dataidxs = [client_idx*10 + i for i in range(10)]
            data_local_num_dict[client_idx] = sum(class_num_dict[client_idx+i] for i in range(10))
        else:
            raise NotImplementedError("Not support other client_number for now!")

        local_data_num = data_local_num_dict[client_idx]

        logging.info("client_idx = %d, local_sample_number = %d" % (client_idx, local_data_num))

        # training batch size = 64; algorithms batch size = 32
        # train_data_local, test_data_local = get_dataloader(dataset, data_dir, batch_size, batch_size,
        #                                          dataidxs)
        train_data_local, test_data_local = get_dataloader_ImageNet_truncated(train_dataset, test_dataset, 
                train_bs=batch_size, test_bs=batch_size, 
                dataidxs=dataidxs, net_dataidx_map=net_dataidx_map)

        logging.info("client_idx = %d, batch_num_train_local = %d, batch_num_test_local = %d" % (
            client_idx, len(train_data_local), len(test_data_local)))
        train_data_local_dict[client_idx] = train_data_local
        test_data_local_dict[client_idx] = test_data_local
    return train_data_num, test_data_num, train_data_global, test_data_global, \
           data_local_num_dict, train_data_local_dict, test_data_local_dict, class_num
示例#9
0
import torch
import numpy as np
import datetime
from model.inceptionv4 import Inceptionv4
from torch.utils.data import DataLoader
from datasets import ImageNet

dataset = ImageNet('data/ILSVRC2012_img_val', 'data/imagenet_classes.txt',
                   'data/imagenet_2012_validation_synset_labels.txt')

## model
model = Inceptionv4()
model = model.cuda()
model.eval()
model.load_state_dict(torch.load('checkpoints/inceptionv4.pth'))


def topk_accuracy():
    val_loader = DataLoader(dataset=dataset,
                            batch_size=25,
                            shuffle=False,
                            num_workers=4)
    tp_1, tp_5 = 0, 0
    for i, data in enumerate(val_loader):
        input, label = data
        input, label = input.cuda(), label.cuda()
        pred = model(input)
        _, pred = torch.topk(pred, 5, dim=1)
        correct = pred.eq(label.view(-1, 1).expand_as(pred)).cpu().numpy()
        tp_1 += correct[:, 0].sum()
        tp_5 += correct.sum()
示例#10
0
def main():
    global args, best_prec1
    args = parser.parse_args()
    print(args)

    if args.run_dir == '':
        writer = SummaryWriter()
    else:
        print("=> Logs can be found in", args.run_dir)
        writer = SummaryWriter(args.run_dir)

    # create model
    print("=> creating model")


    model = nn.DataParallel(NNet()).cuda()

    # print("paralleling")
    # model = torch.nn.DataParallel(model, device_ids=range(args.nGpus)).cuda()
    weights_init(model,args)
    print("=> model weights initialized")
    print(model)

    # optionally resume from a checkpoint
    if args.resume: 
        for (path, net) in [(args.resume, model)]:
            if os.path.isfile(path):
                print("=> loading checkpoint '{}'".format(args.resume))
                checkpoint = torch.load(path)
                args.start_epoch = checkpoint['epoch']
                #best_prec1 = checkpoint['best_prec1']
                net.load_state_dict(checkpoint['state_dict'])

                print("=> loaded checkpoint '{}' (epoch {})"
                      .format(path, checkpoint['epoch']))
            else:
                print("=> no checkpoint found at '{}'".format(path))

    # Data loading code
    train_root = args.train_root
    #val_root = args.val_root
    train_dataset = ImageNet(train_root)
    #val_dataset = datasets.ImageFolder(val_root)
    if not args.evaluate:
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,num_workers=8, pin_memory=True)
        print("=> Loaded data, length = ", len(train_dataset))

    # define loss function (criterion) and optimizer
    criterion = loss.classificationLoss
    optimizer = torch.optim.Adam([{'params': model.parameters()},], args.lr,weight_decay=args.weight_decay, betas=(0.9, 0.99))


    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    for epoch in range(args.start_epoch, args.epochs):
        print("=> Epoch", epoch, "started.")
        adjust_learning_rate(optimizer, epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
        }, args.reduced)
        print("=> Epoch", epoch, "finished.")