def main():
    tf.random.set_seed(0)
    np.random.seed(0)
    colorama.init(autoreset=True)
    np.set_printoptions(linewidth=200, precision=3, suppress=True)
    parser = argparse.ArgumentParser()
    parser.add_argument("dataset_dirs", type=pathlib.Path, nargs="+")
    parser.add_argument("checkpoint", type=pathlib.Path)
    parser.add_argument("--mode",
                        type=str,
                        choices=['train', 'val', 'test'],
                        default='val')

    args = parser.parse_args()

    # TODO: REMOVE ME!
    args.mode = 'train'

    rospy.init_node("test_as_inverse_model")

    test_dataset = DynamicsDatasetLoader(args.dataset_dirs)
    test_tf_dataset = test_dataset.get_datasets(mode=args.mode)

    filter_model = filter_utils.load_filter([args.checkpoint])
    latent_dynamics_model, _ = dynamics_utils.load_generic_model(
        [args.checkpoint])

    test_as_inverse_model(filter_model, latent_dynamics_model, test_dataset,
                          test_tf_dataset)
def main():
    colorama.init(autoreset=True)

    parser = argparse.ArgumentParser(
        description="adds file path to the example",
        formatter_class=my_formatter)
    parser.add_argument('dataset_dir',
                        type=pathlib.Path,
                        help='dataset directory')
    parser.add_argument('dataset_type',
                        choices=['dy', 'cl', 'rcv'],
                        help='dataset type')

    args = parser.parse_args()

    rospy.init_node("add_paths")

    outdir = args.dataset_dir.parent / f"{args.dataset_dir.name}+paths"

    if args.dataset_type == 'dy':
        dataset = DynamicsDatasetLoader([args.dataset_dir])
    elif args.dataset_type == 'cl':
        dataset = ClassifierDatasetLoader([args.dataset_dir],
                                          load_true_states=True,
                                          use_gt_rope=False)
    elif args.dataset_type == 'rcv':
        dataset = RecoveryDatasetLoader([args.dataset_dir])
    else:
        raise NotImplementedError(f"Invalid dataset type {args.dataset_type}")

    # hparams
    hparams_update = {'has_tfrecord_path': True}
    modify_hparams(args.dataset_dir, outdir, hparams_update)

    total_count = 0
    for mode in ['train', 'test', 'val']:
        tf_dataset = dataset.get_datasets(mode=mode,
                                          shuffle_files=False,
                                          do_not_process=True)
        full_output_directory = outdir / mode
        full_output_directory.mkdir(parents=True, exist_ok=True)

        for example, tfrecord_path in zip(
                progressbar(tf_dataset, widgets=base_dataset.widgets),
                tf_dataset.records):
            features = {
                k: float_tensor_to_bytes_feature(v)
                for k, v in example.items()
            }
            features['tfrecord_path'] = bytes_feature(
                tf.io.serialize_tensor(
                    tf.convert_to_tensor(tfrecord_path,
                                         dtype=tf.string)).numpy())
            tf_write_features(total_count, features, full_output_directory)
            total_count += 1
    print(Fore.GREEN + f"Modified {total_count} examples")
def train_main(
    dataset_dirs: List[pathlib.Path],
    model_hparams: pathlib.Path,
    log: str,
    batch_size: int,
    epochs: int,
    seed: int,
    use_gt_rope: bool,
    checkpoint: Optional[pathlib.Path] = None,
    ensemble_idx: Optional[int] = None,
    take: Optional[int] = None,
    trials_directory=pathlib.Path,
):
    print(Fore.CYAN + f"Using seed {seed}")

    model_hparams = hjson.load(model_hparams.open('r'))
    model_class = state_space_dynamics.get_model(model_hparams['model_class'])

    train_dataset = DynamicsDatasetLoader(dataset_dirs,
                                          use_gt_rope=use_gt_rope)
    val_dataset = DynamicsDatasetLoader(dataset_dirs, use_gt_rope=use_gt_rope)

    model_hparams.update(
        setup_hparams(batch_size, dataset_dirs, seed, train_dataset,
                      use_gt_rope))
    model = model_class(hparams=model_hparams,
                        batch_size=batch_size,
                        scenario=train_dataset.scenario)

    checkpoint_name, trial_path = setup_training_paths(checkpoint,
                                                       ensemble_idx, log,
                                                       model_hparams,
                                                       trials_directory)

    runner = ModelRunner(model=model,
                         training=True,
                         params=model_hparams,
                         checkpoint=checkpoint,
                         batch_metadata=train_dataset.batch_metadata,
                         trial_path=trial_path)

    train_tf_dataset, val_tf_dataset = setup_datasets(model_hparams,
                                                      batch_size, seed,
                                                      train_dataset,
                                                      val_dataset, take)

    runner.train(train_tf_dataset, val_tf_dataset, num_epochs=epochs)

    return trial_path
