Example #1
0
def load_checkpoint(net, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)

    net.cuda_flag = checkpoint['cuda_flag']
    net.height = checkpoint['h']
    net.width = checkpoint['w']
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    losses = checkpoint['losses']

    return net, optimizer, start_epoch, losses


# --------------------------------------------------------------
toflow = TOFlow(h, w, task=task, cuda_flag=cuda_flag).cuda()

optimizer = torch.optim.Adam(toflow.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
loss_func = torch.nn.L1Loss()

# Training
prev_time = datetime.datetime.now()  # current time
print('%s  Start training...' % show_time(prev_time))
plotx = []
ploty = []
start_epoch = 0
check_loss = 1

if use_checkpoint:
    toflow, optimizer, start_epoch, ploty = load_checkpoint(toflow, optimizer, checkpoint_path)
    plotx = list(range(len(ploty)))
Example #2
0
File: run.py Project: delldu/TOFlow
    # end
    return tensorOutput.detach().numpy()

# ------------------------------
if __name__ == '__main__':
    if CUDA:
        torch.cuda.set_device(gpuID)
    temp_img = np.array(plt.imread(frameFirstName))
    height = temp_img.shape[0]
    width = temp_img.shape[1]

    intPreprocessedWidth = int(math.floor(math.ceil(width / 32.0) * 32.0))  # 宽度弄成32的倍数,便于上下采样
    intPreprocessedHeight = int(math.floor(math.ceil(height / 32.0) * 32.0))  # 长度弄成32的倍数,便于上下采样

    print('Loading TOFlow Net... ', end='')
    net = TOFlow(intPreprocessedHeight, intPreprocessedWidth, task='slow', cuda_flag=CUDA)
    net.load_state_dict(torch.load(os.path.join(workplace, 'toflow_models', model_name + '.pkl')))
    if CUDA:
        net.eval().cuda()
    else:
        net.eval()

    print('Done.')

    # ------------------------------
    # generate(net=net, model_name=model_name, f1name=os.path.join(test_pic_dir, 'im1.png'),
    #         f2name=os.path.join(test_pic_dir, 'im3.png'), fname=outputname)
    print('Processing...')
    predict = Estimate(net, Firstfilename=frameFirstName, Secondfilename=frameSecondName, cuda_flag=CUDA)
    print(predict, np.min(predict), np.max(predict))
    plt.imsave(frameOutName, predict)
Example #3
0
def vimeo_evaluate(input_dir,
                   out_img_dir,
                   test_codelistfile,
                   task='',
                   cuda_flag=True):
    mkdir_if_not_exist(out_img_dir)

    net = TOFlow(256, 448, cuda_flag=cuda_flag, task=task)
    net.load_state_dict(torch.load(model_path))

    if cuda_flag:
        net.cuda().eval()
    else:
        net.eval()

    fp = open(test_codelistfile)
    test_img_list = fp.read().splitlines()
    fp.close()

    if task == 'interp':
        process_index = [1, 3]
        str_format = 'im%d.png'
    elif task in ['interp', 'denoise', 'denoising', 'sr', 'super-resolution']:
        process_index = [1, 2, 3, 4, 5, 6, 7]
        str_format = 'im%04d.png'
    else:
        raise ValueError(
            'Invalid [--task].\nOnly support: [interp, denoise/denoising, sr/super-resolution]'
        )
    total_count = len(test_img_list)
    count = 0

    pre = datetime.datetime.now()
    for code in test_img_list:
        # print('Processing %s...' % code)
        count += 1
        video = code.split('/')[0]
        sep = code.split('/')[1]
        mkdir_if_not_exist(os.path.join(out_img_dir, video))
        mkdir_if_not_exist(os.path.join(out_img_dir, video, sep))
        input_frames = []
        for i in process_index:
            input_frames.append(
                plt.imread(os.path.join(input_dir, code, str_format % i)))
        input_frames = np.transpose(np.array(input_frames), (0, 3, 1, 2))

        if cuda_flag:
            input_frames = torch.from_numpy(input_frames).cuda()
        else:
            input_frames = torch.from_numpy(input_frames)
        input_frames = input_frames.view(1, input_frames.size(0),
                                         input_frames.size(1),
                                         input_frames.size(2),
                                         input_frames.size(3))
        predicted_img = net(input_frames)[0, :, :, :]
        plt.imsave(os.path.join(out_img_dir, video, sep, 'out.png'),
                   predicted_img.permute(1, 2, 0).cpu().detach().numpy())

        cur = datetime.datetime.now()
        processing_time = (cur - pre).seconds / count
        print('%.2fs per frame.\t%.2fs left.' %
              (processing_time, processing_time * (total_count - count)))
Example #4
0
def load_checkpoint(net, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)

    net.cuda_flag = checkpoint['cuda_flag']
    net.height = checkpoint['h']
    net.width = checkpoint['w']
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    losses = checkpoint['losses']

    return net, optimizer, start_epoch, losses


# --------------------------------------------------------------
toflow = TOFlow(h, w, task=task, cuda_flag=cuda_flag).cuda()

optimizer = torch.optim.Adam(toflow.parameters(),
                             lr=LR,
                             weight_decay=WEIGHT_DECAY)
loss_func = torch.nn.L1Loss()

# Training
prev_time = datetime.datetime.now()  # current time
print('%s  Start training...' % show_time(prev_time))
plotx = []
ploty = []
start_epoch = 0
check_loss = 1

if use_checkpoint:
Example #5
0
def vimeo_evaluate(input_dir,
                   out_img_dir,
                   test_codelistfile,
                   task='',
                   cuda_flag=True):
    mkdir_if_not_exist(out_img_dir)

    net = TOFlow(task=task)
    net.load_state_dict(torch.load(model_path, map_location='cpu'))
    # pdb.set_trace()
    # model_path -- 'toflow_models/sr.pkl'

    if cuda_flag:
        net.cuda().eval()
    else:
        net.eval()

    fp = open(test_codelistfile)
    test_img_list = fp.read().splitlines()
    fp.close()

    if task == 'slow':
        process_index = [1, 3]
        str_format = 'im%d.png'
    elif task in ['slow', 'clean', 'zoom']:
        process_index = [1, 2, 3, 4, 5, 6, 7]
        str_format = 'im%04d.png'
    else:
        raise ValueError(
            'Invalid [--task].\nOnly support: [slow, denoise/clean, sr/zoom]')
    total_count = len(test_img_list)
    count = 0
    # pdb.set_trace()
    # test_img_list -- ['00035/0737', '00053/0807', '00052/0159', '00034/0948', '00053/0337', '00071/0347', '00091/0333', '00067/0741']

    pre = datetime.datetime.now()
    for code in test_img_list:
        # print('Processing %s...' % code)
        count += 1
        video = code.split('/')[0]
        sep = code.split('/')[1]
        mkdir_if_not_exist(os.path.join(out_img_dir, video))
        mkdir_if_not_exist(os.path.join(out_img_dir, video, sep))
        input_frames = []
        for i in process_index:
            image = plt.imread(os.path.join(input_dir, code, str_format % i))
            output_filename = os.path.join(out_img_dir, video, sep,
                                           task + "_" + str_format % i)
            plt.imsave(output_filename, image)

            input_frames.append(
                plt.imread(os.path.join(input_dir, code, str_format % i)))
        # (Pdb) len(input_frames), input_frames[0].shape
        # (7, (256, 448, 3))

        input_frames = np.transpose(np.array(input_frames), (0, 3, 1, 2))

        if cuda_flag:
            input_frames = torch.from_numpy(input_frames).cuda()
        else:
            input_frames = torch.from_numpy(input_frames)
        input_frames = input_frames.view(1, input_frames.size(0),
                                         input_frames.size(1),
                                         input_frames.size(2),
                                         input_frames.size(3))

        predicted_img = net(input_frames)[0, :, :, :]
        # input_frames -- torch.Size([1, 7, 3, 256, 448])
        # predicted_img.size() -- torch.Size([3, 256, 448])

        predicted_img = predicted_img.clamp(0, 1.0)
        plt.imsave(os.path.join(out_img_dir, video, sep, task + '_out.png'),
                   predicted_img.permute(1, 2, 0).cpu().detach().numpy())

        cur = datetime.datetime.now()
        processing_time = (cur - pre).seconds / count
        print('%.2fs per frame.\t%.2fs left.' %
              (processing_time, processing_time * (total_count - count)))