示例#1
0
def load_trial(trial_path):
    """
    @param trial_path: full path or relative path from shape_completion_training/trials
    @type trial_path: str
    @return:
    """
    trial_path = pathlib.Path(trial_path)
    if not trial_path.is_absolute():
        r = rospkg.RosPack()
        trial_path = pathlib.Path(
            r.get_path('shape_completion_training')) / "trials" / trial_path
    if not trial_path.is_dir():
        raise ValueError(
            "Cannot load, the path {} is not an existing directory".format(
                trial_path))

    group_name = get_group_name(trial_path)
    params = default_params.get_default_params(group_name)
    params_filename = trial_path / 'params.json'

    with params_filename.open("r") as params_file:
        params.update(json.load(params_file))
    return trial_path, params
示例#2
0
import argparse

# params = {
#     'batch_size': 1500,
#     'network': 'RealNVP',
#     'dim': 24,
#     'num_masked': 12,
#     'learning_rate': 1e-5,
#     'translation_pixel_range_x': 10,
#     'translation_pixel_range_y': 10,
#     'translation_pixel_range_z': 10,
# }

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process args for training")
    parser.add_argument('--tmp', action='store_true')
    parser.add_argument('--group', default=None)
    args = parser.parse_args()
    params = default_params.get_default_params(group_name=args.group)
    params['load_bb_only'] = True

    data_supervisor = shape_completion_training.utils.dataset_loader.get_dataset_supervisor(
        params['dataset'])

    if args.tmp:
        mr = ModelRunner(training=True, params=params, group_name=None)
    else:
        mr = ModelRunner(training=True, params=params, group_name=args.group)

    mr.train_and_test(data_supervisor)
示例#3
0
from shape_completion_training.utils.tf_utils import log_normal_pdf, stack_known, sample_gaussian
from shape_completion_training.voxelgrid.metrics import chamfer_distance
from shape_completion_visualization.voxelgrid_publisher import VoxelgridPublisher, PointcloudPublisher

"""
Publish object pointclouds for use in gpu_voxels planning
"""

ARGS = None
VG_PUB = None
PT_PUB = None

model_runner = None
dataset_params = None

default_dataset_params = default_params.get_default_params()

default_translations = {
    'translation_pixel_range_x': 0,
    'translation_pixel_range_y': 0,
    'translation_pixel_range_z': 0,
}

Transformer = None


def wip_enforce_contact(elem):
    inference = model_runner.model(elem)
    VG_PUB.publish_inference(inference)
    pssnet = model_runner.model
    latent = tf.Variable(pssnet.sample_latent(elem))