Ejemplo n.º 1
0
def test_start_end(args):
    data = np.load(os.path.join(args.data_dir, args.save_path + ".npz"),
                   allow_pickle=True)
    M = np.load('tmp_data/proj_sawyer.npy')

    img_seq = data['image']
    grip_pos_seq = data['grip_pos']

    img_seq = convert_img_torch(img_seq)
    im, start_pos = get_start_frame(True)
    start_img = convert_img_torch(im[None, Ellipsis])[0]
    goal_img = img_seq[-1]

    start_pos = convert_to_pixel(start_pos, M)
    goal_pos = convert_to_pixel(grip_pos_seq[-1], M)

    model = load_model(args)
    with torch.no_grad():
        start_keyp = model.img_to_keyp(start_img[None, None,
                                                 Ellipsis])[0,
                                                            0]  # num_keyp x 3
        goal_keyp = model.img_to_keyp(goal_img[None, None, Ellipsis])[0, 0]

        start_img_np = utils.img_torch_to_numpy(start_img)
        goal_img_np = utils.img_torch_to_numpy(goal_img)
        check_start_goal((start_img_np, start_keyp.cpu().numpy(), start_pos),
                         (goal_img_np, goal_keyp.cpu().numpy(), goal_pos))
Ejemplo n.º 2
0
def test_start_end(args):
    data = np.load(os.path.join(args.data_dir, args.save_path + ".npz"),
                   allow_pickle=True)

    img_seq = data['image']
    action_seq = data['action'].astype(np.float32)

    img_seq = convert_img_torch(img_seq)
    print(img_seq.shape)
    start_img = img_seq[0]
    goal_img = img_seq[-1]

    #top_9_idx = [46, 26, 51, 25, 3, 22, 20, 39, 19]
    #top_9_idx = [46, 26, 51, 25, 3, 22, 20, 39, 19, 35, 52, 14, 27, 38, 34]

    model = load_model(args)
    with torch.no_grad():
        start_keyp = model.img_to_keyp(start_img[None, None,
                                                 Ellipsis])[0,
                                                            0]  # num_keyp x 3
        goal_keyp = model.img_to_keyp(goal_img[None, None, Ellipsis])[0, 0]

        # start_keyp = start_keyp[top_9_idx]
        # goal_keyp = goal_keyp[top_9_idx]

        start_img_np = utils.img_torch_to_numpy(start_img)
        goal_img_np = utils.img_torch_to_numpy(goal_img)
        check_start_goal((start_img_np, start_keyp.cpu().numpy()),
                         (goal_img_np, goal_keyp.cpu().numpy()))