Exemple #4
0
def load_dataset_and_models(args):
    comparison_info = json.load(args.comparison.open("r"))
    models = {}
    for name, model_info in comparison_info.items():
        model_dir = paths_from_json(model_info['model_dir'])
        model, _ = dynamics_utils.load_generic_model(model_dir)
        models[name] = model

    dataset = DynamicsDatasetLoader(args.dataset_dirs)
    tf_dataset = dataset.get_datasets(mode=args.mode,
                                      shard=args.shard,
                                      take=args.take)
    tf_dataset = batch_tf_dataset(tf_dataset, 1)

    return tf_dataset, dataset, models
Exemple #5
0
def main():
    colorama.init(autoreset=True)

    parser = argparse.ArgumentParser(formatter_class=my_formatter)
    parser.add_argument('dataset_dir',
                        type=pathlib.Path,
                        help='dataset directory')
    parser.add_argument('suffix',
                        type=str,
                        help='string added to the new dataset name')

    args = parser.parse_args()

    rospy.init_node("modify_dynamics_dataset")

    outdir = args.dataset_dir.parent / f"{args.dataset_dir.name}+{args.suffix}"

    def _process_example(dataset: DynamicsDatasetLoader, example: Dict):
        example['gt_rope'] = example.pop('rope')
        yield example

    hparams_update = {}

    dataset = DynamicsDatasetLoader([args.dataset_dir])
    modify_dataset(dataset_dir=args.dataset_dir,
                   dataset=dataset,
                   outdir=outdir,
                   process_example=_process_example,
                   hparams_update=hparams_update)
def eval_main(
    dataset_dirs: List[pathlib.Path],
    checkpoint: pathlib.Path,
    mode: str,
    batch_size: int,
    use_gt_rope: bool,
):
    test_dataset = DynamicsDatasetLoader(dataset_dirs, use_gt_rope=use_gt_rope)

    trials_directory = pathlib.Path('trials').absolute()
    trial_path = checkpoint.parent.absolute()
    _, params = filepath_tools.create_or_load_trial(
        trial_path=trial_path, trials_directory=trials_directory)
    model = state_space_dynamics.get_model(params['model_class'])
    net = model(hparams=params,
                batch_size=batch_size,
                scenario=test_dataset.scenario)

    runner = ModelRunner(model=net,
                         training=False,
                         checkpoint=checkpoint,
                         batch_metadata=test_dataset.batch_metadata,
                         trial_path=trial_path,
                         params=params)

    test_tf_dataset = test_dataset.get_datasets(mode=mode)
    test_tf_dataset = batch_tf_dataset(test_tf_dataset,
                                       batch_size,
                                       drop_remainder=True)
    validation_metrics = runner.val_epoch(test_tf_dataset)
    for name, value in validation_metrics.items():
        print(f"{name}: {value}")

    # more metrics that can't be expressed as just an average over metrics on each batch
    all_errors = None
    for batch in test_tf_dataset:
        outputs = runner.model(batch, training=False)
        errors_for_batch = test_dataset.scenario.classifier_distance(
            outputs, batch)
        if all_errors is not None:
            all_errors = tf.concat([all_errors, errors_for_batch], axis=0)
        else:
            all_errors = errors_for_batch
    print(f"90th percentile {np.percentile(all_errors.numpy(), 90)}")
    print(f"95th percentile {np.percentile(all_errors.numpy(), 95)}")
    print(f"99th percentile {np.percentile(all_errors.numpy(), 99)}")
    print(f"max {np.max(all_errors.numpy())}")
Exemple #7
0
 def _process_example(dataset: DynamicsDatasetLoader, example: Dict):
     out_examples = dataset.split_into_sequences(
         example, args.desired_sequence_length)
     for out_example in out_examples:
         out_example['time_idx'] = tf.range(0,
                                            args.desired_sequence_length,
                                            dtype=tf.float32)
         yield out_example
