コード例 #1
0
def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn, metrics, params, model_dir, restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) - name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    # learning rate schedulers for different models:
    if params.model_version == "resnet18":
        scheduler = StepLR(optimizer, step_size=150, gamma=0.1)
    # for cnn models, num_epoch is always < 100, so it's intentionally not using scheduler here
    elif params.model_version == "cnn":
        scheduler = StepLR(optimizer, step_size=100, gamma=0.2)

    for epoch in range(params.num_epochs):
     
        scheduler.step()
     
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)        

        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
コード例 #2
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=14,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = torch.utils.data.DataLoader(datasets.KMNIST(
        '../data',
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.KMNIST(
        '../data',
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
        ])),
                                              batch_size=args.test_batch_size,
                                              shuffle=True,
                                              **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(args, model, device, test_loader)
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "kmnist_cnn.pt")
コード例 #3
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('-o', '--output_dir', default=None, help='output dir')
    parser.add_argument('-b',
                        '--batch_size',
                        type=int,
                        default=1,
                        metavar='N',
                        help='input batch size for training')
    parser.add_argument('--epochs',
                        type=int,
                        default=400,
                        help='number of epochs to train')
    parser.add_argument('-lr',
                        '--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate')
    parser.add_argument(
        '-reset_lr',
        '--reset_lr',
        action='store_true',
        help='should reset lr cycles? If not count epochs from 0')
    parser.add_argument('-opt',
                        '--optimizer',
                        default='sgd',
                        choices=['sgd', 'adam', 'rmsprop'],
                        help='optimizer type')
    parser.add_argument('--decay_step',
                        type=float,
                        default=100,
                        metavar='EPOCHS',
                        help='learning rate decay step')
    parser.add_argument('--decay_gamma',
                        type=float,
                        default=0.5,
                        help='learning rate decay coeeficient')
    parser.add_argument(
        '--cyclic_lr',
        type=int,
        default=None,
        help=
        '(int)Len of the cycle. If not None use cyclic lr with cycle_len) specified'
    )
    parser.add_argument(
        '--cyclic_duration',
        type=float,
        default=1.0,
        help='multiplier of the duration of segments in the cycle')

    parser.add_argument('--weight_decay',
                        type=float,
                        default=0.0005,
                        help='L2 regularizer weight')
    parser.add_argument('--seed', type=int, default=1993, help='random seed')
    parser.add_argument(
        '--log_aggr',
        type=int,
        default=None,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('-gacc',
                        '--num_grad_acc_steps',
                        type=int,
                        default=1,
                        metavar='N',
                        help='number of vatches to accumulate gradients')
    parser.add_argument(
        '-imsize',
        '--image_size',
        type=int,
        default=1024,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('-f',
                        '--fold',
                        type=int,
                        default=0,
                        metavar='N',
                        help='fold_id')
    parser.add_argument('-nf',
                        '--n_folds',
                        type=int,
                        default=0,
                        metavar='N',
                        help='number of folds')
    parser.add_argument(
        '-fv',
        '--folds_version',
        type=int,
        default=1,
        choices=[1, 2],
        help='version of folds (1) - random, (2) - stratified on mask area')
    parser.add_argument('-group',
                        '--group',
                        type=parse_group,
                        default='all',
                        help='group id')
    parser.add_argument('-no_cudnn',
                        '--no_cudnn',
                        action='store_true',
                        help='dont use cudnn?')
    parser.add_argument('-aug',
                        '--aug',
                        type=int,
                        default=None,
                        help='use augmentations?')
    parser.add_argument('-no_hq',
                        '--no_hq',
                        action='store_true',
                        help='do not use hq images?')
    parser.add_argument('-dbg', '--dbg', action='store_true', help='is debug?')
    parser.add_argument('-is_log_dice',
                        '--is_log_dice',
                        action='store_true',
                        help='use -log(dice) in loss?')
    parser.add_argument('-no_weight_loss',
                        '--no_weight_loss',
                        action='store_true',
                        help='do not weight border in loss?')

    parser.add_argument('-suf',
                        '--exp_suffix',
                        default='',
                        help='experiment suffix')
    parser.add_argument('-net', '--network', default='Unet')

    args = parser.parse_args()
    print 'aug:', args.aug
    # assert args.aug, 'Careful! No aug specified!'
    if args.log_aggr is None:
        args.log_aggr = 1
    print 'log_aggr', args.log_aggr

    random.seed(42)
    torch.manual_seed(args.seed)
    print 'CudNN:', torch.backends.cudnn.version()
    print 'Run on {} GPUs'.format(torch.cuda.device_count())
    torch.backends.cudnn.benchmark = not args.no_cudnn  # Enable use of CudNN

    experiment = "{}_s{}_im{}_gacc{}{}{}{}_{}fold{}.{}".format(
        args.network, args.seed, args.image_size, args.num_grad_acc_steps,
        '_aug{}'.format(args.aug) if args.aug is not None else '',
        '_nohq' if args.no_hq else '',
        '_g{}'.format(args.group) if args.group != 'all' else '',
        'v2' if args.folds_version == 2 else '', args.fold, args.n_folds)
    if args.output_dir is None:
        ckpt_dir = join(config.models_dir, experiment + args.exp_suffix)
        if os.path.exists(join(ckpt_dir, 'checkpoint.pth.tar')):
            args.output_dir = ckpt_dir
    if args.output_dir is not None and os.path.exists(args.output_dir):
        ckpt_path = join(args.output_dir, 'checkpoint.pth.tar')
        if not os.path.isfile(ckpt_path):
            print "=> no checkpoint found at '{}'\nUsing model_best.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best.pth.tar')
        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            if 'filter_sizes' in checkpoint:
                filters_sizes = checkpoint['filter_sizes']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            raise IOError("=> no checkpoint found at '{}'".format(ckpt_path))
    else:
        checkpoint = None
        if args.network == 'Unet':
            filters_sizes = np.asarray([32, 64, 128, 256, 512, 1024, 1024])
        elif args.network == 'UNarrow':
            filters_sizes = np.asarray([32, 32, 64, 128, 256, 512, 768])
        elif args.network == 'Unet7':
            filters_sizes = np.asarray(
                [48, 96, 128, 256, 512, 1024, 1536, 1536])
        elif args.network == 'Unet5':
            filters_sizes = np.asarray([32, 64, 128, 256, 512, 1024])
        elif args.network == 'Unet4':
            filters_sizes = np.asarray([24, 64, 128, 256, 512])
        elif args.network in ['vgg11v1', 'vgg11v2']:
            filters_sizes = np.asarray([64])
        elif args.network in ['vgg11av1', 'vgg11av2']:
            filters_sizes = np.asarray([32])
        else:
            raise ValueError('Unknown Net: {}'.format(args.network))
    if args.network in ['vgg11v1', 'vgg11v2']:
        assert args.network[-2] == 'v'
        v = int(args.network[-1:])
        model = torch.nn.DataParallel(
            UnetVgg11(n_classes=1, num_filters=filters_sizes.item(),
                      v=v)).cuda()
    elif args.network in ['vgg11av1', 'vgg11av2']:
        assert args.network[-2] == 'v'
        v = int(args.network[-1:])
        model = torch.nn.DataParallel(
            vgg_unet.Vgg11a(n_classes=1, num_filters=filters_sizes.item(),
                            v=v)).cuda()
    else:
        unet_class = getattr(unet, args.network)
        model = torch.nn.DataParallel(
            unet_class(is_deconv=False, filters=filters_sizes)).cuda()

    print('  + Number of params: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    rescale_size = (args.image_size, args.image_size)
    is_full_size = False
    if args.image_size == -1:
        print 'Use full size. Use padding'
        is_full_size = True
        rescale_size = (1920, 1280)
    elif args.image_size == -2:
        rescale_size = (1856, 1248)

    train_dataset = CarvanaPlus(
        root=config.input_data_dir,
        subset='train',
        image_size=args.image_size,
        transform=TrainTransform(
            rescale_size,
            aug=args.aug,
            resize_mask=True,
            should_pad=is_full_size,
            should_normalize=args.network.startswith('vgg')),
        seed=args.seed,
        is_hq=not args.no_hq,
        fold_id=args.fold,
        n_folds=args.n_folds,
        group=args.group,
        return_image_id=True,
        v=args.folds_version)
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4 if torch.cuda.device_count() > 1 else 1)

    val_dataset = CARVANA(
        root=config.input_data_dir,
        subset='val',
        image_size=args.image_size,
        transform=TrainTransform(
            rescale_size,
            aug=None,
            resize_mask=False,
            should_pad=is_full_size,
            should_normalize=args.network.startswith('vgg')),
        seed=args.seed,
        is_hq=not args.no_hq,
        fold_id=args.fold,
        n_folds=args.n_folds,
        group=args.group,
        v=args.folds_version,
    )
    val_loader = torch.utils.data.DataLoader(
        dataset=val_dataset,
        batch_size=args.batch_size * 2,
        shuffle=False,
        pin_memory=True,
        num_workers=4
        if torch.cuda.device_count() > 4 else torch.cuda.device_count())

    print 'Weight loss:', not args.no_weight_loss
    print '-log(dice) in loss:', args.is_log_dice
    criterion = CombinedLoss(is_weight=not args.no_weight_loss,
                             is_log_dice=args.is_log_dice).cuda()

    if args.optimizer == 'adam':
        print 'Using adam optimizer!'
        optimizer = optim.Adam(model.parameters(),
                               weight_decay=args.weight_decay,
                               lr=args.lr)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(
            model.parameters(), lr=args.lr,
            weight_decay=args.weight_decay)  # For Tiramisu weight_decay=0.0001
    else:
        optimizer = optim.SGD(model.parameters(),
                              weight_decay=args.weight_decay,
                              lr=args.lr,
                              momentum=0.9,
                              nesterov=False)

    if args.output_dir is not None:
        out_dir = args.output_dir
    else:
        out_dir = join(config.models_dir, experiment + args.exp_suffix)
    print 'Model dir:', out_dir
    if args.dbg:
        out_dir = 'dbg_runs'
    logger = SummaryWriter(log_dir=out_dir)

    if checkpoint is not None:
        start_epoch = checkpoint['epoch']
        best_score = checkpoint['best_score']
        print 'Best score:', best_score
        print 'Current score:', checkpoint['cur_score']
        model.load_state_dict(checkpoint['state_dict'])
        print 'state dict loaded'
        optimizer.load_state_dict(checkpoint['optimizer'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            param_group['initial_lr'] = args.lr
        # validate(val_loader, model, start_epoch * len(train_loader), logger,
        #   is_eval=args.batch_size > 1, is_full_size=is_full_size)
    else:
        start_epoch = 0
        best_score = 0
        # validate(val_loader, model, start_epoch * len(train_loader), logger,
        #          is_eval=args.batch_size > 1, is_full_size=is_full_size)

    if args.cyclic_lr is None:
        scheduler = StepLR(optimizer,
                           step_size=args.decay_step,
                           gamma=args.decay_gamma)
        print 'scheduler.base_lrs=', scheduler.base_lrs
    elif args.network.startswith('vgg'):
        print 'Using VggCyclic LR!'
        cyclic_lr = VggCyclicLr(start_epoch if args.reset_lr else 0,
                                init_lr=args.lr,
                                num_epochs_per_cycle=args.cyclic_lr,
                                duration=args.cyclic_duration)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    else:
        print 'Using Cyclic LR!'
        cyclic_lr = CyclicLr(start_epoch if args.reset_lr else 0,
                             init_lr=args.lr,
                             num_epochs_per_cycle=args.cyclic_lr,
                             epochs_pro_decay=args.decay_step,
                             lr_decay_factor=args.decay_gamma)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))

    logger.add_scalar('data/batch_size', args.batch_size, start_epoch)
    logger.add_scalar('data/num_grad_acc_steps', args.num_grad_acc_steps,
                      start_epoch)
    logger.add_text('config/info', 'filters sizes: {}'.format(filters_sizes))

    last_lr = 100500

    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        scheduler.step(epoch=epoch)
        if last_lr != scheduler.get_lr()[0]:
            last_lr = scheduler.get_lr()[0]
            print 'LR := {}'.format(last_lr)
        logger.add_scalar('data/lr', scheduler.get_lr()[0], epoch)
        logger.add_scalar('data/aug', args.aug if args.aug is not None else -1,
                          epoch)
        logger.add_scalar('data/weight_decay', args.weight_decay, epoch)
        logger.add_scalar('data/is_weight_loss', not args.no_weight_loss,
                          epoch)
        logger.add_scalar('data/is_log_dice', args.is_log_dice, epoch)
        train(train_loader,
              model,
              optimizer,
              epoch,
              args.epochs,
              criterion,
              num_grad_acc_steps=args.num_grad_acc_steps,
              logger=logger,
              log_aggr=args.log_aggr)
        dice_score = validate(val_loader,
                              model,
                              epoch + 1,
                              logger,
                              is_eval=args.batch_size > 1,
                              is_full_size=is_full_size)

        # store best loss and save a model checkpoint
        is_best = dice_score > best_score
        prev_best_score = best_score
        best_score = max(dice_score, best_score)
        ckpt_dict = {
            'epoch': epoch + 1,
            'arch': experiment,
            'state_dict': model.state_dict(),
            'best_score': best_score,
            'cur_score': dice_score,
            'optimizer': optimizer.state_dict(),
        }
        ckpt_dict['filter_sizes'] = filters_sizes

        if is_best:
            print 'Best snapshot! {} -> {}'.format(prev_best_score, best_score)
            logger.add_text('val/best_dice',
                            'best val dice score: {}'.format(dice_score),
                            global_step=epoch + 1)
        save_checkpoint(ckpt_dict,
                        is_best,
                        filepath=join(out_dir, 'checkpoint.pth.tar'))

    logger.close()
コード例 #4
0
        lr=base_lr,
        weight_decay=1e-4)
    scheduler = StepLR(optimizer, step_size=40, gamma=0.1)

    bind_nsml(model, optimizer, scheduler)
    if config.pause:
        nsml.paused(scope=locals())

    if mode == 'train':
        tr_loader, val_loader, val_label = data_loader_with_split(
            root=TRAIN_DATASET_PATH, train_split=train_split)
        time_ = datetime.datetime.now()
        num_batches = len(tr_loader)

        for epoch in range(num_epochs):
            scheduler.step()
            model.train()
            for iter_, data in enumerate(tr_loader):
                _, x, label = data
                if cuda:
                    x = x.cuda()
                    label = label.cuda()
                pred = model(x)
                loss = loss_fn(pred, label)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                if (iter_ + 1) % print_iter == 0:
                    elapsed = datetime.datetime.now() - time_
                    expected = elapsed * (num_batches / print_iter)
                    _epoch = epoch + ((iter_ + 1) / num_batches)
コード例 #5
0
def main(args):
    exp_database_dir = osp.join(args.exp_dir, string.capwords(args.dataset))
    output_dir = osp.join(exp_database_dir, args.method, args.sub_method)
    log_file = osp.join(output_dir, 'log.txt')
    # Redirect print to both console and log file
    sys.stdout = Logger(log_file)

    seed = set_seed(args.seed)
    print('Random seed of this run: %d\n' % seed)

    # Create data loaders
    dataset, num_classes, train_loader, query_loader, gallery_loader = \
        get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.combine_all,
                 args.min_size, args.max_size, args.workers, args.test_fea_batch)

    # Create model
    model = resmap.create(args.arch,
                          final_layer=args.final_layer,
                          neck=args.neck).cuda()
    num_features = model.num_features
    # print(model)
    print('\n')

    for arg in sys.argv:
        print('%s ' % arg, end='')
    print('\n')

    # Criterion

    feamap_factor = {'layer2': 8, 'layer3': 16, 'layer4': 32}
    hei = args.height // feamap_factor[args.final_layer]
    wid = args.width // feamap_factor[args.final_layer]
    criterion = QAConvLoss(num_classes, num_features, hei, wid,
                           args.mem_batch_size).cuda()

    # Optimizer
    base_param_ids = set(map(id, model.base.parameters()))
    new_params = [p for p in model.parameters() if id(p) not in base_param_ids]
    param_groups = [{
        'params': model.base.parameters(),
        'lr': 0.1 * args.lr
    }, {
        'params': new_params,
        'lr': args.lr
    }, {
        'params': criterion.parameters(),
        'lr': args.lr
    }]

    optimizer = torch.optim.SGD(param_groups,
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4,
                                nesterov=True)

    # # Decay LR by a factor of 0.1 every step_size epochs
    lr_scheduler = StepLR(optimizer, step_size=args.step_size, gamma=0.1)

    # Load from checkpoint
    start_epoch = 0

    if args.resume or args.evaluate:
        print('Loading checkpoint...')
        if args.resume and (args.resume != 'ori'):
            checkpoint = load_checkpoint(args.resume)
        else:
            checkpoint = load_checkpoint(
                osp.join(output_dir, 'checkpoint.pth.tar'))
        model.load_state_dict(checkpoint['model'])
        criterion.load_state_dict(checkpoint['criterion'])
        optimizer.load_state_dict(checkpoint['optim'])
        start_epoch = checkpoint['epoch']
        print("=> Start epoch {} ".format(start_epoch))

    model = nn.DataParallel(model).cuda()
    criterion = nn.DataParallel(criterion).cuda()

    if not args.evaluate:
        # Trainer
        trainer = Trainer(model, criterion)

        t0 = time.time()
        # Start training
        for epoch in range(start_epoch, args.epochs):
            loss, acc = trainer.train(epoch, train_loader, optimizer,
                                      args.print_freq)

            lr = list(map(lambda group: group['lr'], optimizer.param_groups))
            lr_scheduler.step(epoch + 1)
            train_time = time.time() - t0

            print(
                '* Finished epoch %d at lr=[%g, %g, %g]. Loss: %.3f. Acc: %.2f%%. Training time: %.0f seconds.\n'
                %
                (epoch + 1, lr[0], lr[1], lr[2], loss, acc * 100, train_time))

            save_checkpoint(
                {
                    'model': model.module.state_dict(),
                    'criterion': criterion.module.state_dict(),
                    'optim': optimizer.state_dict(),
                    'epoch': epoch + 1,
                },
                fpath=osp.join(output_dir, 'checkpoint.pth.tar'))

    # Final test
    cudnn.benchmark = True
    print('Evaluate the learned model:')
    t0 = time.time()

    # Evaluator
    evaluator = Evaluator(model)

    test_names = args.testset.strip().split(',')
    for test_name in test_names:
        if test_name not in datasets.names():
            print('Unknown dataset: {test_name}.')
            continue

        testset, testset_train_loader, test_query_loader, test_gallery_loader = \
            get_test_data(test_name, args.data_dir, args.height, args.width, args.test_fea_batch)

        test_rank1, test_mAP, test_rank1_rerank, test_mAP_rerank, test_rank1_tlift, test_mAP_tlift, test_dist, \
            test_dist_rerank, test_dist_tlift, pre_tlift_dict = \
            evaluator.evaluate(test_query_loader, test_gallery_loader, testset, criterion.module, args.test_ker_batch,
                               args.test_prob_batch)

        print('  %s: rank1=%.1f, mAP=%.1f, rank1_rerank=%.1f, mAP_rerank=%.1f,'
              ' rank1_rerank_tlift=%.1f, mAP_rerank_tlift=%.1f.\n' %
              (test_name, test_rank1 * 100, test_mAP * 100,
               test_rank1_rerank * 100, test_mAP_rerank * 100,
               test_rank1_tlift * 100, test_mAP_tlift * 100))

        result_file = osp.join(exp_database_dir, args.method,
                               test_name + '_results.txt')
        with open(result_file, 'a') as f:
            f.write('%s/%s:\n' % (args.method, args.sub_method))
            f.write(
                '\t%s: rank1=%.1f, mAP=%.1f, rank1_rerank=%.1f, mAP_rerank=%.1f rank1_rerank_tlift=%.1f, '
                'mAP_rerank_tlift=%.1f.\n\n' %
                (test_name, test_rank1 * 100, test_mAP * 100,
                 test_rank1_rerank * 100, test_mAP_rerank * 100,
                 test_rank1_tlift * 100, test_mAP_tlift * 100))

        if args.save_score:
            test_gal_list = np.array(
                [fname for fname, _, _, _ in testset.gallery], dtype=np.object)
            test_prob_list = np.array(
                [fname for fname, _, _, _ in testset.query], dtype=np.object)
            test_gal_ids = [pid for _, pid, _, _ in testset.gallery]
            test_prob_ids = [pid for _, pid, _, _ in testset.query]
            test_gal_cams = [c for _, _, c, _ in testset.gallery]
            test_prob_cams = [c for _, _, c, _ in testset.query]
            test_score_file = osp.join(exp_database_dir, args.method,
                                       args.sub_method,
                                       '%s_score.mat' % test_name)
            sio.savemat(test_score_file, {
                'score': 1. - test_dist,
                'score_rerank': 1. - test_dist_rerank,
                'score_tlift': 1. - test_dist_tlift,
                'gal_time': pre_tlift_dict['gal_time'],
                'prob_time': pre_tlift_dict['prob_time'],
                'gal_list': test_gal_list,
                'prob_list': test_prob_list,
                'gal_ids': test_gal_ids,
                'prob_ids': test_prob_ids,
                'gal_cams': test_gal_cams,
                'prob_cams': test_prob_cams
            },
                        oned_as='column',
                        do_compression=True)

    test_time = time.time() - t0
    if not args.evaluate:
        print('Finished training at epoch %d, loss %.3f, acc %.2f%%.\n' %
              (epoch + 1, loss, acc * 100))
        print(
            "Total training time: %.3f sec. Average training time per epoch: %.3f sec."
            % (train_time, train_time / (args.epochs - start_epoch + 1)))
    print("Total testing time: %.3f sec.\n" % test_time)

    for arg in sys.argv:
        print('%s ' % arg, end='')
    print('\n')
