예제 #1
0
    def testSaveGetInitialState(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        network = q_rnn_network.QRnnNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver_nobatch = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(),
                            'save_model_initial_state_nobatch')
        saver_nobatch.save(path)
        reloaded_nobatch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_nobatch.signatures)
        reloaded_get_initial_state = (
            reloaded_nobatch.signatures['get_initial_state'])
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(tf.TensorSpec(dtype=tf.int32,
                                                shape=(),
                                                name='batch_size'), ),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=None)

        initial_state = policy.get_initial_state(batch_size=3)
        initial_state = self.evaluate(initial_state)

        reloaded_nobatch_initial_state = reloaded_nobatch.get_initial_state(
            batch_size=3)
        reloaded_nobatch_initial_state = self.evaluate(
            reloaded_nobatch_initial_state)
        tf.nest.map_structure(self.assertAllClose, initial_state,
                              reloaded_nobatch_initial_state)

        saver_batch = policy_saver.PolicySaver(policy, batch_size=3)
        path = os.path.join(self.get_temp_dir(),
                            'save_model_initial_state_batch')
        saver_batch.save(path)
        reloaded_batch = tf.compat.v2.saved_model.load(path)
        self.assertIn('get_initial_state', reloaded_batch.signatures)
        reloaded_get_initial_state = reloaded_batch.signatures[
            'get_initial_state']
        self._compare_input_output_specs(
            reloaded_get_initial_state,
            expected_input_specs=(),
            expected_output_spec=policy.policy_state_spec,
            batch_input=False,
            batch_size=3)

        reloaded_batch_initial_state = reloaded_batch.get_initial_state()
        reloaded_batch_initial_state = self.evaluate(
            reloaded_batch_initial_state)
        tf.nest.map_structure(self.assertAllClose, initial_state,
                              reloaded_batch_initial_state)
예제 #2
0
    def testVariablesAccessible(self):
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)
        self.evaluate(
            tf.compat.v1.initializers.variables(reloaded.model_variables))

        model_variables = self.evaluate(policy.variables())
        reloaded_model_variables = self.evaluate(reloaded.model_variables)

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal, model_variables,
                              reloaded_model_variables)
예제 #3
0
    def testTrainStepSaved(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        train_step = common.create_variable('train_step', initial_value=7)
        saver = policy_saver.PolicySaver(policy,
                                         batch_size=None,
                                         train_step=train_step)
        path = os.path.join(self.get_temp_dir(), 'save_model')
        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertIn('get_train_step', reloaded.signatures)
        train_step_value = self.evaluate(reloaded.train_step())
        self.assertEqual(7, train_step_value)

        train_step = train_step.assign_add(3)
        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        train_step_value = self.evaluate(reloaded.train_step())
        self.assertEqual(10, train_step_value)
예제 #4
0
    def testUpdateWithCompositeSavedModelAndCheckpoint(self):
        # Create and saved_model for a q_policy.
        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        train_step = common.create_variable('train_step', initial_value=0)
        saver = policy_saver.PolicySaver(policy,
                                         train_step=train_step,
                                         batch_size=None)
        full_model_path = os.path.join(self.get_temp_dir(), 'save_model')

        def assert_val_equal_var(val, var):
            self.assertTrue(np.array_equal(np.full_like(var, val), var))

        self.evaluate(tf.compat.v1.global_variables_initializer())
        # Set all variables in the saved model to 1
        variables = policy.variables()
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + 1), variables))
        for v in self.evaluate(variables):
            assert_val_equal_var(1, v)
        with self.cached_session():
            saver.save(full_model_path)

        # Assign 2 to all variables in the policy. Making checkpoint different than
        # the initial saved_model.
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + 2), variables))
        for v in self.evaluate(variables):
            assert_val_equal_var(2, v)
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        with self.cached_session():
            saver.save_checkpoint(checkpoint_path)

        # Reload the full model and check all variables are 1
        reloaded_policy = tf.compat.v2.saved_model.load(full_model_path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))
        for v in self.evaluate(reloaded_policy.model_variables):
            assert_val_equal_var(1, v)

        # Compose a new full saved model from the original saved model files
        # and variables from the checkpoint.
        composite_path = os.path.join(self.get_temp_dir(), 'composite_model')
        self.copy_tree(full_model_path, composite_path, skip_variables=True)
        self.copy_tree(checkpoint_path, os.path.join(composite_path))

        # Reload the composite model and check all variables are 2
        reloaded_policy = tf.compat.v2.saved_model.load(composite_path)
        self.evaluate(
            tf.compat.v1.initializers.variables(
                reloaded_policy.model_variables))
        for v in self.evaluate(reloaded_policy.model_variables):
            assert_val_equal_var(2, v)
