def main(): # initial iter = math.log(args.t_interp, int(2)) if iter % 1: print('the times of interpolating must be power of 2!!') return iter = int(iter) bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) dict1 = torch.load(args.checkpoint) structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) bdcn.eval() structure_gen.eval() detail_enhance.eval() IE = 0 PSNR = 0 count = 0 [dir_path, frame_count, fps] = VideoToSequence(args.video_path, args.t_interp) for i in range(iter): print('processing iter' + str(i + 1) + ', ' + str((i + 1) * frame_count) + ' frames in total') filenames = os.listdir(dir_path) filenames.sort() for i in range(0, len(filenames) - 1): arguments_strFirst = os.path.join(dir_path, filenames[i]) arguments_strSecond = os.path.join(dir_path, filenames[i + 1]) index1 = int(re.sub("\D", "", filenames[i])) index2 = int(re.sub("\D", "", filenames[i + 1])) index = int((index1 + index2) / 2) arguments_strOut = os.path.join( dir_path, IndexHelper(index, len(str(args.t_interp * frame_count))) + ".png") # print(arguments_strFirst) # print(arguments_strSecond) # print(arguments_strOut) X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0) X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0) assert (X0.size(2) == X1.size(2)) assert (X0.size(3) == X1.size(3)) intWidth = X0.size(3) intHeight = X0.size(2) channel = X0.size(1) if not channel == 3: print('Not RGB image') continue count += 1 # if intWidth != ((intWidth >> 4) << 4): # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary # intPaddingLeft = int((intWidth_pad - intWidth) / 2) # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft # else: # intWidth_pad = intWidth # intPaddingLeft = 0 # intPaddingRight = 0 # # if intHeight != ((intHeight >> 4) << 4): # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary # intPaddingTop = int((intHeight_pad - intHeight) / 2) # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop # else: # intHeight_pad = intHeight # intPaddingTop = 0 # intPaddingBottom = 0 # # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) # first, second = pader(X0), pader(X1) first, second = X0, X1 imgt = ToImage(first, second) imgt_np = imgt.squeeze(0).cpu().numpy( ) # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] imgt_png = np.uint8( ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) cv2.imwrite(arguments_strOut, imgt_png) # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) # diff_rgb = rec_rgb - gt_rgb # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2)) # mse = np.mean((diff_rgb) ** 2) # PIXEL_MAX = 255.0 # psnr = compare_psnr(gt_rgb, rec_rgb, 255) # print(folder, psnr) # IE += avg_interp_error_abs # PSNR += psnr # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) # IE = IE / count # PSNR = PSNR / count # print('Average IE/PSNR:', IE, PSNR) if args.fps != -1: output_fps = args.fps else: output_fps = fps if args.slow_motion else args.t_interp * fps os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "/*.png' -pix_fmt yuv420p output.mp4") os.system("rm -rf %s" % dir_path)
cLoss = dict1['loss'] valLoss = dict1['valLoss'] valPSNR = dict1['valPSNR'] checkpoint_counter = int((dict1['epoch'] + 1) / args.checkpoint_epoch) if args.final: structure_gen.eval() detail_enhance.eval() detail_enhance_last.train() else: if args.GEN_DE: structure_gen.train() else: structure_gen.eval() detail_enhance.train() bdcn.eval() # --Main training loop-- for epoch in range(dict1['epoch'] + 1, args.epochs): print("Epoch: ", epoch) # Append and reset cLoss.append([]) valLoss.append([]) valPSNR.append([]) iLoss = 0 # Increment scheduler count scheduler.step() if args.test:
def main(interp: int, input_file: str): cwd = Path(__file__).resolve() model_file = cwd.parent / 'models/bdcn/final-model/bdcn_pretrained_on_bsds500.pth' checkpoint_file = cwd.parent / 'checkpoints/FeFlow.ckpt' print(model_file) print(model_file.exists()) print('INTERP: ', interp) # initial # iter = math.log(args.t_interp, int(2)) iter = math.log(interp, int(2)) if iter % 1: print('the times of interpolating must be power of 2!!') return iter = int(iter) # bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) # bdcn.load_state_dict(torch.load('%s' % (model))) bdcn.load_state_dict(torch.load(model_file)) # dict1 = torch.load(args.checkpoint) dict1 = torch.load(checkpoint_file) structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) bdcn.eval() structure_gen.eval() detail_enhance.eval() IE = 0 PSNR = 0 count = 0 # [dir_path, frame_count, fps] = VideoToSequence(args.video_path, args.t_interp) [dir_path, frame_count, fps] = VideoToSequence(input_file, interp) for i in range(iter): print('processing iter' + str(i + 1) + ', ' + str((i + 1) * frame_count) + ' frames in total') # print('Iteration: ',iter) setIteration(iter) filenames = os.listdir(dir_path) filenames.sort() # for i in tqdm(range(0, len(filenames) - 1)): # print('Filename: ', filenames) # interpoRange : int = len(filenames) - 1 setInterpolationRange(len(filenames) - 1) # print('InterpoRange: ', interpoRange) for i in tqdm(range(0, getInterpolationRange())): # global interpoIndex # interpoIndex = i setInterpolationIndex(i) # progressBar(getInterpolationIndex()) # print('InterpoIndex: ', interpoIndex) arguments_strFirst = os.path.join(dir_path, filenames[i]) arguments_strSecond = os.path.join(dir_path, filenames[i + 1]) index1 = int(re.sub("\D", "", filenames[i])) index2 = int(re.sub("\D", "", filenames[i + 1])) index = int((index1 + index2) / 2) arguments_strOut = os.path.join( dir_path, # IndexHelper(index, len(str(args.t_interp * frame_count).zfill(10))) + ".png") IndexHelper(index, len(str(interp * frame_count).zfill(10))) + ".png") # print(arguments_strFirst) # print(arguments_strSecond) # print(arguments_strOut) X0 = transform(_pil_loader(arguments_strFirst)).unsqueeze(0) X1 = transform(_pil_loader(arguments_strSecond)).unsqueeze(0) assert (X0.size(2) == X1.size(2)) assert (X0.size(3) == X1.size(3)) intWidth = X0.size(3) intHeight = X0.size(2) channel = X0.size(1) if not channel == 3: print('Not RGB image') continue count += 1 # if intWidth != ((intWidth >> 4) << 4): # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary # intPaddingLeft = int((intWidth_pad - intWidth) / 2) # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft # else: # intWidth_pad = intWidth # intPaddingLeft = 0 # intPaddingRight = 0 # # if intHeight != ((intHeight >> 4) << 4): # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary # intPaddingTop = int((intHeight_pad - intHeight) / 2) # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop # else: # intHeight_pad = intHeight # intPaddingTop = 0 # intPaddingBottom = 0 # # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) # first, second = pader(X0), pader(X1) first, second = X0, X1 imgt = ToImage(first, second) imgt_np = imgt.squeeze(0).cpu().numpy( ) # [:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] imgt_png = np.uint8( ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) cv2.imwrite(arguments_strOut, imgt_png) # wx.CallAfter(Publisher().sendMessage, 'update', '') # rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) # gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) # diff_rgb = rec_rgb - gt_rgb # avg_interp_error_abs = np.sqrt(np.mean(diff_rgb ** 2)) # mse = np.mean((diff_rgb) ** 2) # PIXEL_MAX = 255.0 # psnr = compare_psnr(gt_rgb, rec_rgb, 255) # print(folder, psnr) # IE += avg_interp_error_abs # PSNR += psnr # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) # IE = IE / count # PSNR = PSNR / count # print('Average IE/PSNR:', IE, PSNR) # if args.fps != -1: # output_fps = args.fps # else: # # output_fps = fps if args.slow_motion else args.t_interp*fps # output_fps = fps if args.slow_motion else interp*fps # if args.high_res: # os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "/*.png' -pix_fmt yuv420p output.mp4") # os.system("ffmpeg -framerate " + str(output_fps) + " -pattern_type glob -i '" + dir_path + "\\*.png' -pix_fmt yuv420p output.mp4") # os.system("ffmpeg -f image2 -framerate " + str(output_fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4") # os.system("ffmpeg -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4") # os.system(str(ffmpeg_exe) + " -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4") # os.system(str(ffmpeg_exe) + " -f image2 -framerate " + str(interp*fps) + " -i .\\" + dir_path + "\\%010d.png -vcodec libx264 -profile:v high444 -refs 16 -crf 0 -preset ultrafast output.mp4") os.system( str(ffmpeg_exe) + " -f image2 -framerate " + str(interp * fps) + " -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4") # os.system("ffmpeg -f image2 -i .\\" + dir_path + "\\%010d.png -pix_fmt yuv420p output.mp4") # os.system("rm -rf %s" % dir_path) shutil.rmtree(dir_path) torch.cuda.empty_cache()
def main(): # initial bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) dict1 = torch.load(args.checkpoint) structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) bdcn.eval() structure_gen.eval() detail_enhance.eval() IE = 0 PSNR = 0 count = 0 for folder in tqdm(os.listdir(args.imgpath)): triple_path = os.path.join(args.imgpath, folder) if not (os.path.isdir(triple_path)): continue X0 = transform(_pil_loader('%s/%s' % (triple_path, args.first))).unsqueeze(0) X1 = transform(_pil_loader('%s/%s' % (triple_path, args.second))).unsqueeze(0) assert (X0.size(2) == X1.size(2)) assert (X0.size(3) == X1.size(3)) intWidth = X0.size(3) intHeight = X0.size(2) channel = X0.size(1) if not channel == 3: print('Not RGB image') continue count += 1 # if intWidth != ((intWidth >> 4) << 4): # intWidth_pad = (((intWidth >> 4) + 1) << 4) # more than necessary # intPaddingLeft = int((intWidth_pad - intWidth) / 2) # intPaddingRight = intWidth_pad - intWidth - intPaddingLeft # else: # intWidth_pad = intWidth # intPaddingLeft = 0 # intPaddingRight = 0 # # if intHeight != ((intHeight >> 4) << 4): # intHeight_pad = (((intHeight >> 4) + 1) << 4) # more than necessary # intPaddingTop = int((intHeight_pad - intHeight) / 2) # intPaddingBottom = intHeight_pad - intHeight - intPaddingTop # else: # intHeight_pad = intHeight # intPaddingTop = 0 # intPaddingBottom = 0 # # pader = torch.nn.ReflectionPad2d([intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom]) # first, second = pader(X0), pader(X1) first, second = X0, X1 imgt = ToImage(first, second) imgt_np = imgt.squeeze(0).cpu().numpy( ) #[:, intPaddingTop:intPaddingTop+intHeight, intPaddingLeft: intPaddingLeft+intWidth] imgt_png = np.uint8( ((imgt_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) if not os.path.isdir(triple_path): os.system('mkdir -p %s' % triple_path) cv2.imwrite(triple_path + '/SeDraw.png', imgt_png) rec_rgb = np.array(_pil_loader('%s/%s' % (triple_path, 'SeDraw.png'))) gt_rgb = np.array(_pil_loader('%s/%s' % (triple_path, args.gt))) diff_rgb = rec_rgb - gt_rgb avg_interp_error_abs = np.sqrt(np.mean(diff_rgb**2)) mse = np.mean((diff_rgb)**2) PIXEL_MAX = 255.0 psnr = compare_psnr(gt_rgb, rec_rgb, 255) print(folder, psnr) IE += avg_interp_error_abs PSNR += psnr # print(triple_path, ': IE/PSNR:', avg_interp_error_abs, psnr) IE = IE / count PSNR = PSNR / count print('Average IE/PSNR:', IE, PSNR)
def main(): # initial bdcn.load_state_dict(torch.load('%s' % (args.bdcn_model))) dict1 = torch.load(args.checkpoint) structure_gen.load_state_dict(dict1['state_dictGEN'], strict=False) detail_enhance.load_state_dict(dict1['state_dictDE'], strict=False) bdcn.eval() structure_gen.eval() detail_enhance.eval() if not os.path.isfile(args.video_name): print('video not exist!') video = cv2.VideoCapture(args.video_name) if args.fix_range: fps = video.get(cv2.CAP_PROP_FPS) * 2 else: # fps = video.get(cv2.CAP_PROP_FPS) fps = 25 size = (int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))) fourcc = cv2.VideoWriter_fourcc(*'mp4v') # fourcc = int(video.get(cv2.CAP_PROP_FOURCC)) video_writer = cv2.VideoWriter(args.video_name[:-4] + '_Sedraw.mp4', fourcc, fps, size) flag = True frame_group = [] while video.isOpened(): for i in range(args.batchsize): ret, frame = video.read() if ret: frame = torch.FloatTensor(frame[:, :, ::-1].transpose( 2, 0, 1).copy()) / 255 frame = normalize(frame).unsqueeze(0) frame_group += [frame] else: break if len(frame_group) <= 1: break first = torch.cat(frame_group[:-1], dim=0) second = torch.cat(frame_group[1:], dim=0) middle_frame = ToImage(first, second) if flag: for i in range(first.shape[0]): first_np = first[i].cpu().numpy() first_png = np.uint8( ((first_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) middle_frame_np = middle_frame[i].cpu().numpy() middle_frame_png = np.uint8( ((middle_frame_np + 1.0) / 2.0).transpose( 1, 2, 0)[:, :, ::-1] * 255) video_writer.write(first_png) video_writer.write(middle_frame_png) second_np = second[-1].cpu().numpy() second_png = np.uint8( ((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) video_writer.write(second_png) frame_group = [second[-1].unsqueeze(0)] flag = False else: for i in range(second.shape[0]): middle_frame_np = middle_frame[i].cpu().numpy() middle_frame_png = np.uint8( ((middle_frame_np + 1.0) / 2.0).transpose( 1, 2, 0)[:, :, ::-1] * 255) second_np = second[i].cpu().numpy() second_png = np.uint8( ((second_np + 1.0) / 2.0).transpose(1, 2, 0)[:, :, ::-1] * 255) video_writer.write(middle_frame_png) video_writer.write(second_png) frame_group = [second[-1].unsqueeze(0)] video_writer.release()