def main():
    parser = argparse.ArgumentParser(
        description='Segmentation using the VRN network')
    parser.add_argument('out', type=str, help='Directory to output the result')
    parser.add_argument('--dataset',
                        type=str,
                        help='Path to dataset',
                        default='bfm09_backface_culling.hdf5')
    parser.add_argument('--background',
                        type=str,
                        help='Path to background',
                        default='dtd_all.hdf5')
    parser.add_argument("--image",
                        "-is",
                        type=int,
                        nargs="+",
                        default=(3, 227, 227),
                        help="Size of images. Default: 256x256x3")

    global args
    args = parser.parse_args()
    # create model using the pretrained alexnet.
    print("=> Construct the model...")

    model = EIG()
    model.cuda()

    if not os.path.exists(args.out):
        os.mkdir(args.out)
    if not os.path.exists(os.path.join(args.out, 'coeffs')):
        os.mkdir(os.path.join(args.out, 'coeffs'))

    print("Output location: {}".format(args.out))

    # Initialize both the foreground and background datasets using the background class
    d = datasets.BFM09(os.path.join(CONFIG['PATHS', 'databases'],
                                    args.dataset),
                       raw_image=True,
                       input_shape=args.image,
                       augment=False)
    b = datasets.Background(os.path.join(CONFIG['PATHS', 'databases'],
                                         args.background),
                            input_shape=args.image)
    train_loader = datasets.BFMOverlay(d, b)

    segment(train_loader, model)
def main():
    parser = argparse.ArgumentParser(description='Training EIG')
    parser.add_argument(
        '--out',
        type=str,
        default='',
        metavar='PATH',
        help=
        'Directory to output the result if other than what is specified in the config'
    )
    parser.add_argument("--image",
                        "-is",
                        type=int,
                        nargs="+",
                        default=(3, 227, 227),
                        help="Image size. Def: (3, 227, 227)")
    parser.add_argument("--z-size",
                        "-zs",
                        type=int,
                        metavar="N",
                        default=404,
                        help="Size of z layer. Default: 404")
    parser.add_argument('--epochs',
                        default=75,
                        type=int,
                        metavar='N',
                        help='number of total epochs to run')
    parser.add_argument('--start-epoch',
                        default=0,
                        type=int,
                        metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-b',
                        '--batch-size',
                        default=20,
                        type=int,
                        metavar='N',
                        help='mini-batch size (default: 20)')
    parser.add_argument('--lr',
                        '--learning-rate',
                        default=1e-4,
                        type=float,
                        metavar='LR',
                        help='initial learning rate')
    parser.add_argument('--weight-decay',
                        '--wd',
                        default=1e-4,
                        type=float,
                        metavar='W',
                        help='weight decay (default: 1e-4)')
    parser.add_argument('--print-freq',
                        '-p',
                        default=10,
                        type=int,
                        metavar='N',
                        help='print frequency (default: 10)')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')

    global args
    args = parser.parse_args()
    # create model using the pretrained alexnet.
    print("=> Construct the model...")

    model = EIG(args.z_size)
    print(model)
    model.cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.MSELoss(size_average=True).cuda()

    optimizer = torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()), args.lr)

    if args.out == '':
        out_path = os.path.join(
            CONFIG['PATHS', 'checkpoints'], 'eig',
            datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
        os.makedirs(out_path)
    else:
        out_path = args.out
    if args.resume != '':
        print("=> loading checkpoint '{}'".format(args.resume))
        resume_path = args.resume
        checkpoint = torch.load(resume_path)
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            args.resume, checkpoint['epoch']))

    print("Current run location: {}".format(out_path))

    print('dataset reader begin')
    # Initialize the datasets
    d_full = datasets.BFM09(os.path.join(CONFIG['PATHS', 'databases'],
                                         'bfm09_backface_culling.hdf5'),
                            raw_image=True,
                            input_shape=args.image)
    d_segment = datasets.BFM09(os.path.join(
        CONFIG['PATHS', 'databases'], 'bfm09_backface_culling_segment.hdf5'),
                               raw_image=True,
                               input_shape=args.image)

    val_loader_full = datasets.BFM09(os.path.join(
        CONFIG['PATHS', 'databases'], 'bfm09_backface_culling_val.hdf5'),
                                     raw_image=True,
                                     input_shape=args.image,
                                     augment=False)
    val_loader_segment = datasets.BFM09(os.path.join(
        CONFIG['PATHS', 'databases'],
        'bfm09_backface_culling_val_segment.hdf5'),
                                        raw_image=True,
                                        input_shape=args.image,
                                        augment=False)
    print('dataset reader end')

    for epoch in range(args.start_epoch, args.epochs):

        print('training begin')
        # train for one epoch
        train(d_full, d_segment, model, criterion, optimizer, epoch)

        # validate
        avg_loss = validate(val_loader_full, val_loader_segment, model,
                            criterion, epoch)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, False, out_path + 'checkpoint_bfm.pth.tar')
from utils import config

CONFIG = config.Config()


def load_image(image, size):
    image = image.resize(size)
    image = image.convert('RGB')
    image = np.asarray(image)
    image = np.moveaxis(image, 2, 0)
    image = image.astype(np.float32)
    return image


models_d = {
    'eig': EIG(),
}

image_sizes = {
    'eig': (227, 227),
}


def main():
    parser = argparse.ArgumentParser(
        description='Predictions of the models on the neural image test sets')
    parser.add_argument('--imagefolder',
                        type=str,
                        default='./demo_images/',
                        help='Folder containing the input images.')
    parser.add_argument(