コード例 #1
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Train a network to predict primitives")
    parser.add_argument("dataset_directory",
                        help="Path to the directory containing the dataset")
    parser.add_argument("output_directory",
                        help="Save the output files in that directory")

    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files")
    parser.add_argument(
        "--weight_file",
        default=None,
        help=("The path to a previously trainined model to continue"
              " the training from"))
    parser.add_argument("--continue_from_epoch",
                        default=0,
                        type=int,
                        help="Continue training from epoch (default=0)")
    parser.add_argument("--n_primitives",
                        type=int,
                        default=32,
                        help="Number of primitives")
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration")
    parser.add_argument("--train_test_splits_file",
                        default=None,
                        help="Path to the train-test splits file")
    parser.add_argument("--run_on_gpu", action="store_true", help="Use GPU")
    parser.add_argument("--probs_only",
                        action="store_true",
                        help="Optimize only using the probabilities")

    parser.add_argument("--experiment_tag",
                        default=None,
                        help="Tag that refers to the current experiment")

    parser.add_argument("--cache_size",
                        type=int,
                        default=2000,
                        help="The batch provider cache size")

    parser.add_argument("--seed",
                        type=int,
                        default=27,
                        help="Seed for the PRNG")

    add_nn_parameters(parser)
    add_dataset_parameters(parser)
    add_voxelizer_parameters(parser)
    add_training_parameters(parser)
    add_sq_mesh_sampler_parameters(parser)
    add_regularizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    # Parameters related to the loss function and the loss weights
    add_loss_parameters(parser)
    # Parameters related to loss options
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)

    if args.train_test_splits_file is not None:
        train_test_splits = parse_train_test_splits(
            args.train_test_splits_file, args.model_tags)
        training_tags = np.hstack(
            [train_test_splits["train"], train_test_splits["val"]])
    else:
        training_tags = args.model_tags

    #device = torch.device("cuda:0")
    if args.run_on_gpu:  #and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    print("Running code on {}".format(device))

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    # Create an experiment directory using the experiment_tag
    if args.experiment_tag is None:
        experiment_tag = id_generator(9)
    else:
        experiment_tag = args.experiment_tag

    experiment_directory = os.path.join(args.output_directory, experiment_tag)
    if not os.path.exists(experiment_directory):
        os.makedirs(experiment_directory)

    # Store the parameters for the current experiment in a json file
    save_experiment_params(args, experiment_tag, experiment_directory)
    print("Save experiment statistics in %s" % (experiment_tag, ))

    # Create two files to store the training and test evolution
    train_stats = os.path.join(experiment_directory, "train.txt")
    val_stats = os.path.join(experiment_directory, "val.txt")
    if args.weight_file is None:
        train_stats_f = open(train_stats, "w")
    else:
        train_stats_f = open(train_stats, "a+")
    train_stats_f.write(
        ("epoch loss pcl_to_prim_loss prim_to_pcl_loss bernoulli_regularizer "
         "entropy_bernoulli_regularizer parsimony_regularizer "
         "overlapping_regularizer sparsity_regularizer\n"))

    # Set the random seed
    np.random.seed(args.seed)
    torch.manual_seed(np.random.randint(np.iinfo(np.int32).max))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(np.random.randint(np.iinfo(np.int32).max))

    # Create an object that will sample points in equal distances on the
    # surface of the primitive
    sampler = get_sampler(args.use_cuboids, args.n_points_from_sq_mesh,
                          args.D_eta, args.D_omega)

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(args.voxelizer_factory,
                                         np.array(voxelizer_shape(args)),
                                         args.save_voxels_to)

    # Create a dataset instance to generate the samples for training
    training_dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder().with_dataset(args.dataset_type).lru_cache(
            2000).filter_tags(training_tags).build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(args.voxelizer_factory))
    # Create a batchprovider object to start generating batches
    train_bp = BatchProvider(training_dataset,
                             batch_size=args.batch_size,
                             cache_size=args.cache_size)
    train_bp.ready()

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for training
    model = network_params.network(network_params)

    # Move model to the device to be used
    model.to(device)
    # Check whether there is a weight file provided to continue training from
    if args.weight_file is not None:
        model.load_state_dict(torch.load(args.weight_file))
    model.train()

    # Build an optimizer object to compute the gradients of the parameters
    optimizer = optimizer_factory(args, model)

    # Loop over the dataset multiple times
    pcl_to_prim_losses = []
    prim_to_pcl_losses = []
    losses = []
    for i in range(args.epochs):
        bar = get_logger("euclidean_dual_loss", i + 1, args.epochs,
                         args.steps_per_epoch)
        for b, sample in zip(range(args.steps_per_epoch),
                             yield_infinite(train_bp)):

            tags, X, y_target = sample
            X, y_target = X.to(device), y_target.to(device)

            # based on `tag`
            part_point_samples = []
            P = []
            indices_w_parts = []
            for idx_batchwide, tag in enumerate(tags):
                # TODO: does this sample have any part pt samples?
                target_part_samples_path = os.path.exists(
                    os.path.join(RANDINDEX_SAMPLES_DIR, '{}.pkl'.format(tag)))

                if os.path.exists(target_part_samples_path):
                    with open(target_part_samples_path, 'rb') as f:
                        foo = pickle.load(f)
                    point_samples = foo['samples']
                    point_samples = torch.Tensor(point_samples).to(device)
                    part_point_samples.append(point_samples)

                    N = point_samples.shape[0]
                    nonzero_indices = foo['neighbor_pairs']
                    assert nonzero_indices.shape[0] == 2
                    val = np.ones(nonzero_indices.shape[1])
                    curr_P = torch.sparse_coo_tensor(nonzero_indices, val,
                                                     (N, N))
                    curr_P = curr_P.to(device)
                    P.append(curr_P)

                    # turn into matrices here?
                    indices_w_parts.append(idx_batchwide)

            P = torch.stack(P, axis=0)
            part_point_samples = torch.stack(part_point_samples, axis=0)
            import ipdb
            ipdb.set_trace()

            # Train on batch
            batch_loss, metrics, debug_stats = train_on_batch_w_parts(
                model,
                lr_schedule(optimizer, i, args.lr, args.lr_factor,
                            args.lr_epochs), euclidean_dual_loss, X, y_target,
                get_regularizer_terms(args, i), sampler,
                get_loss_options(args), P, part_point_samples, indices_w_parts)

            # Get the regularizer terms
            reg_values = debug_stats["regularizer_terms"]
            sparsity_regularizer = reg_values["sparsity_regularizer"]
            overlapping_regularizer = reg_values["overlapping_regularizer"]
            parsimony_regularizer = reg_values["parsimony_regularizer"]
            entropy_bernoulli_regularizer = reg_values[
                "entropy_bernoulli_regularizer"]
            bernoulli_regularizer = reg_values["bernoulli_regularizer"]

            # The lossess
            pcl_to_prim_loss = debug_stats["pcl_to_prim_loss"].item()
            prim_to_pcl_loss = debug_stats["prim_to_pcl_loss"].item()
            bar.loss = moving_average(bar.loss, batch_loss, b)
            bar.pcl_to_prim_loss = \
                moving_average(bar.pcl_to_prim_loss, pcl_to_prim_loss, b)
            bar.prim_to_pcl_loss = \
                moving_average(bar.prim_to_pcl_loss, prim_to_pcl_loss, b)

            losses.append(bar.loss)
            prim_to_pcl_losses.append(bar.prim_to_pcl_loss)
            pcl_to_prim_losses.append(bar.pcl_to_prim_loss)

            bar.bernoulli_regularizer =\
                (bar.bernoulli_regularizer * b + bernoulli_regularizer) / (b+1)
            bar.parsimony_regularizer =\
                (bar.parsimony_regularizer * b + parsimony_regularizer) / (b+1)
            bar.overlapping_regularizer =\
                (bar.overlapping_regularizer * b + overlapping_regularizer) / (b+1)
            bar.entropy_bernoulli_regularizer = \
                (bar.entropy_bernoulli_regularizer * b +
                 entropy_bernoulli_regularizer) / (b+1)
            bar.sparsity_regularizer =\
                (bar.sparsity_regularizer * b + sparsity_regularizer) / (b+1)

            bar.exp_n_prims = metrics[0].sum(-1).mean()
            # Update the file that keeps track of the statistics
            train_stats_f.write(
                ("%d %.8f %.8f %.8f %.6f %.6f %.6f %.6f %.6f") %
                (i, bar.loss, bar.pcl_to_prim_loss, bar.prim_to_pcl_loss,
                 bar.bernoulli_regularizer, bar.entropy_bernoulli_regularizer,
                 bar.parsimony_regularizer, bar.overlapping_regularizer,
                 bar.sparsity_regularizer))
            train_stats_f.write("\n")
            train_stats_f.flush()

            bar.next()
        # Finish the progress bar and save the model after every epoch
        bar.finish()
        # Stop the batch provider
        train_bp.stop()
        torch.save(
            model.state_dict(),
            os.path.join(experiment_directory,
                         "model_%d" % (i + args.continue_from_epoch, )))

    print([
        sum(losses[args.steps_per_epoch:]) / float(args.steps_per_epoch),
        sum(losses[:args.steps_per_epoch]) / float(args.steps_per_epoch),
        sum(pcl_to_prim_losses[args.steps_per_epoch:]) /
        float(args.steps_per_epoch),
        sum(pcl_to_prim_losses[:args.steps_per_epoch]) /
        float(args.steps_per_epoch),
        sum(prim_to_pcl_losses[args.steps_per_epoch:]) /
        float(args.steps_per_epoch),
        sum(prim_to_pcl_losses[:args.steps_per_epoch]) /
        float(args.steps_per_epoch),
    ])