Exemple #8
0
def viz_main(args):
    dataset_dirs = args.dataset_dirs
    checkpoint = args.checkpoint

    trial_path, params = load_trial(checkpoint.parent.absolute())

    dataset = DynamicsDatasetLoader(dataset_dirs)
    scenario = dataset.scenario

    tf_dataset = dataset.get_datasets(mode='val')
    tf_dataset = batch_tf_dataset(tf_dataset,
                                  batch_size=1,
                                  drop_remainder=True)

    model = CFM(hparams=params, batch_size=1, scenario=scenario)
    ckpt = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(ckpt, args.checkpoint, max_to_keep=1)
    status = ckpt.restore(manager.latest_checkpoint).expect_partial()
    if manager.latest_checkpoint:
        print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint))
        status.assert_existing_objects_matched()
    else:
        raise RuntimeError("Failed to restore!!!")

    for example_idx, example in enumerate(tf_dataset):
        stepper = RvizAnimationController(n_time_steps=dataset.steps_per_traj)
        for t in range(dataset.steps_per_traj):
            output = model(
                model.preprocess_no_gradient(example, training=False))

            actual_t = numpify(
                remove_batch(scenario.index_time_batched_predicted(example,
                                                                   t)))
            action_t = numpify(
                remove_batch(scenario.index_time_batched_predicted(example,
                                                                   t)))
            scenario.plot_state_rviz(actual_t, label='actual', color='red')
            scenario.plot_action_rviz(actual_t, action_t, color='gray')
            prediction_t = remove_batch(
                scenario.index_time_batched_predicted(output, t))
            scenario.plot_state_rviz(prediction_t,
                                     label='predicted',
                                     color='blue')

            stepper.step()
Exemple #9
0
def train_main(args):
    dataset_dirs = args.dataset_dirs
    checkpoint = args.checkpoint
    epochs = args.epochs
    trial_path, params = load_trial(checkpoint.parent.absolute())
    now = str(time())
    trial_path = trial_path.parent / (trial_path.name + '-observer-' + now)
    trial_path.mkdir(parents=True)
    batch_size = params['batch_size']
    params['encoder_trainable'] = False
    params['use_observation_feature_loss'] = True
    params['use_cfm_loss'] = False
    out_hparams_filename = trial_path / 'params.json'
    out_params_str = json.dumps(params)
    with out_hparams_filename.open("w") as out_hparams_file:
        out_hparams_file.write(out_params_str)

    train_dataset = DynamicsDatasetLoader(dataset_dirs)
    val_dataset = DynamicsDatasetLoader(dataset_dirs)

    model_class = state_space_dynamics.get_model(params['model_class'])
    model = model_class(hparams=params,
                        batch_size=batch_size,
                        scenario=train_dataset.scenario)

    seed = 0

    runner = ModelRunner(model=model,
                         training=True,
                         params=params,
                         checkpoint=checkpoint,
                         batch_metadata=train_dataset.batch_metadata,
                         trial_path=trial_path)

    train_tf_dataset, val_tf_dataset = train_test.setup_datasets(
        model_hparams=params,
        batch_size=batch_size,
        seed=seed,
        train_dataset=train_dataset,
        val_dataset=val_dataset)

    runner.train(train_tf_dataset, val_tf_dataset, num_epochs=epochs)

    return trial_path
def compute_classifier_threshold(
    dataset_dirs: List[pathlib.Path],
    checkpoint: pathlib.Path,
    mode: str,
    batch_size: int,
    use_gt_rope: bool,
):
    test_dataset = DynamicsDatasetLoader(dataset_dirs, use_gt_rope=use_gt_rope)

    trials_directory = pathlib.Path('trials').absolute()
    trial_path = checkpoint.parent.absolute()
    _, params = filepath_tools.create_or_load_trial(
        trial_path=trial_path, trials_directory=trials_directory)
    model = state_space_dynamics.get_model(params['model_class'])
    net = model(hparams=params,
                batch_size=batch_size,
                scenario=test_dataset.scenario)

    runner = ModelRunner(model=net,
                         training=False,
                         checkpoint=checkpoint,
                         batch_metadata=test_dataset.batch_metadata,
                         trial_path=trial_path,
                         params=params)

    test_tf_dataset = test_dataset.get_datasets(mode=mode)
    test_tf_dataset = batch_tf_dataset(test_tf_dataset,
                                       batch_size,
                                       drop_remainder=True)

    all_errors = None
    for batch in test_tf_dataset:
        outputs = runner.model(batch, training=False)
        errors_for_batch = test_dataset.scenario.classifier_distance(
            batch, outputs)
        if all_errors is not None:
            all_errors = tf.concat([all_errors, errors_for_batch], axis=0)
        else:
            all_errors = errors_for_batch

    classifier_threshold = np.percentile(all_errors.numpy(), 90)
    rospy.loginfo(f"90th percentile {classifier_threshold}")
    return classifier_threshold