예제 #5
0
    def __init__(
        self, batch_size,
        action_spec,
        time_step_spec,
        n_iterations,
        replay_buffer_max_length,
        learning_rate=1e-3,
        checkpoint_dir=None
    ):
        self.batch_size = batch_size
        self.time_step_spec = time_step_spec
        self.action_spec = action_spec
        observation_spec = self.time_step_spec.observation

        self.actor_net = HierachyActorNetwork(
            observation_spec,
            action_spec,
            n_iterations,
            n_options=4
        )
        value_net = value_network.ValueNetwork(
            observation_spec,
            fc_layer_params=(100,)
        )

        optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
        self.global_step = tf.compat.v1.train.get_or_create_global_step()

        self.agent = ppo_agent.PPOAgent(
            time_step_spec,
            self.action_spec,
            actor_net=self.actor_net,
            value_net=value_net,
            optimizer=optimizer,
            normalize_rewards=True,
            normalize_observations=False,
            train_step_counter=self.global_step
        )
        self.agent.initialize()
        self.agent.train = common.function(self.agent.train)

        self.replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.agent.collect_data_spec,
            batch_size=self.batch_size,
            max_length=replay_buffer_max_length
        )

        self.train_checkpointer = None
        if (checkpoint_dir):
            self.train_checkpointer = common.Checkpointer(
                ckpt_dir=checkpoint_dir,
                max_to_keep=1,
                agent=self.agent,
                policy=self.agent.policy,
                replay_buffer=self.replay_buffer,
                global_step=self.global_step
            )
            self.train_checkpointer.initialize_or_restore()

        self.policy_saver = policy_saver.PolicySaver(self.agent.policy)
예제 #6
0
    def testPolicySaverCompatibility(self):
        observation_spec = {
            'a': tf.TensorSpec(4, tf.float32),
            'b': tf.TensorSpec(3, tf.float32)
        }
        time_step_tensor_spec = ts.time_step_spec(observation_spec)
        net = nest_map.NestMap({
            'a':
            tf.keras.layers.LSTM(8, return_state=True, return_sequences=True),
            'b':
            tf.keras.layers.Dense(8)
        })
        net.create_variables(observation_spec)
        policy = MyPolicy(time_step_tensor_spec, net)

        sample = tensor_spec.sample_spec_nest(time_step_tensor_spec,
                                              outer_dims=(5, ))

        step = policy.action(sample)
        self.assertEqual(step.action.shape.as_list(), [5, 8])

        train_step = common.create_variable('train_step')
        saver = policy_saver.PolicySaver(policy, train_step=train_step)
        self.initialize_v1_variables()

        with self.cached_session():
            saver.save(os.path.join(FLAGS.test_tmpdir, 'nest_map_model'))
def save_model():

  optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate = 1e-3);
  obs_spec = TensorSpec((7,), dtype = tf.float32, name = 'observation');
  action_spec = BoundedTensorSpec((1,), dtype = tf.int32, minimum = 0, maximum = 3, name = 'action');
  actor_net = ActorDistributionRnnNetwork(obs_spec, action_spec, lstm_size = (100,100));
  value_net = ValueRnnNetwork(obs_spec);
  agent = ppo_agent.PPOAgent(
    time_step_spec = time_step_spec(obs_spec),
    action_spec = action_spec,
    optimizer = optimizer,
    actor_net = actor_net,
    value_net = value_net,
    normalize_observations = True,
    normalize_rewards = True,
    use_gae = True,
    num_epochs = 1,
  );
  checkpointer = Checkpointer(
    ckpt_dir = 'checkpoints/policy',
    max_to_keep = 1,
    agent = agent,
    policy = agent.policy,
    global_step = tf.compat.v1.train.get_or_create_global_step());
  checkpointer.initialize_or_restore();
  saver = policy_saver.PolicySaver(agent.policy);
  saver.save('final_policy');