コード例 #2
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives"
    )
    parser.add_argument(
        "dataset_directory",
        help="Path to the directory containing the dataset"
    )
    parser.add_argument(
        "output_directory",
        help="Save the output files in that directory"
    )
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files"
    )
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used"
    )

    parser.add_argument(
        "--n_primitives",
        type=int,
        default=32,
        help="Number of primitives"
    )
    parser.add_argument(
        "--prob_threshold",
        type=float,
        default=0.5,
        help="Probability threshold"
    )
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration"
    )
    parser.add_argument(
        "--save_prediction_as_mesh",
        action="store_true",
        help="When true store prediction as a mesh"
    )
    parser.add_argument(
        "--run_on_gpu",
        action="store_true",
        help="Use GPU"
    )
    parser.add_argument(
        "--with_animation",
        action="store_true",
        help="Add animation"
    )

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)

    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print ("Running code on {}".format(device))

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(
        args.voxelizer_factory,
        np.array(voxelizer_shape(args)),
        args.save_voxels_to
    )

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder()
            .with_dataset(args.dataset_type)
            .filter_tags(args.model_tags)
            .build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(voxelizer_factory)
    )

    model_tags = dataset._dataset_object._tags

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(
            torch.load(args.weight_file, map_location=device)
        )
    model.eval()

    colors = get_colors(args.n_primitives)

    for sample_idx, sample in enumerate(dataloader):

        model_tag = model_tags[sample_idx]

        X, y_target = sample
        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        M = args.n_primitives  # number of primitives
        probs = y_hat[0].to("cpu").detach().numpy()
        # Transform the Euler angles to rotation matrices
        if y_hat[2].shape[1] == 3:
            R = euler_angles_to_rotation_matrices(
                y_hat[2].view(-1, 3)
            ).to("cpu").detach()
        else:
            R = quaternions_to_rotation_matrices(
                    y_hat[2].view(-1, 4)
                ).to("cpu").detach()
            # get also the raw quaternions
            quats = y_hat[2].view(-1, 4).to("cpu").detach().numpy()
        translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
        translations = translations.detach().numpy()

        shapes = y_hat[3].to("cpu").view(args.n_primitives, 3).detach().numpy()
        epsilons = y_hat[4].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()
        taperings = y_hat[5].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()

        pts = y_target[:, :, :3].to("cpu")
        pts_labels = y_target[:, :, -1].to("cpu").squeeze().numpy()
        pts = pts.squeeze().detach().numpy().T

        on_prims = 0

        # XXX: UNTIL I FIX THE MLAB ISSUE
        # fig = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        # mlab.view(azimuth=0.0, elevation=0.0, distance=2)

        # Uncomment to visualize the points sampled from the target mesh
        # t = np.array([1.2, 0, 0]).reshape(3, -1)
        # pts_n = pts + t
        #     mlab.points3d(
        #        # pts_n[0], pts_n[1], pts_n[2],
        #        pts[0], pts[1], pts[2],
        #        scale_factor=0.03, color=(0.8, 0.8, 0.8)
        #     )


        save_dir = os.path.join(args.output_directory,
                                os.path.basename(os.path.dirname(args.dataset_directory)),
                                model_tag)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        # args.output_directory/class/model_id/primitive_%d.p
        # args.output_directory/class/model_id/reconstruction

        # Keep track of the files containing the parameters of each primitive
        primitive_files = []
        for i in range(args.n_primitives):
            x_tr, y_tr, z_tr, prim_pts =\
                get_shape_configuration(args.use_cuboids)(
                    shapes[i, 0],
                    shapes[i, 1],
                    shapes[i, 2],
                    epsilons[i, 0],
                    epsilons[i, 1],
                    R[i].numpy(),
                    translations[i].reshape(-1, 1),
                    taperings[i, 0],
                    taperings[i, 1]
                )

            # Dump the parameters of each primitive as a dictionary
            # TODO: change filepath
            store_primitive_parameters(
                size=tuple(shapes[i]),
                shape=tuple(epsilons[i]),
                rotation=tuple(quats[i]),
                location=tuple(translations[i]),
                tapering=tuple(taperings[i]),
                probability=(probs[0, i],),
                color=(colors[i % len(colors)]) + (1.0,),
                filepath=os.path.join(
                    save_dir,
                    "primitive_%d.p" %(i,)
                )
            )
            if probs[0, i] >= args.prob_threshold:
                on_prims += 1
                # mlab.mesh(
                #     x_tr,
                #     y_tr,
                #     z_tr,
                #     color=tuple(colors[i % len(colors)]),
                #     opacity=1.0
                # )
                primitive_files.append(
                    os.path.join(save_dir, "primitive_%d.p" % (i,))
                )

        if args.with_animation:
            cnt = 0
            for az in range(0, 360, 1):
                cnt += 1

                # XXX UNTIL I FIX THE MLAB ISSUE
                # mlab.view(azimuth=az, elevation=0.0, distance=2)
                # mlab.savefig(
                #     os.path.join(
                #         args.output_directory,
                #         "img_%04d.png" % (cnt,)
                #     )
                # )
        for i in range(args.n_primitives):
            print("{} {}".format(i, probs[0, i]))

        print ("Using %d primitives out of %d" % (on_prims, args.n_primitives))

        # XXX UNTIL I FIX THE MLAB ISSUE
        # mlab.show()

        # TODO: from_primitive_parms_to_mesh()
        # TODO: get parts for this chair.
        # TODO: push the parts and superquadric meshes through the metric function
        # TODO: record metrics

        if args.save_prediction_as_mesh:
            # TODO: save with model information, class information ...etc
            print ("Saving prediction as mesh....")
            save_prediction_as_ply(
                primitive_files,
                os.path.join(save_dir, "reconstruction.ply")
            )
            print("Saved prediction as ply file in {}".format(
                os.path.join(save_dir, "reconstruction.ply")
            ))
