Example #1
0
def weighted_resample(inputs, weights, overall_rate, scope=None,
                      mean_decay=0.999, warmup=10, seed=None):
  """Performs an approximate weighted resampling of `inputs`.

  This method chooses elements from `inputs` where each item's rate of
  selection is proportional to its value in `weights`, and the average
  rate of selection across all inputs (and many invocations!) is
  `overall_rate`.

  Args:
    inputs: A list of tensors whose first dimension is `batch_size`.
    weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    overall_rate: Desired overall rate of resampling.
    scope: Scope to use for the op.
    mean_decay: How quickly to decay the running estimate of the mean weight.
    warmup: Until the resulting tensor has been evaluated `warmup`
      times, the resampling menthod uses the true mean over all calls
      as its weight estimate, rather than a decayed mean.
    seed: Random seed.

  Returns:
    A list of tensors exactly like `inputs`, but with an unknown (and
      possibly zero) first dimension.
    A tensor containing the effective resampling rate used for each output.

  """
  # Algorithm: Just compute rates as weights/mean_weight *
  # overall_rate. This way the the average weight corresponds to the
  # overall rate, and a weight twice the average has twice the rate,
  # etc.
  with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
    # First: Maintain a running estimated mean weight, with decay
    # adjusted (by also maintaining an invocation count) during the
    # warmup period so that at the beginning, there aren't too many
    # zeros mixed in, throwing the average off.

    with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
      count_so_far = variable_scope.get_local_variable(
          'resample_count', initializer=0)

      estimated_mean = variable_scope.get_local_variable(
          'estimated_mean', initializer=0.0)

      count = count_so_far.assign_add(1)
      real_decay = math_ops.minimum(
          math_ops.truediv((count - 1), math_ops.minimum(count, warmup)),
          mean_decay)

      batch_mean = math_ops.reduce_mean(weights)
      mean = moving_averages.assign_moving_average(
          estimated_mean, batch_mean, real_decay, zero_debias=False)

    # Then, normalize the weights into rates using the mean weight and
    # overall target rate:
    rates = weights * overall_rate / mean

    results = resample_at_rate([rates] + inputs, rates,
                               scope=opscope, seed=seed, back_prop=False)

    return (results[1:], results[0])
Example #2
0
def weighted_resample(inputs, weights, overall_rate, scope=None,
                      mean_decay=0.999, warmup=10, seed=None):
  """Performs an approximate weighted resampling of `inputs`.

  This method chooses elements from `inputs` where each item's rate of
  selection is proportional to its value in `weights`, and the average
  rate of selection across all inputs (and many invocations!) is
  `overall_rate`.

  Args:
    inputs: A list of tensors whose first dimension is `batch_size`.
    weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    overall_rate: Desired overall rate of resampling.
    scope: Scope to use for the op.
    mean_decay: How quickly to decay the running estimate of the mean weight.
    warmup: Until the resulting tensor has been evaluated `warmup`
      times, the resampling menthod uses the true mean over all calls
      as its weight estimate, rather than a decayed mean.
    seed: Random seed.

  Returns:
    A list of tensors exactly like `inputs`, but with an unknown (and
      possibly zero) first dimension.
    A tensor containing the effective resampling rate used for each output.

  """
  # Algorithm: Just compute rates as weights/mean_weight *
  # overall_rate. This way the average weight corresponds to the
  # overall rate, and a weight twice the average has twice the rate,
  # etc.
  with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
    # First: Maintain a running estimated mean weight, with decay
    # adjusted (by also maintaining an invocation count) during the
    # warmup period so that at the beginning, there aren't too many
    # zeros mixed in, throwing the average off.

    with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
      count_so_far = variable_scope.get_local_variable(
          'resample_count', initializer=0)

      estimated_mean = variable_scope.get_local_variable(
          'estimated_mean', initializer=0.0)

      count = count_so_far.assign_add(1)
      real_decay = math_ops.minimum(
          math_ops.truediv((count - 1), math_ops.minimum(count, warmup)),
          mean_decay)

      batch_mean = math_ops.reduce_mean(weights)
      mean = moving_averages.assign_moving_average(
          estimated_mean, batch_mean, real_decay, zero_debias=False)

    # Then, normalize the weights into rates using the mean weight and
    # overall target rate:
    rates = weights * overall_rate / mean

    results = resample_at_rate([rates] + inputs, rates,
                               scope=opscope, seed=seed, back_prop=False)

    return (results[1:], results[0])