def viz_example(batch, outputs, test_dataset: DynamicsDatasetLoader, model):
    test_dataset.scenario.plot_environment_rviz(remove_batch(batch))
    anim = RvizAnimationController(np.arange(test_dataset.steps_per_traj))
    while not anim.done:
        t = anim.t()
        actual_t = test_dataset.index_time_batched(batch, t)
        test_dataset.scenario.plot_state_rviz(actual_t,
                                              label='actual',
                                              color='red')
        test_dataset.scenario.plot_action_rviz(actual_t,
                                               actual_t,
                                               color='gray')

        prediction_t = test_dataset.index_time_batched(outputs, t)
        test_dataset.scenario.plot_state_rviz(prediction_t,
                                              label='predicted',
                                              color='blue')

        anim.step()
def viz_dataset(
    dataset_dirs: List[pathlib.Path],
    checkpoint: pathlib.Path,
    mode: str,
    viz_func: Callable,
    use_gt_rope: bool,
    **kwargs,
):
    test_dataset = DynamicsDatasetLoader(dataset_dirs, use_gt_rope=use_gt_rope)

    test_tf_dataset = test_dataset.get_datasets(mode=mode)
    test_tf_dataset = batch_tf_dataset(test_tf_dataset, 1, drop_remainder=True)

    model, _ = dynamics_utils.load_generic_model([checkpoint])

    for i, batch in enumerate(test_tf_dataset):
        batch.update(test_dataset.batch_metadata)
        outputs, _ = model.from_example(batch, training=False)

        viz_func(batch, outputs, test_dataset, model)
Exemple #13
0
def main():
    colorama.init(autoreset=True)
    plt.style.use("slides")
    np.set_printoptions(suppress=True, linewidth=250, precision=5)

    parser = argparse.ArgumentParser(formatter_class=my_formatter)
    parser.add_argument('dataset_dir',
                        type=pathlib.Path,
                        help='dataset directory')
    parser.add_argument('desired_sequence_length',
                        type=int,
                        help='desired seqeuence length')

    args = parser.parse_args()

    rospy.init_node("slice_dataset")

    outdir = args.dataset_dir.parent / (args.dataset_dir.name +
                                        f'+L{args.desired_sequence_length}')

    def _process_example(dataset: DynamicsDatasetLoader, example: Dict):
        out_examples = dataset.split_into_sequences(
            example, args.desired_sequence_length)
        for out_example in out_examples:
            out_example['time_idx'] = tf.range(0,
                                               args.desired_sequence_length,
                                               dtype=tf.float32)
            yield out_example

    hparams_update = {
        'data_collection_params': {
            'steps_per_traj': args.desired_sequence_length
        }
    }

    dataset = DynamicsDatasetLoader([args.dataset_dir])
    modify_dataset(dataset_dir=args.dataset_dir,
                   dataset=dataset,
                   outdir=outdir,
                   process_example=_process_example,
                   hparams_update=hparams_update)
