Exemple #1
0
def _dense_projection(inputs, shape, trainable=True):
    shape = list(shape)
    flat_shape = array_utils.product(shape)
    target_shape = _EXTRA_DIMS + shape

    input_shape = inputs.get_shape().as_list()

    input_dims = len(input_shape)
    expected_dims = len(target_shape)
    assert_utils.assert_true(
        input_dims == expected_dims, ', '.join([
            '`inputs` must have the same number of dims as the number expected dims.'
            'expected = {}, actual = {}'.format(input_dims, expected_dims)
        ]))

    if not array_utils.all_equal(input_shape, target_shape):
        input_shape_ = array_ops.shape(inputs)
        if len(input_shape) > 3:
            inputs = gen_array_ops.reshape(
                inputs, [input_shape_[0], input_shape_[1], -1])
        inputs = core.dense(inputs,
                            flat_shape,
                            use_bias=False,
                            trainable=trainable)
        inputs = gen_array_ops.reshape(
            inputs, [input_shape_[0], input_shape_[1]] + shape)
    return inputs
Exemple #2
0
 def step(self, action):
   assert_utils.assert_true(
       self.action_space.contains(action),
       '`action_space` must contain `action`.')
   # sample from bernoulli distribution with p = bandit[action]
   reward = int(np.random.uniform() < self.bandit[action])
   self.time += 1
   done = self.time > self.max_time
   return [self.time], reward, done, {}
Exemple #3
0
def ssd(x, y, extra_dims=2):
    """`Sum Squared over D`: `l2` over `n`-dimensions (starting at `extra_dims`)

  Math:
    ssd(x, y) = sum [l2(x-y)] in range (extra_dims, dims(x|y)]
  """
    assert_utils.assert_true(
        extra_dims >= 0, "extra_dims must be >= 0, got {}".format(extra_dims))
    shape = y.get_shape().as_list()[extra_dims:]
    return math_ops.reduce_sum(math_ops.square(x - y),
                               axis=array_utils.ranged_axes(shape))
Exemple #4
0
 def __init__(self, max_time=5):
   self.time = 0
   self.max_time = max_time
   assert_utils.assert_true(
       len(probs) == 11,
       "if `probs` was meant to be a `list`/`tuple`, then it must be of size 2.")
   self.bandit = np.array([1.] * 11)
   self.informative_action = 0
   self.action_space = discrete.Discrete(len(self.bandit))
   self.observation_space = discrete.Discrete(1)
   self.seed()
Exemple #5
0
def expected_q_value(reward,
                     action,
                     action_value,
                     next_action_value,
                     weights=1.,
                     discount=.95):
    """Computes the expected q returns and values.

  This covers architectures such as DQN, Double-DQN, Dueling-DQN and Noisy-DQN.

  Arguments:
    rewards: 1D or 2D `tf.Tensor`, contiguous sequence(s) of rewards.
    action: 1D or 2D `tf.Tensor`, contiguous sequence(s) of actions.
    next_action_value: `tf.Tensor`, `list` or `tuple` of 2 `tf.Tensor`s, where the first entry is
        the `model(next_state) = action_value`, and the second is `target(next_state) = action_value`
    weights: `tf.Tensor`, the weights/mask to apply to the result.
    discount: 0D scalar, the discount factor (gamma).

  Returns:
    `tuple` containing the `q_value` `tf.Tensor` and `expected_q_value` `tf.Tensor`.

  Reference:
    https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
  """
    weights = ops.convert_to_tensor(weights, dtype=reward.dtype)
    discount = ops.convert_to_tensor(discount, dtype=reward.dtype)

    lda = action_value.get_shape()[-1].value
    q_value = gather_along_second_axis(action_value, action)
    q_value.set_shape([None, None, lda])

    if isinstance(next_action_value, tuple) or isinstance(
            next_action_value, list):
        assert_utils.assert_true(
            len(next_action_value) == 2,
            '`next_action_value` must be a `tuple` of length = 2')
        next_action_value, target_next_action_value = next_action_value
        lda = next_action_value.get_shape()[-1].value
        next_q_value = gather_along_second_axis(
            next_action_value,
            math_ops.argmax(target_next_action_value,
                            -1,
                            output_type=dtypes.int32))
        next_q_value.set_shape([None, None, lda])
    else:
        lda = next_action_value.get_shape()[-1].value
        next_q_value = gather_along_second_axis(
            next_action_value,
            math_ops.argmax(next_action_value, -1, output_type=dtypes.int32))
        next_q_value.set_shape([None, None, lda])

    expected_q_value = reward + discount * next_q_value * weights
    return (q_value, expected_q_value)