예제 #8
0
  def testMetadataSaved(self):
    # We need to use one default session so that self.evaluate and the
    # SavedModel loader share the same session.
    with tf.compat.v1.Session().as_default():
      network = q_network.QNetwork(
          input_tensor_spec=self._time_step_spec.observation,
          action_spec=self._action_spec)

      policy = q_policy.QPolicy(
          time_step_spec=self._time_step_spec,
          action_spec=self._action_spec,
          q_network=network)
      self.evaluate(tf.compat.v1.initializers.variables(policy.variables()))

      train_step = common.create_variable('train_step', initial_value=1)
      env_step = common.create_variable('env_step', initial_value=7)
      metadata = {'env_step': env_step}
      self.evaluate(tf.compat.v1.initializers.variables([train_step, env_step]))

      saver = policy_saver.PolicySaver(
          policy, batch_size=None, train_step=train_step, metadata=metadata)
      if tf.executing_eagerly():
        loaded_metadata = saver.get_metadata()
      else:
        loaded_metadata = self.evaluate(saver.get_metadata())
      self.assertEqual(self.evaluate(metadata), loaded_metadata)

      path = os.path.join(self.get_temp_dir(), 'save_model')
      saver.save(path)

      reloaded = tf.compat.v2.saved_model.load(path)
      self.evaluate(tf.compat.v1.global_variables_initializer())
      self.assertIn('get_metadata', reloaded.signatures)
      env_step_value = self.evaluate(reloaded.get_metadata())['env_step']
      self.assertEqual(7, env_step_value)
예제 #9
0
  def testGetTrainStep(self, train_step):
    path = os.path.join(self.get_temp_dir(), 'saved_policy')
    if train_step is None:
      # Use the default argument, which should set the train step to be -1.
      saver = policy_saver.PolicySaver(self.tf_policy)
      expected_train_step = -1
    else:
      saver = policy_saver.PolicySaver(
          self.tf_policy, train_step=tf.constant(train_step))
      expected_train_step = train_step
    saver.save(path)

    eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        path, self.time_step_spec, self.action_spec)

    self.assertEqual(expected_train_step, eager_py_policy.get_train_step())
예제 #10
0
    def testTrainStepSaved(self):
        # We need to use one default session so that self.evaluate and the
        # SavedModel loader share the same session.
        with tf.compat.v1.Session().as_default():
            network = q_network.QNetwork(
                input_tensor_spec=self._time_step_spec.observation,
                action_spec=self._action_spec)

            policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                      action_spec=self._action_spec,
                                      q_network=network)

            train_step = common.create_variable('train_step', initial_value=7)
            self.evaluate(tf.compat.v1.initializers.variables([train_step]))

            saver = policy_saver.PolicySaver(policy,
                                             batch_size=None,
                                             train_step=train_step)
            path = os.path.join(self.get_temp_dir(), 'save_model')
            saver.save(path)

            reloaded = tf.compat.v2.saved_model.load(path)
            self.assertIn('get_train_step', reloaded.signatures)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            train_step_value = self.evaluate(reloaded.train_step())
            self.assertEqual(7, train_step_value)
            train_step = train_step.assign_add(3)
            self.evaluate(train_step)
            saver.save(path)

            reloaded = tf.compat.v2.saved_model.load(path)
            self.evaluate(tf.compat.v1.global_variables_initializer())
            train_step_value = self.evaluate(reloaded.train_step())
            self.assertEqual(10, train_step_value)
예제 #11
0
    def testRegisterConcreteFunction(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)

        tf_var = tf.Variable(3)

        def add(b):
            return tf_var + b

        add_fn = common.function(add)
        # Called for side effect.
        add_fn.get_concrete_function(tf.TensorSpec((), dtype=tf.int32))

        saver.register_concrete_function(name='add', fn=add_fn)

        path = os.path.join(self.get_temp_dir(), 'save_model')
        saver.save(path)
        reloaded = tf.compat.v2.saved_model.load(path)

        self.assertAllClose(7, reloaded.add(4))