コード例 #6
0
def main():
    train_data, test_data, word2index, labels = omniglot_character_folders()
    config = {
        "CLASS_NUM": 5,
        "SAMPLE_NUM_PER_CLASS": 5,
        "BATCH_NUM_PER_CLASS": 15,
        "EPISODE": 10000,  # 1000000
        "TEST_EPISODE": 100,  # 1000
        "LEARNING_RATE": 0.0001,  # 0.01
        "FEATURE_DIM": 256,  # lstm_hid_dim *2
        "RELATION_DIM": 8,
        "use_bert": False,
        "max_len": 12,
        "emb_dim": 300,
        "lstm_hid_dim": 128,
        "d_a": 64,
        "r": 1,
        "n_classes": 5,
        "num_layers": 1,
        "dropout": 0.1,
        "type": 1,
        "use_pretrained_embeddings": True,
        "word2index": word2index,
        "vocab_size": len(word2index)
    }
    feature_encoder = StructuredSelfAttention(config).to(device)
    relation_network = RelationNetwork(2 * config["FEATURE_DIM"],
                                       config["RELATION_DIM"]).to(device)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),
                                             lr=config["LEARNING_RATE"],
                                             weight_decay=1e-4)
    feature_encoder_scheduler = StepLR(feature_encoder_optim,
                                       step_size=100000,
                                       gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=config["LEARNING_RATE"],
                                              weight_decay=1e-4)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)
    print("开始训练")
    t0 = time()

    for episode in range(config["EPISODE"]):
        feature_encoder.train()
        relation_network.train()
        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        loss = train(feature_encoder, relation_network, train_data, config)

        feature_encoder_optim.step()
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss, "耗时", time() - t0)
            t0 = time()

        if (episode + 1) % (5 * config["TEST_EPISODE"]) == 0:
            test_accuracy = valid(feature_encoder, relation_network, test_data,
                                  config)
            t0 = time()
    print("直接词向量")
    print("完成")
コード例 #7
0
ファイル: train.py プロジェクト: yiyiyi0/face_detect
def train(args):
    start_epoch = 0
    data_loader = DataLoader(dataset=FaceDetectSet(416, True), batch_size=args.batch, shuffle=True, num_workers=16)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    model = MSSD()
    print("add graph")
    writer.add_graph(model, torch.zeros((1, 3, 416, 416)))
    print("add graph over")
    if args.pretrained and os.path.exists(MODEL_SAVE_PATH):
        print("loading ...")
        state = torch.load(MODEL_SAVE_PATH)
        model.load_state_dict(state['net'])
        start_epoch = state['epoch']
        print("loading over")
    model = torch.nn.DataParallel(model, device_ids=[0, 1])  # multi-GPU
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = StepLR(optimizer, step_size=args.step, gamma=args.gama)
    train_loss = 0
    loss_func = MLoss().to(device)
    to_pil_img = tfs.ToPILImage()
    to_tensor = tfs.ToTensor()

    for epoch in range(start_epoch, start_epoch+args.epoes):
        model.train()
        optimizer.zero_grad()
        pbar = tqdm(data_loader)
        for i_batch, (img_tensor, label_tensor) in enumerate(pbar):
            last_img_tensor = img_tensor
            last_label_tensor = label_tensor
            output = model(img_tensor.to(device))
            loss = loss_func(output, label_tensor.to(device))
            if loss is None:
                continue
            loss.backward()
            if i_batch % args.mini_batch == 0:
                optimizer.step()
                optimizer.zero_grad()

            train_loss = loss.item()
            global_step = epoch*len(data_loader)+i_batch
            pbar.set_description('loss: %f, epeche: %d' % (train_loss, epoch))
            writer.add_scalar("loss", train_loss, global_step=global_step)


        #save one pic and output
        pil_img = to_pil_img(last_img_tensor[0].cpu())
        bboxes = tensor2bbox(output[0], 416, [52, 26, 13], thresh=0.5)
        # bboxes = nms(bboxes, 0.6, 0.5)
        draw = ImageDraw.Draw(pil_img)
        for bbox in bboxes:
            draw.text((bbox[1] - bbox[3] / 2, bbox[2] - bbox[4] / 2 - 10), str(round(bbox[0].item(), 2)), fill=(255, 0, 0))
            draw.rectangle((bbox[1] - bbox[3] / 2, bbox[2] - bbox[4] / 2, bbox[1] + bbox[3] / 2, bbox[2] + bbox[4] / 2),
                           outline=(0, 255, 0))
            draw.rectangle((bbox[1] - bbox[3] / 2 + 1, bbox[2] - bbox[4] / 2 + 1, bbox[1] + bbox[3] / 2 - 1, bbox[2] + bbox[4] / 2 - 1),
                           outline=(0, 255, 0))
        writer.add_image("img: "+str(epoch), to_tensor(pil_img))
        scheduler.step()

        if epoch % 10 == 0:
            print('Saving..')
            state = {
                'net': model.module.state_dict(),
                'epoch': epoch,
            }
            torch.save(state, "./data/mssd_face_detect"+str(epoch)+".pt")

    if not os.path.isdir('data'):
        os.mkdir('data')
    print('Saving..')
    state = {
        'net': model.module.state_dict(),
        'epoch': epoch,
    }
    torch.save(state, MODEL_SAVE_PATH)
    writer.close()
コード例 #8
0
def main():  
    # Trainset stats: 2072002577 items from 124950714 sessions
    print('Initializing dataloader...')
    mtrain_loader = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True,
                                      data_sel=(0, 99965071), # 80% 트레인
                                      batch_size=TR_BATCH_SZ,
                                      shuffle=True,
                                      seq_mode=True) # seq_mode implemented  
    
    mval_loader  = SpotifyDataloader(config_fpath=args.config,
                                      mtrain_mode=True, # True, because we use part of trainset as testset
                                      data_sel=(99965071, 104965071),#(99965071, 124950714), # 20%를 테스트
                                      batch_size=TS_BATCH_SZ,
                                      shuffle=False,
                                      seq_mode=True) 
    
    # Init neural net
    SM = SeqModel().cuda(GPU)
    SM_optim = torch.optim.Adam(SM.parameters(), lr=LEARNING_RATE)
    SM_scheduler = StepLR(SM_optim, step_size=1, gamma=0.7)  
    
    LFM_model = MLP_Regressor().cuda(GPU)
    LFM_checkpoint = torch.load(LFM_CHECKPOINT_PATH, map_location='cuda:{}'.format(GPU))
    LFM_model.load_state_dict(LFM_checkpoint['model_state'])
    
    # Load checkpoint
    if args.load_continue_latest is None:
        START_EPOCH = 0        
    else:
        latest_fpath = max(glob.iglob(MODEL_SAVE_PATH + "check*.pth"),key=os.path.getctime)  
        checkpoint = torch.load(latest_fpath, map_location='cuda:{}'.format(GPU))
        tqdm.write("Loading saved model from '{0:}'... loss: {1:.6f}".format(latest_fpath,checkpoint['loss']))
        SM.load_state_dict(checkpoint['SM_state'])
        SM_optim.load_state_dict(checkpoint['SM_opt_state'])
        SM_scheduler.load_state_dict(checkpoint['SM_sch_state'])
        START_EPOCH = checkpoint['ep']
        
    # Train    
    for epoch in trange(START_EPOCH, EPOCHS, desc='epochs', position=0, ascii=True):
        tqdm.write('Train...')
        tr_sessions_iter = iter(mtrain_loader)
        total_corrects = 0
        total_query    = 0
        total_trloss   = 0
        for session in trange(len(tr_sessions_iter), desc='sessions', position=1, ascii=True):
            SM.train();
            x, labels, y_mask, num_items, index = tr_sessions_iter.next() # FIXED 13.Dec. SEPARATE LOGS. QUERY SHOULT NOT INCLUDE LOGS 
            
            # Sample data for 'support' and 'query': ex) 15 items = 7 sup, 8 queries...        
            num_support = num_items[:,0].detach().numpy().flatten() # If num_items was odd number, query has one more item. 
            num_query   = num_items[:,1].detach().numpy().flatten()
            batch_sz    = num_items.shape[0]
            
            # x: the first 10 items out of 20 are support items left-padded with zeros. The last 10 are queries right-padded.
            x[:,10:,:41] = 0 # DELETE METALOG QUE

            # labels_shift: (model can only observe past labels)
            labels_shift = torch.zeros(batch_sz,20,1)
            labels_shift[:,1:,0] = labels[:,:-1].float()
            #!!! NOLABEL for previous QUERY
            labels_shift[:,11:,0] = 0
            # support/query state labels
            sq_state = torch.zeros(batch_sz,20,1)
            sq_state[:,:11,0] = 1
            # compute lastfm_output
            x_audio = x[:,:,41:].data.clone()
            x_audio = Variable(x_audio, requires_grad=False).cuda(GPU)
            x_emb_lastfm, x_lastfm = LFM_model(x_audio)
            x_lastfm = x_lastfm.cpu()
            del x_emb_lastfm
        
            # Pack x: bx122*20
            x = Variable(torch.cat((x_lastfm, x, labels_shift, sq_state), dim=2).permute(0,2,1)).cuda(GPU)
             
  
            # Forward & update
            y_hat = SM(x) # y_hat: b*20
            # Calcultate BCE loss
            loss = F.binary_cross_entropy_with_logits(input=y_hat*y_mask.cuda(GPU), target=labels.cuda(GPU)*y_mask.cuda(GPU))
            total_trloss += loss.item()
            SM.zero_grad()
            loss.backward()
            # Gradient Clipping
            #torch.nn.utils.clip_grad_norm_(SM.parameters(), 0.5)
            SM_optim.step()
            
            # Decision
            y_prob = torch.sigmoid(y_hat*y_mask.cuda(GPU)).detach().cpu().numpy() # bx20               
            y_pred = (y_prob[:,10:]>=0.5).astype(np.int) # bx10
            y_numpy = labels[:,10:].numpy() # bx10
            # Acc
            y_query_mask = y_mask[:,10:].numpy()
            total_corrects += np.sum((y_pred==y_numpy)*y_query_mask)
            total_query += np.sum(num_query)
            # Restore GPU memory
            del loss, y_hat 
    
            if (session+1)%500 == 0:
                hist_trloss.append(total_trloss/900)
                hist_tracc.append(total_corrects/total_query)
                # Prepare display
                sample_sup = labels[0,:num_support[0]].long().numpy().flatten() 
                sample_que = y_numpy[0,:num_query[0]].astype(int)
                sample_pred = y_pred[0,:num_query[0]]
                sample_prob = y_prob[0,10:10+num_query[0]]
                tqdm.write("S:" + np.array2string(sample_sup) +'\n'+
                           "Q:" + np.array2string(sample_que) + '\n' +
                           "P:" + np.array2string(sample_pred) + '\n' +
                           "prob:" + np.array2string(sample_prob))
                tqdm.write("tr_session:{0:}  tr_loss:{1:.6f}  tr_acc:{2:.4f}".format(session, hist_trloss[-1], hist_tracc[-1]))
                total_corrects = 0
                total_query    = 0
                total_trloss   = 0
                
            
            if (session+1)%20000 == 0:
                 # Validation
                 validate(mval_loader, SM, LFM_model, eval_mode=True)
                 # Save
                 torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                             'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                             'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        # Validation
        validate(mval_loader, SM, LFM_model, eval_mode=True)
        # Save
        torch.save({'ep': epoch, 'sess':session, 'SM_state': SM.state_dict(),'loss': hist_trloss[-1], 'hist_vacc': hist_vacc,
                    'hist_vloss': hist_vloss, 'hist_trloss': hist_trloss, 'SM_opt_state': SM_optim.state_dict(),
                    'SM_sch_state': SM_scheduler.state_dict()}, MODEL_SAVE_PATH + "check_{0:}_{1:}.pth".format(epoch, session))
        SM_scheduler.step()
コード例 #9
0
def main(data_dir):
    print("开始训练时间:")
    start_time = time.strftime('%Y-%m-%d %H:%M:%S',
                               time.localtime(time.time()))
    print(start_time)
    args = get_args()

    if args.opts:
        cfg.merge_from_list(args.opts)

    cfg.freeze()
    start_epoch = 0
    # checkpoint_dir = Path(args.checkpoint)
    # checkpoint_dir.mkdir(parents=True, exist_ok=True)

    # create model_dir
    print("=> creating model_dir '{}'".format(cfg.MODEL.ARCH))
    # model_dir = get_model(model_name=cfg.MODEL.ARCH)
    model = my_model(True)

    if cfg.TRAIN.OPT == "sgd":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=cfg.TRAIN.LR,
                                    momentum=cfg.TRAIN.MOMENTUM,
                                    weight_decay=cfg.TRAIN.WEIGHT_DECAY)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=cfg.TRAIN.LR)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    # optionally resume from a checkpoint
    resume_path = args.resume

    if resume_path:
        print(Path(resume_path).is_file())
        if Path(resume_path).is_file():
            print("=> loading checkpoint '{}'".format(resume_path))
            checkpoint = torch.load(resume_path, map_location="cpu")
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                resume_path, checkpoint['epoch']))
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        else:
            print("=> no checkpoint found at '{}'".format(resume_path))

    if args.multi_gpu:
        model = nn.DataParallel(model)

    if device == "cuda":
        cudnn.benchmark = True

    # 损失计算准则
    criterion = nn.CrossEntropyLoss().to(device)
    train_dataset = FaceDataset_FGNET(data_dir,
                                      "train",
                                      img_size=cfg.MODEL.IMG_SIZE,
                                      augment=True,
                                      age_stddev=cfg.TRAIN.AGE_STDDEV)
    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.BATCH_SIZE,
                              shuffle=True,
                              num_workers=cfg.TRAIN.WORKERS,
                              drop_last=True)

    val_dataset = FaceDataset_FGNET(data_dir,
                                    "test",
                                    img_size=cfg.MODEL.IMG_SIZE,
                                    augment=False)
    val_loader = DataLoader(val_dataset,
                            batch_size=cfg.BATCH_SIZE,
                            shuffle=False,
                            num_workers=cfg.TRAIN.WORKERS,
                            drop_last=False)

    scheduler = StepLR(optimizer,
                       step_size=cfg.TRAIN.LR_DECAY_STEP,
                       gamma=cfg.TRAIN.LR_DECAY_RATE,
                       last_epoch=start_epoch - 1)
    best_val_mae = 10000.0
    train_writer = None
    val_mae_list = []

    if args.tensorboard is not None:
        opts_prefix = "_".join(args.opts)
        train_writer = SummaryWriter(log_dir=args.tensorboard + "/" +
                                     opts_prefix + "_train")
        val_writer = SummaryWriter(log_dir=args.tensorboard + "/" +
                                   opts_prefix + "_val")

    for epoch in range(start_epoch, cfg.TRAIN.EPOCHS):
        # train
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, device)

        # validate
        val_loss, val_acc, val_mae = validate(val_loader, model, criterion,
                                              epoch, device)
        val_mae_list.append(val_mae)

        if args.tensorboard is not None:
            train_writer.add_scalar("loss", train_loss, epoch)
            train_writer.add_scalar("acc", train_acc, epoch)
            val_writer.add_scalar("loss", val_loss, epoch)
            val_writer.add_scalar("acc", val_acc, epoch)
            val_writer.add_scalar("mae", val_mae, epoch)

        if val_mae < best_val_mae:
            print(
                f"=> [epoch {epoch:03d}] best val mae was improved from {best_val_mae:.3f} to {val_mae:.3f}"
            )
            best_val_mae = val_mae
            # checkpoint
            # if val_mae < 2.1:
            #     model_state_dict = model.module.state_dict() if args.multi_gpu else model.state_dict()
            #     torch.save(
            #         {
            #             'epoch': epoch + 1,
            #             'arch': cfg.MODEL.ARCH,
            #             'state_dict': model_state_dict,
            #             'optimizer_state_dict': optimizer.state_dict()
            #         },
            #         str(checkpoint_dir.joinpath("epoch{:03d}_{:.5f}_{:.4f}.pth".format(epoch, val_loss, val_mae)))
            #     )
        else:
            print(
                f"=> [epoch {epoch:03d}] best val mae was not improved from {best_val_mae:.3f} ({val_mae:.3f})"
            )

        # adjust learning rate
        scheduler.step()

    print("=> training finished")
    print(f"additional opts: {args.opts}")
    print(f"best val mae: {best_val_mae:.3f}")
    print("结束训练时间:")
    end_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
    print(end_time)
    print("训练耗时: " + smtp.date_gap(start_time, end_time))
    return best_val_mae