def make_classifier_dataset_from_params_dict(dataset_dir: pathlib.Path,
                                             fwd_model_dir: List[pathlib.Path],
                                             labeling_params: Dict,
                                             outdir: pathlib.Path,
                                             use_gt_rope: bool,
                                             visualize: bool,
                                             take: Optional[int] = None,
                                             batch_size: Optional[int] = None,
                                             start_at: Optional[int] = None,
                                             stop_at: Optional[int] = None):
    # append "best_checkpoint" before loading
    if not isinstance(fwd_model_dir, List):
        fwd_model_dir = [fwd_model_dir]
    fwd_model_dir = [p / 'best_checkpoint' for p in fwd_model_dir]

    dynamics_hparams = hjson.load((dataset_dir / 'hparams.hjson').open('r'))
    fwd_models, _ = dynamics_utils.load_generic_model(fwd_model_dir)

    dataset = DynamicsDatasetLoader([dataset_dir], use_gt_rope=use_gt_rope)

    new_hparams_filename = outdir / 'hparams.hjson'
    classifier_dataset_hparams = dynamics_hparams

    classifier_dataset_hparams['dataset_dir'] = dataset_dir.as_posix()
    classifier_dataset_hparams['fwd_model_hparams'] = fwd_models.hparams
    classifier_dataset_hparams['labeling_params'] = labeling_params
    classifier_dataset_hparams['true_state_keys'] = dataset.state_keys
    classifier_dataset_hparams['predicted_state_keys'] = fwd_models.state_keys
    classifier_dataset_hparams['action_keys'] = dataset.action_keys
    classifier_dataset_hparams['scenario_metadata'] = dataset.hparams[
        'scenario_metadata']
    classifier_dataset_hparams['start-at'] = start_at
    classifier_dataset_hparams['stop-at'] = stop_at
    my_hdump(classifier_dataset_hparams,
             new_hparams_filename.open("w"),
             indent=2)

    # because we're currently making this dataset, we can't call "get_dataset" but we can still use it to visualize
    classifier_dataset_for_viz = ClassifierDatasetLoader(
        [outdir], use_gt_rope=use_gt_rope)

    t0 = perf_counter()
    total_example_idx = 0
    for mode in ['train', 'val', 'test']:
        tf_dataset = dataset.get_datasets(mode=mode, take=take)

        full_output_directory = outdir / mode
        full_output_directory.mkdir(parents=True, exist_ok=True)

        out_examples_gen = generate_classifier_examples(
            fwd_models, tf_dataset, dataset, labeling_params, batch_size)
        for out_examples in out_examples_gen:
            for out_examples_for_start_t in out_examples:
                actual_batch_size = out_examples_for_start_t['traj_idx'].shape[
                    0]
                for batch_idx in range(actual_batch_size):
                    out_example_b = index_dict_of_batched_tensors_tf(
                        out_examples_for_start_t, batch_idx)

                    if out_example_b['time_idx'].ndim == 0:
                        continue

                    if visualize:
                        add_label(out_example_b, labeling_params['threshold'])
                        classifier_dataset_for_viz.anim_transition_rviz(
                            out_example_b)

                    tf_write_example(full_output_directory, out_example_b,
                                     total_example_idx)
                    rospy.loginfo_throttle(
                        10,
                        f"Examples: {total_example_idx:10d}, Time: {perf_counter() - t0:.3f}"
                    )
                    total_example_idx += 1

    return outdir
