def testContinuous(self):
     with self.test_session():
         step = 5
         decayed_lr = learning_rate_decay.exponential_decay(
             0.05, step, 10, 0.96)
         expected = .05 * 0.96**(5.0 / 10.0)
         self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
Esempio n. 2
0
def get_learning_rate(lr_config, global_step):
    """
    Instantiate a learning rate operation given a configuration with learning rate, name, and parameters.
    :param lr_config:  learning rate configuration
    :param global_step: global step `Tensor`
    :return: learning rate operation
    """
    lr = lr_config.rate
    name = lr_config.name
    if "exponential_decay" == name:
        decay = exponential_decay(lr, global_step, **lr_config.params)
    elif "inverse_time_decay" == name:
        decay = inverse_time_decay(lr, global_step, **lr_config.params)
    elif "vaswani" == name:
        decay = _transformer_learning_rate(lr_config, global_step)
    elif "bert" == name:
        decay = _bert_learning_rate(lr_config, global_step)
    elif "clr" == name:
        decay = cyclic_learning_rate(global_step,
                                     learning_rate=lr_config.rate,
                                     max_lr=lr_config.params.get(
                                         'max_lr', 0.1),
                                     step_size=lr_config.steps_per_epoch *
                                     lr_config.params.get('step_size', 4))
    else:
        raise ValueError("Unknown learning rate schedule: {}".format(name))
    return decay
