Esempio n. 1
0
def main():
    global args
    args = parser.parse_args()
    os.makedirs(args.output, exist_ok=True)

    # if don't call torch.cuda.current_device(), fails later with
    #   "RuntimeError: cuda runtime error (30) : unknown error at ..\aten\src\THC\THCGeneral.cpp:87"
    torch.cuda.current_device()
    use_cuda = torch.cuda.is_available() and True
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # try to get consistent results across runs
    #   => currently still fails, however, makes runs a bit more consistent
    _set_random_seed()

    # create model
    model = PoseNet(arch=args.arch, num_features=args.features, dropout=args.dropout,
                    pretrained=True, cache_dir=args.cache, loss=args.loss, excl_bn_affine=args.excl_bn,
                    beta=args.beta, sx=args.sx, sq=args.sq)

    # create optimizer
    #  - currently only Adam supported
    if args.optimizer == 'adam':
        eps = 0.1
        if args.split_opt_params:
            new_biases, new_weights, biases, weights, others = model.params_to_optimize(split=True, excl_batch_norm=args.excl_bn)
            optimizer = torch.optim.Adam([
                {'params': new_biases, 'lr': args.lr * 2, 'weight_decay': 0.0, 'eps': eps},
                {'params': new_weights, 'lr': args.lr, 'weight_decay': args.weight_decay, 'eps': eps},
                {'params': biases, 'lr': args.lr * 2, 'weight_decay': 0.0, 'eps': eps},
                {'params': weights, 'lr': args.lr, 'weight_decay': args.weight_decay, 'eps': eps},
                {'params': others, 'lr': 0, 'weight_decay': 0, 'eps': eps},
            ])
        else:
            params = model.params_to_optimize(excl_batch_norm=args.excl_bn)
            optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay, eps=eps)
    else:
        assert False, 'Invalid optimizer: %s' % args.optimizer

    # optionally resume from a checkpoint
    best_loss = float('inf')
    best_epoch = -1
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_epoch = checkpoint['best_epoch']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
            quit()

    # define overall training dataset, set output normalization, load model to gpu
    all_tr_data = PoseDataset(args.data, 'dataset_train.txt', random_crop=not args.center_crop)
    model.set_target_transform(all_tr_data.target_mean, all_tr_data.target_std)
    model.to(device)

    # split overall training data to training and validation sets
    # validation set is used for early stopping, or possibly in future for hyper parameter optimization
    lengths = [round(len(all_tr_data) * 0.75), round(len(all_tr_data) * 0.25)]
    tr_data, val_data = torch.utils.data.random_split(all_tr_data, lengths)

    # define data loaders
    train_loader = DataLoader(tr_data, batch_size=args.batch_size, num_workers=args.workers,
                              shuffle=True, pin_memory=True, worker_init_fn=_worker_init_fn)

    val_loader = DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers,
                            shuffle=False, pin_memory=True, worker_init_fn=_worker_init_fn)

    test_loader = DataLoader(PoseDataset(args.data, 'dataset_test.txt', random_crop=False),
                             batch_size=args.batch_size, num_workers=args.workers,
                             shuffle=False, pin_memory=True, worker_init_fn=_worker_init_fn)

    # evaluate model only
    if args.evaluate:
        validate(test_loader, model)
        return

    # training loop
    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        lss, pos, ori = process(train_loader, model, optimizer, epoch, device, adv_tr_eps=args.adv_tr_eps)
        stats = np.zeros(16)
        stats[:6] = [epoch, lss.avg, pos.avg, pos.median, ori.avg, ori.median]

        # evaluate on validation set
        if (epoch+1) % args.test_freq == 0:
            lss, pos, ori = validate(val_loader, model, device)
            stats[6:11] = [lss.avg, pos.avg, pos.median, ori.avg, ori.median]

            # remember best loss and save checkpoint
            is_best = lss.avg < best_loss
            best_epoch = epoch if is_best else best_epoch
            best_loss = lss.avg if is_best else best_loss

            # save best model
            if is_best:
                _save_checkpoint({
                    'epoch': epoch + 1,
                    'best_epoch': best_epoch,
                    'best_loss': best_loss,
                    'arch': args.arch,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, True)
        else:
            is_best = False

        # maybe save a checkpoint even if not best model
        if (epoch+1) % args.save_freq == 0 and not is_best:
            _save_checkpoint({
                'epoch': epoch + 1,
                'best_epoch': best_epoch,
                'best_loss': best_loss,
                'arch': args.arch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, False)

        # evaluate on test set if best yet result on validation set
        if is_best:
            lss, pos, ori = validate(test_loader, model, device)
            stats[11:] = [lss.avg, pos.avg, pos.median, ori.avg, ori.median]

        # add row to log file
        _save_log(stats, epoch == 0)

        # early stopping
        if args.early_stopping > 0 and epoch - best_epoch >= args.early_stopping:
            print('=====\nEARLY STOPPING CRITERION MET (%d epochs since best validation loss)' % args.early_stopping)
            break

        print('=====\n')

    if epoch+1 == args.epochs:
        print('MAX EPOCHS (%d) REACHED' % args.epochs)
    print('BEST VALIDATION LOSS: %.3f' % best_loss)
Esempio n. 2
0
print('Train size: {} x {}'.format(len(train_data), train_data[0].size()))

## LOAD MODEL
print('\nLOADING GAN.')


def weights_init(m):
    if type(m) == torch.nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.constant_(m.bias, 0.0)


netG = PoseNet(n_hidden=N_HIDDEN, mode='generator').to(device)
netD = PoseNet(n_hidden=N_HIDDEN, mode='discriminator').to(device)
if args.model:
    netG.load_state_dict(torch.load(args.model)['netG'])
    netD.load_state_dict(torch.load(args.model)['netD'])
    print('=> Loaded models from {:s}'.format(args.model))
else:
    netG.apply(weights_init)
    netD.apply(weights_init)
print('Model params: {:.2f}M'.format(
    sum(p.numel() for p in netG.parameters()) / 1e6))

## TRAINING
print('\nTRAINING.')
data_loader = torch.utils.data.DataLoader(train_data,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          **kwargs)