Beispiel #1
0
def main():
    # initial
    iter = math.log(args.t_interp, int(2))
    if iter % 1:
        print('the times of interpolating must be power of 2!!')
        return
    iter = int(iter)
    bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
    dict1 = torch.load(args.checkpoint)
    structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
    detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)

    bdcn.eval()
    structure_gen.eval()
    detail_enhance.eval()

    IE = 0
    PSNR = 0
    count = 0
    [dir_path, frame_count, fps] = VideoToSequence(args.video_path,
                                                   args.t_interp)

    for i in range(iter):
        print('processing iter' + str(i + 1) + ', ' +
              str((i + 1) * frame_count) + ' frames in total')
        filenames = os.listdir(dir_path)
        filenames.sort()
        for i in range(0, len(filenames) - 1):
            arguments_strFirst = os.path.join(dir_path, filenames[i])
            arguments_strSecond = os.path.join(dir_path, filenames[i + 1])
            index1 = int(re.sub("\D", "", filenames[i]))
            index2 = int(re.sub("\D", "", filenames[i + 1]))
            index = int((index1 + index2) / 2)
            arguments_strOut = os.path.join(
                dir_path,
                IndexHelper(index, len(str(args.t_interp * frame_count))) +
                ".png")

            # print(arguments_strFirst)
            # print(arguments_strSecond)
            # print(arguments_strOut)

            X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0)
            X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0)

            assert (X0.size(2) == X1.size(2))
            assert (X0.size(3) == X1.size(3))

            intWidth = X0.size(3)
            intHeight = X0.size(2)
            channel = X0.size(1)
            if not channel == 3:
                print('Not RGB image')
                continue
            count += 1

            # if intWidth != ((intWidth >> 4) << 4):
            #     intWidth_pad = (((intWidth >> 4) + 1) << 4)  # more than necessary
            #     intPaddingLeft = int((intWidth_pad - intWidth) / 2)
            #     intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
            # else:
            #     intWidth_pad = intWidth
            #     intPaddingLeft = 0
            #     intPaddingRight = 0
            #
            # if intHeight != ((intHeight >> 4) << 4):
            #     intHeight_pad = (((intHeight >> 4) + 1) << 4)  # more than necessary
            #     intPaddingTop = int((intHeight_pad - intHeight) / 2)
            #     intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
            # else:
            #     intHeight_pad = intHeight
            #     intPaddingTop = 0
            #     intPaddingBottom = 0
            #
            # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            # first, second = pader(X0), pader(X1)
            first, second = X0, X1
            imgt = ToImage(first, second)

            imgt_np = imgt.squeeze(0).cpu().numpy(
            )  # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
            imgt_png = np.uint8(
                ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
            cv2.imwrite(arguments_strOut, imgt_png)

            # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
            # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))

            # diff_rgb = rec_rgb - gt_rgb
            # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2))

            # mse = np.mean((diff_rgb) ** 2)

            # PIXEL_MAX = 255.0
            # psnr = compare_psnr(gt_rgb, rec_rgb, 255)
            # print(folder, psnr)

            # IE += avg_interp_error_abs
            # PSNR += psnr

            # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)

        # IE = IE / count
        # PSNR = PSNR / count
        # print('Average IE/PSNR:', IE, PSNR)
    if args.fps != -1:
        output_fps = args.fps
    else:
        output_fps = fps if args.slow_motion else args.t_interp * fps
    os.system("ffmpeg -framerate " + str(output_fps) +
              " -pattern_type glob -i '" + dir_path +
              "/*.png' -pix_fmt yuv420p output.mp4")
    os.system("rm -rf %s" % dir_path)
Beispiel #2
0
cLoss = dict1['loss']
valLoss = dict1['valLoss']
valPSNR = dict1['valPSNR']
checkpoint_counter = int((dict1['epoch'] + 1) / args.checkpoint_epoch)

if args.final:
    structure_gen.eval()
    detail_enhance.eval()
    detail_enhance_last.train()
else:
    if args.GEN_DE:
        structure_gen.train()
    else:
        structure_gen.eval()
        detail_enhance.train()
bdcn.eval()

# --Main training loop--
for epoch in range(dict1['epoch'] + 1, args.epochs):
    print("Epoch: ", epoch)

    # Append and reset
    cLoss.append([])
    valLoss.append([])
    valPSNR.append([])
    iLoss = 0

    # Increment scheduler count
    scheduler.step()

    if args.test:
Beispiel #3
0
def main(interp: int, input_file: str):
    cwd = Path(__file__).resolve()
    model_file = cwd.parent / 'models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth'
    checkpoint_file = cwd.parent / 'checkpoints/FeFlow.ckpt'
    print(model_file)
    print(model_file.exists())
    print('INTERP: ', interp)
    # initial
    # iter = math.log(args.t_interp, int(2))
    iter = math.log(interp, int(2))
    if iter % 1:
        print('the times of interpolating must be power of 2!!')
        return
    iter = int(iter)
    # bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
    # bdcn.load_state_dict(torch.load('%s' % (model)))
    bdcn.load_state_dict(torch.load(model_file))
    # dict1 = torch.load(args.checkpoint)
    dict1 = torch.load(checkpoint_file)
    structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
    detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)

    bdcn.eval()
    structure_gen.eval()
    detail_enhance.eval()

    IE = 0
    PSNR = 0
    count = 0
    # [dir_path, frame_count, fps] = VideoToSequence(args.video_path, args.t_interp)
    [dir_path, frame_count, fps] = VideoToSequence(input_file, interp)

    for i in range(iter):
        print('processing iter' + str(i + 1) + ', ' +
              str((i + 1) * frame_count) + ' frames in total')
        # print('Iteration: ',iter)
        setIteration(iter)
        filenames = os.listdir(dir_path)
        filenames.sort()
        # for i in tqdm(range(0, len(filenames) - 1)):
        # print('Filename: ', filenames)
        # interpoRange : int = len(filenames) - 1
        setInterpolationRange(len(filenames) - 1)
        # print('InterpoRange: ', interpoRange)
        for i in tqdm(range(0, getInterpolationRange())):
            # global interpoIndex
            # interpoIndex = i
            setInterpolationIndex(i)
            # progressBar(getInterpolationIndex())
            # print('InterpoIndex: ', interpoIndex)
            arguments_strFirst = os.path.join(dir_path, filenames[i])
            arguments_strSecond = os.path.join(dir_path, filenames[i + 1])
            index1 = int(re.sub("\D", "", filenames[i]))
            index2 = int(re.sub("\D", "", filenames[i + 1]))
            index = int((index1 + index2) / 2)
            arguments_strOut = os.path.join(
                dir_path,
                # IndexHelper(index, len(str(args.t_interp * frame_count).zfill(10))) + ".png")
                IndexHelper(index, len(str(interp * frame_count).zfill(10))) +
                ".png")

            # print(arguments_strFirst)
            # print(arguments_strSecond)
            # print(arguments_strOut)

            X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0)
            X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0)

            assert (X0.size(2) == X1.size(2))
            assert (X0.size(3) == X1.size(3))

            intWidth = X0.size(3)
            intHeight = X0.size(2)
            channel = X0.size(1)
            if not channel == 3:
                print('Not RGB image')
                continue
            count += 1

            # if intWidth != ((intWidth >> 4) << 4):
            #     intWidth_pad = (((intWidth >> 4) + 1) << 4)  # more than necessary
            #     intPaddingLeft = int((intWidth_pad - intWidth) / 2)
            #     intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
            # else:
            #     intWidth_pad = intWidth
            #     intPaddingLeft = 0
            #     intPaddingRight = 0
            #
            # if intHeight != ((intHeight >> 4) << 4):
            #     intHeight_pad = (((intHeight >> 4) + 1) << 4)  # more than necessary
            #     intPaddingTop = int((intHeight_pad - intHeight) / 2)
            #     intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
            # else:
            #     intHeight_pad = intHeight
            #     intPaddingTop = 0
            #     intPaddingBottom = 0
            #
            # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

            # first, second = pader(X0), pader(X1)
            first, second = X0, X1
            imgt = ToImage(first, second)

            imgt_np = imgt.squeeze(0).cpu().numpy(
            )  # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
            imgt_png = np.uint8(
                ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
            cv2.imwrite(arguments_strOut, imgt_png)
            # wx.CallAfter(Publisher().sendMessage, 'update', '')

            # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
            # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))

            # diff_rgb = rec_rgb - gt_rgb
            # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2))

            # mse = np.mean((diff_rgb) ** 2)

            # PIXEL_MAX = 255.0
            # psnr = compare_psnr(gt_rgb, rec_rgb, 255)
            # print(folder, psnr)

            # IE += avg_interp_error_abs
            # PSNR += psnr

            # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)

        # IE = IE / count
        # PSNR = PSNR / count
        # print('Average IE/PSNR:', IE, PSNR)
    # if args.fps != -1:
    #     output_fps = args.fps
    # else:
    #     # output_fps = fps if args.slow_motion else args.t_interp*fps
    # output_fps = fps if args.slow_motion else interp*fps
    # if args.high_res:
    # os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "/*.png' -pix_fmt yuv420p output.mp4")
    # os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "\\*.png' -pix_fmt yuv420p output.mp4")
    # os.system("ffmpeg -f image2 -framerate " + str(output_fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4")
    # os.system("ffmpeg -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4")
    # os.system(str(ffmpeg_exe) + " -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4")
    # os.system(str(ffmpeg_exe) + " -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -vcodec libx264 -profile:v high444 -refs 16 -crf 0 -preset ultrafast output.mp4")
    os.system(
        str(ffmpeg_exe) + " -f image2 -framerate " + str(interp * fps) +
        " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4")
    # os.system("ffmpeg -f image2 -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4")
    # os.system("rm -rf %s" % dir_path)
    shutil.rmtree(dir_path)
    torch.cuda.empty_cache()
Beispiel #4
0
def main():
    # initial

    bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
    dict1 = torch.load(args.checkpoint)
    structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
    detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)

    bdcn.eval()
    structure_gen.eval()
    detail_enhance.eval()

    IE = 0
    PSNR = 0
    count = 0
    for folder in tqdm(os.listdir(args.imgpath)):
        triple_path = os.path.join(args.imgpath, folder)
        if not (os.path.isdir(triple_path)):
            continue
        X0 = transform(_pil_loader('%s/%s' %
                                   (triple_path, args.first))).unsqueeze(0)
        X1 = transform(_pil_loader('%s/%s' %
                                   (triple_path, args.second))).unsqueeze(0)

        assert (X0.size(2) == X1.size(2))
        assert (X0.size(3) == X1.size(3))

        intWidth = X0.size(3)
        intHeight = X0.size(2)
        channel = X0.size(1)
        if not channel == 3:
            print('Not RGB image')
            continue
        count += 1

        # if intWidth != ((intWidth >> 4) << 4):
        #     intWidth_pad = (((intWidth >> 4) + 1) << 4)  # more than necessary
        #     intPaddingLeft = int((intWidth_pad - intWidth) / 2)
        #     intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
        # else:
        #     intWidth_pad = intWidth
        #     intPaddingLeft = 0
        #     intPaddingRight = 0
        #
        # if intHeight != ((intHeight >> 4) << 4):
        #     intHeight_pad = (((intHeight >> 4) + 1) << 4)  # more than necessary
        #     intPaddingTop = int((intHeight_pad - intHeight) / 2)
        #     intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
        # else:
        #     intHeight_pad = intHeight
        #     intPaddingTop = 0
        #     intPaddingBottom = 0
        #
        # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom])

        # first, second = pader(X0), pader(X1)
        first, second = X0, X1
        imgt = ToImage(first, second)

        imgt_np = imgt.squeeze(0).cpu().numpy(
        )  #[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth]
        imgt_png = np.uint8(
            ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
        if not os.path.isdir(triple_path):
            os.system('mkdir -p %s' % triple_path)
        cv2.imwrite(triple_path + '/SeDraw.png', imgt_png)

        rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png')))
        gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt)))

        diff_rgb = rec_rgb - gt_rgb
        avg_interp_error_abs = np.sqrt(np.mean(diff_rgb**2))

        mse = np.mean((diff_rgb)**2)

        PIXEL_MAX = 255.0
        psnr = compare_psnr(gt_rgb, rec_rgb, 255)
        print(folder, psnr)

        IE += avg_interp_error_abs
        PSNR += psnr

        # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr)

    IE = IE / count
    PSNR = PSNR / count
    print('Average IE/PSNR:', IE, PSNR)