def main():
    colorama.init(autoreset=True)
    plt.style.use("slides")
    np.set_printoptions(suppress=True, linewidth=250, precision=5)

    parser = argparse.ArgumentParser(formatter_class=my_formatter)
    parser.add_argument('dataset_dir',
                        type=pathlib.Path,
                        help='dataset directory',
                        nargs='+')
    parser.add_argument('--plot-type',
                        choices=['3d', 'sanity_check', 'just_count'],
                        default='3d')
    parser.add_argument('--take', type=int)
    parser.add_argument('--start-at', type=int)
    parser.add_argument('--mode',
                        choices=['train', 'test', 'val', 'all'],
                        default='all',
                        help='train test or val')
    parser.add_argument('--shuffle', action='store_true', help='shuffle')

    args = parser.parse_args()

    rospy.init_node("visualize_dynamics_dataset")

    np.random.seed(1)
    tf.random.set_seed(1)

    # load the dataset
    dataset = DynamicsDatasetLoader(args.dataset_dir)
    tf_dataset = dataset.get_datasets(mode=args.mode, take=args.take)

    if args.shuffle:
        tf_dataset = tf_dataset.shuffle(1024, seed=1)

    # print info about shapes
    example = next(iter(tf_dataset))
    print("Example:")
    for k, v in example.items():
        print(k, v.shape)

    if args.plot_type == '3d':
        # uses rviz
        plot_3d(args, dataset, tf_dataset)
    elif args.plot_type == 'sanity_check':
        min_x = 100
        max_x = -100
        min_y = 100
        max_y = -100
        min_z = 100
        max_z = -100
        min_d = 100
        max_d = -100
        for example in tf_dataset:
            distances_between_grippers = tf.linalg.norm(example['gripper2'] -
                                                        example['gripper1'],
                                                        axis=-1)
            min_d = min(
                tf.reduce_min(distances_between_grippers).numpy(), min_d)
            max_d = max(
                tf.reduce_max(distances_between_grippers).numpy(), max_d)
            rope = example['link_bot']
            points = tf.reshape(rope, [rope.shape[0], -1, 3])
            min_x = min(tf.reduce_min(points[:, :, 0]).numpy(), min_x)
            max_x = max(tf.reduce_max(points[:, :, 0]).numpy(), max_x)
            min_y = min(tf.reduce_min(points[:, :, 1]).numpy(), min_y)
            max_y = max(tf.reduce_max(points[:, :, 1]).numpy(), max_y)
            min_z = min(tf.reduce_min(points[:, :, 2]).numpy(), min_z)
            max_z = max(tf.reduce_max(points[:, :, 2]).numpy(), max_z)
        print(min_d, max_d)
        print(min_x, max_x, min_y, max_y, min_z, max_z)
    elif args.plot_type == 'just_count':
        i = 0
        for _ in tf_dataset:
            i += 1
        print(f'num examples {i}')
    def __init__(self,
                 planner: MyPlanner,
                 trials: List[int],
                 verbose: int,
                 planner_params: Dict,
                 service_provider: BaseServices,
                 no_execution: bool,
                 test_scenes_dir: Optional[pathlib.Path] = None,
                 save_test_scenes_dir: Optional[pathlib.Path] = None,
                 ):
        self.planner = planner
        self.scenario = self.planner.scenario
        self.scenario.on_before_get_state_or_execute_action()
        self.trials = trials
        self.planner_params = planner_params
        self.verbose = verbose
        self.service_provider = service_provider
        self.no_execution = no_execution
        self.env_rng = np.random.RandomState(0)
        self.goal_rng = np.random.RandomState(0)
        self.recovery_rng = np.random.RandomState(0)
        self.test_scenes_dir = test_scenes_dir
        self.save_test_scenes_dir = save_test_scenes_dir
        if self.planner_params['recovery']['use_recovery']:
            recovery_model_dir = pathlib.Path(self.planner_params['recovery']['recovery_model_dir'])
            self.recovery_policy = recovery_policy_utils.load_generic_model(model_dir=recovery_model_dir,
                                                                            scenario=self.scenario,
                                                                            rng=self.recovery_rng)
        else:
            self.recovery_policy = None

        self.n_failures = 0

        # for saving snapshots of the world
        self.link_states_listener = Listener("gazebo/link_states", LinkStates)

        # Debugging
        if self.verbose >= 2:
            self.goal_bbox_pub = rospy.Publisher('goal_bbox', BoundingBox, queue_size=10, latch=True)
            bbox_msg = extent_to_bbox(planner_params['goal_params']['extent'])
            bbox_msg.header.frame_id = 'world'
            self.goal_bbox_pub.publish(bbox_msg)

        goal_params = self.planner_params['goal_params']
        if goal_params['type'] == 'fixed':
            self.goal_generator = lambda e: numpify(goal_params['goal_fixed'])
        elif goal_params['type'] == 'random':
            self.goal_generator = lambda e: self.scenario.sample_goal(environment=e,
                                                                      rng=self.goal_rng,
                                                                      planner_params=self.planner_params)
        elif goal_params['type'] == 'dataset':
            dataset = DynamicsDatasetLoader([pathlib.Path(goal_params['goals_dataset'])])
            tf_dataset = dataset.get_datasets(mode='val')
            goal_dataset_iterator = iter(tf_dataset)

            def _gen(e):
                example = next(goal_dataset_iterator)
                example_t = dataset.index_time_batched(example_batched=add_batch(example), t=1)
                goal = remove_batch(example_t)
                return goal

            self.goal_generator = _gen
        else:
            raise NotImplementedError(f"invalid goal param type {goal_params['type']}")