コード例 #3
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives")
    parser.add_argument("dataset_directory",
                        help="Path to the directory containing the dataset")

    parser.add_argument(
        "primitives_directory",
        help=
        "Path to the directory containing the superquadrics of the instance")

    parser.add_argument("output_directory",
                        help="Save the output files in that directory")

    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used")

    parser.add_argument("--save_prediction_as_mesh",
                        action="store_true",
                        help="When true store prediction as a mesh")

    parser.add_argument("--run_on_gpu", action="store_true", help="Use GPU")

    parser.add_argument("--prob_threshold",
                        type=float,
                        default=0.5,
                        help="Probability threshold")

    # Parse args
    add_nn_parameters(parser)
    add_dataset_parameters(parser)
    add_datatype_parameters(parser)
    add_training_parameters(parser)
    args = parser.parse_args(argv)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    # device = torch.device("cuda:0")
    print("Running code on ", device)

    # TODO
    M = 11
    data_output_shape = (M, 7)

    # Create a factory that returns the appropriate data type based on the
    # input argument
    data_factory = DataFactory(
        args.data_type, tuple([data_input_shape(args), data_output_shape]))

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("matrix_loss")(
        (DatasetBuilder().with_dataset(args.dataset_type).build(
            args.dataset_directory)),
        data_factory,
        transform=compose_transformations(args.data_type))

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters(args.architecture, M, False)
    model = network_params.network(network_params)
    # Move model to device to be used

    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(torch.load(args.weight_file))
    model.eval()

    # Keep track of the files containing the parameters of each primitive
    primitives = load_all_primitive_parameters(args.primitives_directory,
                                               args.prob_threshold)
    gt_primitives = list(primitives)
    colors = get_colors(M)

    # Prepare matlab figs
    # mlab.view(azimuth=0.0, elevation=0.0, distance=2)

    # Iterate thru the data
    total_runs = 0
    total = 0
    r_loss_total = 0
    t_loss_total = 0
    # fp = open(os.path.join(args.output_directory, "stats.csv"), "w")
    # fp.write("loss_total\trot_loss\ttrans_loss\t\n")

    for sample in dataloader:
        primitive_list = []
        total_runs += 1
        X, y_target = sample

        # Show input image
        # img = X.numpy()[0]
        # img = np.transpose(img, (1,2,0))
        # img = img.reshape((224, 224, 3))
        # imgplot = plt.imshow(img)
        # plt.show()

        X, y_target = X.to(device), y_target.to(device)

        # Declare some variables
        B = y_target.shape[0]  # batch size
        M = y_target.shape[1]  # number of primitives
        poses_target = y_target.view(B, M, 7).detach().cpu().numpy()
        rotations_target = poses_target[:, :, :4].reshape(B, M, 4)[0]
        translations_target = poses_target[:, :, 4:].reshape(B, M, 3)[0]

        # # Do the forward pass
        y_hat = model(X)
        translations = y_hat[0].detach().cpu().numpy().reshape(B, M, 3)[0]
        rotations = y_hat[1].detach().cpu().numpy().reshape(B, M, 4)[0]

        # Loss computations
        # options = dict()
        # options["device"] = device
        # loss, extra = matrix_loss(y_hat, y_target, options)
        # total += (extra["r_loss"] + extra["t_loss"])
        # r_loss_total += extra["r_loss"]
        # t_loss_total += extra["t_loss"]

        # fp.write(str(total / total_runs))
        # fp.write("\t")
        # fp.write(str(r_loss_total / total_runs))
        # fp.write("\t")
        # fp.write(str(t_loss_total / total_runs))
        # fp.write("\t")
        # fp.write("\n")

        # if total_runs % 50 == 0:
        #     print(total / total_runs )

        i = 0
        # fig1 = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        # fig2 = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))

        for p in primitives:
            # primitives[i]["rotation"] = rotations[i]
            # primitives[i]["location"] = translations[i]

            # gt_primitives[i]["rotation"] = rotations_target[i]
            # gt_primitives[i]["location"] = translations_target[i]
            print("using GT...")
            x_tr, y_tr, z_tr, prim_pts =\
                points_on_sq_surface(
                    p["size"][0],
                    p["size"][1],
                    p["size"][2],
                    p["shape"][0],
                    p["shape"][1],
                    # Quaternion(rotations_target[i]).rotation_matrix.reshape(3, 3),
                    # np.array(translations_target[i]).reshape(3, 1),
                    Quaternion(rotations[i]).rotation_matrix.reshape(3, 3),
                    np.array(translations[i]).reshape(3, 1),
                    p["tapering"][0],
                    p["tapering"][1],
                    None,
                    None
                )

            primitive_list.append((prim_pts, p['color']))
            i += 1

        print("-------- GT ---------")
        print(rotations_target)
        print(translations_target)
        print("--------- Pred ---------")
        print(rotations)
        print(translations)
        display_primitives(primitive_list)
