def __init__(self, model_dirs: List[pathlib.Path], batch_size: int, scenario: ExperimentScenario): representative_model_dir = model_dirs[0] _, self.hparams = load_trial( representative_model_dir.parent.absolute()) self.scenario = scenario self.batch_size = batch_size self.data_collection_params = self.hparams['dynamics_dataset_hparams'][ 'data_collection_params'] self.state_description = self.hparams['dynamics_dataset_hparams'][ 'state_description'] self.action_description = self.hparams['dynamics_dataset_hparams'][ 'action_description'] self.nets: List[MyKerasModel] = [] for model_dir in model_dirs: net, ckpt = self.make_net_and_checkpoint(batch_size, scenario) manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1) status = ckpt.restore(manager.latest_checkpoint).expect_partial() if manager.latest_checkpoint: print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint) + Fore.RESET) status.assert_existing_objects_matched() else: raise RuntimeError("Failed to restore!!!") self.nets.append(net)
def __init__(self, paths: List[pathlib.Path], batch_size: int, scenario: ExperimentScenario): super().__init__(paths, scenario) # FIXME: Bad API design assert isinstance(scenario, Base3DScenario) representative_model_dir = paths[0] _, self.hparams = load_trial(representative_model_dir.parent.absolute()) self.dataset_labeling_params = self.hparams['classifier_dataset_hparams']['labeling_params'] self.data_collection_params = self.hparams['classifier_dataset_hparams']['data_collection_params'] self.horizon = self.dataset_labeling_params['classifier_horizon'] net_class_name = self.get_net_class() self.nets = [] for model_dir in paths: net = net_class_name(hparams=self.hparams, batch_size=batch_size, scenario=scenario) ckpt = tf.train.Checkpoint(model=net) manager = tf.train.CheckpointManager(ckpt, model_dir, max_to_keep=1) status = ckpt.restore(manager.latest_checkpoint).expect_partial() if manager.latest_checkpoint: print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint) + Fore.RESET) if manager.latest_checkpoint: status.assert_existing_objects_matched() else: raise RuntimeError(f"Failed to restore {manager.latest_checkpoint}!!!") self.nets.append(net) self.state_keys = net.state_keys self.action_keys = net.action_keys self.true_state_keys = net.true_state_keys self.pred_state_keys = net.pred_state_keys
def load_generic_model(model_dir: pathlib.Path, scenario: ExperimentScenario, rng: np.random.RandomState): _, hparams = load_trial(model_dir.parent.absolute()) model_class = hparams['model_class'] if model_class == 'simple': return SimpleRecoveryPolicy(hparams, model_dir, scenario, rng) elif model_class == 'random': return RandomRecoveryPolicy(hparams, model_dir, scenario, rng) elif model_class == 'nn': return NNRecoveryPolicy(hparams, model_dir, scenario, rng) else: raise NotImplementedError(f"model type {model_class} not implemented")
def load_filter(model_dirs: List[pathlib.Path], scenario: ExperimentScenario = None) -> BaseFilterFunction: representative_model_dir = model_dirs[0] _, common_hparams = load_trial(representative_model_dir.parent.absolute()) if scenario is None: scenario_name = common_hparams['dynamics_dataset_hparams']['scenario'] scenario = get_scenario(scenario_name) model_type = common_hparams['model_class'] if model_type == 'CFM': nn = CFMFilter(model_dirs, batch_size=1, scenario=scenario) return nn elif model_type in ['none', 'pass-through']: return PassThroughFilter() else: raise NotImplementedError("invalid model type {}".format(model_type))
def viz_main(args): dataset_dirs = args.dataset_dirs checkpoint = args.checkpoint trial_path, params = load_trial(checkpoint.parent.absolute()) dataset = DynamicsDatasetLoader(dataset_dirs) scenario = dataset.scenario tf_dataset = dataset.get_datasets(mode='val') tf_dataset = batch_tf_dataset(tf_dataset, batch_size=1, drop_remainder=True) model = CFM(hparams=params, batch_size=1, scenario=scenario) ckpt = tf.train.Checkpoint(model=model) manager = tf.train.CheckpointManager(ckpt, args.checkpoint, max_to_keep=1) status = ckpt.restore(manager.latest_checkpoint).expect_partial() if manager.latest_checkpoint: print(Fore.CYAN + "Restored from {}".format(manager.latest_checkpoint)) status.assert_existing_objects_matched() else: raise RuntimeError("Failed to restore!!!") for example_idx, example in enumerate(tf_dataset): stepper = RvizAnimationController(n_time_steps=dataset.steps_per_traj) for t in range(dataset.steps_per_traj): output = model( model.preprocess_no_gradient(example, training=False)) actual_t = numpify( remove_batch(scenario.index_time_batched_predicted(example, t))) action_t = numpify( remove_batch(scenario.index_time_batched_predicted(example, t))) scenario.plot_state_rviz(actual_t, label='actual', color='red') scenario.plot_action_rviz(actual_t, action_t, color='gray') prediction_t = remove_batch( scenario.index_time_batched_predicted(output, t)) scenario.plot_state_rviz(prediction_t, label='predicted', color='blue') stepper.step()
def train_main(args): dataset_dirs = args.dataset_dirs checkpoint = args.checkpoint epochs = args.epochs trial_path, params = load_trial(checkpoint.parent.absolute()) now = str(time()) trial_path = trial_path.parent / (trial_path.name + '-observer-' + now) trial_path.mkdir(parents=True) batch_size = params['batch_size'] params['encoder_trainable'] = False params['use_observation_feature_loss'] = True params['use_cfm_loss'] = False out_hparams_filename = trial_path / 'params.json' out_params_str = json.dumps(params) with out_hparams_filename.open("w") as out_hparams_file: out_hparams_file.write(out_params_str) train_dataset = DynamicsDatasetLoader(dataset_dirs) val_dataset = DynamicsDatasetLoader(dataset_dirs) model_class = state_space_dynamics.get_model(params['model_class']) model = model_class(hparams=params, batch_size=batch_size, scenario=train_dataset.scenario) seed = 0 runner = ModelRunner(model=model, training=True, params=params, checkpoint=checkpoint, batch_metadata=train_dataset.batch_metadata, trial_path=trial_path) train_tf_dataset, val_tf_dataset = train_test.setup_datasets( model_hparams=params, batch_size=batch_size, seed=seed, train_dataset=train_dataset, val_dataset=val_dataset) runner.train(train_tf_dataset, val_tf_dataset, num_epochs=epochs) return trial_path
def load_generic_model( model_dirs: List[pathlib.Path], scenario: Optional[ExperimentScenario] = None ) -> BaseConstraintChecker: # FIXME: remove batch_size=1 here? can I put it in base model? # we use the first model and assume they all have the same hparams representative_model_dir = model_dirs[0] _, common_hparams = load_trial(representative_model_dir.parent.absolute()) if scenario is None: scenario_name = common_hparams['scenario'] scenario = get_scenario(scenario_name) model_type = common_hparams['model_class'] if model_type == 'rnn': return NNClassifierWrapper(model_dirs, batch_size=1, scenario=scenario) elif model_type == 'collision': return CollisionCheckerClassifier(model_dirs, scenario=scenario) elif model_type == 'none': return NoneClassifier(model_dirs, scenario=scenario) elif model_type == 'gripper_distance': return GripperDistanceClassifier(model_dirs, scenario=scenario) else: raise NotImplementedError("invalid model type {}".format(model_type))
def load_generic_model( model_dirs: List[pathlib.Path] ) -> Tuple[BaseDynamicsFunction, Tuple[str]]: # FIXME: remove batch_size=1 here? can I put it in base model? # we use the first model and assume they all have the same hparams representative_model_dir = model_dirs[0] _, common_hparams = load_trial(representative_model_dir.parent.absolute()) scenario_name = common_hparams['dynamics_dataset_hparams']['scenario'] scenario = get_scenario(scenario_name) model_type = common_hparams['model_class'] if model_type == 'SimpleNN': nn = UDNNWrapper(model_dirs, batch_size=1, scenario=scenario) return nn, representative_model_dir.parts[1:] elif model_type == 'ImageCondDyn': nn = ImageCondDynamicsWrapper(model_dirs, batch_size=1, scenario=scenario) return nn, representative_model_dir.parts[1:] elif model_type == 'CFM': nn = CFMLatentDynamics(model_dirs, batch_size=1, scenario=scenario) return nn, representative_model_dir.parts[1:] else: raise NotImplementedError("invalid model type {}".format(model_type))