예제 #1
0
파일: train.py 프로젝트: zoq/AdaCoF-pytorch
def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)

    dataset = DBreader_Vimeo90k(args.train,
                                random_crop=(args.patch_size, args.patch_size))
    TestDB = Middlebury_other(args.test_input, args.gt)
    train_loader = DataLoader(dataset=dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=0)
    model = models.Model(args)
    loss = losses.Loss(args)

    start_epoch = 0
    if args.load is not None:
        checkpoint = torch.load(args.load)
        model.load(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    my_trainer = Trainer(args, train_loader, TestDB, model, loss, start_epoch)

    now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    with open(args.out_dir + '/config.txt', 'a') as f:
        f.write(now + '\n\n')
        for arg in vars(args):
            f.write('{}: {}\n'.format(arg, getattr(args, arg)))
        f.write('\n')

    while not my_trainer.terminate():
        my_trainer.train()
        my_trainer.test()

    my_trainer.close()
예제 #2
0
def main():
    args = parser.parse_args()
    input_dir = args.input
    gt_dir = args.gt
    output_dir = args.output
    ckpt = args.checkpoint

    print("Reading Test DB...")
    TestDB = Middlebury_other(input_dir, gt_dir)
    print("Loading the Model...")
    checkpoint = torch.load(ckpt)
    kernel_size = checkpoint['kernel_size']
    model = SepConvNet(kernel_size=kernel_size)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(torch.load(state_dict))
    model.epoch = checkpoint['epoch']

    print("Test Start...")
    TestDB.Test(model, output_dir)
예제 #3
0
def main():
    args = parser.parse_args()
    db_dir = args.train

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)
    result_dir = args.out_dir + '/result'
    ckpt_dir = args.out_dir + '/checkpoint'

    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    logfile = open(args.out_dir + '/log.txt', 'w')
    logfile.write('batch_size: ' + str(args.batch_size) + '\n')

    total_epoch = args.epochs
    batch_size = args.batch_size

    dataset = DBreader_frame_interpolation(db_dir, resize=(128, 128))
    train_loader = DataLoader(dataset=dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)

    TestDB = Middlebury_other(args.test_input, args.gt)
    test_output_dir = args.out_dir + '/result'

    if args.load_model is not None:
        checkpoint = torch.load(args.load_model)
        kernel_size = args.kernel
        model = SepConvNet(kernel_size=kernel_size)
        state_dict = torch.load(args.load_model)
        model.load_state_dict(state_dict)
    else:
        kernel_size = args.kernel
        model = SepConvNet(kernel_size=kernel_size)

    logfile.write('kernel_size: ' + str(kernel_size) + '\n')

    if torch.cuda.is_available():
        model = model.cuda()

    max_step = train_loader.__len__()

    model.eval()
    TestDB.Test(model, test_output_dir, logfile,
                str(model.epoch.item()).zfill(3) + '.png')

    while True:
        if model.epoch.item() == total_epoch:
            break
        model.train()
        for batch_idx, (frame0, frame1, frame2) in enumerate(train_loader):
            frame0 = to_variable(frame0)
            frame1 = to_variable(frame1)
            frame2 = to_variable(frame2)
            loss = model.train_model(frame0, frame2, frame1)
            if batch_idx % 100 == 0:
                print('{:<13s}{:<14s}{:<6s}{:<16s}{:<12s}{:<20.16f}'.format(
                    'Train Epoch: ', '[' + str(model.epoch.item()) + '/' +
                    str(total_epoch) + ']', 'Step: ',
                    '[' + str(batch_idx) + '/' + str(max_step) + ']',
                    'train loss: ', loss.item()))
        model.increase_epoch()
        if model.epoch.item() % 1 == 0:
            torch.save(
                {
                    'epoch': model.epoch,
                    'state_dict': model.state_dict(),
                    'kernel_size': kernel_size
                }, ckpt_dir + '/model_epoch' +
                str(model.epoch.item()).zfill(3) + '.pth')
            model.eval()
            TestDB.Test(model, test_output_dir, logfile,
                        str(model.epoch.item()).zfill(3) + '.png')
            logfile.write('\n')

    logfile.close()