예제 #12
0
    def save_model(self):
        # create directory to save checkpoint
        checkpoint_dir = os.path.join(self._savedir, 'checkpoint')
        train_checkpointer = common.Checkpointer(
            ckpt_dir=checkpoint_dir,
            max_to_keep=1,
            agent=self._agent,
            policy=self._agent.policy,
            replay_buffer=self._replay_buffer,
            global_step=self._train_step)
        # save the checkpoint
        train_checkpointer.save(self._train_step)

        # create directory to save policy
        policy_dir = os.path.join(self._savedir, 'policy')
        tf_policy_saver = policy_saver.PolicySaver(self._agent.policy)
        # save the policy
        tf_policy_saver.save(policy_dir)
        print("model saved.")

        if self._visual_flag:
            # save the animation
            for epi in range(self._num_episodes):
                f = self._vizdir + str(epi) + "/eval-animation.mp4"
                writervideo = FFMpegWriter(fps=33)
                self._ani[epi].save(f, writer=writervideo)
                plt.show()
            # try loading saved policy
            loaded_policy = tf.saved_model.load(policy_dir)
            eval_timestep = self._eval_env.reset()
            loaded_action = loaded_policy.action(eval_timestep)
            print("example policy: ", loaded_action)
예제 #13
0
    def testUpdateFromCheckpoint(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        checkpoint = tf.train.Checkpoint(policy=eager_py_policy._policy)
        manager = tf.train.CheckpointManager(checkpoint,
                                             directory=checkpoint_path,
                                             max_to_keep=None)

        eager_py_policy.update_from_checkpoint(manager.latest_checkpoint)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()))