コード例 #10
0
def main():

    psnr_list = []
    ssim_list = []
    print("init data folders")
    encoder_lv1 = models.Encoder()
    encoder_lv2 = models.Encoder()    
    encoder_lv3 = models.Encoder()
    encoder_lv4 = models.Encoder()

    decoder_lv1 = models.Decoder()
    decoder_lv2 = models.Decoder()    
    decoder_lv3 = models.Decoder()
    decoder_lv4 = models.Decoder()
    
    encoder_lv1.apply(weight_init).cuda(GPU)    
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)
    encoder_lv4.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)    
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)
    decoder_lv4.apply(weight_init).cuda(GPU)
    
    encoder_lv1_optim = torch.optim.Adam(encoder_lv1.parameters(),lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim,step_size=1000,gamma=0.1)
    encoder_lv2_optim = torch.optim.Adam(encoder_lv2.parameters(),lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim,step_size=1000,gamma=0.1)
    encoder_lv3_optim = torch.optim.Adam(encoder_lv3.parameters(),lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim,step_size=1000,gamma=0.1)
    encoder_lv4_optim = torch.optim.Adam(encoder_lv4.parameters(),lr=LEARNING_RATE)
    encoder_lv4_scheduler = StepLR(encoder_lv4_optim,step_size=1000,gamma=0.1)

    decoder_lv1_optim = torch.optim.Adam(decoder_lv1.parameters(),lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim,step_size=1000,gamma=0.1)
    decoder_lv2_optim = torch.optim.Adam(decoder_lv2.parameters(),lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim,step_size=1000,gamma=0.1)
    decoder_lv3_optim = torch.optim.Adam(decoder_lv3.parameters(),lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim,step_size=1000,gamma=0.1)
    decoder_lv4_optim = torch.optim.Adam(decoder_lv4.parameters(),lr=LEARNING_RATE)
    decoder_lv4_scheduler = StepLR(decoder_lv4_optim,step_size=1000,gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")
    
    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)    
            
    for epoch in range(args.start_epoch, EPOCHS):
      
        
        print("Training..........")
        print("===========================")
        
        train_dataset = GoProDataset(
            blur_image_files = './datas/GoPro/train_blur_file.txt',
            sharp_image_files = './datas/GoPro/train_sharp_file.txt',
            root_dir = './datas/GoPro',
            crop = True,
            crop_size = IMAGE_SIZE,
            transform = transforms.Compose([
                transforms.ToTensor()
                ]))

        train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)
        start = 0
        
        for iteration, images in enumerate(train_dataloader):            
            mse = nn.MSELoss().cuda(GPU)  
            #smoothL1 = nn.SmoothL1Loss().cuda(GPU)          
            
            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            H = gt.size(2)          
            W = gt.size(3)
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)	# shape (4, 3, 256, 256) (batch, channel, w, h)
            images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
            images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
            images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
            images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
            images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
            images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]
            images_lv4_1 = images_lv3_1[:,:,0:int(H/4),:]
            images_lv4_2 = images_lv3_1[:,:,int(H/4):int(H/2),:]
            images_lv4_3 = images_lv3_2[:,:,0:int(H/4),:]
            images_lv4_4 = images_lv3_2[:,:,int(H/4):int(H/2),:]
            images_lv4_5 = images_lv3_3[:,:,0:int(H/4),:]
            images_lv4_6 = images_lv3_3[:,:,int(H/4):int(H/2),:]
            images_lv4_7 = images_lv3_4[:,:,0:int(H/4),:]
            images_lv4_8 = images_lv3_4[:,:,int(H/4):int(H/2),:]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)
            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)
            feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)
            feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2)
            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)
            loss = mse(deblur_image, gt)
            
            encoder_lv1.zero_grad()
            encoder_lv2.zero_grad()
            encoder_lv3.zero_grad()
            encoder_lv4.zero_grad()

            decoder_lv1.zero_grad()
            decoder_lv2.zero_grad()
            decoder_lv3.zero_grad()
            decoder_lv4.zero_grad()
            
            loss.backward()

            encoder_lv1_optim.step()
            encoder_lv2_optim.step()
            encoder_lv3_optim.step()
            encoder_lv4_optim.step()

            decoder_lv1_optim.step()
            decoder_lv2_optim.step()
            decoder_lv3_optim.step()
            decoder_lv4_optim.step()
            
            if (iteration+1)%50 == 0:
                stop = time.time()
                print("epoch:", epoch, "iteration:", iteration+1, "loss:%.4f"%loss.item(), 'time:%.4f'%(stop-start))
                start = time.time()

        encoder_lv1_scheduler.step(epoch)
        encoder_lv2_scheduler.step(epoch)
        encoder_lv3_scheduler.step(epoch)
        encoder_lv4_scheduler.step(epoch)

        decoder_lv1_scheduler.step(epoch)
        decoder_lv2_scheduler.step(epoch)
        decoder_lv3_scheduler.step(epoch)
        decoder_lv4_scheduler.step(epoch)

        if (epoch)%100==0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False:
            	os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch))
            
            print("Testing............")
            print("===========================")
            test_dataset = GoProDataset(
                blur_image_files = './datas/GoPro/test_blur_file.txt',
                sharp_image_files = './datas/GoPro/test_sharp_file.txt',
                root_dir = './datas/GoPro',
                transform = transforms.Compose([
                    transforms.ToTensor()
                ]))
            test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle=False)
            total_psnr = 0
            total_ssim = 0
            test_time = 0       
            for iteration, images in enumerate(test_dataloader):
                with torch.no_grad():
                    start = time.time()                 
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)  
                    H = images_lv1.size(2)
                    W = images_lv1.size(3)          
                    images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
                    images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
                    images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
                    images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
                    images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
                    images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]
                    images_lv4_1 = images_lv3_1[:,:,0:int(H/4),:]
                    images_lv4_2 = images_lv3_1[:,:,int(H/4):int(H/2),:]
                    images_lv4_3 = images_lv3_2[:,:,0:int(H/4),:]
                    images_lv4_4 = images_lv3_2[:,:,int(H/4):int(H/2),:]
                    images_lv4_5 = images_lv3_3[:,:,0:int(H/4),:]
                    images_lv4_6 = images_lv3_3[:,:,int(H/4):int(H/2),:]
                    images_lv4_7 = images_lv3_4[:,:,0:int(H/4),:]
                    images_lv4_8 = images_lv3_4[:,:,int(H/4):int(H/2),:]
                    
                    feature_lv4_1 = encoder_lv4(images_lv4_1)
                    feature_lv4_2 = encoder_lv4(images_lv4_2)
                    feature_lv4_3 = encoder_lv4(images_lv4_3)
                    feature_lv4_4 = encoder_lv4(images_lv4_4)
                    feature_lv4_5 = encoder_lv4(images_lv4_5)
                    feature_lv4_6 = encoder_lv4(images_lv4_6)
                    feature_lv4_7 = encoder_lv4(images_lv4_7)
                    feature_lv4_8 = encoder_lv4(images_lv4_8)
                    
                    feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
                    feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
                    feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
                    feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)
                    
                    feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
                    feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)
                    
                    residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
                    residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
                    residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
                    residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)
            
                    feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
                    feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
                    feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
                    feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
                    
                    feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
                    feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
                    
                    residual_lv3_top = decoder_lv3(feature_lv3_top)
                    residual_lv3_bot = decoder_lv3(feature_lv3_bot)
                
                    feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
                    feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
                    feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + torch.cat((feature_lv3_top, feature_lv3_bot), 2) + torch.cat((feature_lv4_top, feature_lv4_bot), 2)
                    residual_lv2 = decoder_lv2(feature_lv2)
            
                    feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
                    deblur_image = decoder_lv1(feature_lv1)
                    stop = time.time()
                    test_time += stop - start
                    
                    psnr = compare_psnr(images['sharp_image'].numpy()[0], deblur_image.detach().cpu().numpy()[0]+0.5)   
                    #psnr = PSNR(images['sharp_image'].numpy()[0], deblur_image.detach().cpu().numpy()[0]+0.5)        
                    total_psnr += psnr
                    
                    if (iteration+1)%50 == 0:
                        print('PSNR:%.4f'%(psnr), '  Average PSNR:%.4f'%(total_psnr/(iteration+1)))
                        
                    save_deblur_images(deblur_image.data + 0.5, iteration, epoch)
                    #save_sharp_images(images['sharp_image'][0], iteration, epoch)

            psnr_list.append(total_psnr/(iteration+1))
            print("PSNR list:")
            print(psnr_list)
            

        torch.save(encoder_lv1.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))
        torch.save(encoder_lv2.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))
        torch.save(encoder_lv3.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))
        torch.save(encoder_lv4.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))

        torch.save(decoder_lv1.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))
        torch.save(decoder_lv2.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))
        torch.save(decoder_lv3.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))
        torch.save(decoder_lv4.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))
def main_single_model():
    # Step 1: init data folders
    print("init data folders", flush=True)
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tgtpu.china_drinks_sku_folders(
        DATASET_FOLDER, SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS,
        VALIDATION_SPLIT_PERCENTAGE)

    # Step 2: init neural networks
    print("init neural networks")

    relation_network = CNN_Plus_RNEncoder()
    relation_network.apply(weights_init)
    relation_network.cuda(GPU)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(),
                                              lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim,
                                        step_size=100000,
                                        gamma=0.5)

    #if torch.cuda.device_count() > 1:
    #    print("Let's use", torch.cuda.device_count(), "GPUs!")
    #    relation_network = nn.DataParallel(relation_network)
    #else:
    #    relation_network.cuda(GPU)

    if os.path.exists(
            str("./models/omniglot_full_relation_network_" + str(CLASS_NUM) +
                "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        relation_network.load_state_dict(
            torch.load(str("./models/omniglot_full_relation_network_" +
                           str(CLASS_NUM) + "way_" +
                           str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"),
                       map_location='cuda:0'))
        print("load full relation network success")

    # Step 3: build graph
    print("Training...")

    last_accuracy = 0.0

    for episode in range(EPISODE):

        relation_network_scheduler.step(episode)

        # init dataset
        # sample_dataloader is to obtain previous samples for compare
        # batch_dataloader is to batch samples for training
        degrees = random.choice([0, 90, 180, 270])
        task = tgtpu.ChinaDrinksTask(metatrain_character_folders, CLASS_NUM,
                                     SAMPLE_NUM_PER_CLASS, QUERY_NUM_PER_CLASS)
        sample_batch_dataloader = tgtpu.get_data_loader(
            task,
            image_size=IMAGE_SIZE,
            sample_num_per_class=SAMPLE_NUM_PER_CLASS,
            query_num_per_class=QUERY_NUM_PER_CLASS,
            train_shuffle=False,
            query_shuffle=True,
            rotation=degrees,
            num_workers=NO_OF_TPU_CORES)

        # sample datas
        samples, sample_labels, batches, batch_labels = sample_batch_dataloader.__iter__(
        ).next()
        relation_scores = relation_network(
            Variable(samples).cuda(GPU),
            Variable(batches).cuda(GPU))
        relations = relation_scores.view(-1, CLASS_NUM)

        mse = nn.MSELoss().cuda(GPU)
        one_hot_labels = Variable(
            torch.zeros(QUERY_NUM_PER_CLASS * CLASS_NUM,
                        CLASS_NUM).scatter_(1, batch_labels.view(-1, 1),
                                            1)).cuda(GPU)
        loss = mse(relations, one_hot_labels)

        # training

        relation_network.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5)
        relation_network_optim.step()

        if (episode + 1) % 100 == 0:
            print("episode:", episode + 1, "loss", loss.data, flush=True)

        if (episode + 1) % 5000 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                degrees = random.choice([0, 90, 180, 270])
                task = tgtpu.ChinaDrinksTask(
                    metatest_character_folders,
                    CLASS_NUM,
                    SAMPLE_NUM_PER_CLASS,
                    SAMPLE_NUM_PER_CLASS,
                )
                sample_test_dataloader = tgtpu.get_data_loader(
                    task,
                    IMAGE_SIZE,
                    sample_num_per_class=SAMPLE_NUM_PER_CLASS,
                    query_num_per_class=QUERY_NUM_PER_CLASS,
                    train_shuffle=False,
                    test_shuffle=True,
                    rotation=degrees,
                    num_workers=NO_OF_TPU_CORES)

                sample_images, sample_labels, test_images, test_labels = sample_test_dataloader.__iter__(
                ).next()

                print(sample_images.size, test_images.size)

                test_labels = test_labels.cuda()

                relations = relation_network(
                    Variable(samples).cuda(GPU),
                    Variable(batches).cuda(GPU)).view(-1, CLASS_NUM)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [
                    1 if predict_labels[j] == test_labels[j] else 0
                    for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS)
                ]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE

            print("validation accuracy:", test_accuracy)

            if test_accuracy > last_accuracy:

                # save networks
                #torch.save(relation_network.state_dict(),str("./models/omniglot_full_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

                print("save networks for episode:", episode)

                last_accuracy = test_accuracy