Esempio n. 3
0
 def testContinuous(self):
     self.evaluate(variables.global_variables_initializer())
     step = 5
     decayed_lr = learning_rate_decay.exponential_decay(
         0.05, step, 10, 0.96)
     expected = .05 * 0.96**(5.0 / 10.0)
     self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
 def testStaircase(self):
     with self.test_session():
         step = state_ops.variable_op([], dtypes.int32)
         assign_100 = state_ops.assign(step, 100)
         assign_1 = state_ops.assign(step, 1)
         assign_2 = state_ops.assign(step, 2)
         decayed_lr = learning_rate_decay.exponential_decay(0.1, step, 3, 0.96, staircase=True)
         # No change to learning rate
         assign_1.op.run()
         self.assertAllClose(decayed_lr.eval(), 0.1, 1e-6)
         assign_2.op.run()
         self.assertAllClose(decayed_lr.eval(), 0.1, 1e-6)
         # Decayed learning rate
         assign_100.op.run()
         expected = 0.1 * 0.96 ** (100 // 3)
         self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def testVariables(self):
   step = variables.VariableV1(1)
   assign_1 = step.assign(1)
   assign_2 = step.assign(2)
   assign_100 = step.assign(100)
   decayed_lr = learning_rate_decay.exponential_decay(
       .1, step, 3, 0.96, staircase=True)
   self.evaluate(variables.global_variables_initializer())
   # No change to learning rate
   self.evaluate(assign_1.op)
   self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
   self.evaluate(assign_2.op)
   self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
   # Decayed learning rate
   self.evaluate(assign_100.op)
   expected = .1 * 0.96**(100 // 3)
   self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
 def testVariables(self):
     with self.test_session():
         step = variables.Variable(1)
         assign_1 = step.assign(1)
         assign_2 = step.assign(2)
         assign_100 = step.assign(100)
         decayed_lr = learning_rate_decay.exponential_decay(0.1, step, 3, 0.96, staircase=True)
         variables.initialize_all_variables().run()
         # No change to learning rate
         assign_1.op.run()
         self.assertAllClose(decayed_lr.eval(), 0.1, 1e-6)
         assign_2.op.run()
         self.assertAllClose(decayed_lr.eval(), 0.1, 1e-6)
         # Decayed learning rate
         assign_100.op.run()
         expected = 0.1 * 0.96 ** (100 // 3)
         self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def testVariables(self):
   step = variables.VariableV1(1)
   assign_1 = step.assign(1)
   assign_2 = step.assign(2)
   assign_100 = step.assign(100)
   decayed_lr = learning_rate_decay.exponential_decay(
       .1, step, 3, 0.96, staircase=True)
   self.evaluate(variables.global_variables_initializer())
   # No change to learning rate
   self.evaluate(assign_1.op)
   self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
   self.evaluate(assign_2.op)
   self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)
   # Decayed learning rate
   self.evaluate(assign_100.op)
   expected = .1 * 0.96**(100 // 3)
   self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
Esempio n. 8
0
 def testStaircase(self):
   with self.test_session():
     step = state_ops.variable_op([], dtypes.int32)
     assign_100 = state_ops.assign(step, 100)
     assign_1 = state_ops.assign(step, 1)
     assign_2 = state_ops.assign(step, 2)
     decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
                                                        staircase=True)
     # No change to learning rate
     assign_1.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     assign_2.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     # Decayed learning rate
     assign_100.op.run()
     expected = .1 * 0.96**(100 // 3)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def testVariables(self):
   with self.test_session():
     step = variables.Variable(1)
     assign_1 = step.assign(1)
     assign_2 = step.assign(2)
     assign_100 = step.assign(100)
     decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
                                                        staircase=True)
     variables.global_variables_initializer().run()
     # No change to learning rate
     assign_1.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     assign_2.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     # Decayed learning rate
     assign_100.op.run()
     expected = .1 * 0.96 ** (100 // 3)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def testStaircase(self):
   with self.test_session():
     step = gen_state_ops._variable(shape=[], dtype=dtypes.int32,
         name="step", container="", shared_name="")
     assign_100 = state_ops.assign(step, 100)
     assign_1 = state_ops.assign(step, 1)
     assign_2 = state_ops.assign(step, 2)
     decayed_lr = learning_rate_decay.exponential_decay(.1, step, 3, 0.96,
                                                        staircase=True)
     # No change to learning rate
     assign_1.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     assign_2.op.run()
     self.assertAllClose(decayed_lr.eval(), .1, 1e-6)
     # Decayed learning rate
     assign_100.op.run()
     expected = .1 * 0.96 ** (100 // 3)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
  def testStaircase(self):
    if context.executing_eagerly():
      step = resource_variable_ops.ResourceVariable(0)
      self.evaluate(variables.global_variables_initializer())
      decayed_lr = learning_rate_decay.exponential_decay(
          .1, step, 3, 0.96, staircase=True)

      # No change to learning rate due to staircase
      expected = .1
      self.evaluate(step.assign(1))
      self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)

      expected = .1
      self.evaluate(step.assign(2))
      self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)

      # Decayed learning rate
      expected = .1 * 0.96 ** (100 // 3)
      self.evaluate(step.assign(100))
      self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
  def testStaircase(self):
    if context.executing_eagerly():
      step = resource_variable_ops.ResourceVariable(0)
      self.evaluate(variables.global_variables_initializer())
      decayed_lr = learning_rate_decay.exponential_decay(
          .1, step, 3, 0.96, staircase=True)

      # No change to learning rate due to staircase
      expected = .1
      self.evaluate(step.assign(1))
      self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)

      expected = .1
      self.evaluate(step.assign(2))
      self.assertAllClose(self.evaluate(decayed_lr), .1, 1e-6)

      # Decayed learning rate
      expected = .1 * 0.96 ** (100 // 3)
      self.evaluate(step.assign(100))
      self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
Esempio n. 13
0
def apply_lr_decay(cfg, global_step):
    # Learning rate schedule
    if cfg.lr_decay is None:
        lr = cfg.lr
    elif cfg.lr_decay == 'exp':
        lr = exponential_decay(cfg.lr,
                               global_step,
                               cfg.decay_steps,
                               cfg.decay_rate,
                               staircase=cfg.staircase)
    elif cfg.lr_decay == 'piecewise':
        lr = piecewise_constant(global_step, cfg.lr_boundaries, cfg.lr_values)
    elif cfg.lr_decay == 'polynomial':
        lr = polynomial_decay(cfg.lr,
                              global_step,
                              cfg.decay_steps,
                              end_learning_rate=cfg.end_lr,
                              power=cfg.power,
                              cycle=cfg.staircase)

    elif cfg.lr_decay == 'natural_exp':
        lr = natural_exp_decay(cfg.lr,
                               global_step,
                               cfg.decay_steps,
                               cfg.decay_rate,
                               staircase=cfg.staircase)
    elif cfg.lr_decay == 'inverse_time':
        lr = inverse_time_decay(cfg.lr,
                                global_step,
                                cfg.decay_steps,
                                cfg.decay_rate,
                                staircase=cfg.staircase)

    elif cfg.lr_decay == 'STN':
        epoch = tf.cast(global_step / cfg.decay_steps, tf.int32)
        lr = cfg.lr * tf.pow(0.5, tf.cast(epoch / 50, cfg._FLOATX))
    else:
        raise NotImplementedError()
    return lr
Esempio n. 14
0
 def f2():
     return learning_rate_decay.exponential_decay(
         lr_init, lr_gstep, decay_steps, lr_dec, True)
 def testContinuous(self):
   self.evaluate(variables.global_variables_initializer())
   step = 5
   decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
   expected = .05 * 0.96**(5.0 / 10.0)
   self.assertAllClose(self.evaluate(decayed_lr), expected, 1e-6)
 def f2():
   return learning_rate_decay.exponential_decay(lr_init, lr_gstep,
                                                decay_steps, lr_dec, True)
Esempio n. 17
0
    def test_ppo_ops_gae(self):
        ops.reset_default_graph()
        np.random.seed(42)
        random_seed.set_random_seed(42)
        env = gym.make('CartPole-v0')
        env.seed(42)

        # Setup the policy and model
        global_step = training_util.get_or_create_global_step()
        deterministic_ph = array_ops.placeholder(dtypes.bool, [],
                                                 name='deterministic')
        exploration_op = learning_rate_decay.exponential_decay(
            PPOTest.hparams.initial_exploration, global_step,
            PPOTest.hparams.exploration_decay_steps,
            PPOTest.hparams.exploration_decay_rate)

        state_distribution, state_ph = gym_ops.distribution_from_gym_space(
            env.observation_space, name='state_space')

        # values
        with variable_scope.variable_scope('logits'):
            body_op = mlp(state_ph, PPOTest.hparams.hidden_layers)
            action_distribution, action_value_op = gym_ops.distribution_from_gym_space(
                env.action_space, logits=[body_op], name='action_space')
            action_op = array_ops.squeeze(
                sampling_ops.epsilon_greedy(action_distribution,
                                            exploration_op, deterministic_ph))
            body_op = core.dense(body_op,
                                 units=PPOTest.hparams.value_units,
                                 activation=nn_ops.relu,
                                 use_bias=False)
            value_op = array_ops.squeeze(
                core.dense(body_op, units=1, use_bias=False), -1)
        policy_variables = variables.trainable_variables(scope='logits')

        # target
        with variable_scope.variable_scope('old_logits'):
            old_body_op = mlp(state_ph, PPOTest.hparams.hidden_layers)
            old_action_distribution, old_action_value_op = gym_ops.distribution_from_gym_space(
                env.action_space, logits=[old_body_op], name='action_space')
        assign_policy_op = shortcuts.assign_scope('logits', 'old_logits')

        # Setup the dataset
        stream = streams.Uniform.from_distributions(state_distribution,
                                                    action_distribution,
                                                    with_values=True)
        replay_dataset = dataset.ReplayDataset(
            stream, max_sequence_length=PPOTest.hparams.max_sequence_length)
        replay_dataset = replay_dataset.batch(PPOTest.hparams.batch_size)
        replay_op = replay_dataset.make_one_shot_iterator().get_next()

        action_ph = array_ops.placeholder(stream.action_dtype,
                                          [None, None] + stream.action_shape,
                                          name='action')
        value_ph = array_ops.placeholder(stream.reward_dtype,
                                         [None, None] + stream.reward_shape,
                                         name='value')
        reward_ph = array_ops.placeholder(stream.reward_dtype,
                                          [None, None] + stream.reward_shape,
                                          name='reward')
        terminal_ph = array_ops.placeholder(dtypes.bool, [None, None],
                                            name='terminal')
        sequence_length_ph = array_ops.placeholder(dtypes.int32, [None, 1],
                                                   name='sequence_length')
        sequence_length = array_ops.squeeze(sequence_length_ph, -1)

        # Setup the loss/optimization procedure
        advantage_op, return_op = ppo_ops.generalized_advantage_estimate(
            reward_ph,
            value_ph,
            sequence_length,
            max_sequence_length=PPOTest.hparams.max_sequence_length,
            weights=(1 - math_ops.cast(terminal_ph, reward_ph.dtype)),
            discount=PPOTest.hparams.discount,
            lambda_td=PPOTest.hparams.lambda_td)

        # actor loss
        logits_prob = action_distribution.log_prob(action_ph)
        old_logits_prob = old_action_distribution.log_prob(action_ph)
        ratio = math_ops.exp(logits_prob - old_logits_prob)
        clipped_ratio = clip_ops.clip_by_value(ratio,
                                               1. - PPOTest.hparams.epsilon,
                                               1. + PPOTest.hparams.epsilon)
        actor_loss_op = -math_ops.minimum(ratio * advantage_op,
                                          clipped_ratio * advantage_op)
        critic_loss_op = math_ops.square(
            value_op - return_op) * PPOTest.hparams.value_coeff
        entropy_loss_op = -action_distribution.entropy(
            name='entropy') * PPOTest.hparams.entropy_coeff
        loss_op = actor_loss_op + critic_loss_op + entropy_loss_op

        # total loss
        loss_op = math_ops.reduce_mean(
            math_ops.reduce_sum(loss_op, axis=-1) /
            math_ops.cast(sequence_length, loss_op.dtype))

        optimizer = adam.AdamOptimizer(
            learning_rate=PPOTest.hparams.learning_rate)
        train_op = optimizer.minimize(loss_op, var_list=policy_variables)
        train_op = control_flow_ops.cond(
            gen_math_ops.equal(
                gen_math_ops.mod(
                    ops.convert_to_tensor(PPOTest.hparams.assign_policy_steps,
                                          dtype=dtypes.int64),
                    (global_step + 1)), 0),
            lambda: control_flow_ops.group(*[train_op, assign_policy_op]),
            lambda: train_op)

        with self.test_session() as sess:
            sess.run(variables.global_variables_initializer())
            sess.run(assign_policy_op)

            for iteration in range(PPOTest.hparams.num_iterations):
                rewards = gym_test_utils.rollout_with_values_on_gym_env(
                    sess,
                    env,
                    state_ph,
                    deterministic_ph,
                    action_value_op,
                    action_op,
                    value_op,
                    num_episodes=PPOTest.hparams.num_episodes,
                    stream=stream)

                while True:
                    try:
                        replay = sess.run(replay_op)
                    except (errors_impl.InvalidArgumentError,
                            errors_impl.OutOfRangeError):
                        break

                    _, loss = sess.run(
                        (train_op, loss_op),
                        feed_dict={
                            state_ph: replay.state,
                            action_ph: replay.action,
                            value_ph: replay.value,
                            reward_ph: replay.reward,
                            terminal_ph: replay.terminal,
                            sequence_length_ph: replay.sequence_length,
                        })
                    print(loss)

                rewards = gym_test_utils.rollout_on_gym_env(
                    sess,
                    env,
                    state_ph,
                    deterministic_ph,
                    action_value_op,
                    action_op,
                    num_episodes=PPOTest.hparams.num_episodes,
                    deterministic=True,
                    save_replay=False)
                print('average_rewards = {}'.format(
                    rewards / PPOTest.hparams.num_episodes))
Esempio n. 18
0
  def test_q_ops_dqn(self):
    ops.reset_default_graph()
    np.random.seed(42)
    random_seed.set_random_seed(42)
    env = gym.make('CartPole-v0')
    env.seed(42)

    # Setup the policy and model
    global_step = training_util.get_or_create_global_step()
    deterministic_ph = array_ops.placeholder(
        dtypes.bool, [], name='deterministic')
    exploration_op = learning_rate_decay.exponential_decay(
        QTest.hparams.initial_exploration,
        global_step,
        QTest.hparams.exploration_decay_steps,
        QTest.hparams.exploration_decay_rate)


    state_distribution, state_ph = gym_ops.distribution_from_gym_space(
        env.observation_space, name='state_space')
    with variable_scope.variable_scope('logits'):
      action_value_op = mlp(state_ph, QTest.hparams.hidden_layers)
      action_distribution, action_value_op = gym_ops.distribution_from_gym_space(
          env.action_space, logits=[action_value_op], name='action_space')
      action_op = array_ops.squeeze(sampling_ops.epsilon_greedy(
          action_distribution, exploration_op, deterministic_ph))
    policy_variables = variables.trainable_variables(scope='logits')


    next_state_ph = shortcuts.placeholder_like(state_ph, name='next_state_space')
    with variable_scope.variable_scope('logits', reuse=True):
      next_action_value_op = mlp(next_state_ph, QTest.hparams.hidden_layers)
      next_action_distribution, next_action_value_op = gym_ops.distribution_from_gym_space(
          env.action_space, logits=[next_action_value_op], name='action_space')
      next_action_op = array_ops.squeeze(sampling_ops.epsilon_greedy(
          next_action_distribution, exploration_op, deterministic_ph))


    # Setup the dataset
    stream = streams.Uniform.from_distributions(
        state_distribution, action_distribution)
    replay_dataset = dataset.ReplayDataset(
        stream, max_sequence_length=QTest.hparams.max_sequence_length)
    replay_dataset = replay_dataset.batch(QTest.hparams.batch_size)
    replay_op = replay_dataset.make_one_shot_iterator().get_next()

    action_ph = array_ops.placeholder(
        stream.action_dtype, [None, None] + stream.action_shape, name='action')
    reward_ph = array_ops.placeholder(
        stream.reward_dtype, [None, None] + stream.reward_shape, name='reward')
    terminal_ph = array_ops.placeholder(
        dtypes.bool, [None, None], name='terminal')
    sequence_length_ph = array_ops.placeholder(
        dtypes.int32, [None, 1], name='sequence_length')
    sequence_length = array_ops.squeeze(sequence_length_ph, -1)

    q_value_op, expected_q_value_op = q_ops.expected_q_value(
        reward_ph,
        action_ph,
        action_value_op,
        next_action_value_op,
        weights=(1 - math_ops.cast(terminal_ph, reward_ph.dtype)),
        discount=QTest.hparams.discount)

    # mean_squared_error
    loss_op = math_ops.square(q_value_op - expected_q_value_op)

    loss_op = math_ops.reduce_mean(
        math_ops.reduce_sum(loss_op, axis=-1) / math_ops.cast(
            sequence_length, loss_op.dtype))
    optimizer = adam.AdamOptimizer(
        learning_rate=QTest.hparams.learning_rate)
    train_op = optimizer.minimize(loss_op, var_list=policy_variables)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      for iteration in range(QTest.hparams.num_iterations):
        rewards = gym_test_utils.rollout_on_gym_env(
            sess, env, state_ph, deterministic_ph,
            action_value_op, action_op,
            num_episodes=QTest.hparams.num_episodes,
            stream=stream)

        while True:
          try:
            replay = sess.run(replay_op)
          except (errors_impl.InvalidArgumentError, errors_impl.OutOfRangeError):
            break
          _, loss = sess.run(
              (train_op, loss_op),
              feed_dict={
                state_ph: replay.state,
                next_state_ph: replay.next_state,
                action_ph: replay.action,
                reward_ph: replay.reward,
                terminal_ph: replay.terminal,
                sequence_length_ph: replay.sequence_length,
              })

        rewards = gym_test_utils.rollout_on_gym_env(
            sess, env, state_ph, deterministic_ph,
            action_value_op, action_op,
            num_episodes=QTest.hparams.num_episodes,
            deterministic=True, save_replay=False)
        print('average_rewards = {}'.format(rewards / QTest.hparams.num_episodes))
 def testContinuous(self):
   with self.test_session():
     step = 5
     decayed_lr = learning_rate_decay.exponential_decay(0.05, step, 10, 0.96)
     expected = .05 * 0.96 ** (5.0 / 10.0)
     self.assertAllClose(decayed_lr.eval(), expected, 1e-6)
 def make_opt():
   gstep = training_util.get_or_create_global_step()
   lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
   return training.GradientDescentOptimizer(lr)
 def make_opt():
     gstep = training_util.get_or_create_global_step()
     lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
     return training.GradientDescentOptimizer(lr)
Esempio n. 22
0
  def test_q_ops_quantile_dqn(self):
    env = gym.make('CartPole-v0')
    ops.reset_default_graph()
    np.random.seed(42)
    random_seed.set_random_seed(42)
    env.seed(42)

    # Setup the policy and model
    global_step = training_util.get_or_create_global_step()
    deterministic_ph = array_ops.placeholder(
        dtypes.bool, [], name='deterministic')
    exploration_op = learning_rate_decay.exponential_decay(
        QTest.hparams.initial_exploration,
        global_step,
        QTest.hparams.exploration_decay_steps,
        QTest.hparams.exploration_decay_rate)

    state_distribution, state_ph = gym_ops.distribution_from_gym_space(
        env.observation_space, name='state_space')
    action_distribution, _ = gym_ops.distribution_from_gym_space(
        env.action_space, name='action_space')

    # Setup the dataset
    stream = streams.Uniform.from_distributions(
        state_distribution, action_distribution)

    with variable_scope.variable_scope('logits'):
      action_value_op = mlp(state_ph, QTest.hparams.hidden_layers)
      action_value_op = core.dense(
          action_value_op,
          stream.action_value_shape[-1] * QTest.hparams.num_quantiles,
          use_bias=False)
      action_value_op_shape = array_ops.shape(action_value_op)
      action_value_shape = [
          action_value_op_shape[0],
          action_value_op_shape[1],
          stream.action_value_shape[-1],
          QTest.hparams.num_quantiles]
      action_value_op = gen_array_ops.reshape(action_value_op, action_value_shape)
      mean_action_value_op = math_ops.reduce_mean(action_value_op, axis=-1)
      action_op = math_ops.argmax(mean_action_value_op, axis=-1)
      action_op = array_ops.squeeze(action_op)
    policy_variables = variables.trainable_variables(scope='logits')


    next_state_ph = shortcuts.placeholder_like(state_ph, name='next_state_space')
    with variable_scope.variable_scope('targets'):
      target_next_action_value_op = mlp(next_state_ph, QTest.hparams.hidden_layers)
      target_next_action_value_op = core.dense(
          target_next_action_value_op,
          stream.action_value_shape[-1] * QTest.hparams.num_quantiles,
          use_bias=False)
      target_next_action_value_op_shape = array_ops.shape(target_next_action_value_op)
      target_next_action_value_shape = [
          target_next_action_value_op_shape[0],
          target_next_action_value_op_shape[1],
          stream.action_value_shape[-1],
          QTest.hparams.num_quantiles]
      target_next_action_value_op = gen_array_ops.reshape(
          target_next_action_value_op, target_next_action_value_shape)
      mean_target_next_action_value_op = math_ops.reduce_mean(
          target_next_action_value_op, axis=-1)
    assign_target_op = shortcuts.assign_scope('logits', 'target_logits')


    replay_dataset = dataset.ReplayDataset(
        stream, max_sequence_length=QTest.hparams.max_sequence_length)
    replay_dataset = replay_dataset.batch(QTest.hparams.batch_size)
    replay_op = replay_dataset.make_one_shot_iterator().get_next()


    action_ph = array_ops.placeholder(
        stream.action_dtype, [None, None] + stream.action_shape, name='action')
    reward_ph = array_ops.placeholder(
        stream.reward_dtype, [None, None] + stream.reward_shape, name='reward')
    terminal_ph = array_ops.placeholder(
        dtypes.bool, [None, None], name='terminal')
    sequence_length_ph = array_ops.placeholder(
        dtypes.int32, [None, 1], name='sequence_length')
    sequence_length = array_ops.squeeze(sequence_length_ph, -1)

    q_value_op, expected_q_value_op = q_ops.expected_q_value(
        array_ops.expand_dims(reward_ph, -1),
        action_ph,
        action_value_op,
        (target_next_action_value_op, mean_target_next_action_value_op),
        weights=array_ops.expand_dims(
            1 - math_ops.cast(terminal_ph, reward_ph.dtype), -1),
        discount=QTest.hparams.discount)

    u = expected_q_value_op - q_value_op
    loss_op = losses_impl.huber_loss(u, delta=QTest.hparams.huber_loss_delta)

    tau_op = (2. * math_ops.range(
        0, QTest.hparams.num_quantiles, dtype=u.dtype) + 1) / (
            2. * QTest.hparams.num_quantiles)

    loss_op *= math_ops.abs(tau_op - math_ops.cast(u < 0, tau_op.dtype))
    loss_op = math_ops.reduce_mean(loss_op, axis=-1)

    loss_op = math_ops.reduce_mean(
        math_ops.reduce_sum(loss_op, axis=-1) / math_ops.cast(
            sequence_length, loss_op.dtype))
    optimizer = adam.AdamOptimizer(
        learning_rate=QTest.hparams.learning_rate)
    train_op = optimizer.minimize(loss_op, var_list=policy_variables)
    train_op = control_flow_ops.cond(
        gen_math_ops.equal(
            gen_math_ops.mod(
                ops.convert_to_tensor(
                    QTest.hparams.assign_target_steps, dtype=dtypes.int64),
                (global_step + 1)), 0),
        lambda: control_flow_ops.group(*[train_op, assign_target_op]),
        lambda: train_op)

    with self.test_session() as sess:
      sess.run(variables.global_variables_initializer())
      sess.run(assign_target_op)

      for iteration in range(QTest.hparams.num_iterations):
        rewards = gym_test_utils.rollout_on_gym_env(
            sess, env, state_ph, deterministic_ph,
            mean_action_value_op, action_op,
            num_episodes=QTest.hparams.num_episodes,
            stream=stream)

        while True:
          try:
            replay = sess.run(replay_op)
          except (errors_impl.InvalidArgumentError, errors_impl.OutOfRangeError):
            break
          loss, _ = sess.run(
              (loss_op, train_op),
              feed_dict={
                state_ph: replay.state,
                next_state_ph: replay.next_state,
                action_ph: replay.action,
                reward_ph: replay.reward,
                terminal_ph: replay.terminal,
                sequence_length_ph: replay.sequence_length,
              })

        rewards = gym_test_utils.rollout_on_gym_env(
            sess, env, state_ph, deterministic_ph,
            mean_action_value_op, action_op,
            num_episodes=QTest.hparams.num_episodes,
            deterministic=True, save_replay=False)
        print('average_rewards = {}'.format(rewards / QTest.hparams.num_episodes))