Exemple #1
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives"
    )
    parser.add_argument(
        "dataset_directory",
        help="Path to the directory containing the dataset"
    )
    parser.add_argument(
        "output_directory",
        help="Save the output files in that directory"
    )
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files"
    )
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used"
    )

    parser.add_argument(
        "--n_primitives",
        type=int,
        default=32,
        help="Number of primitives"
    )
    parser.add_argument(
        "--prob_threshold",
        type=float,
        default=0.5,
        help="Probability threshold"
    )
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration"
    )
    parser.add_argument(
        "--save_prediction_as_mesh",
        action="store_true",
        help="When true store prediction as a mesh"
    )
    parser.add_argument(
        "--run_on_gpu",
        action="store_true",
        help="Use GPU"
    )
    parser.add_argument(
        "--with_animation",
        action="store_true",
        help="Add animation"
    )

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)

    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print ("Running code on {}".format(device))

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(
        args.voxelizer_factory,
        np.array(voxelizer_shape(args)),
        args.save_voxels_to
    )

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder()
            .with_dataset(args.dataset_type)
            .filter_tags(args.model_tags)
            .build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(voxelizer_factory)
    )

    model_tags = dataset._dataset_object._tags

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(
            torch.load(args.weight_file, map_location=device)
        )
    model.eval()

    colors = get_colors(args.n_primitives)

    for sample_idx, sample in enumerate(dataloader):

        model_tag = model_tags[sample_idx]

        X, y_target = sample
        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        M = args.n_primitives  # number of primitives
        probs = y_hat[0].to("cpu").detach().numpy()
        # Transform the Euler angles to rotation matrices
        if y_hat[2].shape[1] == 3:
            R = euler_angles_to_rotation_matrices(
                y_hat[2].view(-1, 3)
            ).to("cpu").detach()
        else:
            R = quaternions_to_rotation_matrices(
                    y_hat[2].view(-1, 4)
                ).to("cpu").detach()
            # get also the raw quaternions
            quats = y_hat[2].view(-1, 4).to("cpu").detach().numpy()
        translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
        translations = translations.detach().numpy()

        shapes = y_hat[3].to("cpu").view(args.n_primitives, 3).detach().numpy()
        epsilons = y_hat[4].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()
        taperings = y_hat[5].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()

        pts = y_target[:, :, :3].to("cpu")
        pts_labels = y_target[:, :, -1].to("cpu").squeeze().numpy()
        pts = pts.squeeze().detach().numpy().T

        on_prims = 0

        # XXX: UNTIL I FIX THE MLAB ISSUE
        # fig = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        # mlab.view(azimuth=0.0, elevation=0.0, distance=2)

        # Uncomment to visualize the points sampled from the target mesh
        # t = np.array([1.2, 0, 0]).reshape(3, -1)
        # pts_n = pts + t
        #     mlab.points3d(
        #        # pts_n[0], pts_n[1], pts_n[2],
        #        pts[0], pts[1], pts[2],
        #        scale_factor=0.03, color=(0.8, 0.8, 0.8)
        #     )


        save_dir = os.path.join(args.output_directory,
                                os.path.basename(os.path.dirname(args.dataset_directory)),
                                model_tag)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        # args.output_directory/class/model_id/primitive_%d.p
        # args.output_directory/class/model_id/reconstruction

        # Keep track of the files containing the parameters of each primitive
        primitive_files = []
        for i in range(args.n_primitives):
            x_tr, y_tr, z_tr, prim_pts =\
                get_shape_configuration(args.use_cuboids)(
                    shapes[i, 0],
                    shapes[i, 1],
                    shapes[i, 2],
                    epsilons[i, 0],
                    epsilons[i, 1],
                    R[i].numpy(),
                    translations[i].reshape(-1, 1),
                    taperings[i, 0],
                    taperings[i, 1]
                )

            # Dump the parameters of each primitive as a dictionary
            # TODO: change filepath
            store_primitive_parameters(
                size=tuple(shapes[i]),
                shape=tuple(epsilons[i]),
                rotation=tuple(quats[i]),
                location=tuple(translations[i]),
                tapering=tuple(taperings[i]),
                probability=(probs[0, i],),
                color=(colors[i % len(colors)]) + (1.0,),
                filepath=os.path.join(
                    save_dir,
                    "primitive_%d.p" %(i,)
                )
            )
            if probs[0, i] >= args.prob_threshold:
                on_prims += 1
                # mlab.mesh(
                #     x_tr,
                #     y_tr,
                #     z_tr,
                #     color=tuple(colors[i % len(colors)]),
                #     opacity=1.0
                # )
                primitive_files.append(
                    os.path.join(save_dir, "primitive_%d.p" % (i,))
                )

        if args.with_animation:
            cnt = 0
            for az in range(0, 360, 1):
                cnt += 1

                # XXX UNTIL I FIX THE MLAB ISSUE
                # mlab.view(azimuth=az, elevation=0.0, distance=2)
                # mlab.savefig(
                #     os.path.join(
                #         args.output_directory,
                #         "img_%04d.png" % (cnt,)
                #     )
                # )
        for i in range(args.n_primitives):
            print("{} {}".format(i, probs[0, i]))

        print ("Using %d primitives out of %d" % (on_prims, args.n_primitives))

        # XXX UNTIL I FIX THE MLAB ISSUE
        # mlab.show()

        # TODO: from_primitive_parms_to_mesh()
        # TODO: get parts for this chair.
        # TODO: push the parts and superquadric meshes through the metric function
        # TODO: record metrics

        if args.save_prediction_as_mesh:
            # TODO: save with model information, class information ...etc
            print ("Saving prediction as mesh....")
            save_prediction_as_ply(
                primitive_files,
                os.path.join(save_dir, "reconstruction.ply")
            )
            print("Saved prediction as ply file in {}".format(
                os.path.join(save_dir, "reconstruction.ply")
            ))