コード例 #12
0
def main():

    n_way = 5
    k_shot = 1
    k_query = 5  #  5-way-5shot
    batchsz = 3
    best_acc = 0
    mdfile1 = './ckpy/res_feature-%d-way-%d-shot.pkl' % (n_way, k_shot)
    mdfile2 = './ckpy/res_relation-%d-way-%d-shot.pkl' % (n_way, k_shot)
    # feature_embed = CNNEncoder().cuda()
    # print(torch.cuda.is_available())
    feature_embed = resnet18().cuda()

    Relation_score = RelationNetWork(64, 8).cuda()  # relation_dim == 8 ??

    Relation_score.apply(weight_init)

    feature_optim = torch.optim.Adam(feature_embed.parameters(), lr=0.001)
    relation_opim = torch.optim.Adam(Relation_score.parameters(), lr=0.001)

    feature_optim_scheduler = StepLR(feature_optim, step_size=10,
                                     gamma=0.5)  # 1-shot 1w , 5-shot 5k
    relation_opim_scheduler = StepLR(relation_opim, step_size=10, gamma=0.5)

    loss_fn = torch.nn.MSELoss().cuda()

    if os.path.exists(mdfile1):
        print("load mdfile1...")
        feature_embed.load_state_dict(torch.load(mdfile1))
    if os.path.exists(mdfile2):
        print("load mdfile2...")
        Relation_score.load_state_dict(torch.load(mdfile2))

    for epoch in range(100):
        feature_optim_scheduler.step(epoch)  #  降低学习率
        relation_opim_scheduler.step(epoch)

        mini = MiniImagenet('./mini-imagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=6000,
                            resize=224)  #38400
        db = DataLoader(mini,
                        batch_size=batchsz,
                        shuffle=True,
                        num_workers=0,
                        pin_memory=False)  # 64 , 5*(1+15) , c, h, w
        mini_val = MiniImagenet('./mini-imagenet/',
                                mode='val',
                                n_way=n_way,
                                k_shot=k_shot,
                                k_query=k_query,
                                batchsz=200,
                                resize=224)  #9600
        db_val = DataLoader(mini_val,
                            batch_size=batchsz,
                            shuffle=True,
                            num_workers=0,
                            pin_memory=False)

        for step, batch in enumerate(db):
            support_x = Variable(batch[0]).cuda(
            )  # [batch_size, n_way*(k_shot+k_query), c , h , w]
            support_y = Variable(batch[1]).cuda()
            query_x = Variable(batch[2]).cuda()
            query_y = Variable(batch[3]).cuda()

            bh, set1, c, h, w = support_x.size()
            set2 = query_x.size(1)

            feature_embed.train()
            Relation_score.train()

            # support_xf = feature_embed(support_x.view(bh*set1,c,h,w)).view(bh,set1,64,19,19)                 # 在 test 的 时候 重复
            support_xf = feature_embed(support_x.view(
                bh * set1, c, h, w)).view(bh, set1, 256, 14, 14)
            # query_xf = feature_embed(query_x.view(bh*set2,c,h,w)).view(bh,set2,64,19,19)
            query_xf = feature_embed(query_x.view(bh * set2, c, h, w)).view(
                bh, set2, 256, 14, 14)

            # print("query_f:", query_xf.size())

            # support_xf = support_xf.unsqueeze(1).expand(bh,set2,set1,64,19,19)
            support_xf = support_xf.unsqueeze(1).expand(
                bh, set2, set1, 256, 14, 14)

            query_xf = query_xf.unsqueeze(2).expand(bh, set2, set1, 256, 14,
                                                    14)

            comb = torch.cat((support_xf, query_xf),
                             dim=3)  # bh,set2,set1,2c,h,w
            # print(comb.is_cuda)
            # print(comb.view(bh*set2*set1,2*64,19,19).is_cuda)
            # print(comb.size())
            score = Relation_score(comb.view(bh * set2 * set1, 2 * 256, 14,
                                             14)).view(bh, set2, set1,
                                                       1).squeeze(3)

            support_yf = support_y.unsqueeze(1).expand(bh, set2, set1)
            query_yf = query_y.unsqueeze(2).expand(bh, set2, set1)
            label = torch.eq(support_yf, query_yf).float()

            feature_optim.zero_grad()
            relation_opim.zero_grad()

            loss = loss_fn(score, label)
            loss.backward()

            torch.nn.utils.clip_grad_norm(feature_embed.parameters(),
                                          0.5)  # 梯度裁剪? 降低学习率?
            torch.nn.utils.clip_grad_norm(Relation_score.parameters(), 0.5)

            feature_optim.step()
            relation_opim.step()

            # if step%100==0:
            #     print("step:",epoch+1,"train_loss: ",loss.data[0])
            logger.log_value(
                'resnet_{}-way-{}-shot loss:'.format(n_way, k_shot),
                loss.data[0])

            if step % 200 == 0:
                print("---------test--------")

                total_correct = 0
                total_num = 0
                accuracy = 0
                for j, batch_test in enumerate(db_val):
                    # if (j%100==0):
                    #     print(j,'-------------')
                    support_x = Variable(batch_test[0]).cuda()
                    support_y = Variable(batch_test[1]).cuda()
                    query_x = Variable(batch_test[2]).cuda()
                    query_y = Variable(batch_test[3]).cuda()

                    bh, set1, c, h, w = support_x.size()
                    set2 = query_x.size(1)

                    feature_embed.eval()
                    Relation_score.eval()

                    support_xf = feature_embed(
                        support_x.view(bh * set1, c, h,
                                       w)).view(bh, set1, 256, 14,
                                                14)  # 在 test 的 时候 重复
                    query_xf = feature_embed(query_x.view(
                        bh * set2, c, h, w)).view(bh, set2, 256, 14, 14)

                    support_xf = support_xf.unsqueeze(1).expand(
                        bh, set2, set1, 256, 14, 14)
                    query_xf = query_xf.unsqueeze(2).expand(
                        bh, set2, set1, 256, 14, 14)

                    comb = torch.cat((support_xf, query_xf),
                                     dim=3)  # bh,set2,set1,2c,h,w
                    score = Relation_score(
                        comb.view(bh * set2 * set1, 2 * 256, 14,
                                  14)).view(bh, set2, set1, 1).squeeze(3)

                    rn_score_np = score.cpu().data.numpy()  # 转numpy cpu
                    pred = []
                    support_y_np = support_y.cpu().data.numpy()

                    for ii, tb in enumerate(rn_score_np):
                        for jj, tset in enumerate(tb):
                            sim = []
                            for way in range(n_way):
                                sim.append(
                                    np.sum(tset[way * k_shot:(way + 1) *
                                                k_shot]))

                            idx = np.array(sim).argmax()
                            pred.append(
                                support_y_np[ii, idx *
                                             k_shot])  # 同一个类标签相同 ,注意还有batch维度
                            # ×k_shot是因为,上一个步用sum将k_shot压缩了

                    #此时的pred.size = [b.set2]
                    #print("pred.size=", np.array(pred).shape)
                    pred = Variable(
                        torch.from_numpy(np.array(pred).reshape(bh,
                                                                set2))).cuda()
                    correct = torch.eq(pred, query_y).sum()

                    total_correct += correct.data[0]
                    total_num += query_y.size(0) * query_y.size(1)

                accuracy = total_correct / total_num
                logger.log_value('acc : ', accuracy)

                print("epoch:", epoch, "acc:", accuracy)
                if accuracy > best_acc:
                    print("-------------------epoch", epoch, "step:", step,
                          "acc:", accuracy,
                          "---------------------------------------")
                    best_acc = accuracy
                    torch.save(feature_embed.state_dict(), mdfile1)
                    torch.save(Relation_score.state_dict(), mdfile2)
            logger.step()
コード例 #13
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--dry-run',
                        action='store_true',
                        default=False,
                        help='quickly check a single pass')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1, 'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    scriptPath = os.path.dirname(os.path.realpath(__file__))
    dataDir = os.path.join(scriptPath, 'data')
    dataset1 = datasets.MNIST(dataDir,
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST(dataDir, train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        # Start profiling from 2nd epoch
        if epoch == 2:
            torch.cuda.cudart().cudaProfilerStart()

        nvtx.range_push("Epoch " + str(epoch))
        nvtx.range_push("Train")
        train(args, model, device, train_loader, optimizer, epoch)
        nvtx.range_pop()  # Train

        nvtx.range_push("Test")
        test(model, device, test_loader)
        nvtx.range_pop()  # Test

        scheduler.step()
        nvtx.range_pop()  # Epoch
        # Stop profiling at the end of 2nd epoch
        if epoch == 2:
            torch.cuda.cudart().cudaProfilerStop()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
コード例 #14
0
def main():
    # Training settings
    # Use the command line to modify the default settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for testing (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1.0,
                        metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument(
        '--step',
        type=int,
        default=1,
        metavar='N',
        help='number of epochs between learning rate reductions (default: 1)')
    parser.add_argument('--gamma',
                        type=float,
                        default=0.7,
                        metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=100,
        metavar='N',
        help='how many batches to wait before logging training status')

    parser.add_argument('--evaluate',
                        action='store_true',
                        default=False,
                        help='evaluate your model on the official test set')
    parser.add_argument('--load-model', type=str, help='model file path')

    parser.add_argument('--save-model',
                        action='store_true',
                        default=True,
                        help='For Saving the current Model')

    parser.add_argument('--test-datasize',
                        action='store_true',
                        default=False,
                        help='train on different sizes of dataset')

    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    torch.manual_seed(args.seed)

    # Evaluate on the official test set
    # if args.evaluate:
    #     assert os.path.exists(args.load_model)
    #
    #     # Set the test model
    #     model = Net().to(device)
    #     model = M.resnet18(num_classes=99).to(device)
    #     model.load_state_dict(torch.load(args.load_model))
    #
    #     test_dataset = datasets.MNIST('./data', train=False,
    #                 transform=transforms.Compose([
    #                     transforms.ToTensor(),
    #                     transforms.Normalize((0.1307,), (0.3081,))
    #                 ]))
    #
    #     test_loader = torch.utils.data.DataLoader(
    #         test_dataset, batch_size=args.test_batch_size, shuffle=True, **kwargs)
    #
    #     test(model, device, test_loader, analysis=True)
    #
    #     return

    # Pytorch has default MNIST dataloader which loads data at each iteration
    # train_dataset_no_aug = TrainDataset(True, 'data/imet-2020-fgvc7/labels.csv',
    #             'data/imet-2020-fgvc7/train_20country.csv', 'data/imet-2020-fgvc7/train/',
    #             transform=transforms.Compose([       # Data preprocessing
    #                 transforms.ToPILImage(),           # Add data augmentation here
    #                 transforms.RandomResizedCrop(128),
    #                 transforms.ToTensor(),
    #                 transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225))
    #             ]))
    train_dataset_no_aug = TrainDataset(
        True,
        'data/imet-2020-fgvc7/labels.csv',
        'data/imet-2020-fgvc7/train_20country.csv',
        'data/imet-2020-fgvc7/train/',
        transform=transforms.Compose([  # Data preprocessing
            transforms.ToPILImage(),  # Add data augmentation here
            transforms.Resize(255),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]))
    train_dataset_with_aug = train_dataset_no_aug
    assert (len(train_dataset_no_aug) == len(train_dataset_with_aug))

    # You can assign indices for training/validation or use a random subset for
    # training by using SubsetRandomSampler. Right now the train and validation
    # sets are built from the same indices - this is bad! Change it so that
    # the training and validation sets are disjoint and have the correct relative sizes.
    np.random.seed(args.seed)
    subset_indices_valid = np.random.choice(len(train_dataset_no_aug),
                                            int(0.15 *
                                                len(train_dataset_no_aug)),
                                            replace=False)
    subset_indices_train = [
        i for i in range(len(train_dataset_no_aug))
        if i not in subset_indices_valid
    ]
    # subset_indices_train = []
    # subset_indices_valid = []
    # for target in range(10):
    #     idx = (train_dataset_no_aug.targets == target).nonzero() # indices for each class
    #     idx = idx.numpy().flatten()
    #     val_idx = np.random.choice( len(idx), int(0.15*len(idx)), replace=False )
    #     val_idx = np.ndarray.tolist(val_idx.flatten())
    #     train_idx = [i for i in range(len(idx)) if i not in val_idx]
    #     subset_indices_train += np.ndarray.tolist(idx[train_idx])
    #     subset_indices_valid += np.ndarray.tolist(idx[val_idx])

    assert (len(subset_indices_train) +
            len(subset_indices_valid)) == len(train_dataset_no_aug)
    assert len(np.intersect1d(subset_indices_train, subset_indices_valid)) == 0

    train_loader = torch.utils.data.DataLoader(
        train_dataset_with_aug,
        batch_size=args.batch_size,
        sampler=SubsetRandomSampler(subset_indices_train))
    val_loader = torch.utils.data.DataLoader(
        train_dataset_no_aug,
        batch_size=args.test_batch_size,
        sampler=SubsetRandomSampler(subset_indices_valid))

    # Load your model [fcNet, ConvNet, Net]
    #model = Net().to(device)
    # model = M.resnet50(num_classes=20).to(device)
    # model.load_state_dict(torch.load(args.load_model))
    model = M.resnet50(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, 20)
    model = model.to(device)
    # model.load_state_dict(torch.load(args.load_model))
    # print(model)
    # summary(model, (1,28,28))

    # Try different optimzers here [Adam, SGD, RMSprop]
    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    # Set your learning rate scheduler
    scheduler = StepLR(optimizer, step_size=args.step, gamma=args.gamma)

    # if args.test_datasize:
    #     train_final_loss = []
    #     val_final_loss = []
    #     train_size = []
    #     for i in [1, 2, 4, 8, 16]:
    #         print("Dataset with size 1/{} of original: ".format(i))
    #         subset_indices_train_sub = np.random.choice(subset_indices_train, int(len(subset_indices_train)/i), replace=False)
    #         train_loader_sub = torch.utils.data.DataLoader(
    #             train_dataset_with_aug, batch_size=args.batch_size,
    #             sampler=SubsetRandomSampler(subset_indices_train_sub)
    #         )
    #         train_losses = []
    #         val_losses = []
    #         for epoch in range(1, args.epochs + 1):
    #             train_loss = train(args, model, device, train_loader_sub, optimizer, epoch)
    #             val_loss = validation(model, device, val_loader)
    #             train_losses.append(train_loss)
    #             val_losses.append(val_loss)
    #             scheduler.step()    # learning rate scheduler
    #             # You may optionally save your model at each epoch here
    #         print("Train Loss: ", train_losses)
    #         print("Test Loss: ", val_losses)
    #         print("\n")
    #         train_final_loss.append(train_losses[-1])
    #         val_final_loss.append(val_losses[-1])
    #         train_size.append(int(len(subset_indices_train)/i))
    #
    #     plt.loglog(range(1, args.epochs + 1), train_losses)
    #     plt.loglog(range(1, args.epochs + 1), val_losses)
    #     plt.xlabel("Number of training examples")
    #     plt.ylabel("Loss")
    #     plt.legend(["Training loss", "Val loss"])
    #     plt.title("Training loss and val loss as a function of the number of training examples on log-log scale")
    #     plt.show()
    #     return

    # Training loop
    train_losses = []
    val_losses = []
    accuracies = []
    for epoch in range(1, args.epochs + 1):
        train_loss = train(args, model, device, train_loader, optimizer, epoch)
        (accuracy, val_loss) = validation(model, device, val_loader)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        accuracies.append(accuracy)
        scheduler.step()  # learning rate scheduler
        # You may optionally save your model at each epoch here
        if args.save_model:
            torch.save(model.state_dict(), "mnist_model.pt")

    plt.plot(range(1, args.epochs + 1), train_losses)
    plt.plot(range(1, args.epochs + 1), val_losses)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(["Training loss", "Val loss"])
    plt.title("Training loss and val loss as a function of the epoch")
    plt.show()

    plt.plot(range(1, args.epochs + 1), accuracies)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend(["Validation Accuracy"])
    plt.title("Accuracy in validation set as a function of the epoch")
    plt.show()
コード例 #15
0
class DACAgent(BaseAgent):
    def __init__(self, config, ob_space, ac_space, env_ob_space):
        super().__init__(config, ob_space)

        self._ob_space = ob_space
        self._ac_space = ac_space

        if self._config.dac_rl_algo == "td3":
            self._rl_agent = DDPGAgent(config, ob_space, ac_space,
                                       env_ob_space)
        elif self._config.dac_rl_algo == "sac":
            self._rl_agent = SACAgent(config, ob_space, ac_space, env_ob_space)
        self._rl_agent.set_reward_function(self._predict_reward)

        # build up networks
        self._discriminator = Discriminator(
            config, ob_space, ac_space if not config.gail_no_action else None)
        self._discriminator_loss = nn.BCEWithLogitsLoss()

        # build optimizers
        self._discriminator_optim = optim.Adam(
            self._discriminator.parameters(), lr=config.discriminator_lr)

        # build learning rate scheduler
        self._discriminator_lr_scheduler = StepLR(
            self._discriminator_optim,
            step_size=self._config.max_global_step // 5,
            gamma=0.5,
        )

        # expert dataset
        self._dataset = ExpertDataset(config.demo_path,
                                      config.demo_subsample_interval)
        self._data_loader = torch.utils.data.DataLoader(
            self._dataset, batch_size=self._config.batch_size, shuffle=True)
        self._data_iter = iter(self._data_loader)

        # per-episode replay buffer
        # sampler = RandomSampler(image_crop_size=config.encoder_image_size)
        # buffer_keys = ["ob", "ac", "done", "rew"]
        # self._buffer = ReplayBuffer(
        #     buffer_keys, config.buffer_size, sampler.sample_func
        # )

        # per-step replay buffer
        shapes = {
            "ob": spaces_to_shapes(env_ob_space),
            "ob_next": spaces_to_shapes(env_ob_space),
            "ac": spaces_to_shapes(ac_space),
            "done": [1],
            "done_mask": [1],
            "rew": [1],
        }
        self._buffer = ReplayBufferPerStep(
            shapes,
            config.buffer_size,
            config.encoder_image_size,
            config.absorbing_state,
        )

        self._rl_agent.set_buffer(self._buffer)

        self._update_iter = 0

        self._log_creation()

    def _predict_reward(self, ob, ac):
        if self._config.gail_no_action:
            ac = None
        with torch.no_grad():
            ret = self._discriminator(ob, ac)
            eps = 1e-20
            s = torch.sigmoid(ret)
            if self._config.gail_vanilla_reward:
                reward = -(1 - s + eps).log()
            else:
                reward = (s + eps).log() - (1 - s + eps).log()
        return reward

    def predict_reward(self, ob, ac=None):
        ob = self.normalize(ob)
        ob = to_tensor(ob, self._config.device)
        if self._config.gail_no_action:
            ac = None
        if ac is not None:
            ac = to_tensor(ac, self._config.device)

        reward = self._predict_reward(ob, ac)
        return reward.cpu().item()

    def _log_creation(self):
        if self._config.is_chef:
            logger.info("Creating a DAC agent")
            logger.info(
                "The discriminator has %d parameters",
                count_parameters(self._discriminator),
            )

    def store_episode(self, rollouts):
        self._buffer.store_episode(rollouts)

    def state_dict(self):
        return {
            "rl_agent":
            self._rl_agent.state_dict(),
            "discriminator_state_dict":
            self._discriminator.state_dict(),
            "discriminator_optim_state_dict":
            self._discriminator_optim.state_dict(),
            "ob_norm_state_dict":
            self._ob_norm.state_dict(),
        }

    def load_state_dict(self, ckpt):
        if "rl_agent" in ckpt:
            self._rl_agent.load_state_dict(ckpt["rl_agent"])
        else:
            self._rl_agent.load_state_dict(ckpt)

        self._discriminator.load_state_dict(ckpt["discriminator_state_dict"])
        self._ob_norm.load_state_dict(ckpt["ob_norm_state_dict"])
        self._network_cuda(self._config.device)

        self._discriminator_optim.load_state_dict(
            ckpt["discriminator_optim_state_dict"])
        optimizer_cuda(self._discriminator_optim, self._config.device)

    def _network_cuda(self, device):
        self._discriminator.to(device)

    def sync_networks(self):
        self._rl_agent.sync_networks()
        sync_networks(self._discriminator)

    def train(self):
        train_info = Info()

        self._discriminator_lr_scheduler.step()

        _train_info = self._rl_agent.train()
        train_info.add(_train_info)

        if self._update_iter % self._config.discriminator_update_freq == 0:
            self._num_updates = 1
            for _ in range(self._num_updates):
                policy_data = self._buffer.sample(self._config.batch_size)
                try:
                    expert_data = next(self._data_iter)
                except StopIteration:
                    self._data_iter = iter(self._data_loader)
                    expert_data = next(self._data_iter)
                _train_info = self._update_discriminator(
                    policy_data, expert_data)
                train_info.add(_train_info)

        return train_info.get_dict(only_scalar=True)

    def _update_discriminator(self, policy_data, expert_data):
        info = Info()

        _to_tensor = lambda x: to_tensor(x, self._config.device)
        # pre-process observations
        p_o = policy_data["ob"]
        p_o = self.normalize(p_o)

        p_bs = len(policy_data["ac"])
        p_o = _to_tensor(p_o)
        if self._config.gail_no_action:
            p_ac = None
        else:
            p_ac = _to_tensor(policy_data["ac"])

        e_o = expert_data["ob"]
        e_o = self.normalize(e_o)

        e_bs = len(expert_data["ac"])
        e_o = _to_tensor(e_o)
        if self._config.gail_no_action:
            e_ac = None
        else:
            e_ac = _to_tensor(expert_data["ac"])

        p_logit = self._discriminator(p_o, p_ac)
        e_logit = self._discriminator(e_o, e_ac)

        p_output = torch.sigmoid(p_logit)
        e_output = torch.sigmoid(e_logit)

        p_loss = self._discriminator_loss(
            p_logit,
            torch.zeros_like(p_logit).to(self._config.device))
        e_loss = self._discriminator_loss(
            e_logit,
            torch.ones_like(e_logit).to(self._config.device))

        logits = torch.cat([p_logit, e_logit], dim=0)
        entropy = torch.distributions.Bernoulli(logits).entropy().mean()
        entropy_loss = -self._config.gail_entropy_loss_coeff * entropy

        gail_loss = p_loss + e_loss + entropy_loss

        # update the discriminator
        self._discriminator.zero_grad()
        gail_loss.backward()
        sync_grads(self._discriminator)
        self._discriminator_optim.step()

        info["gail_policy_output"] = p_output.mean().detach().cpu().item()
        info["gail_expert_output"] = e_output.mean().detach().cpu().item()
        info["gail_entropy"] = entropy.detach().cpu().item()
        info["gail_policy_loss"] = p_loss.detach().cpu().item()
        info["gail_expert_loss"] = e_loss.detach().cpu().item()
        info["gail_entropy_loss"] = entropy_loss.detach().cpu().item()

        return mpi_average(info.get_dict(only_scalar=True))
コード例 #16
0
ファイル: main.py プロジェクト: iii45/OOD_Federated_Learning
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=14, metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.99, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--dataset', type=str, default='MNIST',
                        help='dataset to use during the training process')
    parser.add_argument('--model', type=str, default='LeNet',
                        help='model to use during the training process')
    parser.add_argument('--device', type=str, default='cuda',
                        help='device to set, can take the value of: cuda or cuda:x')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=20, metavar='N',
                        help='how many batches to wait before logging training status')

    #parser.add_argument('--save-model', action='store_true', default=False,
    #                    help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device(args.device if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # prepare dataset
    if args.dataset == "EMNIST":
        train_dataset = datasets.EMNIST('./data', split="digits", train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))
        test_dataset = datasets.EMNIST('./data', split="digits", train=False, transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,))
                           ]))
    elif args.dataset == "Cifar10":
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        train_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=True, download=True, transform=transform_train)

        test_dataset = torchvision.datasets.CIFAR10(
            root='./data', train=False, download=True, transform=transform_test)


    train_loader = torch.utils.data.DataLoader(train_dataset,
        batch_size=args.batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_dataset
        ,
        batch_size=args.test_batch_size, shuffle=True, **kwargs)

    if args.model == "LeNet":
        model = Net(num_classes=10).to(device)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    elif args.model in ("vgg9", "vgg11", "vgg13", "vgg16"):
        model = get_vgg_model(args.model).to(device)
        #model = VGG(args.model.upper()).to(device)
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4)
        scheduler = MultiStepLR(optimizer, milestones=[e for e in [151, 251]], gamma=0.1)

    criterion = nn.CrossEntropyLoss()

    
    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, criterion, epoch)
        test(args, model, device, test_loader, criterion)

        for param_group in optimizer.param_groups:
            logger.info(param_group['lr'])
        scheduler.step()

    if epoch % 5 == 0:
        torch.save(model.state_dict(), "./checkpoint/{}_{}_{}epoch.pt".format(args.dataset, args.model.upper(), args.epochs))
