Exemplo n.º 1
0
    def __init__(self, model_dirs: List[pathlib.Path], batch_size: int,
                 scenario: ExperimentScenario):
        representative_model_dir = model_dirs[0]
        _, self.hparams = load_trial(
            representative_model_dir.parent.absolute())

        self.scenario = scenario
        self.batch_size = batch_size
        self.data_collection_params = self.hparams['dynamics_dataset_hparams'][
            'data_collection_params']
        self.state_description = self.hparams['dynamics_dataset_hparams'][
            'state_description']
        self.action_description = self.hparams['dynamics_dataset_hparams'][
            'action_description']

        self.nets: List[MyKerasModel] = []
        for model_dir in model_dirs:
            net, ckpt = self.make_net_and_checkpoint(batch_size, scenario)
            manager = tf.train.CheckpointManager(ckpt,
                                                 model_dir,
                                                 max_to_keep=1)

            status = ckpt.restore(manager.latest_checkpoint).expect_partial()
            if manager.latest_checkpoint:
                print(Fore.CYAN +
                      "Restored from {}".format(manager.latest_checkpoint) +
                      Fore.RESET)
                status.assert_existing_objects_matched()
            else:
                raise RuntimeError("Failed to restore!!!")

            self.nets.append(net)
    def __init__(self, paths: List[pathlib.Path], batch_size: int, scenario: ExperimentScenario):
        super().__init__(paths, scenario)
        # FIXME: Bad API design
        assert isinstance(scenario, Base3DScenario)
        representative_model_dir = paths[0]
        _, self.hparams = load_trial(representative_model_dir.parent.absolute())

        self.dataset_labeling_params = self.hparams['classifier_dataset_hparams']['labeling_params']
        self.data_collection_params = self.hparams['classifier_dataset_hparams']['data_collection_params']
        self.horizon = self.dataset_labeling_params['classifier_horizon']

        net_class_name = self.get_net_class()

        self.nets = []
        for model_dir in paths:
            net = net_class_name(hparams=self.hparams, batch_size=batch_size, scenario=scenario)

            ckpt = tf.train.Checkpoint(model=net)
            manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1)

            status = ckpt.restore(manager.latest_checkpoint).expect_partial()
            if manager.latest_checkpoint:
                print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint) + Fore.RESET)
                if manager.latest_checkpoint:
                    status.assert_existing_objects_matched()
            else:
                raise RuntimeError(f"Failed to restore {manager.latest_checkpoint}!!!")

            self.nets.append(net)

            self.state_keys = net.state_keys
            self.action_keys = net.action_keys
            self.true_state_keys = net.true_state_keys
            self.pred_state_keys = net.pred_state_keys
def load_generic_model(model_dir: pathlib.Path, scenario: ExperimentScenario,
                       rng: np.random.RandomState):
    _, hparams = load_trial(model_dir.parent.absolute())

    model_class = hparams['model_class']
    if model_class == 'simple':
        return SimpleRecoveryPolicy(hparams, model_dir, scenario, rng)
    elif model_class == 'random':
        return RandomRecoveryPolicy(hparams, model_dir, scenario, rng)
    elif model_class == 'nn':
        return NNRecoveryPolicy(hparams, model_dir, scenario, rng)
    else:
        raise NotImplementedError(f"model type {model_class} not implemented")
Exemplo n.º 4
0
def load_filter(model_dirs: List[pathlib.Path],
                scenario: ExperimentScenario = None) -> BaseFilterFunction:
    representative_model_dir = model_dirs[0]
    _, common_hparams = load_trial(representative_model_dir.parent.absolute())
    if scenario is None:
        scenario_name = common_hparams['dynamics_dataset_hparams']['scenario']
        scenario = get_scenario(scenario_name)
    model_type = common_hparams['model_class']
    if model_type == 'CFM':
        nn = CFMFilter(model_dirs, batch_size=1, scenario=scenario)
        return nn
    elif model_type in ['none', 'pass-through']:
        return PassThroughFilter()
    else:
        raise NotImplementedError("invalid model type {}".format(model_type))