コード例 #4
0
model.to(device)
if weight_file is not None:
    # Load the model parameters of the previously trained model
    model.load_state_dict(torch.load(weight_file))
    print("Loading...", weight_file)
model.eval()

# Keep track of the files containing the parameters of each primitive
primitives = load_all_primitive_parameters(primitives_directory,
                                           prob_threshold)
gt_primitives = list(primitives)
colors = get_colors(M)

parser = argparse.ArgumentParser(
    description="Do the forward pass and estimate a set of primitives")
add_nn_parameters(parser)
add_dataset_parameters(parser)
add_datatype_parameters(parser)
add_training_parameters(parser)
args = parser.parse_args("")
print(args)

data_type = "image"
data_factory = DataFactory(data_type,
                           tuple([data_input_shape(args), data_output_shape]))
dataset = get_dataset_type("matrix_loss")(
    (DatasetBuilder().with_dataset(
        args.dataset_type).build(dataset_directory)),
    data_factory,
    transform=compose_transformations(data_type))
dataloader = DataLoader(dataset, batch_size=1, num_workers=4)
コード例 #5
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Train a network to predict primitives")

    parser.add_argument("dataset_directory",
                        help="Path to the directory containing the dataset")

    parser.add_argument("output_directory",
                        help="Save the output files in that directory")

    parser.add_argument(
        "--weight_file",
        default=None,
        help=("The path to a previously trainined model to continue"
              " the training from"))
    parser.add_argument("--continue_from_epoch",
                        default=0,
                        type=int,
                        help="Continue training from epoch (default=0)")

    parser.add_argument("--run_on_gpu", action="store_true", help="Use GPU")

    parser.add_argument("--experiment_tag",
                        default=None,
                        help="Tag that refers to the current experiment")

    parser.add_argument("--cache_size",
                        type=int,
                        default=2000,
                        help="The batch provider cache size")

    parser.add_argument("--seed",
                        type=int,
                        default=27,
                        help="Seed for the PRNG")

    # Parse args
    add_nn_parameters(parser)
    add_dataset_parameters(parser)
    add_datatype_parameters(parser)
    add_training_parameters(parser)
    args = parser.parse_args(argv)

    if args.run_on_gpu:  #and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    print("Running code on", device)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    # Create an experiment directory using the experiment_tag
    if args.experiment_tag is None:
        experiment_tag = id_generator(9)
    else:
        experiment_tag = args.experiment_tag

    experiment_directory = os.path.join(args.output_directory, experiment_tag)
    if not os.path.exists(experiment_directory):
        os.makedirs(experiment_directory)

    # Store the parameters for the current experiment in a json file
    save_experiment_params(args, experiment_tag, experiment_directory)
    print("Save experiment statistics in %s" % (experiment_tag, ))

    # Create two files to store the training and test evolution
    train_stats = os.path.join(experiment_directory, "train.txt")
    val_stats = os.path.join(experiment_directory, "val.txt")
    if args.weight_file is None:
        train_stats_f = open(train_stats, "w")
    else:
        train_stats_f = open(train_stats, "a+")
    train_stats_f.write(("epoch loss\n"))

    # Set the random seed
    np.random.seed(args.seed)
    torch.manual_seed(np.random.randint(np.iinfo(np.int32).max))
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(np.random.randint(np.iinfo(np.int32).max))

    # TODO
    M = 11
    data_output_shape = (M, 7)

    # Create a factory that returns the appropriate data type based on the
    # input argument
    data_factory = DataFactory(
        args.data_type, tuple([data_input_shape(args), data_output_shape]))

    # Create a dataset instance to generate the samples for training
    training_dataset = get_dataset_type("matrix_loss")(
        (DatasetBuilder().with_dataset(args.dataset_type).build(
            args.dataset_directory)),
        data_factory,
        transform=compose_transformations(args.data_type))

    training_loader = DataLoader(training_dataset,
                                 batch_size=32,
                                 num_workers=4,
                                 pin_memory=True,
                                 drop_last=True,
                                 shuffle=True)

    # Build the model to be used for training
    network_params = NetworkParameters(args.architecture, M, False)
    model = network_params.network(network_params)

    # Move model to the device to be used
    model.to(device)

    # Check whether there is a weight file provided to continue training from
    if args.weight_file is not None:
        model.load_state_dict(torch.load(args.weight_file))
    model.train()

    # Build an optimizer object to compute the gradients of the parameters
    optimizer = optimizer_factory(args, model)

    # Loop over the dataset multiple times
    losses = []
    for i in range(args.epochs):
        bar = get_logger("matrix_loss", i + 1, args.epochs,
                         args.steps_per_epoch)

        j = 0
        for sample in training_loader:
            X, y_target = sample

            # if j == 0:
            #     import matplotlib.pyplot as plt
            #     import matplotlib.image as mpimg

            #     print(np.shape(X))
            #     print(X)
            #     img = X.numpy()[0]
            #     img = np.transpose(img, (1,2,0))
            #     img = img.reshape((224, 224, 3))
            #     print(img)

            #     imgplot = plt.imshow(img)
            #     print(imgplot)
            #     plt.show()

            # print(j)
            # j +=1
            # if j > 20:
            #     break
            # continue
            # print(X.shape)
            # print(y_target.shape)
            # #exit(1)

            # exit(1)

            X, y_target = X.to(device), y_target.to(device)

            # Train on batch
            batch_loss, metrics, debug_stats = train_on_batch(
                model,
                lr_schedule(optimizer, i, args.lr, args.lr_factor,
                            args.lr_epochs), matrix_loss, X, y_target, device)

            # The losses
            bar.loss = moving_average(bar.loss, batch_loss, b)

            # Record in list
            losses.append(bar.loss)

            # TODO: Update the file that keeps track of the statistics
            if (j % 50) == 0:
                train_stats_f.write(("%d %5.8f") % (i, bar.loss))
                train_stats_f.write("\n")
                train_stats_f.flush()
            j += 1
            bar.next()

            if j >= args.steps_per_epoch:
                break

        # Finish the progress bar and save the model after every epoch
        bar.finish()

        if (i % 5) == 0:
            torch.save(
                model.state_dict(),
                os.path.join(experiment_directory,
                             "model_%d" % (i + args.continue_from_epoch, )))

    torch.save(model.state_dict(),
               os.path.join(experiment_directory, "model_final"))

    # TODO: print final training stats
    print([
        sum(losses[args.steps_per_epoch:]) / float(args.steps_per_epoch),
        sum(losses[:args.steps_per_epoch]) / float(args.steps_per_epoch)
    ])