Example #3
0
 def noise(a):
     noise_var = get_local_variable("nm",
                                    initializer=tf.zeros(a.get_shape()[1:]))
     ou_theta = get_local_variable("ou_theta", initializer=0.2)
     ou_sigma = get_local_variable("ou_sigma", initializer=0.15)
     n = noise_var.assign_sub(
         ou_theta * noise_var -
         tf.random_normal(a.get_shape()[1:], stddev=ou_sigma))
     return a + n
Example #4
0
def ornstein_uhlenbeck_noise(a, t_decay=100000):
    noise_var = get_local_variable("nm",
                                   initializer=tf.zeros(a.get_shape()[1:]))
    ou_theta = get_local_variable("ou_theta", initializer=0.2)
    ou_sigma = get_local_variable("ou_sigma", initializer=0.15)
    # ou_theta = tf.Print(ou_theta, [noise_var], 'noise: ', first_n=2000)
    ou_sigma = tf.train.exponential_decay(ou_sigma, tt.function.step(),
                                          t_decay, 1e-6)
    n = noise_var.assign_sub(
        ou_theta * noise_var -
        tf.random_normal(a.get_shape()[1:], stddev=ou_sigma))
    return a + n
  def get_placements(self, *args, **kwargs):
    num_children = self.hparams.num_children
    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
      actions_cache = variable_scope.get_local_variable(
          "actions_cache",
          initializer=init_ops.zeros_initializer,
          dtype=dtypes.int32,
          shape=[num_children, self.num_groups],
          trainable=False)

    x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1])
    last_c, last_h, attn_mem = self.encode(x)
    actions, log_probs = {}, {}
    actions["sample"], log_probs["sample"] = (
        self.decode(
            x, last_c, last_h, attn_mem, mode="sample"))
    actions["target"], log_probs["target"] = (
        self.decode(
            x,
            last_c,
            last_h,
            attn_mem,
            mode="target",
            y=actions_cache))
    actions["greedy"], log_probs["greedy"] = (
        self.decode(
            x, last_c, last_h, attn_mem, mode="greedy"))
    actions["sample"] = control_flow_ops.cond(
        self.global_step < self.hparams.stop_sampling,
        lambda: state_ops.assign(actions_cache, actions["sample"]),
        lambda: state_ops.assign(actions_cache, actions["target"]))
    self.actions_cache = actions_cache

    return actions, log_probs
  def get_placements(self, *args, **kwargs):
    num_children = self.hparams.num_children
    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
      actions_cache = variable_scope.get_local_variable(
          "actions_cache",
          initializer=init_ops.zeros_initializer,
          dtype=dtypes.int32,
          shape=[num_children, self.num_groups],
          trainable=False)

    x = self.seq2seq_input_layer
    last_c, last_h, attn_mem = self.encode(x)
    actions, log_probs = {}, {}
    actions["sample"], log_probs["sample"] = (
        self.decode(
            x, last_c, last_h, attn_mem, mode="sample"))
    actions["target"], log_probs["target"] = (
        self.decode(
            x,
            last_c,
            last_h,
            attn_mem,
            mode="target",
            y=actions_cache))
    actions["greedy"], log_probs["greedy"] = (
        self.decode(
            x, last_c, last_h, attn_mem, mode="greedy"))
    actions["sample"] = control_flow_ops.cond(
        self.global_step < self.hparams.stop_sampling,
        lambda: state_ops.assign(actions_cache, actions["sample"]),
        lambda: state_ops.assign(actions_cache, actions["target"]))
    self.actions_cache = actions_cache

    return actions, log_probs
    def testGetLocalVar(self):
        with self.test_session():
            # Check that local variable respects naming.
            with tf.variable_scope("outer") as outer:
                with tf.variable_scope(outer, "default", []):
                    local_var = variable_scope.get_local_variable("w", [], collections=["foo"])
                    self.assertEqual(local_var.name, "outer/w:0")

            # Since variable is local, it should be in the local variable collection
            # but not the trainable collection.
            self.assertIn(local_var, tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))
            self.assertIn(local_var, tf.get_collection("foo"))
            self.assertNotIn(local_var, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

            # Check that local variable respects `reuse`.
            with tf.variable_scope(outer, "default", reuse=True):
                self.assertEqual(variable_scope.get_local_variable("w", []).name, "outer/w:0")
def weighted_resample(inputs,
                      weights,
                      overall_rate,
                      scope=None,
                      mean_decay=0.999,
                      seed=None):
    """Performs an approximate weighted resampling of `inputs`.

  This method chooses elements from `inputs` where each item's rate of
  selection is proportional to its value in `weights`, and the average
  rate of selection across all inputs (and many invocations!) is
  `overall_rate`.

  Args:
    inputs: A list of tensors whose first dimension is `batch_size`.
    weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    overall_rate: Desired overall rate of resampling.
    scope: Scope to use for the op.
    mean_decay: How quickly to decay the running estimate of the mean weight.
    seed: Random seed.

  Returns:
    A list of tensors exactly like `inputs`, but with an unknown (and
      possibly zero) first dimension.
    A tensor containing the effective resampling rate used for each output.

  """
    # Algorithm: Just compute rates as weights/mean_weight *
    # overall_rate. This way the average weight corresponds to the
    # overall rate, and a weight twice the average has twice the rate,
    # etc.
    with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
        # First: Maintain a running estimated mean weight, with zero debiasing
        # enabled (by default) to avoid throwing the average off.

        with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
            estimated_mean = variable_scope.get_local_variable(
                'estimated_mean',
                initializer=math_ops.cast(0, weights.dtype),
                dtype=weights.dtype)

            batch_mean = math_ops.reduce_mean(weights)
            mean = moving_averages.assign_moving_average(
                estimated_mean, batch_mean, mean_decay)

        # Then, normalize the weights into rates using the mean weight and
        # overall target rate:
        rates = weights * overall_rate / mean

        results = resample_at_rate([rates] + inputs,
                                   rates,
                                   scope=opscope,
                                   seed=seed,
                                   back_prop=False)

        return (results[1:], results[0])
Example #9
0
  def testGetLocalVar(self):
    with self.test_session():
      # Check that local variable respects naming.
      with tf.variable_scope("outer") as outer:
        with tf.variable_scope(outer, "default", []):
          local_var = variable_scope.get_local_variable(
              "w", [], collections=["foo"])
          self.assertEqual(local_var.name, "outer/w:0")

      # Since variable is local, it should be in the local variable collection
      # but not the trainable collection.
      self.assertIn(local_var, tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES))
      self.assertIn(local_var, tf.get_collection("foo"))
      self.assertNotIn(
          local_var, tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))

      # Check that local variable respects `reuse`.
      with tf.variable_scope(outer, "default", reuse=True):
        self.assertEqual(variable_scope.get_local_variable("w", []).name,
                         "outer/w:0")