Exemple #6
0
  def step(self, action):
    assert_utils.assert_true(
        self.action_space.contains(action),
        '`action_space` must contain `action`.')

    if self.time == 0:
      if action == self.informative_action:
        reward = .55
      else:
        reward = 1.4
    else:
      reward = self.bandit[action]

    self.time += 1
    done = self.time > self.max_time
    return [self.time], reward, done, {}
Exemple #7
0
    def from_distributions(cls,
                           state_distribution,
                           action_distribution,
                           reward_shape=[],
                           reward_dtype=dtypes.float32,
                           with_values=False):
        """Construct a `alchemy.contrib.rl.ReplayStream` from a `gym.Env`.

    Arguments:
      env: a `gym.Env` instance that has `action_space` and `observation_space` properties.
      state_distribution: distribution of the state space.
      action_distribution: distribution of the action space.
      reward_shape: shape representing the reward for a chosen action.
      reward_dtype: dtype representing the reward for a chosen action.
      with_values: Python `bool` for recording values.

    Returns:
      A `ay.contrib.rl.ReplayStream`.
    """
        assert_utils.assert_true(
            distribution_utils.is_distribution(state_distribution),
            '`state_distribution` must be an instance of `tf.distributions.Distribution`'
        )
        assert_utils.assert_true(
            distribution_utils.is_distribution(action_distribution),
            '`action_distribution` must be an instance of `tf.distributions.Distribution`'
        )

        state_shape, state_dtype = distribution_utils.logits_shape_and_dtype(
            state_distribution)
        action_value_shape, action_value_dtype = distribution_utils.logits_shape_and_dtype(
            action_distribution)
        action_shape, action_dtype = distribution_utils.sample_shape_and_dtype(
            action_distribution)
        if isinstance(reward_shape, tensor_shape.TensorShape):
            reward_shape = reward_shape.as_list()

        return cls(state_shape,
                   state_dtype,
                   action_shape,
                   action_dtype,
                   action_value_shape,
                   action_value_dtype,
                   reward_shape,
                   type_utils.safe_tf_dtype(reward_dtype),
                   with_values=with_values)
Exemple #8
0
  def __init__(self, probs, max_time=99):
    self.time = 0
    self.max_time = max_time
    # independent case where probs are independent of each other
    # where p = p1 + p2, p >= 1 and > 0
    if isinstance(probs, list) or isinstance(probs, tuple):
      assert_utils.assert_true(
          len(probs) == 2,
          "if `probs` was meant to be a `list`/`tuple`, then it must be of size 2.")
      self.bandit = np.array([probs[0], probs[1]])
    # dependent bandits
    # where p = p1 + (1 - p1), p = 1
    else:
      self.bandit = np.array([probs, 1. - probs])

    self.action_space = discrete.Discrete(len(self.bandit))
    self.observation_space = discrete.Discrete(1)
    self.seed()
