示例#1
0
def batch_calOpt(args):
    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to(DEVICE)
    model.eval()

    # Transform
    transform = transforms.Compose([transforms.ToTensor()])

    # ImageFolder and Loader
    image_dataset = OpticalFlowFolder(IMAGE_FOLDER_PATH, transform=transform)
    image_loader = torch.utils.data.DataLoader(image_dataset,
                                               batch_size=BATCH_SIZE)

    start = time.time()
    with torch.no_grad():
        for left, right in image_loader:
            # Most of the time, this preprocessing is not needed
            # Especially if the video dimensions are multiple of 8s
            _, _, h, w = left.shape
            if ((h % 8 != 0) or (w % 8 != 0)):
                left = pad8(left)
                right = pad8(right)

            # Forward
            flow_predictions = model(left,
                                     right,
                                     iters=args.iters,
                                     upsample=False)

    print("Time Elapsed: ", time.time() - start)
示例#2
0
文件: demo.py 项目: Pavelrst/RAFT
def demo(args):
    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to(DEVICE)
    model.eval()

    with torch.no_grad():

        # sintel images
        image1 = load_image('images/sintel_0.png')
        image2 = load_image('images/sintel_1.png')

        flow_predictions = model(image1,
                                 image2,
                                 iters=args.iters,
                                 upsample=False)
        display(image1[0], image2[0], flow_predictions[-1][0])

        # kitti images
        image1 = load_image('images/kitti_0.png')
        image2 = load_image('images/kitti_1.png')

        flow_predictions = model(image1, image2, iters=16)
        display(image1[0], image2[0], flow_predictions[-1][0])

        # davis images
        image1 = load_image('images/davis_0.jpg')
        image2 = load_image('images/davis_1.jpg')

        flow_predictions = model(image1, image2, iters=16)
        display(image1[0], image2[0], flow_predictions[-1][0])
示例#3
0
def demo(args):
    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to(DEVICE)
    model.eval()

    with torch.no_grad():

        cap = cv2.VideoCapture('video.mp4')
        _, left_frame = cap.read()
        h, w, _ = left_frame.shape
        left_tensor = preprocess(left_frame)

        while (1):
            _, right_frame = cap.read()
            right_tensor = preprocess(right_frame)

            start1 = time.time()
            flow_predictions = model(left_tensor,
                                     right_tensor,
                                     iters=args.iters,
                                     upsample=True)
            print(time.time() - start1)

            flow_image = postprocess(flow_predictions, w * 2, h * 2)
            cv2.imshow('frame', flow_image)

            k = cv2.waitKey(25)
            if (k == 27):
                break

            left_tensor = right_tensor.clone()
示例#4
0
def inference(args):
    # get the RAFT model
    model = RAFT(args)
    # load pretrained weights
    pretrained_weights = torch.load(args.model)

    save = args.save
    if save:
        if not os.path.exists("demo_frames"):
            os.mkdir("demo_frames")

    if torch.cuda.is_available():
        device = "cuda"
        # parallel between available GPUs
        model = torch.nn.DataParallel(model)
        # load the pretrained weights into model
        model.load_state_dict(pretrained_weights)
        model.to(device)
    else:
        device = "cpu"
        # change key names for CPU runtime
        pretrained_weights = get_cpu_model(pretrained_weights)
        # load the pretrained weights into model
        model.load_state_dict(pretrained_weights)

    # change model's mode to evaluation
    model.eval()

    video_path = args.video
    # capture the video and get the first frame
    cap = cv2.VideoCapture(video_path)
    ret, frame_1 = cap.read()

    # frame preprocessing
    frame_1 = frame_preprocess(frame_1, device)

    counter = 0
    with torch.no_grad():
        while True:
            # read the next frame
            ret, frame_2 = cap.read()
            if not ret:
                break
            # preprocessing
            frame_2 = frame_preprocess(frame_2, device)
            # predict the flow
            flow_low, flow_up = model(frame_1,
                                      frame_2,
                                      iters=20,
                                      test_mode=True)
            # transpose the flow output and convert it into numpy array
            ret = vizualize_flow(frame_1, flow_up, save, counter)
            if not ret:
                break
            frame_1 = frame_2
            counter += 1
示例#5
0
文件: train.py 项目: ywu40/RAFT
def train(args):

    model = RAFT(args)
    model = nn.DataParallel(model)
    print("Parameter Count: %d" % count_parameters(model))

    if args.restore_ckpt is not None:
        model.load_state_dict(torch.load(args.restore_ckpt))

    model.cuda()
    model.train()

    if 'chairs' not in args.dataset:
        model.module.freeze_bn()

    train_loader = fetch_dataloader(args)
    optimizer, scheduler = fetch_optimizer(args, model)

    total_steps = 0
    logger = Logger(model, scheduler)

    should_keep_training = True
    while should_keep_training:

        for i_batch, data_blob in enumerate(train_loader):
            image1, image2, flow, valid = [x.cuda() for x in data_blob]

            optimizer.zero_grad()
            flow_predictions = model(image1, image2, iters=args.iters)

            loss, metrics = sequence_loss(flow_predictions, flow, valid)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()
            total_steps += 1

            logger.push(metrics)

            if total_steps % VAL_FREQ == VAL_FREQ - 1:
                PATH = 'checkpoints/%d_%s.pth' % (total_steps + 1, args.name)
                torch.save(model.state_dict(), PATH)

            if total_steps == args.num_steps:
                should_keep_training = False
                break

    PATH = 'checkpoints/%s.pth' % args.name
    torch.save(model.state_dict(), PATH)

    return PATH
示例#6
0
文件: evaluate.py 项目: ywu40/RAFT
            out = ((epe > 3.0) & ((epe / mag) > 0.05)).float()
            epe_list.append(epe[val].mean().item())
            out_list.append(out[val].cpu().numpy())

    epe_list = np.array(epe_list)
    out_list = np.concatenate(out_list)

    print("Validation KITTI: %f, %f" %
          (np.mean(epe_list), 100 * np.mean(out_list)))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', help="restore checkpoint")
    parser.add_argument('--small', action='store_true', help='use small model')
    parser.add_argument('--sintel_iters', type=int, default=50)
    parser.add_argument('--kitti_iters', type=int, default=32)

    args = parser.parse_args()

    model = RAFT(args)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(torch.load(args.model))

    model.to('cuda')
    model.eval()

    validate_sintel(args, model, args.sintel_iters)
    validate_kitti(args, model, args.kitti_iters)