コード例 #6
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives"
    )
    parser.add_argument(
        "dataset_directory",
        help="Path to the directory containing the dataset"
    )
    parser.add_argument(
        "output_directory",
        help="Save the output files in that directory"
    )
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files"
    )
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used"
    )

    parser.add_argument(
        "--n_primitives",
        type=int,
        default=32,
        help="Number of primitives"
    )
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration"
    )
    parser.add_argument(
        "--run_on_gpu",
        action="store_true",
        help="Use GPU"
    )

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)

    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print "Running code on ", device

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(
        args.voxelizer_factory,
        np.array(voxelizer_shape(args)),
        args.save_voxels_to
    )

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder()
            .with_dataset(args.dataset_type)
            .filter_tags(args.model_tags)
            .build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        n_bbox=args.n_bbox,
        n_surface=args.n_surface,
        equal=args.equal,
        transform=compose_transformations(voxelizer_factory)
    )

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(
            torch.load(args.weight_file, map_location=device)
        )
    model.eval()

    losses = []
    pcl_to_prim_losses = []
    prim_to_pcl_losses = []

    prog = Progbar(len(dataloader))
    i = 0
    for sample in dataloader:
        X, y_target = sample
        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        reg_terms = {
            "regularizer_type": [],
            "bernoulli_regularizer_weight": 0.0,
            "entropy_bernoulli_regularizer_weight": 0.0,
            "parsimony_regularizer_weight": 0.0,
            "overlapping_regularizer_weight": 0.0,
            "sparsity_regularizer_weight": 0.0,
        }
        loss, debug_stats = euclidean_dual_loss(
            y_hat,
            y_target,
            reg_terms,
            e,
            get_loss_options(args)
        )

        if not np.isnan(loss.item()):
            losses.append(loss.item())
            pcl_to_prim_losses.append(debug_stats["pcl_to_prim_loss"].item())
            prim_to_pcl_losses.append(debug_stats["prim_to_pcl_loss"].item())
        # Update progress bar
        prog.update(i+1)
        i += 1
    np.savetxt(
        os.path.join(args.output_directory, "losses.txt"),
        losses
    )

    np.savetxt(
        os.path.join(args.output_directory, "pcl_to_prim_losses.txt"),
        pcl_to_prim_losses
    )
    np.savetxt(
        os.path.join(args.output_directory, "prim_to_pcl_losses.txt"),
        prim_to_pcl_losses
    )
    np.savetxt(
        os.path.join(args.output_directory, "mean_std_losses.txt"),
        [np.mean(losses), np.std(losses),
        np.mean(pcl_to_prim_losses), np.std(pcl_to_prim_losses),
        np.mean(prim_to_pcl_losses), np.std(prim_to_pcl_losses)]
    )

    print "loss: %.7f +/- %.7f - pcl_to_prim_loss %.7f +/- %.7f - prim_to_pcl_loss %.7f +/- %.7f" %(
        np.mean(losses),
        np.std(losses),
        np.mean(pcl_to_prim_losses),
        np.std(pcl_to_prim_losses),
        np.mean(prim_to_pcl_losses),
        np.std(prim_to_pcl_losses)
    )