Example #10
0
def weighted_resample(inputs, weights, overall_rate, scope=None,
                      mean_decay=0.999, seed=None):
  """Performs an approximate weighted resampling of `inputs`.

  This method chooses elements from `inputs` where each item's rate of
  selection is proportional to its value in `weights`, and the average
  rate of selection across all inputs (and many invocations!) is
  `overall_rate`.

  Args:
    inputs: A list of tensors whose first dimension is `batch_size`.
    weights: A `[batch_size]`-shaped tensor with each batch member's weight.
    overall_rate: Desired overall rate of resampling.
    scope: Scope to use for the op.
    mean_decay: How quickly to decay the running estimate of the mean weight.
    seed: Random seed.

  Returns:
    A list of tensors exactly like `inputs`, but with an unknown (and
      possibly zero) first dimension.
    A tensor containing the effective resampling rate used for each output.

  """
  # Algorithm: Just compute rates as weights/mean_weight *
  # overall_rate. This way the average weight corresponds to the
  # overall rate, and a weight twice the average has twice the rate,
  # etc.
  with ops.name_scope(scope, 'weighted_resample', inputs) as opscope:
    # First: Maintain a running estimated mean weight, with zero debiasing
    # enabled (by default) to avoid throwing the average off.

    with variable_scope.variable_scope(scope, 'estimate_mean', inputs):
      estimated_mean = variable_scope.get_local_variable(
          'estimated_mean',
          initializer=math_ops.cast(0, weights.dtype),
          dtype=weights.dtype)

      batch_mean = math_ops.reduce_mean(weights)
      mean = moving_averages.assign_moving_average(
          estimated_mean, batch_mean, mean_decay)

    # Then, normalize the weights into rates using the mean weight and
    # overall target rate:
    rates = weights * overall_rate / mean

    results = resample_at_rate([rates] + inputs, rates,
                               scope=opscope, seed=seed, back_prop=False)

    return (results[1:], results[0])
  def get_groupings(self, *args, **kwargs):
    num_children = self.hparams.num_children
    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
      grouping_actions_cache = variable_scope.get_local_variable(
          "grouping_actions_cache",
          initializer=init_ops.zeros_initializer,
          dtype=dtypes.int32,
          shape=[num_children, self.num_ops],
          trainable=False)
    input_layer = self.op_embeddings
    input_layer = array_ops.expand_dims(input_layer, 0)
    feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
    grouping_actions, grouping_log_probs = {}, {}
    grouping_actions["sample"], grouping_log_probs[
        "sample"] = self.make_grouping_predictions(feed_ff_input_layer)

    grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
                                                  grouping_actions["sample"])
    self.grouping_actions_cache = grouping_actions_cache

    return grouping_actions, grouping_log_probs
  def get_groupings(self, *args, **kwargs):
    num_children = self.hparams.num_children
    with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
      grouping_actions_cache = variable_scope.get_local_variable(
          "grouping_actions_cache",
          initializer=init_ops.zeros_initializer,
          dtype=dtypes.int32,
          shape=[num_children, self.num_ops],
          trainable=False)
    input_layer = self.op_embeddings
    input_layer = array_ops.expand_dims(input_layer, 0)
    feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
    grouping_actions, grouping_log_probs = {}, {}
    grouping_actions["sample"], grouping_log_probs[
        "sample"] = self.make_grouping_predictions(feed_ff_input_layer)

    grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
                                                  grouping_actions["sample"])
    self.grouping_actions_cache = grouping_actions_cache

    return grouping_actions, grouping_log_probs
