Esempio n. 1
0
 def _init_global_step(self):
     self.global_step = training_util.get_or_create_global_step()
     self._training_ops.update({
         'increment_global_step': training_util._increment_global_step(1)
     })
     self._misc_training_ops.update({
         'increment_global_step': training_util._increment_global_step(1)
     })
Esempio n. 2
0
 def test_reads_before_increments(self):
   with ops.Graph().as_default():
     training_util.create_global_step()
     read_tensor = training_util._get_or_create_global_step_read()
     inc_op = training_util._increment_global_step(1)
     inc_three_op = training_util._increment_global_step(3)
     with monitored_session.MonitoredTrainingSession() as sess:
       read_value, _ = sess.run([read_tensor, inc_op])
       self.assertEqual(0, read_value)
       read_value, _ = sess.run([read_tensor, inc_three_op])
       self.assertEqual(1, read_value)
       read_value = sess.run(read_tensor)
       self.assertEqual(4, read_value)
Esempio n. 3
0
 def test_reads_before_increments(self):
     with ops.Graph().as_default():
         training_util.create_global_step()
         read_tensor = training_util._get_or_create_global_step_read()
         inc_op = training_util._increment_global_step(1)
         inc_three_op = training_util._increment_global_step(3)
         with monitored_session.MonitoredTrainingSession() as sess:
             read_value, _ = sess.run([read_tensor, inc_op])
             self.assertEqual(0, read_value)
             read_value, _ = sess.run([read_tensor, inc_three_op])
             self.assertEqual(1, read_value)
             read_value = sess.run(read_tensor)
             self.assertEqual(4, read_value)
Esempio n. 4
0
  def test_global_step_name(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      with variable_scope.variable_scope('bar'):
        variable_scope.get_variable(
            'foo',
            initializer=0,
            trainable=False,
            collections=[
                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
            ])
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2], summary_writer.summaries.keys())
      summary_value = summary_writer.summaries[2][0].value[0]
      self.assertEqual('bar/foo/sec', summary_value.tag)
Esempio n. 5
0
  def test_step_counter_every_n_secs(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
      for summary in summary_writer.summaries.values():
        summary_value = summary[0].value[0]
        self.assertEqual('global_step/sec', summary_value.tag)
        self.assertGreater(summary_value.simple_value, 0)
Esempio n. 6
0
 def test_step_counter_every_n_steps(self):
   with ops.Graph().as_default() as g, session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(1)
     summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
     hook = basic_session_run_hooks.StepCounterHook(
         summary_writer=summary_writer, every_n_steps=10)
     hook.begin()
     sess.run(variables_lib.global_variables_initializer())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         time.sleep(0.01)
         mon_sess.run(train_op)
       # logging.warning should not be called.
       self.assertIsNone(mock_log.call_args)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=self.log_dir,
         expected_graph=g,
         expected_summaries={})
     self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
     for step in [11, 21]:
       summary_value = summary_writer.summaries[step][0].value[0]
       self.assertEqual('global_step/sec', summary_value.tag)
       self.assertGreater(summary_value.simple_value, 0)