コード例 #7
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives")
    parser.add_argument("dataset_directory",
                        help="Path to the directory containing the dataset")
    parser.add_argument("output_directory",
                        help="Save the output files in that directory")
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files")
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used")

    parser.add_argument("--n_primitives",
                        type=int,
                        default=32,
                        help="Number of primitives")
    parser.add_argument("--prob_threshold",
                        type=float,
                        default=0.5,
                        help="Probability threshold")
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration")
    parser.add_argument("--run_on_gpu", action="store_true", help="Use GPU")
    parser.add_argument("--title", default="Fooo", help="Title on the plot")
    parser.add_argument("--save_image_to",
                        default="/tmp/image_0.png",
                        help="Path to image")
    parser.add_argument("--with_animation",
                        action="store_true",
                        help="Add animation")
    parser.add_argument("--model_id",
                        type=int,
                        default=0,
                        help="Epoch at which this model was captured")

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_tsdf_fusion_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)

    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print "Running code on ", device

    losses = {"euclidean_dual_loss": euclidean_dual_loss}
    loss_factory = losses[args.loss_type]

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(args.voxelizer_factory,
                                         np.array(voxelizer_shape(args)),
                                         args.save_voxels_to)

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type(args.loss_type)(
        (DatasetBuilder().with_dataset(args.dataset_type).filter_tags(
            args.model_tags).build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(voxelizer_factory))

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(torch.load(args.weight_file,
                                         map_location=device))
    model.eval()

    colors = get_colors(args.n_primitives)
    for sample in dataloader:
        X, y_target = sample
        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        M = args.n_primitives  # number of primitives
        probs = y_hat[0].to("cpu").detach().numpy()
        # Transform the Euler angles to rotation matrices
        if y_hat[2].shape[1] == 3:
            R = euler_angles_to_rotation_matrices(y_hat[2].view(
                -1, 3)).to("cpu").detach()
        else:
            R = quaternions_to_rotation_matrices(y_hat[2].view(
                -1, 4)).to("cpu").detach()
        translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
        translations = translations.detach().numpy()

        shapes = y_hat[3].to("cpu").view(args.n_primitives, 3).detach().numpy()
        epsilons = y_hat[4].to("cpu").view(args.n_primitives,
                                           2).detach().numpy()
        taperings = y_hat[5].to("cpu").view(args.n_primitives,
                                            2).detach().numpy()

        pts = y_target[:, :, :3].to("cpu")
        pts_labels = y_target[:, :, -1].to("cpu").squeeze().numpy()
        pts = pts.squeeze().detach().numpy().T

        on_prims = 0
        fig = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        mlab.view(azimuth=0.0, elevation=0.0, distance=2)
        # Uncomment to visualize the points sampled from the target mesh
        # t = np.array([1.2, 0, 0]).reshape(3, -1)
        # pts_n = pts + t
        #     mlab.points3d(
        #        # pts_n[0], pts_n[1], pts_n[2],
        #        pts[0], pts[1], pts[2],
        #        scale_factor=0.03, color=(0.8, 0.8, 0.8)
        #     )
        for i in range(args.n_primitives):
            x_tr, y_tr, z_tr, prim_pts =\
                get_shape_configuration(args.use_cuboids)(
                    shapes[i, 0],
                    shapes[i, 1],
                    shapes[i, 2],
                    epsilons[i, 0],
                    epsilons[i, 1],
                    R[i].numpy(),
                    translations[i].reshape(-1, 1),
                    taperings[i, 0],
                    taperings[i, 1]
                )
            if probs[0, i] >= args.prob_threshold:
                on_prims += 1
                mlab.mesh(x_tr,
                          y_tr,
                          z_tr,
                          color=tuple(colors[i % len(colors)]),
                          opacity=1.0)

        if args.with_animation:
            cnt = 0
            for az in range(0, 360, 1):
                cnt += 1
                mlab.view(azimuth=az, elevation=0.0, distance=2)
                mlab.savefig(
                    os.path.join(args.output_directory,
                                 "img_%04d.png" % (cnt, )))
        for i in range(args.n_primitives):
            print i, probs[0, i]

        print "Using %d primitives out of %d" % (on_prims, args.n_primitives)
        mlab.show()
コード例 #8
0
def main(argv):
    parser = argparse.ArgumentParser(
        description="Do the forward pass and estimate a set of primitives"
    )
    parser.add_argument(
        "dataset_directory",
        help="Path to the directory containing the dataset"
    )
    parser.add_argument(
        "output_directory",
        help="Save the output files in that directory"
    )
    parser.add_argument(
        "--tsdf_directory",
        default="",
        help="Path to the directory containing the precomputed tsdf files"
    )
    parser.add_argument(
        "--weight_file",
        default=None,
        help="The path to the previously trainined model to be used"
    )

    parser.add_argument(
        "--n_primitives",
        type=int,
        default=32,
        help="Number of primitives"
    )
    parser.add_argument(
        "--prob_threshold",
        type=float,
        default=0.5,
        help="Probability threshold"
    )
    parser.add_argument(
        "--use_deformations",
        action="store_true",
        help="Use Superquadrics with deformations as the shape configuration"
    )
    parser.add_argument(
        "--save_prediction_as_mesh",
        action="store_true",
        help="When true store prediction as a mesh"
    )
    parser.add_argument(
        "--run_on_gpu",
        action="store_true",
        help="Use GPU"
    )
    parser.add_argument(
        "--with_animation",
        action="store_true",
        help="Add animation"
    )

    parser.add_argument(
        "--train_test_splits_file",
        default=None,
        help="Path to the train-test splits file"
    ) # we are going to test on the test splits

    parser.add_argument(
        '--save_individual_IOUs',
        action="store_true",
        help="saves the IOU with partnet parts for every 3d model evaluated."
        )
    parser.add_argument(
        '--recenter_superquadrics',
        action="store_true",
        help="recenters the superquadrics to the part bounding boxes before evaluating metric"
        )

    add_dataset_parameters(parser)
    add_nn_parameters(parser)
    add_voxelizer_parameters(parser)
    add_gaussian_noise_layer_parameters(parser)
    add_loss_parameters(parser)
    add_loss_options_parameters(parser)
    args = parser.parse_args(argv)


    if args.train_test_splits_file is not None:
        train_test_splits = parse_train_test_splits(
            args.train_test_splits_file,
            args.model_tags
        )
        test_tags = np.hstack([
            train_test_splits["test"]
        ])
    else:
        test_tags = args.model_tags


    # A sampler instance
    e = EqualDistanceSamplerSQ(200)

    # Check if output directory exists and if it doesn't create it
    if not os.path.exists(args.output_directory):
        os.makedirs(args.output_directory)

    if args.run_on_gpu and torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")
    print ("Running code on {}".format(device))

    # Create a factory that returns the appropriate voxelizer based on the
    # input argument
    voxelizer_factory = VoxelizerFactory(
        args.voxelizer_factory,
        np.array(voxelizer_shape(args)),
        args.save_voxels_to
    )

    # Create a dataset instance to generate the samples for training
    dataset = get_dataset_type("euclidean_dual_loss")(
        (DatasetBuilder()
            .with_dataset(args.dataset_type)
            .filter_tags(test_tags)
            .build(args.dataset_directory)),
        voxelizer_factory,
        args.n_points_from_mesh,
        transform=compose_transformations(voxelizer_factory)
    )

    model_tags = dataset._dataset_object._tags

    # TODO: Change batch_size in dataloader
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)

    network_params = NetworkParameters.from_options(args)
    # Build the model to be used for testing
    model = network_params.network(network_params)
    # Move model to device to be used
    model.to(device)
    if args.weight_file is not None:
        # Load the model parameters of the previously trained model
        model.load_state_dict(
            torch.load(args.weight_file, map_location=device)
        )
    model.eval()

    colors = get_colors(args.n_primitives)


    # getting the big picture for parts
    id_set = set([])
    label_set = set([])
    parts_dir = '../parts_output/{}'.format(os.path.basename(os.path.dirname(args.dataset_directory)))
    for pick in os.listdir(parts_dir):
        with open(os.path.join(parts_dir, pick), 'rb') as f:
            leaves = pickle.load(f)
        for leaf in leaves:
            id_set.add(leaf['id'])
            label_set.add(leaf['label'])


    print('Found a total of {} part id'.format(len(id_set)))
    part_id_list = sorted(list(id_set))
    part_id_to_IOUidx = {id_: idx for idx, id_ in enumerate(part_id_list)}
    part_IOUidx_to_id = {idx: id_ for idx, id_ in enumerate(part_id_list)}

    IOU_matrix = np.zeros((args.n_primitives, len(id_set)))
    IOU_counter = 0

    for sample_idx, sample in enumerate(dataloader):

        model_tag = model_tags[sample_idx]
        print('evaluating model_tag {}'.format(model_tag))

        X, y_target = sample[1]

        X, y_target = X.to(device), y_target.to(device)

        # Do the forward pass and estimate the primitive parameters
        y_hat = model(X)

        M = args.n_primitives  # number of primitives
        probs = y_hat[0].to("cpu").detach().numpy()
        # Transform the Euler angles to rotation matrices
        if y_hat[2].shape[1] == 3:
            R = euler_angles_to_rotation_matrices(
                y_hat[2].view(-1, 3)
            ).to("cpu").detach()
        else:
            R = quaternions_to_rotation_matrices(
                    y_hat[2].view(-1, 4)
                ).to("cpu").detach()
            # get also the raw quaternions
            quats = y_hat[2].view(-1, 4).to("cpu").detach().numpy()
        translations = y_hat[1].to("cpu").view(args.n_primitives, 3)
        translations = translations.detach().numpy()

        shapes = y_hat[3].to("cpu").view(args.n_primitives, 3).detach().numpy()
        epsilons = y_hat[4].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()
        taperings = y_hat[5].to("cpu").view(
            args.n_primitives, 2
        ).detach().numpy()

        pts = y_target[:, :, :3].to("cpu")
        pts_labels = y_target[:, :, -1].to("cpu").squeeze().numpy()
        pts = pts.squeeze().detach().numpy().T

        on_prims = 0

        # XXX: UNTIL I FIX THE MLAB ISSUE
        # fig = mlab.figure(size=(400, 400), bgcolor=(1, 1, 1))
        # mlab.view(azimuth=0.0, elevation=0.0, distance=2)

        # Uncomment to visualize the points sampled from the target mesh
        # t = np.array([1.2, 0, 0]).reshape(3, -1)
        # pts_n = pts + t
        #     mlab.points3d(
        #        # pts_n[0], pts_n[1], pts_n[2],
        #        pts[0], pts[1], pts[2],
        #        scale_factor=0.03, color=(0.8, 0.8, 0.8)
        #     )


        save_dir = os.path.join(args.output_directory,
                                os.path.basename(os.path.dirname(args.dataset_directory)),
                                model_tag)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        # args.output_directory/class/model_id/primitive_%d.p
        # args.output_directory/class/model_id/reconstruction

        # Keep track of the files containing the parameters of each primitive
        primitive_files = []
        primitive_indices = []
        for i in range(args.n_primitives):
            x_tr, y_tr, z_tr, prim_pts =\
                get_shape_configuration(args.use_cuboids)(
                    shapes[i, 0],
                    shapes[i, 1],
                    shapes[i, 2],
                    epsilons[i, 0],
                    epsilons[i, 1],
                    R[i].numpy(),
                    translations[i].reshape(-1, 1),
                    taperings[i, 0],
                    taperings[i, 1]
                )

            # Dump the parameters of each primitive as a dictionary
            # TODO: change filepath
            store_primitive_parameters(
                size=tuple(shapes[i]),
                shape=tuple(epsilons[i]),
                rotation=tuple(quats[i]),
                location=tuple(translations[i]),
                tapering=tuple(taperings[i]),
                probability=(probs[0, i],),
                color=(colors[i % len(colors)]) + (1.0,),
                filepath=os.path.join(
                    save_dir,
                    "primitive_%d.p" %(i,)
                )
            )
            if probs[0, i] >= args.prob_threshold:
                on_prims += 1
                # mlab.mesh(
                #     x_tr,
                #     y_tr,
                #     z_tr,
                #     color=tuple(colors[i % len(colors)]),
                #     opacity=1.0
                # )
                primitive_files.append(
                    os.path.join(save_dir, "primitive_%d.p" % (i,))
                )
                primitive_indices.append(i)

        if args.with_animation:
            cnt = 0
            for az in range(0, 360, 1):
                cnt += 1

                # XXX UNTIL I FIX THE MLAB ISSUE
                # mlab.view(azimuth=az, elevation=0.0, distance=2)
                # mlab.savefig(
                #     os.path.join(
                #         args.output_directory,
                #         "img_%04d.png" % (cnt,)
                #     )
                # )
        # for i in range(args.n_primitives):
        #     print("{} {}".format(i, probs[0, i]))

        print ("Using %d primitives out of %d" % (on_prims, args.n_primitives))

        # XXX UNTIL I FIX THE MLAB ISSUE
        # mlab.show()

        # get the meshes
        superquadric_id_meshes = []
        for i, p in zip(primitive_indices, primitive_files):
            prim_params = pickle.load(open(p, "rb"))
            _m = _from_primitive_parms_to_mesh(prim_params)
            superquadric_id_meshes.append((i, _m))

        # get the parts
        with open(os.path.join(parts_dir, '{}.pkl'.format(model_tag)), 'rb') as f:
            leaf_parts = pickle.load(f)
        # get the bounding box information
        part_id_meshes = [(leaf['id'], leaf['bbox_mesh']) for leaf in leaf_parts]
        part_id_bboxes = [(leaf['id'], leaf['bbox']) for leaf in leaf_parts]

        if args.recenter_superquadrics: # recenter superquadric predictions
            # moving the superquadrics to the part centers
            supe_vert = np.vstack([el[1].vertices for el in superquadric_id_meshes])
            assert supe_vert.shape[1] == 3
            supe_center = 0.5*(np.max(supe_vert, axis=0) + np.min(supe_vert, axis=0))

            part_vert = np.vstack([el[1].vertices for el in part_id_meshes])
            assert part_vert.shape[1] == 3
            part_center = 0.5*(np.max(part_vert, axis=0) + np.min(part_vert, axis=0))

            for el in superquadric_id_meshes:
                el[1].vertices = el[1].vertices - supe_center + part_center



        # TODO: push the parts and superquadric meshes through the metric function
        IOU_matrix, delta_IOU = update_IOUs(IOU_matrix,
                                            superquadric_id_meshes,
                                            part_id_bboxes,
                                            part_id_meshes,
                                            part_id_to_IOUidx)
        IOU_counter += 1
        mean_IOU_matrix = IOU_matrix/IOU_counter

        if IOU_counter % 5 == 0:
            bestIOU, supeidx2partidx = get_consistency(mean_IOU_matrix)
            supeid2partid = {supe_id: part_IOUidx_to_id[part_id]
                             for supe_id, part_id in supeidx2partidx}
            print('##################################')
            print('Best superquadric -> part matching')
            print(supeid2partid)
            print('Best IOU')
            print(bestIOU)
            print('##################################')

        if args.save_prediction_as_mesh:
            # TODO: save with model information, class information ...etc
            print ("Saving prediction as mesh....")
            save_prediction_as_ply(
                primitive_files,
                os.path.join(save_dir, "reconstruction.ply")
            )
            print("Saved prediction as ply file in {}".format(
                os.path.join(save_dir, "reconstruction.ply")
            ))

        if args.save_individual_IOUs:
            save_dict = {'shapenetid': model_tag,
                         'indiv_IOU': delta_IOU,
                         'id2labels': {leaf['id']: leaf['label']
                                       for leaf in leaf_parts},
                         'part_id2idx': part_id_to_IOUidx,
                         'part_bboxes': part_id_bboxes,
                         'part_mesh': part_id_meshes,
                         'superquadrics': superquadric_id_meshes}
            with open(os.path.join(save_dir, "part_IOU.pkl"), 'wb') as f:
                pickle.dump(save_dict, f)


    # record metrics
    # TODO: record metrics

    # 1) the matrix of IOU's
    # 2) the assignments: superquadric_id's to partid

    mean_IOU_matrix = IOU_matrix/IOU_counter
    bestIOU, supeidx2partidx = get_consistency(mean_IOU_matrix)
    supeid2partid = {supe_id: part_IOUidx_to_id[part_id]
                     for supe_id, part_id in supeidx2partidx}
    print('##################################')
    print('Best superquadric -> part matching')
    print(supeid2partid)
    print('Best IOU')
    print(bestIOU)
    print('##################################')