Beispiel #5
0
def main():
    # initial

    bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model)))
    dict1 = torch.load(args.checkpoint)
    structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False)
    detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False)

    bdcn.eval()
    structure_gen.eval()
    detail_enhance.eval()

    if not os.path.isfile(args.video_name):
        print('video not exist!')
    video = cv2.VideoCapture(args.video_name)
    if args.fix_range:
        fps = video.get(cv2.CAP_PROP_FPS) * 2
    else:
        # fps = video.get(cv2.CAP_PROP_FPS)
        fps = 25
    size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)),
            int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    # fourcc = int(video.get(cv2.CAP_PROP_FOURCC))
    video_writer = cv2.VideoWriter(args.video_name[:-4] + '_Sedraw.mp4',
                                   fourcc, fps, size)

    flag = True
    frame_group = []
    while video.isOpened():
        for i in range(args.batchsize):
            ret, frame = video.read()
            if ret:
                frame = torch.FloatTensor(frame[:, :, ::-1].transpose(
                    2, 0, 1).copy()) / 255
                frame = normalize(frame).unsqueeze(0)
                frame_group += [frame]
            else:
                break
        if len(frame_group) <= 1:
            break
        first = torch.cat(frame_group[:-1], dim=0)
        second = torch.cat(frame_group[1:], dim=0)

        middle_frame = ToImage(first, second)

        if flag:
            for i in range(first.shape[0]):
                first_np = first[i].cpu().numpy()
                first_png = np.uint8(
                    ((first_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] *
                    255)
                middle_frame_np = middle_frame[i].cpu().numpy()
                middle_frame_png = np.uint8(
                    ((middle_frame_np + 1.0) / 2.0).transpose(
                        1, 2, 0)[:, :, ::-1] * 255)
                video_writer.write(first_png)
                video_writer.write(middle_frame_png)
            second_np = second[-1].cpu().numpy()
            second_png = np.uint8(
                ((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255)
            video_writer.write(second_png)
            frame_group = [second[-1].unsqueeze(0)]
            flag = False
        else:
            for i in range(second.shape[0]):
                middle_frame_np = middle_frame[i].cpu().numpy()
                middle_frame_png = np.uint8(
                    ((middle_frame_np + 1.0) / 2.0).transpose(
                        1, 2, 0)[:, :, ::-1] * 255)
                second_np = second[i].cpu().numpy()
                second_png = np.uint8(
                    ((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] *
                    255)
                video_writer.write(middle_frame_png)
                video_writer.write(second_png)
            frame_group = [second[-1].unsqueeze(0)]

    video_writer.release()