Esempio n. 7
0
  def test_two_listeners_with_default_saver(self):
    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      listener1 = MockCheckpointSaverListener()
      listener2 = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=1,
          listeners=[listener1, listener2])
      with monitored_session.SingularMonitoredSession(
          hooks=[hook],
          checkpoint_dir=self.model_dir) as sess:
        sess.run(train_op)
        sess.run(train_op)
        global_step_val = sess.raw_session().run(global_step)
      listener1_counts = listener1.get_counts()
      listener2_counts = listener2.get_counts()
    self.assertEqual(2, global_step_val)
    self.assertEqual({
        'begin': 1,
        'before_save': 2,
        'after_save': 2,
        'end': 1
    }, listener1_counts)
    self.assertEqual(listener1_counts, listener2_counts)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      with monitored_session.SingularMonitoredSession(
          checkpoint_dir=self.model_dir) as sess2:
        global_step_saved_val = sess2.run(global_step)
    self.assertEqual(2, global_step_saved_val)
 def setUp(self):
   self.model_dir = tempfile.mkdtemp()
   self.graph = ops.Graph()
   with self.graph.as_default():
     self.scaffold = monitored_session.Scaffold()
     self.global_step = variables.get_or_create_global_step()
     self.train_op = training_util._increment_global_step(1)
  def test_global_step_name(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      with variable_scope.variable_scope('bar'):
        variable_scope.get_variable(
            'foo',
            initializer=0,
            trainable=False,
            collections=[
                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
            ])
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2], summary_writer.summaries.keys())
      summary_value = summary_writer.summaries[2][0].value[0]
      self.assertEqual('bar/foo/sec', summary_value.tag)
  def test_step_counter_every_n_secs(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      time.sleep(0.2)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
      for summary in summary_writer.summaries.values():
        summary_value = summary[0].value[0]
        self.assertEqual('global_step/sec', summary_value.tag)
        self.assertGreater(summary_value.simple_value, 0)
  def test_two_listeners_with_default_saver(self):
    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      train_op = training_util._increment_global_step(1)
      listener1 = MockCheckpointSaverListener()
      listener2 = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=1,
          listeners=[listener1, listener2])
      with monitored_session.SingularMonitoredSession(
          hooks=[hook],
          checkpoint_dir=self.model_dir) as sess:
        sess.run(train_op)
        sess.run(train_op)
        global_step_val = sess.raw_session().run(global_step)
      listener1_counts = listener1.get_counts()
      listener2_counts = listener2.get_counts()
    self.assertEqual(2, global_step_val)
    self.assertEqual({
        'begin': 1,
        'before_save': 2,
        'after_save': 2,
        'end': 1
    }, listener1_counts)
    self.assertEqual(listener1_counts, listener2_counts)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      with monitored_session.SingularMonitoredSession(
          checkpoint_dir=self.model_dir) as sess2:
        global_step_saved_val = sess2.run(global_step)
    self.assertEqual(2, global_step_saved_val)
Esempio n. 12
0
 def setUp(self):
   self.model_dir = tempfile.mkdtemp()
   self.graph = ops.Graph()
   with self.graph.as_default():
     self.scaffold = monitored_session.Scaffold()
     self.global_step = variables.get_or_create_global_step()
     self.train_op = training_util._increment_global_step(1)
Esempio n. 13
0
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
    tensor = state_ops.assign_add(var, 1.0)
    self.summary_op = summary_lib.scalar('my_summary', tensor)

    with variable_scope.variable_scope('foo', use_resource=True):
      variables.create_global_step()
    self.train_op = training_util._increment_global_step(1)
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
    tensor = state_ops.assign_add(var, 1.0)
    self.summary_op = summary_lib.scalar('my_summary', tensor)

    with variable_scope.variable_scope('foo', use_resource=True):
      variables.create_global_step()
    self.train_op = training_util._increment_global_step(1)
Esempio n. 15
0
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variables_lib.Variable(0.0)
    tensor = state_ops.assign_add(var, 1.0)
    tensor2 = tensor * 2
    self.summary_op = summary_lib.scalar('my_summary', tensor)
    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)

    variables.get_or_create_global_step()
    self.train_op = training_util._increment_global_step(1)
  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variables_lib.Variable(0.0)
    tensor = state_ops.assign_add(var, 1.0)
    tensor2 = tensor * 2
    self.summary_op = summary_lib.scalar('my_summary', tensor)
    self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)

    variables.get_or_create_global_step()
    self.train_op = training_util._increment_global_step(1)
