def main(): args = parser.parse_args() torch.cuda.set_device(args.gpu_id) dataset = DBreader_Vimeo90k(args.train, random_crop=(args.patch_size, args.patch_size)) TestDB = Middlebury_other(args.test_input, args.gt) train_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=0) model = models.Model(args) loss = losses.Loss(args) start_epoch = 0 if args.load is not None: checkpoint = torch.load(args.load) model.load(checkpoint['state_dict']) start_epoch = checkpoint['epoch'] my_trainer = Trainer(args, train_loader, TestDB, model, loss, start_epoch) now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') with open(args.out_dir + '/config.txt', 'a') as f: f.write(now + '\n\n') for arg in vars(args): f.write('{}: {}\n'.format(arg, getattr(args, arg))) f.write('\n') while not my_trainer.terminate(): my_trainer.train() my_trainer.test() my_trainer.close()
def main(): args = parser.parse_args() input_dir = args.input gt_dir = args.gt output_dir = args.output ckpt = args.checkpoint print("Reading Test DB...") TestDB = Middlebury_other(input_dir, gt_dir) print("Loading the Model...") checkpoint = torch.load(ckpt) kernel_size = checkpoint['kernel_size'] model = SepConvNet(kernel_size=kernel_size) state_dict = checkpoint['state_dict'] model.load_state_dict(torch.load(state_dict)) model.epoch = checkpoint['epoch'] print("Test Start...") TestDB.Test(model, output_dir)
def main(): args = parser.parse_args() db_dir = args.train if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) result_dir = args.out_dir + '/result' ckpt_dir = args.out_dir + '/checkpoint' if not os.path.exists(result_dir): os.makedirs(result_dir) if not os.path.exists(ckpt_dir): os.makedirs(ckpt_dir) logfile = open(args.out_dir + '/log.txt', 'w') logfile.write('batch_size: ' + str(args.batch_size) + '\n') total_epoch = args.epochs batch_size = args.batch_size dataset = DBreader_frame_interpolation(db_dir, resize=(128, 128)) train_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=0) TestDB = Middlebury_other(args.test_input, args.gt) test_output_dir = args.out_dir + '/result' if args.load_model is not None: checkpoint = torch.load(args.load_model) kernel_size = args.kernel model = SepConvNet(kernel_size=kernel_size) state_dict = torch.load(args.load_model) model.load_state_dict(state_dict) else: kernel_size = args.kernel model = SepConvNet(kernel_size=kernel_size) logfile.write('kernel_size: ' + str(kernel_size) + '\n') if torch.cuda.is_available(): model = model.cuda() max_step = train_loader.__len__() model.eval() TestDB.Test(model, test_output_dir, logfile, str(model.epoch.item()).zfill(3) + '.png') while True: if model.epoch.item() == total_epoch: break model.train() for batch_idx, (frame0, frame1, frame2) in enumerate(train_loader): frame0 = to_variable(frame0) frame1 = to_variable(frame1) frame2 = to_variable(frame2) loss = model.train_model(frame0, frame2, frame1) if batch_idx % 100 == 0: print('{:<13s}{:<14s}{:<6s}{:<16s}{:<12s}{:<20.16f}'.format( 'Train Epoch: ', '[' + str(model.epoch.item()) + '/' + str(total_epoch) + ']', 'Step: ', '[' + str(batch_idx) + '/' + str(max_step) + ']', 'train loss: ', loss.item())) model.increase_epoch() if model.epoch.item() % 1 == 0: torch.save( { 'epoch': model.epoch, 'state_dict': model.state_dict(), 'kernel_size': kernel_size }, ckpt_dir + '/model_epoch' + str(model.epoch.item()).zfill(3) + '.pth') model.eval() TestDB.Test(model, test_output_dir, logfile, str(model.epoch.item()).zfill(3) + '.png') logfile.write('\n') logfile.close()