コード例 #17
0
def train_and_evaluate_kd(model,
                          teacher_model,
                          train_dataloader,
                          val_dataloader,
                          optimizer,
                          loss_fn_kd,
                          metrics,
                          params,
                          model_dir,
                          restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) - file to restore (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir,
                                    args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    # Tensorboard logger setup
    # board_logger = utils.Board_Logger(os.path.join(model_dir, 'board_logs'))

    # learning rate schedulers for different models:
    if params.model_version == "resnet18_distill":
        scheduler = StepLR(optimizer, step_size=150, gamma=0.1)
    # for cnn models, num_epoch is always < 100, so it's intentionally not using scheduler here
    elif params.model_version == "cnn_distill":
        scheduler = StepLR(optimizer, step_size=100, gamma=0.2)

    for epoch in range(params.num_epochs):

        scheduler.step()

        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train_kd(model, teacher_model, optimizer, loss_fn_kd, train_dataloader,
                 metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate_kd(model, val_dataloader, metrics, params)

        val_acc = val_metrics['accuracy']
        is_best = val_acc >= best_val_acc

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir,
                                          "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir,
                                      "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
コード例 #18
0
ファイル: train.py プロジェクト: SoonbeomChoi/BEGANSing
def main():
    config = Config()
    config_basename = os.path.basename(config.config[0])
    print("Configuration file: \'%s\'" % (config_basename))

    checkpoint_path = create_path(config.checkpoint_path,
                                  action=config.checkpoint_path_action)
    config.save(os.path.join(checkpoint_path, config_basename))
    logger = Logger(os.path.join(checkpoint_path, 'log'))

    dataloader = dataprocess.load_train(config)
    step_size = config.step_epoch * len(dataloader.train)

    G = Generator(config)
    D = Discriminator(config)
    G, D = set_device((G, D), config.device, config.use_cpu)

    criterionL1 = nn.L1Loss()
    optimizerG = torch.optim.Adam(G.parameters(),
                                  lr=config.learn_rate,
                                  betas=config.betas,
                                  weight_decay=config.weight_decay)
    optimizerD = torch.optim.Adam(D.parameters(),
                                  lr=config.learn_rate,
                                  betas=config.betas,
                                  weight_decay=config.weight_decay)
    schedulerG = StepLR(optimizerG,
                        step_size=step_size,
                        gamma=config.decay_factor)
    schedulerD = StepLR(optimizerD,
                        step_size=step_size,
                        gamma=config.decay_factor)

    k = 0.0
    M = AverageMeter()
    lossG_train = AverageMeter()
    lossG_valid = AverageMeter()
    lossD_train = AverageMeter()

    print('Training start')
    for epoch in range(config.stop_epoch + 1):
        # Training Loop
        G.train()
        D.train()
        for batch in tqdm(dataloader.train, leave=False, ascii=True):
            x, y_prev, y = set_device(batch, config.device, config.use_cpu)
            y = y.unsqueeze(1)

            optimizerG.zero_grad()
            y_gen = G(x, y_prev)
            lossL1 = criterionL1(y_gen, y)
            loss_advG = criterionAdv(D, y_gen)
            lossG = lossL1 + loss_advG
            lossG.backward()
            optimizerG.step()
            schedulerG.step()

            optimizerD.zero_grad()
            loss_real = criterionAdv(D, y)
            loss_fake = criterionAdv(D, y_gen.detach())
            loss_advD = loss_real - k * loss_fake
            loss_advD.backward()
            optimizerD.step()
            schedulerD.step()

            diff = torch.mean(config.gamma * loss_real - loss_fake)
            k = k + config.lambda_k * diff.item()
            k = min(max(k, 0), 1)

            measure = (loss_real + torch.abs(diff)).data
            M.step(measure, y.size(0))

            logger.log_train(lossL1, loss_advG, lossG, loss_real, loss_fake,
                             loss_advD, M.avg, k, lossG_train.steps)
            lossG_train.step(lossG.item(), y.size(0))
            lossD_train.step(loss_advD.item(), y.size(0))

        # Validation Loop
        G.eval()
        D.eval()
        for batch in tqdm(dataloader.valid, leave=False, ascii=True):
            x, y_prev, y = set_device(batch, config.device, config.use_cpu)
            y = y.unsqueeze(1)

            y_gen = G(x, y_prev)
            lossL1 = criterionL1(y_gen, y)
            loss_advG = criterionAdv(D, y_gen)
            lossG = lossL1 + loss_advG

            logger.log_valid(lossL1, loss_advG, lossG, lossG_valid.steps)
            lossG_valid.step(lossG.item(), y.size(0))

        for param_group in optimizerG.param_groups:
            learn_rate = param_group['lr']

        print(
            "[Epoch %d/%d] [loss G train: %.5f] [loss G valid: %.5f] [loss D train: %.5f] [lr: %.6f]"
            % (epoch, config.stop_epoch, lossG_train.avg, lossG_valid.avg,
               lossD_train.avg, learn_rate))

        lossG_train.reset()
        lossG_valid.reset()
        lossD_train.reset()

        savename = os.path.join(checkpoint_path, 'latest_')
        save_checkpoint(savename + 'G.pt', G, optimizerG, learn_rate,
                        lossG_train.steps)
        save_checkpoint(savename + 'D.pt', D, optimizerD, learn_rate,
                        lossD_train.steps)
        if epoch % config.save_epoch == 0:
            savename = os.path.join(checkpoint_path,
                                    'epoch' + str(epoch) + '_')
            save_checkpoint(savename + 'G.pt', G, optimizerG, learn_rate,
                            lossG_train.steps)
            save_checkpoint(savename + 'D.pt', D, optimizerD, learn_rate,
                            lossD_train.steps)

    print('Training finished')
コード例 #19
0
def main():    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    
    
    # * Step 1: init data folders
    print("init data folders")
    
    # * Init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.mini_imagenet_folders()
    
    # * Step 2: init neural networks
    print("init neural networks")
    
    feature_encoder = models.CNNEncoder()    
    model = models.ActorCritic(FEATURE_DIM, RELATION_DIM, CLASS_NUM)

    #feature_encoder = torch.nn.DataParallel(feature_encoder)
    #actor = torch.nn.DataParallel(actor)
    #critic = torch.nn.DataParallel(critic)
    
    feature_encoder.train()
    model.train()
    
    feature_encoder.apply(models.weights_init)
    model.apply(models.weights_init)
    
    feature_encoder.to(device)
    model.to(device)

    cross_entropy = nn.CrossEntropyLoss()
        
    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=10000, gamma=0.5)
    
    model_optim = torch.optim.Adam(model.parameters(), lr=2.5 * LEARNING_RATE)
    model_scheduler = StepLR(model_optim, step_size=10000, gamma=0.5)
    
    agent = ppoAgent.PPOAgent(GAMMA, ENTROPY_WEIGHT, CLASS_NUM, device)
    
    if os.path.exists(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        feature_encoder.load_state_dict(torch.load(str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load feature encoder success")            
        
    if os.path.exists(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")):
        model.load_state_dict(torch.load(str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")))
        print("load model network success")
        
    # * Step 3: build graph
    print("Training...")
    loss_list = []
    last_accuracy = 0.0    
    number_of_query_image = 15
    clip_param = 0.1
    for episode in range(EPISODE):
        if clip_param > 0 and clip_param % CLIP_DECREASE == 0:
            clip_param *= 0.5
            
        #print(f"EPISODE : {episode}")
        losses = []        
        for meta_batch in range(META_BATCH_RANGE):
            meta_env_states_list = []
            meta_env_labels_list = []
            model_fast_weight = OrderedDict(model.named_parameters())
            for inner_batch in range(INNER_BATCH_RANGE):
                # * Generate environment
                env_states_list = []
                env_labels_list = []
                inner_loss_list = []
                for _ in range(ENV_LENGTH):
                    task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                    sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)                
                    batch_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=5, split="test", shuffle=True)    
                    
                    samples, sample_labels = next(iter(sample_dataloader))
                    samples, sample_labels = samples.to(device), sample_labels.to(device)
                    for batches, batch_labels in batch_dataloader:
                        batches, batch_labels = batches.to(device), batch_labels.to(device)
                        
                        inner_sample_features = feature_encoder(samples)            
                        inner_sample_features = inner_sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                        inner_sample_features = torch.sum(inner_sample_features, 1).squeeze(1)
                        
                        inner_batch_features = feature_encoder(batches)
                        inner_sample_feature_ext = inner_sample_features.unsqueeze(0).repeat(5 * CLASS_NUM, 1, 1, 1, 1)
                        inner_batch_features_ext = inner_batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)      
                        inner_batch_features_ext = torch.transpose(inner_batch_features_ext, 0, 1)
                        
                        inner_relation_pairs = torch.cat((inner_sample_feature_ext, inner_batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)
                        env_states_list.append(inner_relation_pairs)
                        env_labels_list.append(batch_labels)
                
                inner_env = ppoAgent.env(env_states_list, env_labels_list)
                agent.train(inner_env, model, loss_list=inner_loss_list)
                inner_loss = torch.stack(inner_loss_list).mean()
                inner_gradients = torch.autograd.grad(inner_loss.mean(), model_fast_weight.values(), create_graph=True, allow_unused=True)
    
                model_fast_weight = OrderedDict(
                    (name, param - INNER_LR * (0 if grad is None else grad))                    
                    for ((name, param), grad) in zip(model_fast_weight.items(), inner_gradients)                    
                )
            
            model.weight = model_fast_weight
            # * Generate env for meta update
            for _ in range(META_ENV_LENGTH):
                # * init dataset
                # * sample_dataloader is to obtain previous samples for compare
                # * batch_dataloader is to batch samples for training
                task = tg.MiniImagenetTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)               
                batch_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True)
                # * num_per_class : number of query images
                
                # * sample datas
                samples, sample_labels = next(iter(sample_dataloader))
                batches, batch_labels = next(iter(batch_dataloader))
                
                samples, sample_labels = samples.to(device), sample_labels.to(device)
                batches, batch_labels = batches.to(device), batch_labels.to(device)
                                
                # * calculates features
                #feature_encoder.weight = feature_fast_weights
                
                sample_features = feature_encoder(samples)
                sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                batch_features = feature_encoder(batches)
                
                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100 * 128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = batch_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
                batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
                relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)   
                
                meta_env_states_list.append(relation_pairs)
                meta_env_labels_list.append(batch_labels)
            
            meta_env = ppoAgent.env(meta_env_states_list, meta_env_labels_list)
            agent.train(meta_env, model, loss_list=losses, clip_param=clip_param)
            
        feature_encoder_optim.zero_grad()
        model_optim.zero_grad()     
        
        torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5)

        meta_batch_loss = torch.stack(losses).mean()
        meta_batch_loss.backward()
                
        feature_encoder_optim.step()
        model_optim.step()

        feature_encoder_scheduler.step()
        model_scheduler.step()
        
        mean_loss = None
        if (episode + 1) % 100 == 0:
            mean_loss = meta_batch_loss.cpu().detach().numpy()
            print(f"episode : {episode+1}, meta_loss : {mean_loss:.4f}")
            loss_list.append(mean_loss)
            
        if (episode + 1) % 500 == 0:
            print("Testing...")
            total_reward = 0
            
            total_test_samples = 0            
            for i in range(TEST_EPISODE):
                # * Generate env
                env_states_list = []
                env_labels_list = []
                number_of_query_image = 10
                task = tg.MiniImagenetTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, number_of_query_image)
                sample_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False)                
                test_dataloader = tg.get_mini_imagenet_data_loader(task, num_per_class=number_of_query_image, split="test", shuffle=True)
                
                sample_images, sample_labels = next(iter(sample_dataloader))
                test_images, test_labels = next(iter(test_dataloader))

                total_test_samples += len(test_labels)

                sample_images, sample_labels = sample_images.to(device), sample_labels.to(device)
                test_images, test_labels = test_images.to(device), test_labels.to(device)
                    
                # * calculate features
                sample_features = feature_encoder(sample_images)
                sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 19, 19)
                sample_features = torch.sum(sample_features, 1).squeeze(1)
                test_features = feature_encoder(test_images)
                
                # * calculate relations
                # * each batch sample link to every samples to calculate relations
                # * to form a 100x128 matrix for relation network
                
                sample_features_ext = sample_features.unsqueeze(0).repeat(number_of_query_image * CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = test_features.unsqueeze(0).repeat(CLASS_NUM, 1, 1, 1, 1)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)

                relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 19, 19)
                env_states_list.append(relation_pairs)
                env_labels_list.append(test_labels)
                    
                test_env = ppoAgent.env(env_states_list, env_labels_list)
                rewards = agent.test(test_env, model)
                total_reward += rewards 
                
            test_accuracy = total_reward / (1.0 * total_test_samples)

            print(f'mean loss : {mean_loss}')   
            print("test accuracy : ", test_accuracy)
            
            writer.add_scalar('1.loss', mean_loss, episode + 1)      
            writer.add_scalar('4.test accuracy', test_accuracy, episode + 1)
            
            loss_list = []   
            
            if test_accuracy > last_accuracy:
                # save networks
                torch.save(
                    feature_encoder.state_dict(),
                    str("./models/miniimagenet_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )
                torch.save(
                    model.state_dict(),
                    str("./models/miniimagenet_actor_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")
                )
                
                print("save networks for episode:", episode)
                last_accuracy = test_accuracy    
コード例 #20
0
def train(env_name,
          arch,
          timesteps=1,
          init_timesteps=0,
          seed=42,
          er_capacity=1,
          epsilon_start=1.0,
          epsilon_stop=0.05,
          epsilon_decay_stop=1,
          batch_size=16,
          target_sync=16,
          lr=1e-3,
          gamma=1.0,
          dueling=False,
          play_steps=1,
          lr_steps=1e4,
          lr_gamma=0.99,
          save_steps=5e4,
          logger=None,
          experiment_name='test'):
    """
        Main training function. Calls the subprocesses to get experience and
        train the network.
    """

    # Casting params which are expressable in scientific notation
    def int_scientific(x):
        return int(float(x))

    timesteps, init_timesteps = map(int_scientific,
                                    [timesteps, init_timesteps])
    lr_steps, epsilon_decay_stop = map(int_scientific,
                                       [lr_steps, epsilon_decay_stop])
    er_capacity, target_sync, save_steps = map(
        int_scientific, [er_capacity, target_sync, save_steps])
    lr = float(lr)

    # Multiprocessing method
    mp.set_start_method('spawn')

    # Get PyTorch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the Q network
    _env = make_env(env_name, seed)
    net = QNetwork(_env.observation_space,
                   _env.action_space,
                   arch=arch,
                   dueling=dueling).to(device)
    # Create the target network as a copy of the Q network
    tgt_net = ptan.agent.TargetNet(net)
    # Create buffer and optimizer
    buffer = ptan.experience.ExperienceReplayBuffer(experience_source=None,
                                                    buffer_size=er_capacity)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=lr_steps, gamma=0.99)

    # Multiprocessing queue
    epsilon_schedule = (epsilon_start, epsilon_stop, epsilon_decay_stop)
    exp_queue = mp.Queue(maxsize=play_steps * 2)
    play_proc = mp.Process(target=play_func,
                           args=(env_name, net, exp_queue, seed, timesteps,
                                 epsilon_schedule, gamma))
    play_proc.start()

    # Main training loop
    timestep = 0
    while play_proc.is_alive() and timestep < timesteps:
        timestep += play_steps
        # Query the environments and log results if the episode has ended
        for _ in range(play_steps):
            exp, info = exp_queue.get()
            if exp is None:
                play_proc.join()
                break
            buffer._add(exp)
            logger.log_kv('internals/epsilon', info['epsilon'][0],
                          info['epsilon'][1])
            if 'ep_reward' in info.keys():
                logger.log_kv('performance/return', info['ep_reward'],
                              timestep)
                logger.log_kv('performance/length', info['ep_length'],
                              timestep)
                logger.log_kv('performance/speed', info['speed'], timestep)

        # Check if we are in the starting phase
        if len(buffer) < init_timesteps:
            continue

        scheduler.step()
        logger.log_kv('internals/lr', scheduler.get_lr()[0], timestep)
        # Get a batch from experience replay
        optimizer.zero_grad()
        batch = buffer.sample(batch_size * play_steps)
        # Unpack the batch
        states, actions, rewards, dones, next_states = unpack_batch(batch)
        states_v = torch.tensor(states).to(device)
        next_states_v = torch.tensor(next_states).to(device)
        actions_v = torch.tensor(actions).to(device)
        rewards_v = torch.tensor(rewards).to(device)
        done_mask = torch.ByteTensor(dones).to(device)
        # Optimize defining the loss function
        state_action_values = net(states_v).gather(
            1, actions_v.unsqueeze(-1)).squeeze(-1)
        next_state_values = tgt_net.target_model(next_states_v).max(1)[0]
        next_state_values[done_mask] = 0.0
        expected_state_action_values = next_state_values.detach(
        ) * gamma + rewards_v
        loss = F.mse_loss(state_action_values, expected_state_action_values)
        logger.log_kv('internals/loss', loss.item(), timestep)
        loss.backward()
        # Clip the gradients to avoid to abrupt changes (this is equivalent to Huber Loss)
        for param in net.parameters():
            param.grad.data.clamp_(-1, 1)
        optimizer.step()

        # Check if the target network need to be synched
        if timestep % target_sync == 0:
            tgt_net.sync()

        # Check if we need to save a checkpoint
        if timestep % save_steps == 0:
            torch.save(net.get_extended_state(), experiment_name + '.pth')
コード例 #21
0
def train(data_path,
          epoch,
          batch_size,
          hidden_size,
          embedding_dim,
          testing=False):
    print('Loading data...')
    train, valid, test = load_data(data_path, valid_portion=0.1)
    mrr_list, recall_list, loss_list = [], [], []

    train_data = RecSysDataset(train)
    valid_data = RecSysDataset(valid)
    test_data = RecSysDataset(test)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=collate_fn)
    valid_loader = DataLoader(valid_data,
                              batch_size=batch_size,
                              shuffle=False,
                              collate_fn=collate_fn)
    test_loader = DataLoader(test_data,
                             batch_size=batch_size,
                             shuffle=False,
                             collate_fn=collate_fn)

    n_items = 37484

    model = NARM(n_items,
                 hidden_size=hidden_size,
                 embedding_dim=embedding_dim,
                 batch_size=batch_size).to(device)

    if testing:
        ckpt = torch.load('latest_checkpoint.pth.tar')
        model.load_state_dict(ckpt['state_dict'])
        model.eval()
        recall, mrr = validate(test_loader, model)
        print("Test: Recall@{}: {:.4f}, MRR@{}: {:.4f}".format(
            topk, recall, topk, mrr))
        return

    optimizer = optim.Adam(model.parameters(), 1e-3)
    criterion = nn.CrossEntropyLoss()
    scheduler = StepLR(optimizer, step_size=80, gamma=0.1)

    for e in tqdm(range(epoch)):
        # train for one epoch
        scheduler.step(epoch=e)
        sum_loss = trainForEpoch(train_loader, model, optimizer, e, epoch,
                                 criterion)

        print('[TRAIN] epoch %d/%d avg loss %.4f' %
              (epoch + 1, epoch, sum_loss / len(train_loader.dataset)))

        recall, mrr = validate(valid_loader, model)
        recall_list.append(recall)
        mrr_list.append(mrr)
        print(
            'Epoch {} validation: Recall@{}: {:.4f}, MRR@{}: {:.4f} \n'.format(
                e, topk, recall, topk, mrr))

        # store best loss and save a model checkpoint
        ckpt_dict = {
            'epoch': e + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }

        torch.save(ckpt_dict, here + f'/checkpoint_{e}.pth.tar')
    return mrr_list, recall_list, loss_list