Example #13
0
def variable_scoped_function_with_local_variable():
  variable_scope.get_local_variable(
      "local", shape=[1], initializer=init_ops.zeros_initializer())
  return variable_scope.get_variable(
      "dummy", shape=[1], initializer=init_ops.zeros_initializer())
  def build_controller(self):
    """RL optimization interface.

    Returns:
      ops: A dictionary holding handles of the model used for training.
    """

    self._global_step = training_util.get_or_create_global_step()
    ops = {}
    ops["loss"] = 0

    failing_signal = self.compute_reward(self.hparams.failing_signal)

    ctr = {}

    with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
      with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
        ctr["reward"] = {"value": [], "ph": [], "update": []}
        ctr["ready"] = {"value": [], "ph": [], "update": []}
        ctr["best_reward"] = {"value": [], "update": []}
        for i in range(self.hparams.num_children):
          reward_value = variable_scope.get_local_variable(
              "reward_{}".format(i),
              initializer=0.0,
              dtype=dtypes.float32,
              trainable=False)
          reward_ph = array_ops.placeholder(
              dtypes.float32, shape=(), name="reward_ph_{}".format(i))
          reward_update = state_ops.assign(
              reward_value, reward_ph, use_locking=True)
          ctr["reward"]["value"].append(reward_value)
          ctr["reward"]["ph"].append(reward_ph)
          ctr["reward"]["update"].append(reward_update)
          best_reward = variable_scope.get_local_variable(
              "best_reward_{}".format(i),
              initializer=failing_signal,
              dtype=dtypes.float32,
              trainable=False)
          ctr["best_reward"]["value"].append(best_reward)
          ctr["best_reward"]["update"].append(
              state_ops.assign(best_reward,
                               math_ops.minimum(best_reward, reward_update)))

          ready_value = variable_scope.get_local_variable(
              "ready_{}".format(i),
              initializer=True,
              dtype=dtypes.bool,
              trainable=False)
          ready_ph = array_ops.placeholder(
              dtypes.bool, shape=(), name="ready_ph_{}".format(i))
          ready_update = state_ops.assign(
              ready_value, ready_ph, use_locking=True)
          ctr["ready"]["value"].append(ready_value)
          ctr["ready"]["ph"].append(ready_ph)
          ctr["ready"]["update"].append(ready_update)

      ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings()
      summary.histogram(
          "grouping_actions",
          array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
                          [1, array_ops.shape(self.op_embeddings)[0]]))

      with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
        ctr["baseline"] = variable_scope.get_local_variable(
            "baseline",
            initializer=failing_signal
            if self.hparams.start_with_failing_signal else 0.0,
            dtype=dtypes.float32,
            trainable=False)

      new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
          1 - self.hparams.bl_dec) * math_ops.reduce_mean(
              ctr["reward"]["value"])
      if not self.hparams.always_update_baseline:
        baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal)
        selected_reward = array_ops.boolean_mask(ctr["reward"]["value"],
                                                 baseline_mask)
        selected_baseline = control_flow_ops.cond(
            math_ops.reduce_any(baseline_mask),
            lambda: math_ops.reduce_mean(selected_reward),
            lambda: constant_op.constant(0, dtype=dtypes.float32))
        ctr["pos_reward"] = selected_baseline
        pos_ = math_ops.less(
            constant_op.constant(0, dtype=dtypes.float32), selected_baseline)
        selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
            1 - self.hparams.bl_dec) * selected_baseline
        selected_baseline = control_flow_ops.cond(
            pos_, lambda: selected_baseline, lambda: ctr["baseline"])
        new_baseline = control_flow_ops.cond(
            math_ops.less(self.global_step,
                          self.hparams.stop_updating_after_steps),
            lambda: new_baseline, lambda: selected_baseline)
      ctr["baseline_update"] = state_ops.assign(
          ctr["baseline"], new_baseline, use_locking=True)

      ctr["y_preds"], ctr["log_probs"] = self.get_placements()
      summary.histogram("actions", ctr["y_preds"]["sample"])
      mask = math_ops.less(ctr["reward"]["value"], failing_signal)
      ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
      ctr["loss"] *= (
          ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"])

      selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
      selected_loss = control_flow_ops.cond(
          math_ops.reduce_any(mask),
          lambda: math_ops.reduce_mean(-selected_loss),
          lambda: constant_op.constant(0, dtype=dtypes.float32))

      ctr["loss"] = control_flow_ops.cond(
          math_ops.less(self.global_step,
                        self.hparams.stop_updating_after_steps),
          lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss)

      ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
      summary.scalar("loss", ctr["loss"])
      summary.scalar("avg_reward", ctr["reward_s"])
      summary.scalar("best_reward_so_far", best_reward)
      summary.scalar(
          "advantage",
          math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))

    with variable_scope.variable_scope(
        "optimizer", reuse=variable_scope.AUTO_REUSE):
      (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
       ctr["grad_norms"]) = self._get_train_ops(
           ctr["loss"],
           tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
           self.global_step,
           grad_bound=self.hparams.grad_bound,
           lr_init=self.hparams.lr,
           lr_dec=self.hparams.lr_dec,
           start_decay_step=self.hparams.start_decay_step,
           decay_steps=self.hparams.decay_steps,
           optimizer_type=self.hparams.optimizer_type)

    summary.scalar("gradnorm", ctr["grad_norm"])
    summary.scalar("lr", ctr["lr"])
    ctr["summary"] = summary.merge_all()
    ops["controller"] = ctr

    self.ops = ops
    return ops