Exemple #9
0
    def create_example(self, k=8, use_modified=False):
        """Creates a single (modified) example of length `k`.

    Arguments:
      k: an even `int` that defines the length of the sample space. For example, if `k = 8` and the
          vocab contains `ATCG`, then a sample would look like this: (A9C5G1T3??C, 5).
      use_modified: `bool` that, when `True`, makes samples contiguous alpha-numeric. For example,
          when `k = 8` and the vocab contains `ATCG`, then a sample would look like this:
          (ACTG9513??C, 5). The label is the same as the unmodified version, but the sequence is no
          longer zipped.

    Returns:
      A tuple containing the one-hot encoded values from the vocab. For example, if `k = 8` and the
          vocab contains `ATCG`, then this would return onehot(A9C5G1T3??C, 5), where the
          length of the onehot encoding vectors = len('ACTG') + len([0...9]) + len('?')
          = `vocab_size`.
    """
        q, r = divmod(k, 2)
        assert_utils.assert_true(
            r == 0 and k > 1 and k < self._alphabet_size,
            "k must be even, > 1, and < {}".format(self._alphabet_size))

        letters = np.random.choice(range(0, self._chars_size),
                                   q,
                                   replace=False)
        numbers = np.random.choice(range(self._chars_size + 1,
                                         self._alphabet_size),
                                   q,
                                   replace=True)
        if use_modified:
            x = np.concatenate((letters, numbers))
        else:
            x = np.stack((letters, numbers)).T.ravel()

        x = np.append(x, [self._alphabet_size, self._alphabet_size])
        index = np.random.choice(range(0, q), 1, replace=False)
        x = np.append(x, [letters[index]]).astype('int')
        y = numbers[index]
        return (self._encoder[x], self._encoder[y][0])
Exemple #10
0
def epsilon_greedy(dist, epsilon, deterministic):
    """Compute the mode of the distribution if epsilon < X ~ U(0, 1), else sample.

  Arguments:
    dist: a `tf.distribution.Distribution` to sample/mode from.
    epsilon: scalar `tf.Tensor`.
    deterministic: `Boolean` or `tf.Tensor` boolean, if `True` the mode is always be chosen.

  Raises:
    `AssertionError` if dist is not a `tf.distribution.Distribution`.

  Returns:
    `tf.Tensor` shape of `dist.event_shape`.
  """
    assert_utils.assert_true(
        distribution_utils.is_distribution(dist),
        '`dist must be a `tf.distribution.Distribution.`')

    deterministic_sample = lambda: dist.mode()
    return control_flow_ops.cond(
        deterministic, deterministic_sample,
        lambda: control_flow_ops.cond(epsilon < random_ops.random_uniform([
        ]), deterministic_sample, lambda: dist.sample()))