コード例 #22
0
ファイル: mnist_test_pipe.py プロジェクト: zzszmyf/fairscale
def main():
    # Training settings
    parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
    parser.add_argument("--batch-size",
                        type=int,
                        default=64,
                        metavar="N",
                        help="input batch size for training (default: 64)")
    parser.add_argument("--test-batch-size",
                        type=int,
                        default=1000,
                        metavar="N",
                        help="input batch size for testing (default: 1000)")
    parser.add_argument("--epochs",
                        type=int,
                        default=14,
                        metavar="N",
                        help="number of epochs to train (default: 14)")
    parser.add_argument("--lr",
                        type=float,
                        default=1.0,
                        metavar="LR",
                        help="learning rate (default: 1.0)")
    parser.add_argument("--gamma",
                        type=float,
                        default=0.7,
                        metavar="M",
                        help="Learning rate step gamma (default: 0.7)")
    parser.add_argument("--dry-run",
                        action="store_true",
                        default=False,
                        help="quickly check a single pass")
    parser.add_argument("--seed",
                        type=int,
                        default=1,
                        metavar="S",
                        help="random seed (default: 1)")
    parser.add_argument(
        "--log-interval",
        type=int,
        default=10,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument("--save-model",
                        action="store_true",
                        default=False,
                        help="For Saving the current Model")
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    kwargs = {"batch_size": args.batch_size}
    kwargs.update({"num_workers": 1, "pin_memory": True, "shuffle": True}, )

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])
    dataset1 = datasets.MNIST("../data",
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST("../data", train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    model = net
    model = Pipe(model, balance=[6, 6], devices=[0, 1], chunks=2)
    device = model.devices[0]

    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
    for epoch in range(1, args.epochs + 1):
        tic = time.perf_counter()
        train(args, model, device, train_loader, optimizer, epoch)
        toc = time.perf_counter()
        print(f">>> TRANING Time {toc - tic:0.4f} seconds")

        tic = time.perf_counter()
        test(model, device, test_loader)
        toc = time.perf_counter()
        print(f">>> TESTING Time {toc - tic:0.4f} seconds")
        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")
コード例 #23
0
def trainNet(model,
             train_loader,
             val_loader,
             device,
             static_map,
             start_epoch=0,
             globaliter_=0):
    # Print all of the hyper parameters of the training iteration:
    print("===== HYPERPARAMETERS =====")
    print("batch_size=", config['dataloader']['batch_size'])
    print("epochs=", config['num_epochs'])
    print('starting from epoch %i' % start_epoch)
    print("learning_rate=", config['optimizer']['lr'])
    # print("network_depth=", config['model']['depth'])
    print("=" * 30)

    # define the optimizer & learning rate
    optim = torch.optim.SGD(model.parameters(), **config['optimizer'])

    scheduler = StepLR(optim,
                       step_size=config['lr_step_size'],
                       gamma=config['lr_gamma'])

    if config['cont_model_path'] is not None:
        log_dir = config['cont_model_path']
    else:
        log_dir = 'runs/Unet-' + datetime.now().strftime("%Y-%m-%d-%H-%M-%S-") + \
                  '-'.join(config['dataset']['cities'])
    writer = Visualizer(log_dir)

    # dump config file
    with open(os.path.join(log_dir, 'config.json'), 'w') as fp:
        json.dump(config, fp)

    # Time for printing
    training_start_time = time.time()
    globaliter = globaliter_

    # initialize the early_stopping object
    early_stopping = EarlyStopping(log_dir,
                                   patience=config['patience'],
                                   verbose=True)

    # Loop for n_epochs
    for epoch_idx, epoch in enumerate(range(start_epoch,
                                            config['num_epochs'])):
        writer.write_lr(optim, epoch)

        # train for one epoch
        globaliter = train(model, train_loader, static_map, optim, device,
                           writer, epoch, globaliter)

        # At the end of the epoch, do a pass on the validation set
        val_loss = validate(model, val_loader, static_map, device, writer,
                            globaliter)

        # At the end of the epoch, do a pass on the validation set only considering the test times
        # val_loss_testtimes = validate(model, val_loader_ttimes, device, writer, globaliter, if_testtimes=True)

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        early_stopping(val_loss, model, epoch + 1, globaliter)

        if early_stopping.early_stop:
            print("Early stopping")
            break

        if config['debug'] and epoch_idx >= 0:
            break

        scheduler.step(epoch)

    print("Training finished, took {:.2f}s".format(time.time() -
                                                   training_start_time))

    # remember to close tensorboard writer
    writer.close()
コード例 #24
0
def main():
    # Command line arguments for hyperparameters of model/training.
    parser = argparse.ArgumentParser(description='PyTorch Object Detection')
    parser.add_argument('--batch-size',
                        type=int,
                        default=8,
                        metavar='N',
                        help='input batch size for training (default: 8)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=50,
                        metavar='N',
                        help='number of epochs to train (default: 50)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    args = parser.parse_args()
    # Command to use gpu depending on command line arguments and if there is a cuda device
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    # Random seed to use
    torch.manual_seed(args.seed)

    # Set to either use gpu or cpu
    device = torch.device("cuda" if use_cuda else "cpu")

    # GPU keywords.
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    # Generate our labels for the training and testing data from the original labels
    generate_labels()

    # Load in the training and testing datasets for the x values. Convert to pytorch tensor.
    train_data = DetectionImages(csv_file="../data/labels/train_labels.txt",
                                 root_dir="../data/train",
                                 transform=ToTensor())
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0,
                              drop_last=True)
    test_data = DetectionImages(
        csv_file="../data/labels/validation_labels.txt",
        root_dir="../data/validation",
        transform=ToTensor())
    test_loader = DataLoader(test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=0)

    # Create model for x prediction
    model = Net().to(device)

    # Store the lowest test loss found with random search for both x and y models
    lowest_loss = 1000
    # Store the learning curve from lowest test loss for x and y models
    lowest_test_list = []
    lowest_train_list = []

    # Randomly search over 30 different learning rate and gamma value combinations
    for i in range(30):
        # Boolean value for if this model is has the lowest validation loss of any so far
        best_model = False
        # Get random learning rate
        lr = random.uniform(0.0008, 0.002)
        # Get random gamma
        gamma = random.uniform(0.7, 1)
        # Print out the current learning rate and gamma value
        print("##################################################")
        print("Learning Rate: ", lr)
        print("Gamma: ", gamma)
        print("##################################################")

        # Specify Adam optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Store the training and testing losses over time
        train_losses = []
        test_losses = []
        # Create scheduler.
        scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

        # Train the model for the set number of epochs
        for epoch in range(1, args.epochs + 1):
            # Train and validate for this epoch
            train_losses = train(args, model, device, train_loader, optimizer,
                                 epoch, train_losses)
            test_losses, output_x, output_y = test(args, model, device,
                                                   test_loader, test_losses)
            scheduler.step()

            # If this is the lowest validation loss so far, save model and the training curve. This allows
            # us to recover a model for early stopping
            if lowest_loss > test_losses[epoch - 1]:
                # Print out the current loss and the predictions
                print("New Lowest Loss: ", test_losses[epoch - 1])
                print("Validation X Predictions: ")
                print(output_x)
                print("Validation Y Predictions: ")
                print(output_y)
                # Print out the euclidean distances by converting labels to floating
                # point values corresponding to the center of the window
                print_euclidean_distance(output_x, output_y)
                # Save the model
                torch.save(model.state_dict(), MODEL_NAME)
                # Update the lowest loss so far and the learning curve for lowest loss
                lowest_loss = test_losses[epoch - 1]
                lowest_test_list = test_losses
                lowest_train_list = train_losses
                # Set that this is best model
                best_model = True

        # Save the learning curve if this is best x model
        if best_model:
            # Create plot
            figure, axes = plt.subplots()
            # Set axes labels and title
            axes.set(xlabel="Epoch", ylabel="Loss", title="Learning Curve")
            # Plot the learning curves for training and validation loss
            axes.plot(np.array(lowest_train_list), label="train_loss", c="b")
            axes.plot(np.array(lowest_test_list),
                      label="validation_loss",
                      c="r")
            plt.legend()
            # Save the figure
            plt.savefig('curve18.png')
            plt.close()

    # After Random Search is finished:
    # Display the learning curves for the best x result from random search
    figure, axes = plt.subplots()
    axes.set(xlabel="Epoch", ylabel="Loss", title="Learning Curve")
    axes.plot(np.array(lowest_train_list), label="train_loss", c="b")
    axes.plot(np.array(lowest_test_list), label="validation_loss", c="r")
    plt.legend()
    plt.show()
    plt.close()
コード例 #25
0
def main():
    print("init data folders")
    encoder_lv1 = models.Encoder()
    encoder_lv2 = models.Encoder()    
    encoder_lv3 = models.Encoder()
    encoder_lv4 = models.Encoder()

    decoder_lv1 = models.Decoder()
    decoder_lv2 = models.Decoder()    
    decoder_lv3 = models.Decoder()
    decoder_lv4 = models.Decoder()
    
    encoder_lv1.apply(weight_init).cuda(GPU)    
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)
    encoder_lv4.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)    
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)
    decoder_lv4.apply(weight_init).cuda(GPU)

    encoder_lv1_optim = RAdam(encoder_lv1.parameters(),lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim,step_size=1000,gamma=0.1)
    encoder_lv2_optim = RAdam(encoder_lv2.parameters(),lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim,step_size=1000,gamma=0.1)
    encoder_lv3_optim = RAdam(encoder_lv3.parameters(),lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim,step_size=1000,gamma=0.1)
    encoder_lv4_optim = RAdam(encoder_lv4.parameters(),lr=LEARNING_RATE)
    encoder_lv4_scheduler = StepLR(encoder_lv4_optim,step_size=1000,gamma=0.1)

    decoder_lv1_optim = RAdam(decoder_lv1.parameters(),lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim,step_size=1000,gamma=0.1)
    decoder_lv2_optim = RAdam(decoder_lv2.parameters(),lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim,step_size=1000,gamma=0.1)
    decoder_lv3_optim = RAdam(decoder_lv3.parameters(),lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim,step_size=1000,gamma=0.1)
    decoder_lv4_optim = RAdam(decoder_lv4.parameters(),lr=LEARNING_RATE)
    decoder_lv4_scheduler = StepLR(decoder_lv4_optim,step_size=1000,gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    # for param in decoder_lv4.layer24.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv3.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv3.layer1.bias.requires_grad)
    # for param in decoder_lv3.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv2.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv2.layer1.bias.requires_grad)
    # for param in decoder_lv2.parameters():
    #     param.requires_grad = False

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")
    
    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)


    for epoch in range(args.start_epoch, EPOCHS):
        epoch += 1
        encoder_lv1_scheduler.step(epoch)
        encoder_lv2_scheduler.step(epoch)
        encoder_lv3_scheduler.step(epoch)
        encoder_lv4_scheduler.step(epoch)

        decoder_lv1_scheduler.step(epoch)
        decoder_lv2_scheduler.step(epoch)
        decoder_lv3_scheduler.step(epoch)
        decoder_lv4_scheduler.step(epoch)      
        
        print("Training...")
        print('lr:',encoder_lv1_scheduler.get_lr())
        
        train_dataset = GoProDataset(
            blur_image_files = './datas/GoPro/train_blur_file.txt',
            sharp_image_files = './datas/GoPro/train_sharp_file.txt',
            root_dir = './datas/GoPro',
            crop = True,
            crop_size = IMAGE_SIZE,
            transform = transforms.Compose([
                transforms.ToTensor()
                ]))

        train_dataloader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True,num_workers=8,pin_memory=True)
        start = 0
        
        for iteration, images in enumerate(train_dataloader):            
            mse = nn.MSELoss().cuda(GPU)            
            
            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            H = gt.size(2)          
            W = gt.size(3)
            
            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
            images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
            images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
            images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
            images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
            images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]
            images_lv4_1 = images_lv3_1[:,:,0:int(H/4),:]
            images_lv4_2 = images_lv3_1[:,:,int(H/4):int(H/2),:]
            images_lv4_3 = images_lv3_2[:,:,0:int(H/4),:]
            images_lv4_4 = images_lv3_2[:,:,int(H/4):int(H/2),:]
            images_lv4_5 = images_lv3_3[:,:,0:int(H/4),:]
            images_lv4_6 = images_lv3_3[:,:,int(H/4):int(H/2),:]
            images_lv4_7 = images_lv3_4[:,:,0:int(H/4),:]
            images_lv4_8 = images_lv3_4[:,:,int(H/4):int(H/2),:]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)
            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)
            feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)
            feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2)
            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
            feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)

            loss = mse(deblur_image, gt)
            
            encoder_lv1.zero_grad()
            encoder_lv2.zero_grad()
            encoder_lv3.zero_grad()
            encoder_lv4.zero_grad()

            decoder_lv1.zero_grad()
            decoder_lv2.zero_grad()
            decoder_lv3.zero_grad()
            decoder_lv4.zero_grad()
            
            loss.backward()

            encoder_lv1_optim.step()
            encoder_lv2_optim.step()
            encoder_lv3_optim.step()
            encoder_lv4_optim.step()

            decoder_lv1_optim.step()
            decoder_lv2_optim.step()
            decoder_lv3_optim.step()
            decoder_lv4_optim.step()
            
            if (iteration+1)%10 == 0:
                stop = time.time()
                print("epoch:", epoch, "iteration:", iteration+1, "loss:%.4f"%loss.item(), 'time:%.4f'%(stop-start))
                start = time.time()
                
        if (epoch)%100==0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' + str(epoch)) == False:
            	os.system('mkdir ./checkpoints/' + METHOD + '/epoch' + str(epoch))

            print("Testing...")
            test_dataset = GoProDataset(
                blur_image_files = './datas/GoPro/test_blur_file.txt',
                sharp_image_files = './datas/GoPro/test_sharp_file.txt',
                root_dir = './datas/GoPro',
                transform = transforms.Compose([
                    transforms.ToTensor()
                ]))
            test_dataloader = DataLoader(test_dataset, batch_size = 1, shuffle=False,num_workers=8,pin_memory=True)
            test_time = 0       
            for iteration, images in enumerate(test_dataloader):
                with torch.no_grad():
                    start = time.time()                 
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)  
                    H = images_lv1.size(2)
                    W = images_lv1.size(3)          
                    images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
                    images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
                    images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
                    images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
                    images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
                    images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]
                    images_lv4_1 = images_lv3_1[:,:,0:int(H/4),:]
                    images_lv4_2 = images_lv3_1[:,:,int(H/4):int(H/2),:]
                    images_lv4_3 = images_lv3_2[:,:,0:int(H/4),:]
                    images_lv4_4 = images_lv3_2[:,:,int(H/4):int(H/2),:]
                    images_lv4_5 = images_lv3_3[:,:,0:int(H/4),:]
                    images_lv4_6 = images_lv3_3[:,:,int(H/4):int(H/2),:]
                    images_lv4_7 = images_lv3_4[:,:,0:int(H/4),:]
                    images_lv4_8 = images_lv3_4[:,:,int(H/4):int(H/2),:]
                    
                    feature_lv4_1 = encoder_lv4(images_lv4_1)
                    feature_lv4_2 = encoder_lv4(images_lv4_2)
                    feature_lv4_3 = encoder_lv4(images_lv4_3)
                    feature_lv4_4 = encoder_lv4(images_lv4_4)
                    feature_lv4_5 = encoder_lv4(images_lv4_5)
                    feature_lv4_6 = encoder_lv4(images_lv4_6)
                    feature_lv4_7 = encoder_lv4(images_lv4_7)
                    feature_lv4_8 = encoder_lv4(images_lv4_8)
                    
                    feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
                    feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4), 2)
                    feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
                    feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8), 2)
                    
                    feature_lv4_top = torch.cat((feature_lv4_top_left, feature_lv4_top_right), 3)
                    feature_lv4_bot = torch.cat((feature_lv4_bot_left, feature_lv4_bot_right), 3)
                    
                    residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
                    residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
                    residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
                    residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)
            
                    feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
                    feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
                    feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
                    feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
                    
                    feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
                    feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
                    residual_lv3_top = decoder_lv3(feature_lv3_top)
                    residual_lv3_bot = decoder_lv3(feature_lv3_bot)
                
                    feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
                    feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
                    feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + torch.cat((feature_lv3_top, feature_lv3_bot), 2)
                    residual_lv2 = decoder_lv2(feature_lv2)
            
                    feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
                    deblur_image = decoder_lv1(feature_lv1)
                    stop = time.time()
                    test_time += stop - start
                    print('RunTime:%.4f'%(stop-start), '  Average Runtime:%.4f'%(test_time/(iteration+1)))
                    save_deblur_images(deblur_image.data + 0.5, iteration, epoch)
                    #
                    torch.save(encoder_lv1.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv1.pkl"))
                    torch.save(encoder_lv2.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv2.pkl"))
                    torch.save(encoder_lv3.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv3.pkl"))
                    torch.save(encoder_lv4.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/encoder_lv4.pkl"))
                    torch.save(decoder_lv1.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv1.pkl"))
                    torch.save(decoder_lv2.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv2.pkl"))
                    torch.save(decoder_lv3.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv3.pkl"))
                    torch.save(decoder_lv4.state_dict(),
                               str('./checkpoints/' + METHOD + '/epoch' + str(epoch) + "/decoder_lv4.pkl"))
                
        torch.save(encoder_lv1.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))
        torch.save(encoder_lv2.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))
        torch.save(encoder_lv3.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))
        torch.save(encoder_lv4.state_dict(),str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))
        torch.save(decoder_lv1.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))
        torch.save(decoder_lv2.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))
        torch.save(decoder_lv3.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))
        torch.save(decoder_lv4.state_dict(),str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))