Ejemplo n.º 3
0
def viz_seq_unroll(args):
    torch.random.manual_seed(0)
    np.random.seed(0)

    cfg = hyperparameters.get_config(args)
    unroll_T = 16

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")

    l_dir = cfg.train_dir if args.is_train else args.test_dir
    print("Data loader: ", l_dir)
    loader, data_shapes = datasets.get_sequence_dataset(
        data_dir=os.path.join(cfg.data_dir, l_dir),
        batch_size=cfg.batch_size,
        num_timesteps=cfg.observed_steps + cfg.predicted_steps,
        shuffle=False)

    cfg.data_shapes = data_shapes
    model = train_dynamics.KeypointModel(cfg).to(device)

    if args.pretrained_path:
        checkpoint_path = get_latest_checkpoint(args.pretrained_path)
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
        print("Loading model from: ", checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()

    with torch.no_grad():
        for data in islice(loader, 1):
            img_seq = data['image'].to(device)
            pred_img_seq, pred_keyp_seq = model.unroll(img_seq, unroll_T)

            bs, T = img_seq.shape[0], img_seq.shape[1]
            print(
                "LOSS:",
                F.mse_loss(img_seq, pred_img_seq[:, :T], reduction='sum') /
                (bs * T))

            print(img_seq.shape, pred_keyp_seq.shape, pred_img_seq.shape)

            imgs_seq_np, pred_img_seq_np = img_torch_to_numpy(
                img_seq), img_torch_to_numpy(pred_img_seq)
            keypoints_seq_np = pred_keyp_seq.cpu().numpy()

            num_seq = imgs_seq_np.shape[0]
            for i in islice(range(num_seq), 3):
                save_path = os.path.join(
                    args.vids_dir,
                    args.vids_path + "_" + l_dir + "_{}.mp4".format(i))
                print(i, "Video PRED Save Path", save_path)
                viz_all_unroll(imgs_seq_np[i], pred_img_seq_np[i],
                               keypoints_seq_np[i], True, 100, save_path)
Ejemplo n.º 4
0
def main(args):
    cfg = hyperparameters.get_config(args)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")

    l_dir = cfg.train_dir if args.is_train else args.test_dir
    print("Data loader: ", l_dir)
    loader, data_shapes = datasets.get_dataset(data_dir=os.path.join(
        cfg.data_dir, l_dir),
                                               batch_size=cfg.batch_size)

    models = build_model_noseq(cfg, data_shapes).to(device)

    if args.pretrained_path:
        model_path = get_latest_checkpoint(args.pretrained_path, "*.pth")
        print("Loading model from: ", model_path)
        models.load_state_dict(torch.load(model_path))
        models.eval()

    images_to_keypoints_net, keypoints_to_images_net = models

    with torch.no_grad():
        for i, data in islice(enumerate(loader), 5):
            img = data['image'].to(device)

            keypoints, _ = images_to_keypoints_net(img)
            pred_img = keypoints_to_images_net(keypoints)

            print(img.shape, keypoints.shape, pred_img.shape)

            save_path = os.path.join(
                args.vids_dir,
                args.vids_path + "_" + l_dir + "_{}.mp4".format(i))
            print(i, "Video Save Path", save_path)

            imgs_np, pred_img_np = img_torch_to_numpy(img), img_torch_to_numpy(
                pred_img)
            keypoints_np = keypoints.cpu().numpy()

            viz_all(imgs_np, pred_img_np, keypoints_np, True, 300, save_path)
Ejemplo n.º 5
0
def check_recon(args):
    data = np.load(os.path.join(args.data_dir, args.save_path + ".npz"),
                   allow_pickle=True)

    img_seq = data['image']
    action_seq = data['action']

    img_seq = convert_img_torch(img_seq).unsqueeze(0)
    action_seq = torch.from_numpy(action_seq.astype(np.float32)).unsqueeze(0)

    model = load_model(args)

    with torch.no_grad():
        keypoints_seq, heatmaps_seq, recon_img_seq, pred_img_seq, pred_keyp_seq = \
            model(img_seq, action_seq)

        print(
            "LOSS:",
            F.mse_loss(img_seq, recon_img_seq, reduction='sum') /
            ((img_seq.shape[0]) * img_seq.shape[1]))
        img_seq_np, recon_img_seq_np = utils.img_torch_to_numpy(
            img_seq), utils.img_torch_to_numpy(recon_img_seq)
        keypoints_seq_np = keypoints_seq.cpu().numpy()

        d = {
            'img': img_seq_np,
            'pred_img': recon_img_seq_np,
            'keyp': keypoints_seq_np,
            'heatmap': heatmaps_seq.permute(0, 1, 3, 4, 2).cpu().numpy(),
            'action': action_seq.cpu().numpy() if 'action' in data else None
        }

        tmp_save_path = 'tmp_data/{}_GOAL_data_{}'.format(
            "test", args.save_path)
        print("Save intermediate data path: ", tmp_save_path)
        np.savez(tmp_save_path, **d)

        save_path = 'vids/check.mp4'
        viz_track(img_seq_np[0], recon_img_seq_np[0], keypoints_seq_np[0],
                  True, 100, save_path)
Ejemplo n.º 6
0
    def save_sample_keyp(self, img_seq, keyp_seq, file_id_seq, frame_id_seq,
                         step_num, save_dir):
        """
        :param img_seq: N x 3 x H x W
        :param keyp_seq: N x num_keyp x 3
        :param step_num: int
        """
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        img_seq_np = utils.img_torch_to_numpy(img_seq)
        keyp_seq_np = keyp_seq.cpu().numpy()

        file_dir = "file_{}_frame_{}"

        N, num_keyp = keyp_seq_np.shape[:2]

        for n in range(N):
            file_id = file_id_seq[n]
            frame_id = frame_id_seq[n]
            img = img_seq_np[n]
            keyps = keyp_seq_np[n]

            save_file_dir = os.path.join(save_dir,
                                         file_dir.format(file_id, frame_id))

            if not os.path.isdir(save_file_dir):
                os.makedirs(save_file_dir)

            keyps_history_path = os.path.join(save_file_dir,
                                              "keyps_history.npy")
            if not os.path.isfile(keyps_history_path):
                keyps_history = keyps[np.newaxis, :, :]
            else:
                prev_keyps_history = np.load(keyps_history_path)
                keyps_history = np.concatenate(
                    (prev_keyps_history, keyps[np.newaxis, :, :]))

            for k in range(num_keyp):
                save_path = os.path.join(save_file_dir,
                                         'keyp_{}.png'.format(k))
                keyp_history = keyps_history[:, k]
                save_img_keyp(img, keyp_history, save_path, k, step_num)

            np.save(keyps_history_path, keyps_history)

        self.log_steps += 1
Ejemplo n.º 7
0
def viz_seq(args):
    utils.set_seed_everywhere(args.seed)
    cfg = hyperparameters.get_config(args)

    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")

    l_dir = cfg.train_dir if args.is_train else args.test_dir
    print("Data loader: ", l_dir)
    loader, data_shapes = datasets.get_sequence_dataset(
        data_dir=os.path.join(cfg.data_dir, l_dir),
        batch_size=10,
        num_timesteps=args.timesteps,
        shuffle=True)

    cfg.data_shapes = data_shapes

    model = train_keyp_inverse_forward.KeypointModel(cfg).to(device)

    if args.pretrained_path:
        if args.ckpt:
            checkpoint_path = os.path.join(
                args.pretrained_path, "_ckpt_epoch_" + args.ckpt + ".ckpt")
        else:
            print("Loading latest")
            checkpoint_path = get_latest_checkpoint(args.pretrained_path)
        checkpoint = torch.load(checkpoint_path,
                                map_location=lambda storage, loc: storage)
        print("Loading model from: ", checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()
        print("Load complete")

    with torch.no_grad():
        for data in islice(loader, 1):
            img_seq = data['image'].to(device)
            action_seq = data['action'].to(device)

            keypoints_seq, heatmaps_seq, pred_keyp_seq, pred_action_seq = model(
                img_seq, action_seq)
            print(
                "Keypoint Pred LOSS:",
                F.mse_loss(pred_keyp_seq[Ellipsis, :2],
                           keypoints_seq[:, 1:, :, :2],
                           reduction='sum') /
                ((pred_keyp_seq.shape[0]) * pred_keyp_seq.shape[1]))
            if args.unroll:
                pred_keyp_seq = model.unroll(img_seq, action_seq)

            pred_keyp_seq_np = pred_keyp_seq.cpu().numpy()

            print(img_seq.shape, keypoints_seq.shape)

            img_seq_np = img_torch_to_numpy(img_seq)
            heatmaps_seq_np = heatmaps_seq.permute(0, 1, 3, 4, 2).cpu().numpy()
            keypoints_seq_np = keypoints_seq.cpu().numpy()

            d = {
                'img':
                img_seq_np,
                'keyp':
                keypoints_seq_np,
                'heatmap':
                heatmaps_seq.permute(0, 1, 3, 4, 2).cpu().numpy(),
                'action':
                data['action'].cpu().numpy() if 'action' in data else None
            }

            tmp_save_path = 'tmp_data/{}_data_{}_seed_{}'.format(
                l_dir, args.vids_path, args.seed)
            print("Save intermediate data path: ", tmp_save_path)
            np.savez(tmp_save_path, **d)

            num_seq = img_seq_np.shape[0]
            for i in islice(range(num_seq), 3):
                save_path = os.path.join(
                    args.vids_dir, args.vids_path + "_" + l_dir +
                    "_{}_seed_{}.mp4".format(i, args.seed))
                print(i, "Video Save Path", save_path)
                viz_keypoints(img_seq_np[i], keypoints_seq_np[i], True, 100,
                              save_path, args.annotate)