예제 #14
0
    def testUniqueSignatures(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        network = q_network.QNetwork(
            input_tensor_spec=self._time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=self._time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        action_signature_names = [
            s.name for s in saver._signatures['action'].input_signature
        ]
        self.assertAllEqual(
            ['0/step_type', '0/reward', '0/discount', '0/observation'],
            action_signature_names)
        initial_state_signature_names = [
            s.name
            for s in saver._signatures['get_initial_state'].input_signature
        ]
        self.assertAllEqual(['batch_size'], initial_state_signature_names)
예제 #15
0
  def testRenamedSignatures(self):
    time_step_spec = self._time_step_spec._replace(
        observation=tensor_spec.BoundedTensorSpec(
            dtype=tf.float32, shape=(4,), minimum=-10.0, maximum=10.0))

    network = q_network.QNetwork(
        input_tensor_spec=time_step_spec.observation,
        action_spec=self._action_spec)

    policy = q_policy.QPolicy(
        time_step_spec=time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    train_step = common.create_variable('train_step', initial_value=7)
    saver = policy_saver.PolicySaver(
        policy, train_step=train_step, batch_size=None)
    action_signature_names = [
        s.name for s in saver._signatures['action'].input_signature
    ]
    self.assertAllEqual(
        ['0/step_type', '0/reward', '0/discount', '0/observation'],
        action_signature_names)
    initial_state_signature_names = [
        s.name for s in saver._signatures['get_initial_state'].input_signature
    ]
    self.assertAllEqual(['batch_size'], initial_state_signature_names)
예제 #16
0
def save_agent_policy():
    now = datetime.datetime.now()
    policy_dir = config.POLICIES_PATH + now.strftime("%m%d%Y-%H%M%S")
    os.mkdir(policy_dir)
    tf_policy_saver = policy_saver.PolicySaver(agent.policy)
    tf_policy_saver.save(policy_dir)
    print(">>>Policy saved in ", policy_dir)
    def testUpdateFromCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        eager_py_policy.update_from_checkpoint(checkpoint_path)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()),
                              check_types=False)
예제 #18
0
  def testRegisterFunction(self):
    if not common.has_eager_been_enabled():
      self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

    network = q_network.QNetwork(
        input_tensor_spec=self._time_step_spec.observation,
        action_spec=self._action_spec)

    policy = q_policy.QPolicy(
        time_step_spec=self._time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    saver = policy_saver.PolicySaver(policy, batch_size=None)
    saver.register_function('q_network', network,
                            self._time_step_spec.observation)

    path = os.path.join(self.get_temp_dir(), 'save_model')
    saver.save(path)
    reloaded = tf.compat.v2.saved_model.load(path)

    sample_input = self.evaluate(
        tensor_spec.sample_spec_nest(
            self._time_step_spec.observation, outer_dims=(3,)))
    expected_output, _ = network(sample_input)
    reloaded_output, _ = reloaded.q_network(sample_input)

    self.assertAllClose(expected_output, reloaded_output)
예제 #19
0
  def testCheckpointSave(self):
    network = q_network.QNetwork(
        input_tensor_spec=self._time_step_spec.observation,
        action_spec=self._action_spec)

    policy = q_policy.QPolicy(
        time_step_spec=self._time_step_spec,
        action_spec=self._action_spec,
        q_network=network)

    train_step = common.create_variable('train_step', initial_value=0)
    saver = policy_saver.PolicySaver(
        policy, train_step=train_step, batch_size=None)
    path = os.path.join(self.get_temp_dir(), 'save_model')

    self.evaluate(tf.compat.v1.global_variables_initializer())
    with self.cached_session():
      saver.save(path)
    checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
    with self.cached_session():
      saver.save_checkpoint(checkpoint_path)
    self.assertTrue(tf.compat.v2.io.gfile.exists(checkpoint_path))

    # Also test CheckpointOptions
    checkpoint2_path = os.path.join(self.get_temp_dir(), 'checkpoint2')
    options = tf.train.CheckpointOptions(
        experimental_io_device='/job:localhost')
    with self.cached_session():
      saver.save_checkpoint(checkpoint2_path, options=options)
    self.assertTrue(tf.compat.v2.io.gfile.exists(checkpoint2_path))
예제 #20
0
    def testDistributionNotImplemented(self):
        policy = PolicyNoDistribution()

        with self.assertRaisesRegex(NotImplementedError,
                                    '_distribution has not been implemented'):
            policy.distribution(
                ts.TimeStep(step_type=(),
                            reward=(),
                            discount=(),
                            observation=()))

        train_step = common.create_variable('train_step', initial_value=0)
        saver = policy_saver.PolicySaver(policy,
                                         train_step=train_step,
                                         batch_size=None)
        path = os.path.join(self.get_temp_dir(), 'save_model')

        self.evaluate(tf.compat.v1.global_variables_initializer())
        with self.cached_session():
            saver.save(path)

        reloaded = tf.compat.v2.saved_model.load(path)
        with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
                                     '_distribution has not been implemented'):
            self.evaluate(
                reloaded.distribution(
                    ts.TimeStep(step_type=(),
                                reward=(),
                                discount=(),
                                observation=())))
    def testSavedModel(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)
        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(self.time_step_spec,
                                                       rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        original_action = self.tf_policy.action(batched_sample_time_step)
        unbatched_original_action = nest_utils.unbatch_nested_tensors(
            original_action)
        original_action_np = tf.nest.map_structure(lambda t: t.numpy(),
                                                   unbatched_original_action)
        saved_policy_action = eager_py_policy.action(sample_time_step)

        tf.nest.assert_same_structure(saved_policy_action.action,
                                      self.action_spec)

        np.testing.assert_array_almost_equal(original_action_np.action,
                                             saved_policy_action.action)
예제 #22
0
    def testRenamedSignatures(self):
        if not tf.executing_eagerly():
            self.skipTest(
                'b/129079730: PolicySaver does not work in TF1.x yet')

        time_step_spec = self._time_step_spec._replace(
            observation=tensor_spec.BoundedTensorSpec(
                dtype=tf.float32, shape=(4, ), minimum=-10.0, maximum=10.0))

        network = q_network.QNetwork(
            input_tensor_spec=time_step_spec.observation,
            action_spec=self._action_spec)

        policy = q_policy.QPolicy(time_step_spec=time_step_spec,
                                  action_spec=self._action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        action_signature_names = [
            s.name for s in saver._signatures['action'].input_signature
        ]
        self.assertAllEqual(
            ['0/step_type', '0/reward', '0/discount', '0/observation'],
            action_signature_names)
        initial_state_signature_names = [
            s.name
            for s in saver._signatures['get_initial_state'].input_signature
        ]
        self.assertAllEqual(['batch_size'], initial_state_signature_names)
예제 #23
0
 def _build_saver(
     self, policy: tf_policy.TFPolicy
 ) -> Union[policy_saver.PolicySaver, async_policy_saver.AsyncPolicySaver]:
   saver = policy_saver.PolicySaver(
       policy, train_step=self._train_step, metadata=self._metadata)
   if self._async_saving:
     saver = async_policy_saver.AsyncPolicySaver(saver)
   return saver
예제 #24
0
    def testRegisterFunction(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x. Step is required in TF1.x')

        time_step_spec = ts.TimeStep(
            step_type=tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    name='st',
                                                    minimum=0,
                                                    maximum=2),
            reward=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                 shape=(),
                                                 name='reward',
                                                 minimum=0.0,
                                                 maximum=5.0),
            discount=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                   shape=(),
                                                   name='discount',
                                                   minimum=0.0,
                                                   maximum=1.0),
            observation=tensor_spec.BoundedTensorSpec(dtype=tf.float32,
                                                      shape=(4, ),
                                                      name='obs',
                                                      minimum=-10.0,
                                                      maximum=10.0))
        action_spec = tensor_spec.BoundedTensorSpec(dtype=tf.int32,
                                                    shape=(),
                                                    minimum=0,
                                                    maximum=10,
                                                    name='act_0')

        network = q_network.QNetwork(
            input_tensor_spec=time_step_spec.observation,
            action_spec=action_spec)

        policy = q_policy.QPolicy(time_step_spec=time_step_spec,
                                  action_spec=action_spec,
                                  q_network=network)

        saver = policy_saver.PolicySaver(policy, batch_size=None)
        async_saver = async_policy_saver.AsyncPolicySaver(saver)
        async_saver.register_function('q_network', network,
                                      time_step_spec.observation)

        path = os.path.join(self.get_temp_dir(), 'save_model')
        async_saver.save(path)
        async_saver.flush()
        async_saver.close()
        self.assertFalse(async_saver._save_thread.is_alive())
        reloaded = tf.compat.v2.saved_model.load(path)

        sample_input = self.evaluate(
            tensor_spec.sample_spec_nest(time_step_spec.observation,
                                         outer_dims=(3, )))
        expected_output, _ = network(sample_input)
        reloaded_output, _ = reloaded.q_network(sample_input)

        self.assertAllClose(expected_output, reloaded_output)
 def setUp(self):
     super(PolicyLoaderTest, self).setUp()
     self.root_dir = self.get_temp_dir()
     tf_observation_spec = tensor_spec.TensorSpec((), np.float32)
     tf_time_step_spec = ts.time_step_spec(tf_observation_spec)
     tf_action_spec = tensor_spec.BoundedTensorSpec((), np.float32, 0, 3)
     self.net = AddNet()
     self.policy = greedy_policy.GreedyPolicy(
         q_policy.QPolicy(tf_time_step_spec, tf_action_spec, self.net))
     self.train_step = common.create_variable('train_step', initial_value=0)
     self.saver = policy_saver.PolicySaver(self.policy,
                                           train_step=self.train_step)
예제 #26
0
 def testMixturePolicyDynamicBatchSize(self):
     context_dim = 35
     observation_spec = tensor_spec.TensorSpec([context_dim], tf.float32)
     time_step_spec = ts.time_step_spec(observation_spec)
     action_spec = tensor_spec.BoundedTensorSpec(shape=(),
                                                 dtype=tf.int32,
                                                 minimum=0,
                                                 maximum=9,
                                                 name='action')
     sub_policies = [
         ConstantPolicy(action_spec, time_step_spec, i) for i in range(10)
     ]
     weights = [0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.5, 0]
     dist = tfd.Categorical(probs=weights)
     policy = mixture_policy.MixturePolicy(dist, sub_policies)
     batch_size = tf.random.uniform(shape=(),
                                    minval=10,
                                    maxval=15,
                                    dtype=tf.int32)
     time_step = ts.TimeStep(
         tf.fill(tf.expand_dims(batch_size, axis=0),
                 ts.StepType.FIRST,
                 name='step_type'),
         tf.zeros(shape=[batch_size], dtype=tf.float32, name='reward'),
         tf.ones(shape=[batch_size], dtype=tf.float32, name='discount'),
         tf.reshape(tf.range(tf.cast(batch_size * context_dim,
                                     dtype=tf.float32),
                             dtype=tf.float32),
                    shape=[-1, context_dim],
                    name='observation'))
     action_step = policy.action(time_step)
     actions, bsize = self.evaluate([action_step.action, batch_size])
     self.assertAllEqual(actions.shape, [bsize])
     self.assertAllInSet(actions, [2, 5, 8])
     saver = policy_saver.PolicySaver(policy)
     location = os.path.join(self.get_temp_dir(), 'saved_policy')
     saver.save(location)
     loaded_policy = tf.compat.v2.saved_model.load(location)
     new_batch_size = 3
     new_time_step = ts.TimeStep(
         tf.fill(tf.expand_dims(new_batch_size, axis=0),
                 ts.StepType.FIRST,
                 name='step_type'),
         tf.zeros(shape=[new_batch_size], dtype=tf.float32, name='reward'),
         tf.ones(shape=[new_batch_size], dtype=tf.float32, name='discount'),
         tf.reshape(tf.range(tf.cast(new_batch_size * context_dim,
                                     dtype=tf.float32),
                             dtype=tf.float32),
                    shape=[-1, context_dim],
                    name='observation'))
     new_action = self.evaluate(loaded_policy.action(new_time_step).action)
     self.assertAllEqual(new_action.shape, [new_batch_size])
     self.assertAllInSet(new_action, [2, 5, 8])
예제 #27
0
    def initiate(self):
        self._checkpoint_dir = os.path.join(self._save_dir, 'checkpoint')
        self._train_checkpointer = common.Checkpointer(
            ckpt_dir=self._checkpoint_dir,
            max_to_keep=1,
            agent=self._agent,
            policy=self._agent.policy,
            replay_buffer=self._replay_buffer,
            global_step=self._global_step)

        self._policy_dir = os.path.join(self._save_dir, 'policy')
        self._tf_policy_saver = policy_saver.PolicySaver(self._agent.policy)
예제 #28
0
  def testTrainStepSaved(self):
    # We need to use one default session so that self.evaluate and the
    # SavedModel loader share the same session.
    with tf.compat.v1.Session().as_default():
      network = q_network.QNetwork(
          input_tensor_spec=self._time_step_spec.observation,
          action_spec=self._action_spec)

      policy = q_policy.QPolicy(
          time_step_spec=self._time_step_spec,
          action_spec=self._action_spec,
          q_network=network)
      self.evaluate(tf.compat.v1.initializers.variables(policy.variables()))

      train_step = common.create_variable('train_step', initial_value=7)
      self.evaluate(tf.compat.v1.initializers.variables([train_step]))

      saver = policy_saver.PolicySaver(
          policy, batch_size=None, train_step=train_step)
      if tf.executing_eagerly():
        step = saver.get_train_step()
      else:
        step = self.evaluate(saver.get_train_step())
      self.assertEqual(7, step)
      path = os.path.join(self.get_temp_dir(), 'save_model')
      saver.save(path)

      reloaded = tf.compat.v2.saved_model.load(path)
      self.assertIn('get_train_step', reloaded.signatures)
      self.evaluate(tf.compat.v1.global_variables_initializer())
      train_step_value = self.evaluate(reloaded.get_train_step())
      self.assertEqual(7, train_step_value)
      train_step = train_step.assign_add(3)
      self.evaluate(train_step)
      saver.save(path)

      reloaded = tf.compat.v2.saved_model.load(path)
      self.evaluate(tf.compat.v1.global_variables_initializer())
      train_step_value = self.evaluate(reloaded.get_train_step())
      self.assertEqual(10, train_step_value)

      # Also test passing SaveOptions.
      train_step = train_step.assign_add(3)
      self.evaluate(train_step)
      path2 = os.path.join(self.get_temp_dir(), 'save_model2')
      saver.save(
          path2,
          options=tf.saved_model.SaveOptions(
              experimental_io_device='/job:localhost'))
      reloaded = tf.compat.v2.saved_model.load(path2)
      self.evaluate(tf.compat.v1.global_variables_initializer())
      train_step_value = self.evaluate(reloaded.get_train_step())
      self.assertEqual(13, train_step_value)
예제 #29
0
    def policy_saver(self, reward=0, iteration=0):
        """Save the policy in the directory path defined by the
        mehtod _save_points_dir from the current class"""

        save_dir = self._save_points_dir() + '_{}_{}'.format(
            str(reward), str(iteration))
        print(save_dir)
        tf_policy_saver = policy_saver.PolicySaver(self.agent.policy)
        tf_policy_saver.save(save_dir)
        print('\nThe policy of the deep Q network is saved in: "{}"'.format(
            save_dir))
        return save_dir
예제 #30
0
  def testSaver(self):
    policy = categorical_q_policy.CategoricalQPolicy(
        self._time_step_spec, self._action_spec, self._q_network,
        self._min_q_value, self._max_q_value)

    saver = policy_saver.PolicySaver(policy)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.evaluate(tf.compat.v1.local_variables_initializer())

    save_path = os.path.join(flags.FLAGS.test_tmpdir,
                             'saved_categorical_q_policy')
    saver.save(save_path)