コード例 #26
0
def main():
    write_results(str(args))
    # Step 1: init data folders
    print("init data folders")
    # init character folders for dataset construction
    metatrain_character_folders, metatest_character_folders = tg.acw_folders(TEST_PERSON)

    # Step 2: init neural networks
    print("init neural networks")

    feature_encoder = CNNEncoder()
    relation_network = RelationNetwork()

    feature_encoder.apply(weights_init)
    relation_network.apply(weights_init)

    feature_encoder.to(device)
    relation_network.to(device)

    feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE)
    feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5)
    relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE)
    relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5)

    print("Training...")

    test_accuracies = []

    for episode in range(EPISODE):

        feature_encoder_scheduler.step(episode)
        relation_network_scheduler.step(episode)

        task = tg.ACWTask(metatrain_character_folders, CLASS_NUM, SPT_NUM_PER_CLASS)
        sample_dataloader = tg.get_data_loader(task, num_per_class=SPT_NUM_PER_CLASS, split="train", shuffle=False)
        batch_dataloader = tg.get_data_loader(task, num_per_class=task.per_class_num, split="test", shuffle=True)

        # sample datas
        samples, sample_labels = sample_dataloader.__iter__().next()
        # print('samples.shape:', samples.shape)
        batches, batch_labels = batch_dataloader.__iter__().next()
        # print('batches.shape:', batches.shape)

        # calculate features
        sample_features = feature_encoder(Variable(samples).to(device))
        # print('sample_features.shape:', sample_features.shape)
        batch_features = feature_encoder(Variable(batches).to(device))
        # print('batch_features.shape:', batch_features.shape)

        # calculate relations
        # each batch sample link to every samples to calculate relations
        # to form a 100x128 matrix for relation network
        sample_features_ext = sample_features.unsqueeze(0).repeat(task.per_class_num * CLASS_NUM, 1, 1, 1, 1)
        # print('sample_features_ext.shape:', sample_features_ext.shape)
        batch_features_ext = batch_features.unsqueeze(0).repeat(SPT_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
        # print('batch_features_ext.shape:', batch_features_ext.shape)
        batch_features_ext = torch.transpose(batch_features_ext, 0, 1)
        # print('batch_features_ext.shape:', batch_features_ext.shape)

        relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, 64 * 2, 1, 89)
        # print('relation_pairs.shape:', relation_pairs.shape)
        relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
        # print('relations.shape:', relations.shape)

        mse = nn.MSELoss().to(device)
        one_hot_labels = Variable(
            torch.zeros(task.per_class_num * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)).to(device)
        loss = mse(relations, one_hot_labels)

        # training

        feature_encoder.zero_grad()
        relation_network.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5)
        torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5)

        feature_encoder_optim.step()
        relation_network_optim.step()

        print("episode:", episode + 1, "loss", loss.item())

        if (episode + 1) % 5 == 0:

            # test
            print("Testing...")
            total_rewards = 0

            for i in range(TEST_EPISODE):
                task = tg.ACWTask(metatest_character_folders, CLASS_NUM, SPT_NUM_PER_CLASS)
                sample_dataloader = tg.get_data_loader(task, num_per_class=SPT_NUM_PER_CLASS, split="train",
                                                       shuffle=False)
                test_dataloader = tg.get_data_loader(task, num_per_class=task.per_class_num, split="test",
                                                     shuffle=True)

                sample_images, sample_labels = sample_dataloader.__iter__().next()
                test_images, test_labels = test_dataloader.__iter__().next()

                # calculate features
                sample_features = feature_encoder(Variable(sample_images).to(device))
                # print('sample_features.shape:', sample_features.shape)
                test_features = feature_encoder(Variable(test_images).to(device))
                # print('test_features.shape:', test_features.shape)# 20x64

                # calculate relations
                # each batch sample link to every samples to calculate relations
                # to form a 100x128 matrix for relation network
                sample_features_ext = sample_features.unsqueeze(0).repeat(task.per_class_num * CLASS_NUM, 1, 1, 1, 1)
                # print('sample_features_ext.shape:', sample_features_ext.shape)
                test_features_ext = test_features.unsqueeze(0).repeat(SPT_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1)
                # print('test_features_ext.shape:', test_features_ext.shape)
                test_features_ext = torch.transpose(test_features_ext, 0, 1)
                # print('test_features_ext.shape:', test_features_ext.shape)

                relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, 64 * 2, 1, 89)
                # print('relation_pairs.shape:', relation_pairs.shape)
                relations = relation_network(relation_pairs).view(-1, CLASS_NUM)
                # print('relations.shape:', relations.shape)

                _, predict_labels = torch.max(relations.data, 1)

                rewards = [1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM)]

                total_rewards += np.sum(rewards)

            test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE
            test_accuracies.append(str(test_accuracy))
            print("test accuracy:", test_accuracy)
    write_results(','.join(test_accuracies))
コード例 #27
0
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='Deep neural network for Super-resolution of multitemporal '
                                                 'Remote Sensing Images')
    parser.add_argument('--batch-size', type=int, default=16, metavar='N',
                        help='input batch size for training (default: 4)')
    parser.add_argument('--test-batch-size', type=int, default=16, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train (default: 80)')
    parser.add_argument('--lr', type=float, default=1e-4, metavar='LR',
                        help='learning rate (default: 1.0)')
    parser.add_argument('--gamma', type=float, default=0.5, metavar='M',
                        help='Learning rate step gamma (default: 0.7)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=1, metavar='N',
                        help='how many batches to wait before logging training status')

    parser.add_argument('--save-model', action='store_true', default=True,
                        help='For Saving the current Model')
    
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    device = torch.device("cuda")
    
    kwargs = {'num_workers': 1, 'pin_memory': True}

    train_dataset = SR_Dataset_4(csv_file='PATH',
                              root_dir='PATH', transform=ToTensor(), stand=True, norm=False)

    test_dataset = SR_Dataset_4(csv_file='PATH',
                              root_dir='PATH', transform=ToTensor(), stand=True, norm=False)

    validation_dataset = SR_Dataset_4(csv_file='PATH',
                              root_dir='PATH', transform=ToTensor(), stand=True, norm=False)


    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=args.test_batch_size, shuffle=True, drop_last=True, **kwargs)

    validation_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=args.test_batch_size, shuffle=True, drop_last=True, **kwargs)
    
    model = SR().to(device)
    
    model = model.type(dst_type=torch.float32)
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    
    scheduler = StepLR(optimizer, step_size=20, gamma=args.gamma)
    
    
    for epoch in range(1, args.epochs + 1):
        
        train(args, model, device, train_loader, optimizer, epoch)
        
        if epoch % 5 == 0:
            validation(args, model, device, validation_loader, epoch)
            
        if epoch == 100:
            test(args, model, device, test_loader, epoch)

        scheduler.step()
        
    tb.close()
コード例 #28
0
path, bbxs = arrange_data(anno_path)
wider = WIDER(opt.train_path,
              path,
              bbxs,
              high_resolution=(opt.hr_height, opt.hr_width))
dataloader = DataLoader(wider,
                        batch_size=opt.batch_size,
                        shuffle=True,
                        num_workers=8)

# ----------
#  Training
# ----------

for epoch in range(opt.epoch, opt.n_epochs):
    scheduler_G.step()
    scheduler_D.step()
    loss_D_ = 0
    loss_G_ = 0
    for i, imgs in enumerate(dataloader):

        # Configure model input
        lr_face = imgs["lr_face"].to(device)
        hr_face = imgs["hr_face"].to(device)
        lr_background = imgs["lr_background"].to(device)
        hr_background = imgs["hr_background"].to(device)

        # ------------------
        #  Train Generators
        # ------------------