Exemple #11
0
def ReplayDataset(replay_stream, max_sequence_length=200, name=None):
    """Creates a `tf.data.Dataset` from a `ay.contrib.rl.ReplayStream` instance.

  Arguments:
    replay_stream: `ay.contrib.rl.ReplayStream` instance. Must implement `replay_stream.read`.
        The method is called `replay_stream.read(limit=max_sequence_length)` each time an instance
        is requested by the dataset. This method should return `None` or raise an
        `tf.errors.OutOfRangeError` when the stream is done and execution of the dataset should stop.
        `replay_stream.read` should always return a `tf.SequenceExample` proto.

  Returns:
    A `tf.data.Dataset`.

  Raises:
    An `tf.errors.OutOfRangeError` when the stream returns a `None` or raises
        `tf.errors.OutOfRangeError`.
  """
    assert_utils.assert_true(
        isinstance(replay_stream, streams.ReplayStream),
        '`replay_stream` must be an instance of `ay.contrib.rl.ReplayStream`')

    with ops.name_scope(name or 'replay_dataset'):
        state_shape = list(replay_stream.state_shape)
        state_dtype = replay_stream.state_dtype
        action_shape = list(replay_stream.action_shape)
        action_dtype = replay_stream.action_dtype
        action_value_shape = list(replay_stream.action_value_shape)
        action_value_dtype = replay_stream.action_value_dtype
        reward_shape = list(replay_stream.reward_shape)
        reward_dtype = replay_stream.reward_dtype

        replay_dtypes = {
            'state': state_dtype,
            'next_state': state_dtype,
            'action': action_dtype,
            'action_value': action_value_dtype,
            'reward': reward_dtype,
            'terminal': dtypes.bool,
            'sequence_length': dtypes.int32,
        }

        if replay_stream.with_values:
            replay_dtypes['value'] = reward_dtype

        def convert_to_safe_feature_type(dtype):
            return type_utils.safe_tf_dtype(
                serialize.type_to_feature[dtype][-1])

        replay_features = {
            'state':
            parsing_ops.FixedLenSequenceFeature(
                shape=state_shape,
                dtype=convert_to_safe_feature_type(state_dtype)),
            'next_state':
            parsing_ops.FixedLenSequenceFeature(
                shape=state_shape,
                dtype=convert_to_safe_feature_type(state_dtype)),
            'action':
            parsing_ops.FixedLenSequenceFeature(
                shape=action_shape,
                dtype=convert_to_safe_feature_type(action_dtype)),
            'action_value':
            parsing_ops.FixedLenSequenceFeature(
                shape=action_value_shape,
                dtype=convert_to_safe_feature_type(action_value_dtype)),
            'reward':
            parsing_ops.FixedLenSequenceFeature(
                shape=reward_shape,
                dtype=convert_to_safe_feature_type(reward_dtype)),
            'terminal':
            parsing_ops.FixedLenSequenceFeature(
                shape=[], dtype=convert_to_safe_feature_type(dtypes.bool)),
            'sequence_length':
            parsing_ops.FixedLenSequenceFeature(
                shape=[], dtype=convert_to_safe_feature_type(dtypes.int32)),
        }

        if replay_stream.with_values:
            replay_features['value'] = parsing_ops.FixedLenSequenceFeature(
                shape=reward_shape,
                dtype=convert_to_safe_feature_type(reward_dtype))

        def convert_and_fix_dtypes(replay):
            """Cast dtypes back to their original types."""
            fixed_replay = {}
            for k, v in replay.items():
                fixed_replay[k] = math_ops.cast(v, dtype=replay_dtypes[k])
            return fixed_replay

        def generator():
            """Create `tf.Tensor`s from the `ay.contrib.rl.ReplayStream` instance."""
            while True:
                replay_example = None
                try:
                    replay_example = replay_stream.read(
                        limit=max_sequence_length)
                except:
                    yield ""
                else:
                    yield replay_example.SerializeToString()

        def serialize_map(replay_example_str):
            """Parse each example string to `tf.Tensor`."""
            try:
                assert_op = control_flow_ops.Assert(replay_example_str != "",
                                                    [replay_example_str])
                with ops.control_dependencies([assert_op]):
                    _, replay = parsing_ops.parse_single_sequence_example(
                        replay_example_str, sequence_features=replay_features)
            except errors_impl.InvalidArgumentError:
                raise errors_impl.OutOfRangeError()

            return convert_and_fix_dtypes(replay)

        def pad_or_truncate_map(replay):
            """Truncate or pad replays."""
            with_values = 'value' in replay

            if with_values:
                replay = experience.ReplayWithValues(**replay)
            else:
                replay = experience.Replay(**replay)

            sequence_length = math_ops.minimum(max_sequence_length,
                                               replay.sequence_length)
            sequence_length.set_shape([1])

            state = sequence_utils.pad_or_truncate(replay.state,
                                                   max_sequence_length,
                                                   axis=0,
                                                   pad_value=0)
            state.set_shape([max_sequence_length] + state_shape)

            next_state = sequence_utils.pad_or_truncate(replay.next_state,
                                                        max_sequence_length,
                                                        axis=0,
                                                        pad_value=0)
            next_state.set_shape([max_sequence_length] + state_shape)

            action = sequence_utils.pad_or_truncate(replay.action,
                                                    max_sequence_length,
                                                    axis=0,
                                                    pad_value=0)
            action.set_shape([max_sequence_length] + action_shape)

            action_value = sequence_utils.pad_or_truncate(replay.action_value,
                                                          max_sequence_length,
                                                          axis=0,
                                                          pad_value=0)
            action_value.set_shape([max_sequence_length] + action_value_shape)

            reward = sequence_utils.pad_or_truncate(replay.reward,
                                                    max_sequence_length,
                                                    axis=0,
                                                    pad_value=0)
            reward.set_shape([max_sequence_length] + reward_shape)

            terminal = sequence_utils.pad_or_truncate(
                replay.terminal,
                max_sequence_length,
                axis=0,
                pad_value=ops.convert_to_tensor(False))
            terminal.set_shape([max_sequence_length])

            if with_values:
                value = sequence_utils.pad_or_truncate(replay.value,
                                                       max_sequence_length,
                                                       axis=0,
                                                       pad_value=0)
                value.set_shape([max_sequence_length] + reward_shape)

                return experience.ReplayWithValues(
                    state=state,
                    next_state=next_state,
                    action=action,
                    action_value=action_value,
                    value=value,
                    reward=reward,
                    terminal=terminal,
                    sequence_length=sequence_length)

            return experience.Replay(state=state,
                                     next_state=next_state,
                                     action=action,
                                     action_value=action_value,
                                     reward=reward,
                                     terminal=terminal,
                                     sequence_length=sequence_length)

        dataset = dataset_ops.Dataset.from_generator(generator, dtypes.string)
        dataset = dataset.map(serialize_map)
        return dataset.map(pad_or_truncate_map)