Exemple #2
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives"
    )
    parser.add_argument(
        "dataset_directory",
        help="Path to the directory containing the dataset"
    )
    parser.add_argument(
        "output_directory",
        help="Save the output files in that directory"
    )
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files"
    )
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used"
    )

    parser.add_argument(
        "--n_primitives",
        type=int,
        default=32,
        help="Number of primitives"
    )
    parser.add_argument(
        "--prob_threshold",
        type=float,
        default=0.5,
        help="Probability threshold"
    )
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration"
    )
    parser.add_argument(
        "--save_prediction_as_mesh",
        action="store_true",
        help="When true store prediction as a mesh"
    )
    parser.add_argument(
        "--run_on_gpu",
        action="store_true",
        help="Use GPU"
    )
    parser.add_argument(
        "--with_animation",
        action="store_true",
        help="Add animation"
    )

    parser.add_argument(
        "--train_test_splits_file",
        default=None,
        help="Path to the train-test splits file"
    ) # we are going to test on the test splits

    parser.add_argument(
        '--save_individual_IOUs',
        action="store_true",
        help="saves the IOU with partnet parts for every 3d model evaluated."
        )
    parser.add_argument(
        '--recenter_superquadrics',
        action="store_true",
        help="recenters the superquadrics to the part bounding boxes before evaluating metric"
        )

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)


    if args.train_test_splits_file is not None:
        train_test_splits = parse_train_test_splits(
            args.train_test_splits_file,
            args.model_tags
        )
        test_tags = np.hstack([
            train_test_splits["test"]
        ])
    else:
        test_tags = args.model_tags


    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print ("Running code on {}".format(device))

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(
        args.voxelizer_factory,
        np.array(voxelizer_shape(args)),
        args.save_voxels_to
    )

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder()
            .with_dataset(args.dataset_type)
            .filter_tags(test_tags)
            .build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(voxelizer_factory)
    )

    model_tags = dataset._dataset_object._tags

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(
            torch.load(args.weight_file, map_location=device)
        )
    model.eval()

    colors = get_colors(args.n_primitives)


    # getting the big picture for parts
    id_set = set([])
    label_set = set([])
    parts_dir = '../parts_output/{}'.format(os.path.basename(os.path.dirname(args.dataset_directory)))
    for pick in os.listdir(parts_dir):
        with open(os.path.join(parts_dir, pick), 'rb') as f:
            leaves = pickle.load(f)
        for leaf in leaves:
            id_set.add(leaf['id'])
            label_set.add(leaf['label'])


    print('Found a total of {} part id'.format(len(id_set)))
    part_id_list = sorted(list(id_set))
    part_id_to_IOUidx = {id_: idx for idx, id_ in enumerate(part_id_list)}
    part_IOUidx_to_id = {idx: id_ for idx, id_ in enumerate(part_id_list)}

    IOU_matrix = np.zeros((args.n_primitives, len(id_set)))
    IOU_counter = 0

    for sample_idx, sample in enumerate(dataloader):

        model_tag = model_tags[sample_idx]
        print('evaluating model_tag {}'.format(model_tag))

        X, y_target = sample[1]

        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        M = args.n_primitives  # number of primitives
        probs = y_hat[0].to("cpu").detach().numpy()
        # Transform the Euler angles to rotation matrices
        if y_hat[2].shape[1] == 3:
            R = euler_angles_to_rotation_matrices(
                y_hat[2].view(-1, 3)
            ).to("cpu").detach()
        else:
            R = quaternions_to_rotation_matrices(
                    y_hat[2].view(-1, 4)
                ).to("cpu").detach()
            # get also the raw quaternions
            quats = y_hat[2].view(-1, 4).to("cpu").detach().numpy()
        translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
        translations = translations.detach().numpy()

        shapes = y_hat[3].to("cpu").view(args.n_primitives, 3).detach().numpy()
        epsilons = y_hat[4].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()
        taperings = y_hat[5].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()

        pts = y_target[:, :, :3].to("cpu")
        pts_labels = y_target[:, :, -1].to("cpu").squeeze().numpy()
        pts = pts.squeeze().detach().numpy().T

        on_prims = 0

        # XXX: UNTIL I FIX THE MLAB ISSUE
        # fig = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        # mlab.view(azimuth=0.0, elevation=0.0, distance=2)

        # Uncomment to visualize the points sampled from the target mesh
        # t = np.array([1.2, 0, 0]).reshape(3, -1)
        # pts_n = pts + t
        #     mlab.points3d(
        #        # pts_n[0], pts_n[1], pts_n[2],
        #        pts[0], pts[1], pts[2],
        #        scale_factor=0.03, color=(0.8, 0.8, 0.8)
        #     )


        save_dir = os.path.join(args.output_directory,
                                os.path.basename(os.path.dirname(args.dataset_directory)),
                                model_tag)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        # args.output_directory/class/model_id/primitive_%d.p
        # args.output_directory/class/model_id/reconstruction

        # Keep track of the files containing the parameters of each primitive
        primitive_files = []
        primitive_indices = []
        for i in range(args.n_primitives):
            x_tr, y_tr, z_tr, prim_pts =\
                get_shape_configuration(args.use_cuboids)(
                    shapes[i, 0],
                    shapes[i, 1],
                    shapes[i, 2],
                    epsilons[i, 0],
                    epsilons[i, 1],
                    R[i].numpy(),
                    translations[i].reshape(-1, 1),
                    taperings[i, 0],
                    taperings[i, 1]
                )

            # Dump the parameters of each primitive as a dictionary
            # TODO: change filepath
            store_primitive_parameters(
                size=tuple(shapes[i]),
                shape=tuple(epsilons[i]),
                rotation=tuple(quats[i]),
                location=tuple(translations[i]),
                tapering=tuple(taperings[i]),
                probability=(probs[0, i],),
                color=(colors[i % len(colors)]) + (1.0,),
                filepath=os.path.join(
                    save_dir,
                    "primitive_%d.p" %(i,)
                )
            )
            if probs[0, i] >= args.prob_threshold:
                on_prims += 1
                # mlab.mesh(
                #     x_tr,
                #     y_tr,
                #     z_tr,
                #     color=tuple(colors[i % len(colors)]),
                #     opacity=1.0
                # )
                primitive_files.append(
                    os.path.join(save_dir, "primitive_%d.p" % (i,))
                )
                primitive_indices.append(i)

        if args.with_animation:
            cnt = 0
            for az in range(0, 360, 1):
                cnt += 1

                # XXX UNTIL I FIX THE MLAB ISSUE
                # mlab.view(azimuth=az, elevation=0.0, distance=2)
                # mlab.savefig(
                #     os.path.join(
                #         args.output_directory,
                #         "img_%04d.png" % (cnt,)
                #     )
                # )
        # for i in range(args.n_primitives):
        #     print("{} {}".format(i, probs[0, i]))

        print ("Using %d primitives out of %d" % (on_prims, args.n_primitives))

        # XXX UNTIL I FIX THE MLAB ISSUE
        # mlab.show()

        # get the meshes
        superquadric_id_meshes = []
        for i, p in zip(primitive_indices, primitive_files):
            prim_params = pickle.load(open(p, "rb"))
            _m = _from_primitive_parms_to_mesh(prim_params)
            superquadric_id_meshes.append((i, _m))

        # get the parts
        with open(os.path.join(parts_dir, '{}.pkl'.format(model_tag)), 'rb') as f:
            leaf_parts = pickle.load(f)
        # get the bounding box information
        part_id_meshes = [(leaf['id'], leaf['bbox_mesh']) for leaf in leaf_parts]
        part_id_bboxes = [(leaf['id'], leaf['bbox']) for leaf in leaf_parts]

        if args.recenter_superquadrics: # recenter superquadric predictions
            # moving the superquadrics to the part centers
            supe_vert = np.vstack([el[1].vertices for el in superquadric_id_meshes])
            assert supe_vert.shape[1] == 3
            supe_center = 0.5*(np.max(supe_vert, axis=0) + np.min(supe_vert, axis=0))

            part_vert = np.vstack([el[1].vertices for el in part_id_meshes])
            assert part_vert.shape[1] == 3
            part_center = 0.5*(np.max(part_vert, axis=0) + np.min(part_vert, axis=0))

            for el in superquadric_id_meshes:
                el[1].vertices = el[1].vertices - supe_center + part_center



        # TODO: push the parts and superquadric meshes through the metric function
        IOU_matrix, delta_IOU = update_IOUs(IOU_matrix,
                                            superquadric_id_meshes,
                                            part_id_bboxes,
                                            part_id_meshes,
                                            part_id_to_IOUidx)
        IOU_counter += 1
        mean_IOU_matrix = IOU_matrix/IOU_counter

        if IOU_counter % 5 == 0:
            bestIOU, supeidx2partidx = get_consistency(mean_IOU_matrix)
            supeid2partid = {supe_id: part_IOUidx_to_id[part_id]
                             for supe_id, part_id in supeidx2partidx}
            print('##################################')
            print('Best superquadric -> part matching')
            print(supeid2partid)
            print('Best IOU')
            print(bestIOU)
            print('##################################')

        if args.save_prediction_as_mesh:
            # TODO: save with model information, class information ...etc
            print ("Saving prediction as mesh....")
            save_prediction_as_ply(
                primitive_files,
                os.path.join(save_dir, "reconstruction.ply")
            )
            print("Saved prediction as ply file in {}".format(
                os.path.join(save_dir, "reconstruction.ply")
            ))

        if args.save_individual_IOUs:
            save_dict = {'shapenetid': model_tag,
                         'indiv_IOU': delta_IOU,
                         'id2labels': {leaf['id']: leaf['label']
                                       for leaf in leaf_parts},
                         'part_id2idx': part_id_to_IOUidx,
                         'part_bboxes': part_id_bboxes,
                         'part_mesh': part_id_meshes,
                         'superquadrics': superquadric_id_meshes}
            with open(os.path.join(save_dir, "part_IOU.pkl"), 'wb') as f:
                pickle.dump(save_dict, f)


    # record metrics
    # TODO: record metrics

    # 1) the matrix of IOU's
    # 2) the assignments: superquadric_id's to partid

    mean_IOU_matrix = IOU_matrix/IOU_counter
    bestIOU, supeidx2partidx = get_consistency(mean_IOU_matrix)
    supeid2partid = {supe_id: part_IOUidx_to_id[part_id]
                     for supe_id, part_id in supeidx2partidx}
    print('##################################')
    print('Best superquadric -> part matching')
    print(supeid2partid)
    print('Best IOU')
    print(bestIOU)
    print('##################################')