コード例 #29
0
def main():
    parser = argparse.ArgumentParser()
    arg = parser.add_argument

    arg('clf_gt', help='segmentation predictions')
    # Dataset params
    arg('--test-height', type=int, default=2528)
    arg('--crop-height', type=int, default=768)
    arg('--crop-width', type=int, default=512)
    arg('--scale-aug', type=float, default=0.3)
    arg('--color-hue-aug', type=int, default=7)
    arg('--color-sat-aug', type=int, default=30)
    arg('--color-val-aug', type=int, default=30)
    arg('--n-tta', type=int, default=1)
    arg('--pseudolabels', nargs='+',
        help='path to pseudolabels to be added to train')
    arg('--pseudolabels-oversample', type=int, default=1)
    arg('--test-book', help='use only this book for testing and pseudolabels')
    arg('--fold', type=int, default=0)
    arg('--n-folds', type=int, default=5)
    arg('--train-limit', type=int)
    arg('--test-limit', type=int)
    # Model params
    arg('--base', default='resnet50')
    arg('--use-sequences', type=int, default=0)
    arg('--head-dropout', type=float, default=0.5)
    arg('--frozen-start', type=int)
    arg('--head', type=str, default='Head')
    # Training params
    arg('--device', default='cuda', help='device')
    arg('--opt-level', help='pass 01 to use fp16 training with apex')
    arg('--benchmark', type=int)
    arg('--batch-size', default=10, type=int)
    arg('--max-targets', type=int)
    arg('--workers', default=8, type=int,
        help='number of data loading workers')
    arg('--lr', default=14e-3, type=float, help='initial learning rate')
    arg('--wd', default=1e-4, type=float, help='weight decay')
    arg('--optimizer', default='sgd')
    arg('--accumulation-steps', type=int, default=1)
    arg('--epochs', default=50, type=int, help='number of total epochs to run')
    arg('--repeat-train', type=int, default=6)
    arg('--drop-lr-epoch', default=0, type=int,
        help='epoch at which to drop lr')
    arg('--cosine', type=int, default=1, help='cosine lr schedule')
    # Misc. params
    arg('--output-dir', help='path where to save')
    arg('--resume', help='resume from checkpoint')
    arg('--test-only', help='Only test the model', action='store_true')
    arg('--submission', help='Create submission', action='store_true')
    arg('--detailed-postfix', default='', help='postfix of detailed file name')
    arg('--print-model', default=1, type=int)
    arg('--dump-features', default=0, type=int)  # for knn, unused
    args = parser.parse_args()
    if args.test_only and args.submission:
        parser.error('pass one of --test-only and --submission')
    print(args)

    output_dir = Path(args.output_dir) if args.output_dir else None
    if output_dir:
        output_dir.mkdir(parents=True, exist_ok=True)
        if not args.resume:
            (output_dir / 'params.json').write_text(
                json.dumps(vars(args), indent=4))

    print('Loading data')
    df_train_gt, df_valid_gt = load_train_valid_df(args.fold, args.n_folds)
    df_clf_gt = load_train_df(args.clf_gt)[['labels', 'image_id']]
    if args.submission:
        df_valid = df_train = df_clf_gt
        empty_index = df_valid['labels'] == ''
        empty_pages = df_valid[empty_index]['image_id'].values
        df_valid = df_valid[~empty_index]
    else:
        df_train, df_valid = [
            df_clf_gt[df_clf_gt['image_id'].isin(set(df['image_id']))]
            for df in [df_train_gt, df_valid_gt]]
        df_valid = df_valid[df_valid['labels'] != '']
    if args.pseudolabels:
        df_ps = pd.concat(
            [pd.read_csv(p)[df_train.columns] for p in args.pseudolabels])
        if args.test_book:
            df_ps = df_ps[df_ps['image_id'].apply(
                lambda x: get_book_id(x) == args.test_book)]
        df_train = (
            pd.concat([df_train] + [df_ps] * args.pseudolabels_oversample)
            .reset_index(drop=True))
    if args.test_book:
        df_valid = df_valid[df_valid['image_id'].apply(
            lambda x: get_book_id(x) == args.test_book)]
    if args.train_limit:
        df_train = df_train.sample(n=args.train_limit, random_state=42)
    if args.test_limit:
        df_valid = df_valid.sample(n=args.test_limit, random_state=42)
    gt_by_image_id = {item.image_id: item for item in df_valid_gt.itertuples()}
    print(f'{len(df_train):,} in train, {len(df_valid):,} in valid')
    classes = get_encoded_classes()

    def _get_transforms(*, train: bool):
        if not train and args.n_tta > 1:
            test_heights = [
                args.test_height * (1 + s)
                for s in np.linspace(0, args.scale_aug, args.n_tta)]
            print('TTA test heights:', list(map(int, test_heights)))
        else:
            test_heights = [args.test_height]
        return [
            get_transform(
                train=train,
                test_height=test_height,
                crop_width=args.crop_width,
                crop_height=args.crop_height,
                scale_aug=args.scale_aug,
                color_hue_aug=args.color_hue_aug,
                color_sat_aug=args.color_sat_aug,
                color_val_aug=args.color_val_aug,
            ) for test_height in test_heights]

    def make_test_data_loader(df):
        return DataLoader(
            Dataset(
                df=df,
                transforms=_get_transforms(train=False),
                resample_empty=False,
                classes=classes,
            ),
            batch_size=1,
            collate_fn=collate_fn,
            num_workers=args.workers,
        )

    data_loader_test = make_test_data_loader(df_valid)
    if args.dump_features:  # unused
        df_train = df_train[df_train['labels'] != '']
        data_loader_train = make_test_data_loader(df_train)
    else:
        data_loader_train = DataLoader(
            Dataset(
                df=pd.concat([df_train] * args.repeat_train),
                transforms=_get_transforms(train=True),
                resample_empty=True,
                classes=classes,
            ),
            num_workers=args.workers,
            shuffle=True,
            collate_fn=partial(collate_fn, max_targets=args.max_targets),
            batch_size=args.batch_size,
        )

    print('Creating model')
    fp16 = bool(args.opt_level)
    model: nn.Module = build_model(
        base=args.base,
        head=args.head,
        frozen_start=args.frozen_start,
        fp16=fp16,
        n_classes=len(classes),
        head_dropout=args.head_dropout,
        use_sequences=bool(args.use_sequences),
    )
    if args.print_model:
        print(model)
    device = torch.device(args.device)
    model.to(device)
    if args.benchmark:
        torch.backends.cudnn.benchmark = True

    parameters = model.parameters()
    if args.optimizer == 'adam':
        optimizer = optim.Adam(
            parameters, lr=args.lr, weight_decay=args.wd)
    elif args.optimizer == 'sgd':
        optimizer = optim.SGD(
            parameters, lr=args.lr, weight_decay=args.wd, momentum=0.9)
    else:
        parser.error(f'Unexpected optimzier {args.optimizer}')

    if fp16:
        from apex import amp
        model, optimizer = amp.initialize(
            model, optimizer, opt_level=args.opt_level)
    loss = nn.CrossEntropyLoss()
    step = epoch = 0
    best_f1 = 0

    if args.resume:
        state = torch.load(args.resume, map_location='cpu')
        if 'optimizer' in state:
            optimizer.load_state_dict(state['optimizer'])
            model.load_state_dict(state['model'])
            step = state['step']
            epoch = state['epoch']
            best_f1 = state['best_f1']
        else:
            model.load_state_dict(state)
        del state

    @contextmanager
    def no_benchmark():
        torch.backends.cudnn.benchmark = False
        yield
        if args.benchmark:
            torch.backends.cudnn.benchmark = True

    if args.dump_features and not args.submission:  # unused
        if not output_dir:
            parser.error('set --output-dir with --dump-features')
        # We also dump test features below
        feature_evaluator = create_supervised_evaluator(
            model,
            device=device,
            prepare_batch=_prepare_batch,
            metrics={'features': GetFeatures(n_tta=args.n_tta)},
        )
        with no_benchmark():
            run_with_pbar(feature_evaluator, data_loader_train,
                          desc='train features')
        torch.save(feature_evaluator.state.metrics['features'],
                   output_dir / 'train_features.pth')

    def get_y_pred_y(output):
        y_pred, y = output
        return get_output(y_pred), get_labels(y)

    metrics = {
        'accuracy': Accuracy(output_transform=get_y_pred_y),
        'loss': Loss(loss, output_transform=get_y_pred_y),
        'predictions': GetPredictions(
            n_tta=args.n_tta, classes=classes),
        'detailed': GetDetailedPrediction(
            n_tta=args.n_tta, classes=classes),
    }
    if args.dump_features:
        metrics['features'] = GetFeatures(n_tta=args.n_tta)
    evaluator = create_supervised_evaluator(
        model,
        device=device,
        prepare_batch=_prepare_batch,
        metrics=metrics)

    def evaluate():
        with no_benchmark():
            run_with_pbar(evaluator, data_loader_test, desc='evaluate')
        metrics = {
            'valid_loss': evaluator.state.metrics['loss'],
            'accuracy': evaluator.state.metrics['accuracy'],
        }
        scores = []
        for prediction, meta in evaluator.state.metrics['predictions']:
            item = gt_by_image_id[meta['image_id']]
            target_boxes, target_labels = get_target_boxes_labels(item)
            target_boxes = torch.from_numpy(target_boxes)
            pred_centers = np.array([p['center'] for p in prediction])
            pred_labels = [p['cls'] for p in prediction]
            scores.append(
                dict(score_boxes(
                    truth_boxes=from_coco(target_boxes).numpy(),
                    truth_label=target_labels,
                    preds_center=pred_centers,
                    preds_label=np.array(pred_labels),
                ), image_id=item.image_id))
        metrics.update(get_metrics(scores))
        if output_dir:
            pd.DataFrame(evaluator.state.metrics['detailed']).to_csv(
                output_dir / f'detailed{args.detailed_postfix}.csv.gz',
                index=None)
        if args.dump_features:
            f_name = 'test' if args.submission else 'valid'
            torch.save(evaluator.state.metrics['features'],
                       output_dir / f'{f_name}_features.pth')
        return metrics

    def make_submission():
        with no_benchmark():
            run_with_pbar(evaluator, data_loader_test, desc='evaluate')
        submission = []
        for prediction, meta in tqdm.tqdm(
                evaluator.state.metrics['predictions']):
            submission.append(submission_item(
                meta['image_id'], prediction))
        submission.extend(submission_item(image_id, [])
                          for image_id in empty_pages)
        pd.DataFrame(submission).to_csv(
            output_dir / f'submission_{output_dir.name}.csv.gz',
            index=None)
        pd.DataFrame(evaluator.state.metrics['detailed']).to_csv(
            output_dir / f'test_detailed{args.detailed_postfix}.csv.gz',
            index=None)
        if args.dump_features:
            torch.save(evaluator.state.metrics['features'],
                       output_dir / 'test_features.pth')

    if args.test_only or args.submission:
        if not args.resume:
            parser.error('please pass --resume when running with --test-only '
                         'or --submission')
        if args.test_only:
            print_metrics(evaluate())
        elif args.submission:
            if not output_dir:
                parser.error('--output-dir required with --submission')
            make_submission()
        return

    trainer = create_supervised_trainer(
        model, optimizer,
        loss_fn=lambda y_pred, y: loss(get_output(y_pred), get_labels(y)),
        device=device,
        prepare_batch=_prepare_batch,
        accumulation_steps=args.accumulation_steps,
        fp16=fp16,
    )

    epochs_left = args.epochs - epoch
    epochs_pbar = tqdm.trange(epochs_left)
    epoch_pbar = tqdm.trange(len(data_loader_train))
    train_losses = deque(maxlen=20)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(_):
        nonlocal step
        train_losses.append(trainer.state.output)
        smoothed_loss = np.mean(train_losses)
        epoch_pbar.set_postfix(loss=f'{smoothed_loss:.4f}')
        epoch_pbar.update(1)
        step += 1
        if step % 20 == 0 and output_dir:
            json_log_plots.write_event(
                output_dir, step=step * args.batch_size,
                loss=smoothed_loss)

    @trainer.on(Events.EPOCH_COMPLETED)
    def checkpoint(_):
        if output_dir:
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'step': step,
                'epoch': epoch,
                'best_f1': best_f1,
            }, output_dir / 'checkpoint.pth')

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(_):
        nonlocal best_f1
        metrics = evaluate()
        if output_dir:
            json_log_plots.write_event(
                output_dir, step=step * args.batch_size, **metrics)
        if metrics['f1'] > best_f1:
            best_f1 = metrics['f1']
            if output_dir:
                torch.save(model.state_dict(), output_dir / 'model_best.pth')
        epochs_pbar.set_postfix({
            k: format_value(v) for k, v in metrics.items()})

    @trainer.on(Events.EPOCH_COMPLETED)
    def update_pbars_on_epoch_completion(_):
        nonlocal epoch
        epochs_pbar.update(1)
        epoch_pbar.reset()
        epoch += 1

    scheduler = None
    if args.drop_lr_epoch and args.cosine:
        parser.error('Choose only one schedule')
    if args.drop_lr_epoch:
        scheduler = StepLR(optimizer, step_size=args.drop_lr_epoch, gamma=0.1)
    if args.cosine:
        scheduler = CosineAnnealingLR(optimizer, epochs_left)
    if scheduler is not None:
        trainer.on(Events.EPOCH_COMPLETED)(lambda _: scheduler.step())

    trainer.run(data_loader_train, max_epochs=epochs_left)
コード例 #30
0
# train all layers
other_parameters = [param for name, param in model.module.named_parameters() if 'last_linear' not in name]
optimizer = AdamW(
    [
        {"params": model.module.last_linear.parameters(), "lr": 1e-3},
        {"params": other_parameters},
    ], 
    lr=1e-4, weight_decay = 0.01)    
    

best_loss_val = 100 
criterion = CosineMarginCrossEntropy().cuda()
exp_lr_scheduler = StepLR(optimizer, step_size=18, gamma=0.1)
for epoch in range(num_epochs):
    exp_lr_scheduler.step()
   
    
    # train for one epoch
    sample_weights = train(train_loader, model, criterion, optimizer, epoch, sample_weights, neptune_ctx)

    # evaluate on validation set
    acc1, acc5, loss_val = validate(val_loader, model, criterion)
    neptune_ctx.channel_send('val-acc1', acc1)
    neptune_ctx.channel_send('val-acc5', acc5)
    neptune_ctx.channel_send('val-loss', loss_val)
    neptune_ctx.channel_send('lr', float(exp_lr_scheduler.get_lr()[0]))
    
    logger.info(f'Epoch: {epoch} Acc1: {acc1} Acc5: {acc5} Val-Loss: {loss_val}')
    
    # remember best acc@1 and save checkpoint
コード例 #31
0
class UserAVG(User):
    def __init__(self, numeric_id, train_data, test_data, model, batch_size,
                 learning_rate, L, local_epochs):
        super().__init__(numeric_id, train_data, test_data, model[0],
                         batch_size, learning_rate, L, local_epochs)

        if model[1] == "linear":
            self.loss = nn.MSELoss()
        elif model[1] == "cnn":
            self.loss = nn.CrossEntropyLoss()
        else:
            self.loss = nn.NLLLoss()

        if model[1] == "cnn":
            layers = [
                self.model.conv1, self.model.conv2, self.model.conv3,
                self.model.fc1, self.model.fc2
            ]
            self.optimizer = torch.optim.SGD([{
                'params': layer.weight
            } for layer in layers] + [{
                'params': layer.bias,
                'lr': 2 * self.learning_rate
            } for layer in layers],
                                             lr=self.learning_rate,
                                             weight_decay=L)
            self.scheduler = StepLR(self.optimizer, step_size=8, gamma=0.1)
            self.lr_drop_rate = 0.95
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.learning_rate)

        self.csi = None

    def set_grads(self, new_grads):
        if isinstance(new_grads, nn.Parameter):
            for model_grad, new_grad in zip(self.model.parameters(),
                                            new_grads):
                model_grad.data = new_grad.data
        elif isinstance(new_grads, list):
            for idx, model_grad in enumerate(self.model.parameters()):
                model_grad.data = new_grads[idx]

    def train(self):
        self.model.train()
        for epoch in range(1, self.local_epochs + 1):
            self.model.train()
            for batch_idx, (X, y) in enumerate(self.trainloader):
                self.optimizer.zero_grad()
                output = self.model(X)
                loss = self.loss(output, y)
                loss.backward()
                self.optimizer.step()
            if self.scheduler:
                self.scheduler.step()

        # get model difference
        for local, server, delta in zip(self.model.parameters(),
                                        self.server_model, self.delta_model):
            delta.data = local.data.detach() - server.data.detach()

        return loss

    def get_params_norm(self):
        params = []
        for delta in self.delta_model:
            params.append(torch.flatten(delta.data))
        # return torch.linalg.norm(torch.cat(params), 2)
        return float(torch.norm(torch.cat(params)))
コード例 #32
0
ファイル: train.py プロジェクト: tyhu/PyAI
def main():
    #batch_size = 500
    batch_size = 128
    img_net = ImgBranch2()
    text_net = TextBranch2()
    #img_net, text_net = torch.load('img_net.pt'), torch.load('text_net.pt')
    tri_loss = TripletLoss(0.1)
    params = list(img_net.parameters())+list(text_net.parameters())
    opt = optim.SGD(params, lr=0.1, momentum=0.9, weight_decay=0.00005)
    scheduler = StepLR(opt, step_size=10, gamma=0.1)
    img_net.cuda()
    text_net.cuda()

    idlst = [l.strip() for l in file('train_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/train/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/train/'

    dataset = COCOImgTextFeatPairDataset(idlst,img_dir,text_dir)
    dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    ### Test set ###
    tiidlst = [l.strip() for l in file('test_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/val/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/val/'
    img_feat_dataset = COCOImgFeatDataset(tiidlst, img_dir)
    text_feat_dataset = COCOTextFeatDataset(tiidlst,text_dir)

    ### train subset
    triidlst = [l.strip() for l in file('train_val_ids.txt')]
    img_dir = '/home/datasets/coco/raw/vgg19_feat/train/'
    text_dir = '/home/datasets/coco/raw/annotation_text/hglmm_pca_npy/train/'
    img_sub_dataset = COCOImgFeatDataset(triidlst, img_dir)
    text_sub_dataset = COCOTextFeatDataset(triidlst,text_dir)
    

    total_loss = 0
    for eidx in range(50):
        #total_loss = 0
        print 'epoch',eidx
        for i, batch in enumerate(dataloader):
            #anc_i, pos_t, neg_t, anc_t, pos_i, neg_i = hard_negative_sample(batch,img_net,text_net)
            anc_i, pos_t, neg_t, anc_t, pos_i, neg_i = random_sample(batch)
            sub_batch_num = 1
            sub_batch_size = anc_i.shape[0]/sub_batch_num
            for j in range(sub_batch_num):
                start, end = j*sub_batch_size, (j+1)*sub_batch_size
                anc_i_sub, pos_t_sub, neg_t_sub, neg_i_sub = anc_i[start:end], pos_t[start:end], neg_t[start:end], neg_i[start:end]

                #anc_i_sub = img_net(Variable(torch.Tensor(anc_i_sub).cuda()))
                #pos_t_sub = text_net(Variable(torch.Tensor(pos_t_sub).cuda()))
                #neg_t_sub = text_net(Variable(torch.Tensor(neg_t_sub).cuda()))
                #neg_i_sub = img_net(Variable(torch.Tensor(neg_i_sub).cuda()))
                anc_i_sub = img_net(Variable(anc_i_sub.cuda()))
                pos_t_sub = text_net(Variable(pos_t_sub.cuda()))
                neg_t_sub = text_net(Variable(neg_t_sub.cuda()))
                neg_i_sub = img_net(Variable(neg_i_sub.cuda()))
                anc_t_sub = pos_t_sub
                pos_i_sub = anc_i_sub
                loss1 = tri_loss(anc_i_sub, pos_t_sub, neg_t_sub)
                loss2 = tri_loss(anc_t_sub, pos_i_sub, neg_i_sub)
                loss = loss1+2*loss2

                opt.zero_grad()
                loss.backward()
                opt.step()
                total_loss+=loss.data[0]

            if i%200==0:
                print 'epoch',eidx,'batch', i
                test(tiidlst,img_feat_dataset, text_feat_dataset,img_net,text_net)
                print 'train sub'
                test(triidlst,img_sub_dataset, text_sub_dataset,img_net,text_net)
                print 'train loss:',total_loss
                total_loss = 0
            #break
        #break
        scheduler.step()

        ###### TEST ######
        # test(tiidlst,img_feat_dataset, text_feat_dataset,img_net,text_net)

        torch.save(img_net,'img_net.pt')
        torch.save(text_net,'text_net.pt')
コード例 #33
0
def train(args, out, net_name):
    data_path = get_data_path(args.dataset)
    data_loader = get_loader(args.dataset)
    loader = data_loader(data_path, is_transform=True)
    n_classes = loader.n_classes
    print(n_classes)
    kwargs = {'num_workers': 8, 'pin_memory': True}

    trainloader = data.DataLoader(loader,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    another_loader = data_loader(data_path, split='val', is_transform=True)

    valloader = data.DataLoader(another_loader,
                                batch_size=args.batch_size,
                                shuffle=True)

    # compute weight for cross_entropy2d
    norm_hist = hist / np.max(hist)
    weight = 1 / np.log(norm_hist + 1.02)
    weight[-1] = 0
    weight = torch.FloatTensor(weight)
    model = Bilinear_Res(n_classes)

    if torch.cuda.is_available():
        model.cuda(0)
        weight = weight.cuda(0)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr_rate,
                                 weight_decay=args.w_decay)
    # optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr_rate)
    scheduler = StepLR(optimizer, step_size=100, gamma=args.lr_decay)

    for epoch in tqdm.tqdm(range(args.epochs),
                           desc='Training',
                           ncols=80,
                           leave=False):
        scheduler.step()
        model.train()
        loss_list = []
        file = open(out + '/{}_epoch_{}.txt'.format(net_name, epoch), 'w')
        for i, (images, labels) in tqdm.tqdm(enumerate(trainloader),
                                             total=len(trainloader),
                                             desc='Iteration',
                                             ncols=80,
                                             leave=False):
            if torch.cuda.is_available():
                images = Variable(images.cuda(0))
                labels = Variable(labels.cuda(0))
            else:
                images = Variable(images)
                labels = Variable(labels)
            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy2d(outputs, labels, weight=weight)
            loss_list.append(loss.data[0])
            loss.backward()
            optimizer.step()

        # file.write(str(np.average(loss_list)))
        print(np.average(loss_list))
        file.write(str(np.average(loss_list)) + '\n')
        model.eval()
        gts, preds = [], []
        if (epoch % 10 == 0):
            for i, (images, labels) in tqdm.tqdm(enumerate(valloader),
                                                 total=len(valloader),
                                                 desc='Valid Iteration',
                                                 ncols=80,
                                                 leave=False):
                if torch.cuda.is_available():
                    images = Variable(images.cuda(0))
                    labels = Variable(labels.cuda(0))
                else:
                    images = Variable(images)
                    labels = Variable(labels)
                outputs = model(images)
                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()
                for gt_, pred_ in zip(gt, pred):
                    gts.append(gt_)
                    preds.append(pred_)
            score, class_iou = scores(gts, preds, n_class=n_classes)
            for k, v in score.items():
                file.write('{} {}\n'.format(k, v))

            for i in range(n_classes):
                file.write('{} {}\n'.format(i, class_iou[i]))
            torch.save(
                model.state_dict(),
                out + "/{}_{}_{}.pkl".format(net_name, args.dataset, epoch))
        file.close()