Exemple #12
0
def serialize_replay(replay,
                     state_dtype,
                     action_dtype,
                     action_value_dtype,
                     reward_dtype,
                     with_values=False):
    """Returns a `tf.train.SequenceExample` for the given `ay.contrib.rl.Replay` instance.

  Arguments:
    replay: `ay.contrib.rl.Replay` instance.
    state_dtype: dtype of the state space.
    action_dtype: dtype of the action space.
    action_value_dtype: dtype of the action-values space.
    reward_dtype: dtype of the reward space.
    with_values: Python `bool` for recording values.

  Returns:
    A `tf.train.SequenceExample` containing info from the `ay.contrib.rl.Replay`.

  Raises:
    `AssertionError` when replay is not an `ay.rl.Replay` instance.
  """
    if with_values:
        assert_utils.assert_true(
            isinstance(replay, experience.ReplayWithValues),
            '`replay` must be an instance of `ay.contrib.rl.ReplayWithValues`')
    else:
        assert_utils.assert_true(
            isinstance(replay, experience.Replay),
            '`replay` must be an instance of `ay.contrib.rl.Replay`')

    feature_list = {
        'state':
        tf.train.FeatureList(
            feature=serialize_replay_feature(replay.state, dtype=state_dtype)),
        'next_state':
        tf.train.FeatureList(feature=serialize_replay_feature(
            replay.next_state, dtype=state_dtype)),
        'action':
        tf.train.FeatureList(feature=serialize_replay_feature(
            replay.action, dtype=action_dtype)),
        'action_value':
        tf.train.FeatureList(feature=serialize_replay_feature(
            replay.action_value, dtype=action_value_dtype)),
        'reward':
        tf.train.FeatureList(feature=serialize_replay_feature(
            replay.reward, dtype=reward_dtype)),
        'terminal':
        tf.train.FeatureList(feature=serialize_replay_feature(
            replay.terminal, dtype=dtypes.bool)),
        'sequence_length':
        tf.train.FeatureList(feature=serialize_replay_feature(
            [replay.sequence_length], dtype=dtypes.int32)),
    }

    if with_values:
        feature_list['value'] = tf.train.FeatureList(
            feature=serialize_replay_feature(replay.value, dtype=reward_dtype))

    feature_lists = tf.train.FeatureLists(feature_list=feature_list)
    return tf.train.SequenceExample(feature_lists=feature_lists)