Example #15
0
    def build_controller(self):
        """RL optimization interface.

    Returns:
      ops: A dictionary holding handles of the model used for training.
    """

        self._global_step = training_util.get_or_create_global_step()
        ops = {}
        ops["loss"] = 0

        failing_signal = self.compute_reward(self.hparams.failing_signal)

        ctr = {}

        with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
            with variable_scope.variable_scope("controller_{}".format(
                    self.ctrl_id)):
                ctr["reward"] = {"value": [], "ph": [], "update": []}
                ctr["ready"] = {"value": [], "ph": [], "update": []}
                ctr["best_reward"] = {"value": [], "update": []}
                for i in range(self.hparams.num_children):
                    reward_value = variable_scope.get_local_variable(
                        "reward_{}".format(i),
                        initializer=0.0,
                        dtype=dtypes.float32,
                        trainable=False)
                    reward_ph = array_ops.placeholder(
                        dtypes.float32,
                        shape=(),
                        name="reward_ph_{}".format(i))
                    reward_update = state_ops.assign(reward_value,
                                                     reward_ph,
                                                     use_locking=True)
                    ctr["reward"]["value"].append(reward_value)
                    ctr["reward"]["ph"].append(reward_ph)
                    ctr["reward"]["update"].append(reward_update)
                    best_reward = variable_scope.get_local_variable(
                        "best_reward_{}".format(i),
                        initializer=failing_signal,
                        dtype=dtypes.float32,
                        trainable=False)
                    ctr["best_reward"]["value"].append(best_reward)
                    ctr["best_reward"]["update"].append(
                        state_ops.assign(
                            best_reward,
                            math_ops.minimum(best_reward, reward_update)))

                    ready_value = variable_scope.get_local_variable(
                        "ready_{}".format(i),
                        initializer=True,
                        dtype=dtypes.bool,
                        trainable=False)
                    ready_ph = array_ops.placeholder(
                        dtypes.bool, shape=(), name="ready_ph_{}".format(i))
                    ready_update = state_ops.assign(ready_value,
                                                    ready_ph,
                                                    use_locking=True)
                    ctr["ready"]["value"].append(ready_value)
                    ctr["ready"]["ph"].append(ready_ph)
                    ctr["ready"]["update"].append(ready_update)

            ctr["grouping_y_preds"], ctr[
                "grouping_log_probs"] = self.get_groupings()
            summary.histogram(
                "grouping_actions",
                array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
                                [1, array_ops.shape(self.op_embeddings)[0]]))

            with variable_scope.variable_scope("controller_{}".format(
                    self.ctrl_id)):
                ctr["baseline"] = variable_scope.get_local_variable(
                    "baseline",
                    initializer=failing_signal
                    if self.hparams.start_with_failing_signal else 0.0,
                    dtype=dtypes.float32,
                    trainable=False)

            new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
                1 - self.hparams.bl_dec) * math_ops.reduce_mean(
                    ctr["reward"]["value"])
            if not self.hparams.always_update_baseline:
                baseline_mask = math_ops.less(ctr["reward"]["value"],
                                              failing_signal)
                selected_reward = array_ops.boolean_mask(
                    ctr["reward"]["value"], baseline_mask)
                selected_baseline = control_flow_ops.cond(
                    math_ops.reduce_any(baseline_mask),
                    lambda: math_ops.reduce_mean(selected_reward),
                    lambda: constant_op.constant(0, dtype=dtypes.float32))
                ctr["pos_reward"] = selected_baseline
                pos_ = math_ops.less(
                    constant_op.constant(0, dtype=dtypes.float32),
                    selected_baseline)
                selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
                    1 - self.hparams.bl_dec) * selected_baseline
                selected_baseline = control_flow_ops.cond(
                    pos_, lambda: selected_baseline, lambda: ctr["baseline"])
                new_baseline = control_flow_ops.cond(
                    math_ops.less(self.global_step,
                                  self.hparams.stop_updating_after_steps),
                    lambda: new_baseline, lambda: selected_baseline)
            ctr["baseline_update"] = state_ops.assign(ctr["baseline"],
                                                      new_baseline,
                                                      use_locking=True)

            ctr["y_preds"], ctr["log_probs"] = self.get_placements()
            summary.histogram("actions", ctr["y_preds"]["sample"])
            mask = math_ops.less(ctr["reward"]["value"], failing_signal)
            ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
            ctr["loss"] *= (ctr["log_probs"]["sample"] +
                            ctr["grouping_log_probs"]["sample"])

            selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
            selected_loss = control_flow_ops.cond(
                math_ops.reduce_any(mask),
                lambda: math_ops.reduce_mean(-selected_loss),
                lambda: constant_op.constant(0, dtype=dtypes.float32))

            ctr["loss"] = control_flow_ops.cond(
                math_ops.less(self.global_step,
                              self.hparams.stop_updating_after_steps),
                lambda: math_ops.reduce_mean(-ctr["loss"]),
                lambda: selected_loss)

            ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
            summary.scalar("loss", ctr["loss"])
            summary.scalar("avg_reward", ctr["reward_s"])
            summary.scalar("best_reward_so_far", best_reward)
            summary.scalar(
                "advantage",
                math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))

        with variable_scope.variable_scope("optimizer",
                                           reuse=variable_scope.AUTO_REUSE):
            (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
             ctr["grad_norms"]) = self._get_train_ops(
                 ctr["loss"],
                 tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
                 self.global_step,
                 grad_bound=self.hparams.grad_bound,
                 lr_init=self.hparams.lr,
                 lr_dec=self.hparams.lr_dec,
                 start_decay_step=self.hparams.start_decay_step,
                 decay_steps=self.hparams.decay_steps,
                 optimizer_type=self.hparams.optimizer_type)

        summary.scalar("gradnorm", ctr["grad_norm"])
        summary.scalar("lr", ctr["lr"])
        ctr["summary"] = summary.merge_all()
        ops["controller"] = ctr

        self.ops = ops
        return ops