Esempio n. 17
0
 def test_log_warning_if_global_step_not_increased(self):
   with ops.Graph().as_default(), session_lib.Session() as sess:
     variables.get_or_create_global_step()
     train_op = training_util._increment_global_step(0)  # keep same.
     sess.run(variables_lib.global_variables_initializer())
     hook = basic_session_run_hooks.StepCounterHook(
         every_n_steps=1, every_n_secs=None)
     hook.begin()
     mon_sess = monitored_session._HookedSession(sess, [hook])
     mon_sess.run(train_op)  # Run one step to record global step.
     with test.mock.patch.object(tf_logging, 'warning') as mock_log:
       for _ in range(30):
         mon_sess.run(train_op)
       self.assertRegexpMatches(
           str(mock_log.call_args),
           'global step.*has not been increased')
     hook.end(sess)
Esempio n. 18
0
    def _build(self):

        self._training_ops = {}
        # place holder
        self.global_step = training_util.get_or_create_global_step()
        self._training_ops.update(
            {'increment_global_step': training_util._increment_global_step(1)})

        self._iteration_ph = tf.placeholder(tf.int64,
                                            shape=None,
                                            name='iteration')

        self._observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._observation_shape),
            name='observation',
        )

        self._next_observations_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._observation_shape),
            name='next_observation',
        )

        self._actions_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, *self._action_shape),
            name='actions',
        )

        self._prev_state_p_ph = tf.placeholder(
            tf.float32,
            shape=(None, self.gru_state_dim),
            name='prev_state_p',
        )
        self._prev_state_v_ph = tf.placeholder(
            tf.float32,
            shape=(None, self.gru_state_dim),
            name='prev_state_v',
        )

        self.seq_len = tf.placeholder(tf.float32, shape=[None], name="seq_len")

        self._rewards_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, 1),
            name='rewards',
        )

        self._terminals_ph = tf.placeholder(
            tf.float32,
            shape=(None, None, 1),
            name='terminals',
        )

        if self._store_extra_policy_info:
            self._log_pis_ph = tf.placeholder(
                tf.float32,
                shape=(None, None, 1),
                name='log_pis',
            )
            self._raw_actions_ph = tf.placeholder(
                tf.float32,
                shape=(None, None, *self._action_shape),
                name='raw_actions',
            )

        # inner functions
        LOG_STD_MAX = 2
        LOG_STD_MIN = -20
        EPS = 1e-8

        def mlp(x,
                hidden_sizes=(32, ),
                activation=tf.tanh,
                output_activation=None,
                kernel_initializer=None):
            print('[ DEBUG ], hidden layer size: ', hidden_sizes)
            for h in hidden_sizes[:-1]:
                x = tf.layers.dense(x,
                                    units=h,
                                    activation=activation,
                                    kernel_initializer=kernel_initializer)
            return tf.layers.dense(x,
                                   units=hidden_sizes[-1],
                                   activation=output_activation,
                                   kernel_initializer=kernel_initializer)

        def gaussian_likelihood(x, mu, log_std):
            pre_sum = -0.5 * (
                ((x - mu) /
                 (tf.exp(log_std) + EPS))**2 + 2 * log_std + np.log(2 * np.pi))
            return tf.reduce_sum(pre_sum, axis=-1)

        def apply_squashing_func(mu, pi, logp_pi):
            # Adjustment to log prob
            # NOTE: This formula is a little bit magic. To get an understanding of where it
            # comes from, check out the original SAC paper (arXiv 1801.01290) and look in
            # appendix C. This is a more numerically-stable equivalent to Eq 21.
            # Try deriving it yourself as a (very difficult) exercise. :)
            logp_pi -= tf.reduce_sum(
                2 * (np.log(2) - pi - tf.nn.softplus(-2 * pi)), axis=-1)
            # Squash those unbounded actions!
            mu = tf.tanh(mu)
            pi = tf.tanh(pi)
            return mu, pi, logp_pi

        def mlp_gaussian_policy(x, a, hidden_sizes, activation,
                                output_activation):
            print('[ DEBUG ]: output activation: ', output_activation,
                  ', activation: ', activation)
            act_dim = a.shape.as_list()[-1]
            net = mlp(x, list(hidden_sizes), activation, activation)
            mu = tf.layers.dense(net, act_dim, activation=output_activation)
            log_std = tf.layers.dense(net, act_dim, activation=None)
            log_std = tf.clip_by_value(log_std, LOG_STD_MIN, LOG_STD_MAX)
            std = tf.exp(log_std)
            pi = mu + tf.random_normal(tf.shape(mu)) * std
            logp_pi = gaussian_likelihood(pi, mu, log_std)
            return mu, pi, logp_pi, std

        def mlp_actor_critic(x,
                             x_v,
                             a,
                             hidden_sizes=(256, 256),
                             activation=tf.nn.relu,
                             output_activation=None,
                             policy=mlp_gaussian_policy):
            # policy
            with tf.variable_scope('pi'):
                mu, pi, logp_pi, std = policy(x, a, hidden_sizes, activation,
                                              output_activation)
                mu, pi, logp_pi = apply_squashing_func(mu, pi, logp_pi)

            # vfs
            vf_mlp = lambda x: tf.squeeze(
                mlp(x,
                    list(hidden_sizes) + [1], activation, None), axis=-1)

            with tf.variable_scope('q1'):
                q1 = vf_mlp(tf.concat([x_v, a], axis=-1))
            with tf.variable_scope('q2'):
                q2 = vf_mlp(tf.concat([x_v, a], axis=-1))
            return mu, pi, logp_pi, q1, q2, std

        policy_state1 = self._observations_ph
        value_state1 = self._observations_ph
        policy_state2 = value_state2 = self._next_observations_ph

        ac_kwargs = {
            "hidden_sizes": self.network_kwargs["hidden_sizes"],
            "activation": self.network_kwargs["activation"],
            "output_activation": self.network_kwargs["output_activation"]
        }

        with tf.variable_scope('main', reuse=False):
            self.mu, self.pi, logp_pi, q1, q2, std = mlp_actor_critic(
                policy_state1, value_state1, self._actions_ph, **ac_kwargs)

        pi_entropy = tf.reduce_sum(tf.log(std + 1e-8) +
                                   0.5 * tf.log(2 * np.pi * np.e),
                                   axis=-1)
        with tf.variable_scope('main', reuse=True):
            # compose q with pi, for pi-learning
            _, _, _, q1_pi, q2_pi, _ = mlp_actor_critic(
                policy_state1, value_state1, self.pi, **ac_kwargs)
            # get actions and log probs of actions for next states, for Q-learning
            _, pi_next, logp_pi_next, _, _, _ = mlp_actor_critic(
                policy_state2, value_state2, self._actions_ph, **ac_kwargs)

        with tf.variable_scope('target'):
            # target q values, using actions from *current* policy
            _, _, _, q1_targ, q2_targ, _ = mlp_actor_critic(
                policy_state2, value_state2, pi_next, **ac_kwargs)

        # actions = self._policy.actions([self._observations_ph])
        # log_pis = self._policy.log_pis([self._observations_ph], actions)
        # assert log_pis.shape.as_list() == [None, 1]

        # alpha optimizer
        log_alpha = self._log_alpha = tf.get_variable('log_alpha',
                                                      dtype=tf.float32,
                                                      initializer=0.0)
        alpha = tf.exp(log_alpha)

        self._alpha = alpha
        assert self._action_prior == 'uniform'
        policy_prior_log_probs = 0.0

        min_q_pi = tf.minimum(q1_pi, q2_pi)
        min_q_targ = tf.minimum(q1_targ, q2_targ)

        if self._reparameterize:
            policy_kl_losses = (tf.stop_gradient(alpha) * logp_pi - min_q_pi -
                                policy_prior_log_probs)
        else:
            raise NotImplementedError

        policy_loss = tf.reduce_mean(policy_kl_losses)

        # Q
        next_log_pis = logp_pi_next
        min_next_Q = min_q_targ
        next_value = min_next_Q - self._alpha * next_log_pis

        q_target = td_target(
            reward=self._reward_scale * self._rewards_ph[..., 0],
            discount=self._discount,
            next_value=(1 - self._terminals_ph[..., 0]) * next_value)

        print('q1_pi: {}, q2_pi: {}, policy_state2: {}, policy_state1: {}, '
              'tmux a: {}, q_targ: {}, mu: {}, reward: {}, '
              'terminal: {}, target_q: {}, next_value: {}, '
              'q1: {}, logp_pi: {}, min_q_pi: {}'.format(
                  q1_pi, q2_pi, policy_state1, policy_state2, pi_next, q1_targ,
                  self.mu, self._rewards_ph[..., 0], self._terminals_ph[...,
                                                                        0],
                  q_target, next_value, q1, logp_pi, min_q_pi))
        # assert q_target.shape.as_list() == [None, 1]
        # (self._Q_values,
        #  self._Q_losses,
        #  self._alpha,
        #  self.global_step),
        self.Q1 = q1
        self.Q2 = q2
        q_target = tf.stop_gradient(q_target)
        q1_loss = tf.losses.mean_squared_error(labels=q_target,
                                               predictions=q1,
                                               weights=0.5)
        q2_loss = tf.losses.mean_squared_error(labels=q_target,
                                               predictions=q2,
                                               weights=0.5)
        self.Q_loss = (q1_loss + q2_loss) / 2

        value_optimizer1 = tf.train.AdamOptimizer(learning_rate=self._Q_lr)
        value_optimizer2 = tf.train.AdamOptimizer(learning_rate=self._Q_lr)
        print('[ DEBUG ]: Q lr is {}'.format(self._Q_lr))

        # train_value_op = value_optimizer.apply_gradients(zip(grads, variables))
        pi_optimizer = tf.train.AdamOptimizer(learning_rate=self._policy_lr)
        print('[ DEBUG ]: policy lr is {}'.format(self._policy_lr))

        pi_var_list = get_vars('main/pi')
        if self.adapt:
            pi_var_list += get_vars("lstm_net_pi")
        train_pi_op = pi_optimizer.minimize(policy_loss, var_list=pi_var_list)
        pgrads, variables = zip(
            *pi_optimizer.compute_gradients(policy_loss, var_list=pi_var_list))

        _, pi_global_norm = tf.clip_by_global_norm(pgrads, 2000)

        with tf.control_dependencies([train_pi_op]):
            value_params1 = get_vars('main/q1')
            value_params2 = get_vars('main/q2')
            if self.adapt:
                value_params1 += get_vars("lstm_net_v")
                value_params2 += get_vars("lstm_net_v")

            grads, variables = zip(*value_optimizer1.compute_gradients(
                self.Q_loss, var_list=value_params1))
            _, q_global_norm = tf.clip_by_global_norm(grads, 2000)
            train_value_op1 = value_optimizer1.minimize(q1_loss,
                                                        var_list=value_params1)
            train_value_op2 = value_optimizer2.minimize(q2_loss,
                                                        var_list=value_params2)
            with tf.control_dependencies([train_value_op1, train_value_op2]):
                if isinstance(self._target_entropy, Number):
                    alpha_loss = -tf.reduce_mean(
                        log_alpha *
                        tf.stop_gradient(logp_pi + self._target_entropy))
                    self._alpha_optimizer = tf.train.AdamOptimizer(
                        self._policy_lr, name='alpha_optimizer')
                    self._alpha_train_op = self._alpha_optimizer.minimize(
                        loss=alpha_loss, var_list=[log_alpha])
                else:
                    self._alpha_train_op = tf.no_op()

        self.target_update = tf.group([
            tf.assign(v_targ, (1 - self._tau) * v_targ + self._tau * v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        self.target_init = tf.group([
            tf.assign(v_targ, v_main)
            for v_main, v_targ in zip(get_vars('main'), get_vars('target'))
        ])

        # construct opt
        self._training_ops = [
            tf.group((train_value_op2, train_value_op1, train_pi_op,
                      self._alpha_train_op)), {
                          "sac_pi/pi_global_norm": pi_global_norm,
                          "sac_Q/q_global_norm": q_global_norm,
                          "Q/q1_loss": q1_loss,
                          "sac_Q/q2_loss": q2_loss,
                          "sac_Q/q1": q1,
                          "sac_Q/q2": q2,
                          "sac_pi/alpha": alpha,
                          "sac_pi/pi_entropy": pi_entropy,
                          "sac_pi/logp_pi": logp_pi,
                          "sac_pi/std": logp_pi,
                      }
        ]

        self._session.run(tf.global_variables_initializer())
        self._session.run(self.target_init)
Esempio n. 19
0
def _encoder_model_fn(features, labels, mode, params=None, config=None):
    generator_layers = [64, 64, 128, 128]
    discriminator_layers = [64, 64, 128, 128]
    gen_depth_per_layer = params['generator_inner_layers']
    discr_depth_per_layer = params['discriminator_inner_layers']
    gen_inner_layers = [
        gen_depth_per_layer, gen_depth_per_layer, gen_depth_per_layer,
        gen_depth_per_layer
    ]
    discr_inner_layers = [
        discr_depth_per_layer, discr_depth_per_layer, discr_depth_per_layer,
        discr_depth_per_layer
    ]
    generator_layer_padding = ["SAME", "SAME", "SAME", "SAME"]
    a = features['i']
    z = features['z']
    b = None
    if (mode == tf.estimator.ModeKeys.TRAIN):
        b = features['j']
    dagan = DAGAN(batch_size=params['batch_size'],
                  input_x_i=a,
                  input_x_j=b,
                  dropout_rate=params['dropout_rate'],
                  generator_layer_sizes=generator_layers,
                  generator_layer_padding=generator_layer_padding,
                  num_channels=a.shape[3],
                  is_training=(mode == tf.estimator.ModeKeys.TRAIN),
                  augment=tf.constant(params['random_rotate'], dtype=tf.bool),
                  discriminator_layer_sizes=discriminator_layers,
                  discr_inner_conv=discr_inner_layers,
                  gen_inner_conv=gen_inner_layers,
                  z_dim=params['z_dim'],
                  z_inputs=z,
                  use_wide_connections=params['use_wide_connections'])

    if (mode == tf.estimator.ModeKeys.TRAIN):
        generated = None
        losses, graph_ops = dagan.init_train()
        accumulated_d_loss = tf.Variable(
            0.0, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
        acc_d_loss_op = accumulated_d_loss.assign_add(losses["d_losses"])
        acc_d_loss_zero_op = accumulated_d_loss.assign(
            tf.zeros_like(accumulated_d_loss))
        accumulated_g_loss = tf.Variable(
            0.0, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
        acc_g_loss_op = accumulated_g_loss.assign_add(losses["g_losses"])
        acc_g_loss_zero_op = accumulated_g_loss.assign(
            tf.zeros_like(accumulated_g_loss))
        reset = tf.group(acc_d_loss_zero_op, acc_g_loss_zero_op)
        d_train = tf.group(graph_ops['d_opt_op'], acc_d_loss_op)
        g_train = tf.group(graph_ops['g_opt_op'], acc_g_loss_op)
        train_hooks = [
            MultiStepOps(params['discriminator_inner_steps'],
                         params['generator_inner_steps'], d_train, g_train,
                         reset)
        ]
        total_loss = accumulated_d_loss + accumulated_g_loss
        train_op = training_util._increment_global_step(1)  # pylint: disable=protected-access
    else:
        _, generated = dagan.sample_same_images()
        output = tf.identity(generated, name='output')
        export_outputs = {
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            tf.estimator.export.PredictOutput(output)
        }
        total_loss = None
        train_op = None
        train_hooks = None

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=total_loss,
                                      predictions=generated,
                                      training_hooks=train_hooks,
                                      export_outputs=export_outputs,
                                      train_op=train_op)