Exemplo n.º 5
0
def viz_main(args):
    dataset_dirs = args.dataset_dirs
    checkpoint = args.checkpoint

    trial_path, params = load_trial(checkpoint.parent.absolute())

    dataset = DynamicsDatasetLoader(dataset_dirs)
    scenario = dataset.scenario

    tf_dataset = dataset.get_datasets(mode='val')
    tf_dataset = batch_tf_dataset(tf_dataset,
                                  batch_size=1,
                                  drop_remainder=True)

    model = CFM(hparams=params, batch_size=1, scenario=scenario)
    ckpt = tf.train.Checkpoint(model=model)
    manager = tf.train.CheckpointManager(ckpt, args.checkpoint, max_to_keep=1)
    status = ckpt.restore(manager.latest_checkpoint).expect_partial()
    if manager.latest_checkpoint:
        print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint))
        status.assert_existing_objects_matched()
    else:
        raise RuntimeError("Failed to restore!!!")

    for example_idx, example in enumerate(tf_dataset):
        stepper = RvizAnimationController(n_time_steps=dataset.steps_per_traj)
        for t in range(dataset.steps_per_traj):
            output = model(
                model.preprocess_no_gradient(example, training=False))

            actual_t = numpify(
                remove_batch(scenario.index_time_batched_predicted(example,
                                                                   t)))
            action_t = numpify(
                remove_batch(scenario.index_time_batched_predicted(example,
                                                                   t)))
            scenario.plot_state_rviz(actual_t, label='actual', color='red')
            scenario.plot_action_rviz(actual_t, action_t, color='gray')
            prediction_t = remove_batch(
                scenario.index_time_batched_predicted(output, t))
            scenario.plot_state_rviz(prediction_t,
                                     label='predicted',
                                     color='blue')

            stepper.step()
Exemplo n.º 6
0
def train_main(args):
    dataset_dirs = args.dataset_dirs
    checkpoint = args.checkpoint
    epochs = args.epochs
    trial_path, params = load_trial(checkpoint.parent.absolute())
    now = str(time())
    trial_path = trial_path.parent / (trial_path.name + '-observer-' + now)
    trial_path.mkdir(parents=True)
    batch_size = params['batch_size']
    params['encoder_trainable'] = False
    params['use_observation_feature_loss'] = True
    params['use_cfm_loss'] = False
    out_hparams_filename = trial_path / 'params.json'
    out_params_str = json.dumps(params)
    with out_hparams_filename.open("w") as out_hparams_file:
        out_hparams_file.write(out_params_str)

    train_dataset = DynamicsDatasetLoader(dataset_dirs)
    val_dataset = DynamicsDatasetLoader(dataset_dirs)

    model_class = state_space_dynamics.get_model(params['model_class'])
    model = model_class(hparams=params,
                        batch_size=batch_size,
                        scenario=train_dataset.scenario)

    seed = 0

    runner = ModelRunner(model=model,
                         training=True,
                         params=params,
                         checkpoint=checkpoint,
                         batch_metadata=train_dataset.batch_metadata,
                         trial_path=trial_path)

    train_tf_dataset, val_tf_dataset = train_test.setup_datasets(
        model_hparams=params,
        batch_size=batch_size,
        seed=seed,
        train_dataset=train_dataset,
        val_dataset=val_dataset)

    runner.train(train_tf_dataset, val_tf_dataset, num_epochs=epochs)

    return trial_path
Exemplo n.º 7
0
def load_generic_model(
        model_dirs: List[pathlib.Path],
        scenario: Optional[ExperimentScenario] = None
) -> BaseConstraintChecker:
    # FIXME: remove batch_size=1 here? can I put it in base model?
    # we use the first model and assume they all have the same hparams
    representative_model_dir = model_dirs[0]
    _, common_hparams = load_trial(representative_model_dir.parent.absolute())
    if scenario is None:
        scenario_name = common_hparams['scenario']
        scenario = get_scenario(scenario_name)
    model_type = common_hparams['model_class']
    if model_type == 'rnn':
        return NNClassifierWrapper(model_dirs, batch_size=1, scenario=scenario)
    elif model_type == 'collision':
        return CollisionCheckerClassifier(model_dirs, scenario=scenario)
    elif model_type == 'none':
        return NoneClassifier(model_dirs, scenario=scenario)
    elif model_type == 'gripper_distance':
        return GripperDistanceClassifier(model_dirs, scenario=scenario)
    else:
        raise NotImplementedError("invalid model type {}".format(model_type))
Exemplo n.º 8
0
def load_generic_model(
        model_dirs: List[pathlib.Path]
) -> Tuple[BaseDynamicsFunction, Tuple[str]]:
    # FIXME: remove batch_size=1 here? can I put it in base model?
    # we use the first model and assume they all have the same hparams
    representative_model_dir = model_dirs[0]
    _, common_hparams = load_trial(representative_model_dir.parent.absolute())
    scenario_name = common_hparams['dynamics_dataset_hparams']['scenario']
    scenario = get_scenario(scenario_name)
    model_type = common_hparams['model_class']
    if model_type == 'SimpleNN':
        nn = UDNNWrapper(model_dirs, batch_size=1, scenario=scenario)
        return nn, representative_model_dir.parts[1:]
    elif model_type == 'ImageCondDyn':
        nn = ImageCondDynamicsWrapper(model_dirs,
                                      batch_size=1,
                                      scenario=scenario)
        return nn, representative_model_dir.parts[1:]
    elif model_type == 'CFM':
        nn = CFMLatentDynamics(model_dirs, batch_size=1, scenario=scenario)
        return nn, representative_model_dir.parts[1:]
    else:
        raise NotImplementedError("invalid model type {}".format(model_type))