Beispiel #1
0
    def test_get_policy_model_files(self):
        output_dir = self.get_temp_dir()

        def write_policy_model_file(epoch):
            with gfile.GFile(
                    ppo.get_policy_model_file_from_epoch(output_dir, epoch),
                    'w') as f:
                f.write('some data')

        epochs = [200, 100, 300]

        # 300, 200, 100
        expected_policy_model_files = [
            output_dir + '/model-000300.pkl',
            output_dir + '/model-000200.pkl',
            output_dir + '/model-000100.pkl',
        ]

        for epoch in epochs:
            write_policy_model_file(epoch)

        policy_model_files = ppo.get_policy_model_files(output_dir)

        self.assertEqual(expected_policy_model_files, policy_model_files)

        gfile.rmtree(output_dir)
Beispiel #2
0
 def tmp_dir(self):
   tmp = tempfile.mkdtemp(dir=test.get_temp_dir())
   yield tmp
   gfile.rmtree(tmp)
  def export(self, path, session, overwrite=False):
    """Build the TF-Hub spec, module and sync ops."""

    method_specs = {}

    def module_fn():
      """A module_fn for use with hub.create_module_spec()."""
      # We will use a copy of the original object to build the graph.
      wrapped_object = self._object_factory()

      for method_name, method_info in self._captured_calls.items():
        captured_inputs, captured_specs = method_info
        tensor_inputs = nest.map_structure(_to_placeholder, captured_inputs)
        method_to_call = getattr(wrapped_object, method_name)
        tensor_outputs = method_to_call(**tensor_inputs)

        flat_tensor_inputs = nest.flatten(tensor_inputs)
        flat_tensor_inputs = {
            str(k): v for k, v in zip(
                range(len(flat_tensor_inputs)), flat_tensor_inputs)
        }
        flat_tensor_outputs = nest.flatten(tensor_outputs)
        flat_tensor_outputs = {
            str(k): v for k, v in zip(
                range(len(flat_tensor_outputs)), flat_tensor_outputs)
        }

        method_specs[method_name] = dict(
            specs=captured_specs,
            inputs=nest.map_structure(lambda _: None, tensor_inputs),
            outputs=nest.map_structure(lambda _: None, tensor_outputs))

        signature_name = ("default"
                          if method_name == "__call__" else method_name)
        hub.add_signature(signature_name, flat_tensor_inputs,
                          flat_tensor_outputs)

      hub.attach_message(
          "methods", tf.train.BytesList(value=[pickle.dumps(method_specs)]))
      hub.attach_message(
          "properties",
          tf.train.BytesList(value=[pickle.dumps(self._captured_attrs)]))

    # Create the spec that will be later used in export.
    hub_spec = hub.create_module_spec(module_fn, drop_collections=["sonnet"])

    # Get variables values
    module_weights = [
        session.run(v) for v in self._wrapped_object.get_all_variables()
    ]

    # create the sync ops
    with tf.Graph().as_default():
      hub_module = hub.Module(hub_spec, trainable=True, name="hub")

      assign_ops = []
      assign_phs = []
      for _, v in sorted(hub_module.variable_map.items()):
        ph = tf.placeholder(shape=v.shape, dtype=v.dtype)
        assign_phs.append(ph)
        assign_ops.append(tf.assign(v, ph))

      with tf.Session() as module_session:
        module_session.run(tf.local_variables_initializer())
        module_session.run(tf.global_variables_initializer())
        module_session.run(
            assign_ops, feed_dict=dict(zip(assign_phs, module_weights)))

        if overwrite and gfile.exists(path):
          gfile.rmtree(path)
        gfile.makedirs(path)
        hub_module.export(path, module_session)
Beispiel #4
0
    def test_load_from_directory(self):
        output_dir = self.get_temp_dir()

        epochs = [0, 1, 2]
        env_ids = [0, 1, 2]
        temperatures = [0.5, 1.0]
        random_strings = ["a", "b"]

        # Write some trajectories.
        # There are 3x3x2x2 (36) trajectories, and of them 3x2x2 (12) are done.
        for epoch in epochs:
            for env_id in env_ids:
                for temperature in temperatures:
                    for random_string in random_strings:
                        traj = trajectory.Trajectory(time_steps=[
                            time_step.TimeStep(observation=epoch,
                                               done=(epoch == 0),
                                               raw_reward=1.0,
                                               processed_reward=1.0,
                                               action=env_id,
                                               info={})
                        ])

                        trajectory_file_name = trajectory.TRAJECTORY_FILE_FORMAT.format(
                            epoch=epoch,
                            env_id=env_id,
                            temperature=temperature,
                            r=random_string)

                        with gfile.GFile(
                                os.path.join(output_dir, trajectory_file_name),
                                "w") as f:
                            trajectory.get_pickle_module().dump(traj, f)

        # Load everything and check.
        bt = trajectory.BatchTrajectory.load_from_directory(output_dir)

        self.assertIsInstance(bt, trajectory.BatchTrajectory)
        self.assertEqual(36, bt.num_completed_trajectories)
        self.assertEqual(36, bt.batch_size)

        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=0)
        self.assertEqual(12, bt.num_completed_trajectories)
        self.assertEqual(12, bt.batch_size)

        # Get 100 trajectories, but there aren't any.
        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=0,
                                                            n_trajectories=100,
                                                            max_tries=0)
        self.assertIsNone(bt)

        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=0,
                                                            temperature=0.5)
        self.assertEqual(6, bt.num_completed_trajectories)
        self.assertEqual(6, bt.batch_size)

        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=1)
        self.assertEqual(12, bt.num_completed_trajectories)
        self.assertEqual(12, bt.batch_size)

        # Constraints cannot be satisfied.
        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=1,
                                                            n_trajectories=100,
                                                            up_sample=False,
                                                            max_tries=0)
        self.assertIsNone(bt)

        # Constraints can be satisfied.
        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=1,
                                                            n_trajectories=100,
                                                            up_sample=True,
                                                            max_tries=0)
        self.assertEqual(100, bt.num_completed_trajectories)
        self.assertEqual(100, bt.batch_size)

        bt = trajectory.BatchTrajectory.load_from_directory(output_dir,
                                                            epoch=1,
                                                            n_trajectories=10)
        self.assertEqual(10, bt.num_completed_trajectories)
        self.assertEqual(10, bt.batch_size)

        gfile.rmtree(output_dir)
Beispiel #5
0
 def tmp_dir(self):
     tmp = tempfile.mkdtemp()
     yield tmp
     gfile.rmtree(tmp)