Exemple #13
0
def distribution_from_gym_space(space, logits=None, name='SpaceDistribution'):
    """Determines a parameterized `tf.distribution.Distribution` from the `gym.Space`.

  Arguments:
    space: a `gym.Space` instance (i.e. `env.action_space`)
    logits: optional `list` of `tf.Tensor` to be used instead of creating them.
    name: Python `str` name prefixed to Ops created.

  Raises:
    `TypeError` when space is not a `gym.Space` instance.

  Returns:
    Either one of the following: , `tuple` or `dict` of `DistributionWithLogits`, or
        just `DistributionWithLogits`.
  """
    assert_utils.assert_true(isinstance(space, Space),
                             '`space` must be an instance of `gym.Space`')

    with ops.name_scope(name):
        if isinstance(space, discrete.Discrete):
            if logits and isinstance(logits[0], ops.Tensor):
                logits = _dense_projection(logits[0], [space.n])
            else:
                logits = _placeholder_factory_map[discrete.Discrete](space)
            distribution = categorical.Categorical(
                logits=math_ops.cast(logits, dtypes.float32))
            return DistributionWithLogits(distribution=distribution,
                                          logits=logits)

        elif isinstance(space, multi_discrete.MultiDiscrete):
            if logits and isinstance(logits[0], ops.Tensor):
                logits = _dense_projection(logits[0], space.shape)
            else:
                logits = _placeholder_factory_map[
                    multi_discrete.MultiDiscrete](space)
            distribution = categorical.Categorical(
                logits=math_ops.cast(logits, dtypes.float32))
            return DistributionWithLogits(distribution=distribution,
                                          logits=logits)

        elif isinstance(space, multi_binary.MultiBinary):
            if logits and isinstance(logits[0], ops.Tensor):
                logits = _dense_projection(logits[0], space.shape)
            else:
                logits = _placeholder_factory_map[multi_binary.MultiBinary](
                    space)
            distribution = bernoulli.Bernoulli(logits=logits)
            return DistributionWithLogits(distribution=distribution,
                                          logits=logits)

        elif isinstance(space, box.Box):
            if logits and isinstance(logits[0], ops.Tensor):
                logits = _dense_projection(logits[0], space.shape)
            else:
                logits = _placeholder_factory_map[box.Box](space)

            flat_shape = array_utils.product(space.shape)
            shape = array_ops.shape(logits)
            logits = gen_array_ops.reshape(logits,
                                           [shape[0], shape[1], flat_shape])

            log_eps = math.log(distribution_utils.epsilon)

            alpha = core.dense(logits, flat_shape, use_bias=False)
            alpha = clip_ops.clip_by_value(alpha, log_eps, -log_eps)
            alpha = math_ops.log(math_ops.exp(alpha) + 1.0) + 1.0
            alpha = gen_array_ops.reshape(alpha, shape)

            beta = core.dense(logits, flat_shape, use_bias=False)
            beta = clip_ops.clip_by_value(beta, log_eps, -log_eps)
            beta = math_ops.log(math_ops.exp(beta) + 1.0) + 1.0
            beta = gen_array_ops.reshape(beta, shape)

            distribution = beta_min_max.BetaMinMax(concentration1=alpha,
                                                   concentration0=beta,
                                                   min_value=space.low,
                                                   max_value=space.high)
            return DistributionWithLogits(distribution=distribution,
                                          logits=logits)

        elif isinstance(space, tuple_space.Tuple):
            if not logits:
                logits = [None] * len(space.spaces)
            return tuple(
                distribution_from_gym_space(
                    val, logits=[logit], name='tuple_{}'.format(idx))
                for idx, (val, logit) in enumerate(zip(space.spaces, logits)))

        elif isinstance(space, dict_space.Dict):
            if not logits:
                logits = [None] * len(space.spaces)
            return {
                key: distribution_from_gym_space(val,
                                                 logits=[logit],
                                                 name='{}'.format(key))
                for (key, val), logit in zip(space.spaces.items(), logits)
            }

    raise TypeError('`space` not supported: {}'.format(type(space)))