Exemple #17
0
def make_recovery_dataset_from_params_dict(dataset_dir: pathlib.Path,
                                           fwd_model_dir,
                                           classifier_model_dir: pathlib.Path,
                                           labeling_params: Dict,
                                           outdir: pathlib.Path,
                                           batch_size: int,
                                           use_gt_rope: bool,
                                           start_at: Optional = None,
                                           stop_at: Optional = None):
    # append "best_checkpoint" before loading
    classifier_model_dir = classifier_model_dir / 'best_checkpoint'
    if not isinstance(fwd_model_dir, List):
        fwd_model_dir = [fwd_model_dir]
    fwd_model_dir = [p / 'best_checkpoint' for p in fwd_model_dir]

    np.random.seed(0)
    tf.random.set_seed(0)

    dynamics_hparams = hjson.load((dataset_dir / 'hparams.hjson').open('r'))
    fwd_model, _ = dynamics_utils.load_generic_model(fwd_model_dir)

    dataset = DynamicsDatasetLoader([dataset_dir], use_gt_rope=use_gt_rope)

    outdir.mkdir(exist_ok=True)
    print(Fore.GREEN + f"Making recovery dataset {outdir.as_posix()}")
    new_hparams_filename = outdir / 'hparams.hjson'
    recovery_dataset_hparams = dynamics_hparams

    scenario = fwd_model.scenario
    if not isinstance(classifier_model_dir, List):
        classifier_model_dir = [classifier_model_dir]
    classifier_model = classifier_utils.load_generic_model(
        classifier_model_dir, scenario)

    recovery_dataset_hparams['dataset_dir'] = dataset_dir
    recovery_dataset_hparams['fwd_model_dir'] = fwd_model_dir
    recovery_dataset_hparams['classifier_model'] = classifier_model_dir
    recovery_dataset_hparams['fwd_model_hparams'] = fwd_model.hparams
    recovery_dataset_hparams['labeling_params'] = labeling_params
    recovery_dataset_hparams['state_keys'] = fwd_model.state_keys
    recovery_dataset_hparams['action_keys'] = fwd_model.action_keys
    recovery_dataset_hparams['start-at'] = start_at
    recovery_dataset_hparams['stop-at'] = stop_at
    my_hdump(recovery_dataset_hparams,
             new_hparams_filename.open("w"),
             indent=2)

    outdir.mkdir(parents=True, exist_ok=True)

    start_at = progress_point(start_at)
    stop_at = progress_point(stop_at)

    modes = ['train', 'val', 'test']
    for mode in modes:
        if start_at is not None and modes.index(mode) < modes.index(
                start_at[0]):
            continue
        if stop_at is not None and modes.index(mode) > modes.index(stop_at[0]):
            continue

        tf_dataset_for_mode = dataset.get_datasets(mode=mode)

        full_output_directory = outdir / mode
        full_output_directory.mkdir(parents=True, exist_ok=True)

        # figure out that record_idx to start at
        record_idx = count_up_to_next_record_idx(full_output_directory)

        # FIXME: start_at is not implemented correctly in the sense that it shouldn't be the same
        #  across train/val/test
        for out_example in generate_recovery_examples(
                tf_dataset=tf_dataset_for_mode,
                modes=modes,
                mode=mode,
                fwd_model=fwd_model,
                classifier_model=classifier_model,
                dataset=dataset,
                labeling_params=labeling_params,
                batch_size=batch_size,
                start_at=start_at,
                stop_at=stop_at):
            # FIXME: is there an extra time/batch dimension?
            for batch_idx in range(out_example['traj_idx'].shape[0]):
                out_example_b = index_dict_of_batched_tensors_tf(
                    out_example, batch_idx)

                # # BEGIN DEBUG
                # from link_bot_data.visualization import init_viz_env, recovery_transition_viz_t, init_viz_action
                # from copy import deepcopy
                #
                # viz_out_example_b = deepcopy(out_example_b)
                # recovery_probability = compute_recovery_probabilities(viz_out_example_b['accept_probabilities'],
                #                                                       labeling_params['n_action_samples'])
                # viz_out_example_b['recovery_probability'] = recovery_probability
                # anim = RvizAnimation(scenario=scenario,
                #                      n_time_steps=labeling_params['action_sequence_horizon'],
                #                      init_funcs=[init_viz_env,
                #                                  init_viz_action(dataset.scenario_metadata, fwd_model.action_keys,
                #                                                  fwd_model.state_keys),
                #                                  ],
                #                      t_funcs=[init_viz_env,
                #                               recovery_transition_viz_t(dataset.scenario_metadata,
                #                                                         fwd_model.state_keys),
                #                               lambda s, e, t: scenario.plot_recovery_probability_t(e, t),
                #                               ])
                # anim.play(viz_out_example_b)
                # # END DEBUG

                tf_write_example(full_output_directory, out_example_b,
                                 record_idx)
                record_idx += 1

    return outdir