def batch_calOpt(args): model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to(DEVICE) model.eval() # Transform transform = transforms.Compose([transforms.ToTensor()]) # ImageFolder and Loader image_dataset = OpticalFlowFolder(IMAGE_FOLDER_PATH, transform=transform) image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=BATCH_SIZE) start = time.time() with torch.no_grad(): for left, right in image_loader: # Most of the time, this preprocessing is not needed # Especially if the video dimensions are multiple of 8s _, _, h, w = left.shape if ((h % 8 != 0) or (w % 8 != 0)): left = pad8(left) right = pad8(right) # Forward flow_predictions = model(left, right, iters=args.iters, upsample=False) print("Time Elapsed: ", time.time() - start)
def demo(args): model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to(DEVICE) model.eval() with torch.no_grad(): # sintel images image1 = load_image('images/sintel_0.png') image2 = load_image('images/sintel_1.png') flow_predictions = model(image1, image2, iters=args.iters, upsample=False) display(image1[0], image2[0], flow_predictions[-1][0]) # kitti images image1 = load_image('images/kitti_0.png') image2 = load_image('images/kitti_1.png') flow_predictions = model(image1, image2, iters=16) display(image1[0], image2[0], flow_predictions[-1][0]) # davis images image1 = load_image('images/davis_0.jpg') image2 = load_image('images/davis_1.jpg') flow_predictions = model(image1, image2, iters=16) display(image1[0], image2[0], flow_predictions[-1][0])
def demo(args): model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to(DEVICE) model.eval() with torch.no_grad(): cap = cv2.VideoCapture('video.mp4') _, left_frame = cap.read() h, w, _ = left_frame.shape left_tensor = preprocess(left_frame) while (1): _, right_frame = cap.read() right_tensor = preprocess(right_frame) start1 = time.time() flow_predictions = model(left_tensor, right_tensor, iters=args.iters, upsample=True) print(time.time() - start1) flow_image = postprocess(flow_predictions, w * 2, h * 2) cv2.imshow('frame', flow_image) k = cv2.waitKey(25) if (k == 27): break left_tensor = right_tensor.clone()
def inference(args): # get the RAFT model model = RAFT(args) # load pretrained weights pretrained_weights = torch.load(args.model) save = args.save if save: if not os.path.exists("demo_frames"): os.mkdir("demo_frames") if torch.cuda.is_available(): device = "cuda" # parallel between available GPUs model = torch.nn.DataParallel(model) # load the pretrained weights into model model.load_state_dict(pretrained_weights) model.to(device) else: device = "cpu" # change key names for CPU runtime pretrained_weights = get_cpu_model(pretrained_weights) # load the pretrained weights into model model.load_state_dict(pretrained_weights) # change model's mode to evaluation model.eval() video_path = args.video # capture the video and get the first frame cap = cv2.VideoCapture(video_path) ret, frame_1 = cap.read() # frame preprocessing frame_1 = frame_preprocess(frame_1, device) counter = 0 with torch.no_grad(): while True: # read the next frame ret, frame_2 = cap.read() if not ret: break # preprocessing frame_2 = frame_preprocess(frame_2, device) # predict the flow flow_low, flow_up = model(frame_1, frame_2, iters=20, test_mode=True) # transpose the flow output and convert it into numpy array ret = vizualize_flow(frame_1, flow_up, save, counter) if not ret: break frame_1 = frame_2 counter += 1
out = ((epe > 3.0) & ((epe / mag) > 0.05)).float() epe_list.append(epe[val].mean().item()) out_list.append(out[val].cpu().numpy()) epe_list = np.array(epe_list) out_list = np.concatenate(out_list) print("Validation KITTI: %f, %f" % (np.mean(epe_list), 100 * np.mean(out_list))) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--sintel_iters', type=int, default=50) parser.add_argument('--kitti_iters', type=int, default=32) args = parser.parse_args() model = RAFT(args) model = torch.nn.DataParallel(model) model.load_state_dict(torch.load(args.model)) model.to('cuda') model.eval() validate_sintel(args, model, args.sintel_iters) validate_kitti(args, model, args.kitti_iters)