Exemple #14
0
def expected_q_value(reward,
                     action,
                     action_value,
                     next_action_value,
                     sequence_length,
                     max_sequence_length,
                     weights=1.,
                     discount=.95,
                     n_step=False):
    """Computes the expected q returns and values.

  This covers architectures such as DQN, Double-DQN, Dueling-DQN and Noisy-DQN.

  Arguments:
    reward: 1D or 2D `tf.Tensor`, contiguous sequence(s) of rewards.
    action: 1D or 2D `tf.Tensor`, contiguous sequence(s) of actions.
    next_action_value: `tf.Tensor`, `list` or `tuple` of 2 `tf.Tensor`s, where the first entry is
        the `model(next_state) = action_value`, and the second is `target(next_state) = action_value`
    sequence_length: 1D `tf.Tensor`, tensor containing lengths of rewards, action_values, etc..
    max_sequence_length: `int` or `list`, maximum length(s) of rewards.
    weights: `tf.Tensor`, the weights/mask to apply to the result.
    discount: 0D scalar, the discount factor (gamma).
    n_step: 0D bool, if n-step algorithm should be used, MC sampling with discounts.

  Returns:
    `tuple` containing the `q_value` `tf.Tensor` and `expected_q_value` `tf.Tensor`.

  Reference:
    https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
  """
    weights = ops.convert_to_tensor(weights, dtype=reward.dtype)
    discount = ops.convert_to_tensor(discount, dtype=reward.dtype)
    n_step = ops.convert_to_tensor(n_step, dtype=dtypes.bool)
    ndim = len(action_value.shape)

    q_value = sequence_utils.gather_along_second_axis(action_value, action)
    q_value.set_shape([None, max_sequence_length])

    if isinstance(next_action_value, tuple) or isinstance(
            next_action_value, list):
        assert_utils.assert_true(
            len(next_action_value) == 2,
            '`next_action_value` must be a `tuple` of length = 2')
        next_action_value, target_next_action_value = next_action_value
        next_q_value = sequence_utils.gather_along_second_axis(
            next_action_value,
            math_ops.argmax(target_next_action_value,
                            -1,
                            output_type=dtypes.int32))
    else:
        next_q_value = sequence_utils.gather_along_second_axis(
            next_action_value,
            math_ops.argmax(next_action_value, -1, output_type=dtypes.int32))
    next_q_value.set_shape([None, max_sequence_length])

    def n_step_return():
        rest_of_rewards = reward[:, 1:]
        initial_reward = reward[:, 0]
        initial_rewards = array_ops.concat([
            array_ops.expand_dims(initial_reward, -1),
            array_ops.zeros_like(rest_of_rewards)
        ], -1)
        reward_t = initial_rewards + array_ops.concat([
            array_ops.zeros_like(array_ops.expand_dims(initial_reward, -1)),
            math_ops.cumsum(discount * rest_of_rewards, axis=-1, reverse=False)
        ], -1)
        discount_t = array_ops.expand_dims(
            array_ops.tile(array_ops.expand_dims(discount, -1),
                           [array_ops.shape(sequence_length)[0]
                            ])**math_ops.cast(sequence_length, dtypes.float32),
            -1)
        return reward_t + discount_t * next_q_value

    def single_step_return():
        reward_t = reward
        return reward_t + discount * next_q_value

    expected_q_value = control_flow_ops.cond(n_step, n_step_return,
                                             single_step_return)
    expected_q_value.set_shape([None, max_sequence_length])

    return (q_value * weights, expected_q_value * weights)