def __call__(self, getter, name, trainable, collections, *args, **kwargs):
    if trainable:
      with ops.device(self._worker_device):
        local_var = getter(name, trainable=True,
                           collections=[ops.GraphKeys.LOCAL_VARIABLES], 
                           *args, **kwargs)
        
      global_center_variable = variable_scope.variable(
        name='%s/%s' %
             (GLOBAL_VARIABLE_NAME,
              name),
        initial_value=local_var.initialized_value(),
        trainable=False,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])

      with ops.device(self._worker_device):
        local_center_variable = variable_scope.variable(
          name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
          initial_value=local_var.initialized_value(),
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES])
        
      self._local_map[local_var] = local_center_variable
      self._global_map[local_var] = global_center_variable
      return local_var
    else:
      return getter(name, trainable, collections, *args, **kwargs)
Esempio n. 2
0
    def _create_slots(self, var_list):
        # Create the beta1 and beta2 accumulators on the same device as the first
        # variable. Sort the var_list to make sure this device is consistent across
        # workers (these need to go on the same PS, otherwise some updates are
        # silently ignored).
        first_var = min(var_list, key=lambda x: x.name)

        create_new = self._iterations is None
        if not create_new and context.in_graph_mode():
            create_new = (self._iterations.graph is not first_var.graph)

        if create_new:
            with ops.colocate_with(first_var):
                self._beta1_power = variable_scope.variable(self._beta1,
                                                            name="beta1_power",
                                                            trainable=False)
                self._beta2_power = variable_scope.variable(self._beta2,
                                                            name="beta2_power",
                                                            trainable=False)
                self._iterations = variable_scope.variable(0.,
                                                           name="iterations",
                                                           trainable=False)
                self._m_schedule = variable_scope.variable(1.,
                                                           name="m_schedule",
                                                           trainable=False)
        # Create slots for the first and second moments.
        for v in var_list:
            self._zeros_slot(v, "m", self._name)
            self._zeros_slot(v, "v", self._name)
  def __init__(self,
               init_loss_scale,
               incr_every_n_steps,
               decr_every_n_nan_or_inf=2,
               incr_ratio=2,
               decr_ratio=0.8):
    """Constructor of exponential-update loss scale manager.

    Args:
      init_loss_scale: A Python float.  The loss scale to use at the beginning.
      incr_every_n_steps: Increases loss scale every n consecutive steps with
        finite gradients.
      decr_every_n_nan_or_inf: Decreases loss scale every n accumulated steps
        with nan or inf gradients.
      incr_ratio: The multiplier to use when increasing the loss scale.
      decr_ratio: The less-than-one-multiplier to use when decreasing the loss
        scale.
    """
    self._incr_every_n_steps = incr_every_n_steps
    self._decr_every_n_nan_or_inf = decr_every_n_nan_or_inf
    self._incr_ratio = incr_ratio
    self._decr_ratio = decr_ratio
    self._loss_scale = variable_scope.variable(
        name="loss_scale",
        initial_value=ops.convert_to_tensor(init_loss_scale, dtypes.float32),
        dtype=dtypes.float32,
        trainable=False)
    self._num_good_steps = variable_scope.variable(
        name="good_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
    self._num_bad_steps = variable_scope.variable(
        name="bad_steps", initial_value=0, dtype=dtypes.int32, trainable=False)
 def model_fn():
   vs = []
   vs.append(variable_scope.variable(1.0, name="foo/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar"))
   vs.append(variable_scope.variable(1.0, name="foo_1/bar_1"))
   vs.append(variable_scope.variable(1.0, name="foo/bar_1"))
   distribute_lib.get_tower_context().merge_call(lambda _: _)
   return vs
 def testInvalidSynchronizationWithVariable(self):
   self._skip_eager_if_gpus_less_than(1)
   devices = ["/device:CPU:0", "/device:GPU:0"]
   dist = mirrored_strategy.MirroredStrategy(devices)
   with dist.scope():
     with self.assertRaisesRegexp(
         ValueError, "Invalid variable synchronization mode: Invalid for "
         "variable: v"):
       variable_scope.variable(1.0, name="v", synchronization="Invalid")
 def testNoneSynchronizationWithVariable(self):
   self._skip_eager_if_gpus_less_than(1)
   devices = ["/device:CPU:0", "/device:GPU:0"]
   dist = mirrored_strategy.MirroredStrategy(devices)
   with dist.scope():
     with self.assertRaisesRegexp(
         ValueError, "`NONE` variable synchronization mode is not "
         "supported with `Mirrored` distribution strategy. Please change "
         "the `synchronization` for variable: v"):
       variable_scope.variable(
           1.0,
           name="v",
           synchronization=variable_scope.VariableSynchronization.NONE)
 def model_fn(device_id):
   tower_context = distribute_lib.get_tower_context()
   with tower_context.tower_local_var_scope("sum"):
     v_sum = variable_scope.variable(1.0)
   with tower_context.tower_local_var_scope("mean"):
     v_mean = variable_scope.variable(4.0)
   self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
   self.assertTrue(isinstance(v_mean, values.TowerLocalVariable))
   updates = [v_sum.assign_add(2.0 + device_id),
              v_mean.assign(6.0 * device_id)]
   all_v_sum[device_id] = v_sum
   all_v_mean[device_id] = v_mean
   return updates, v_sum, v_mean
Esempio n. 8
0
    def _create_slots(self, var_list):
        first_var = min(var_list, key=lambda x: x.name)

        create_new = self._beta1_power is None

        if create_new:
            with ops.colocate_with(first_var):
                self._beta1_power = variable_scope.variable(self._beta1, name="beta1_power", trainable=False)
                self._beta2_power = variable_scope.variable(self._beta2, name="beta2_power", trainable=False)
        # Create slots for the first and second moments.
        for v in var_list :
            self._zeros_slot(v, "m", self._name)
            self._zeros_slot(v, "v", self._name)
            self._zeros_slot(v, "vhat", self._name)
  def testNameScopeWithVariable(self):
    def in_cross_tower(_):
      c = variable_scope.variable(1.0, name="c")
      return c

    def model_fn():
      b = variable_scope.variable(1.0, name="b")
      with ops.name_scope("foo"):
        c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
      return b, c

    dist = mirrored_strategy.MirroredStrategy(
        ["/device:GPU:0", "/device:CPU:0"])

    with context.graph_mode(), dist.scope():
      with ops.name_scope("main"):
        a = variable_scope.variable(1.0, name="a")
        result = dist.call_for_each_tower(model_fn, run_concurrently=False)
      result_b = result[0]
      result_c = result[1]
      self.assertIsInstance(result_b, values.DistributedValues)
      self.assertIsInstance(result_c, values.DistributedValues)
      a0, a1 = dist.unwrap(a)
      b0, b1 = dist.unwrap(result_b)
      c0, c1 = dist.unwrap(result_c)
      self.assertEquals("main/a:0", a0.name)
      self.assertEquals("main/a/replica_1:0", a1.name)
      self.assertEquals("main/b:0", b0.name)
      self.assertEquals("main/b/replica_1:0", b1.name)
      self.assertEquals("main/foo/c:0", c0.name)
      self.assertEquals("main/foo/c/replica_1:0", c1.name)
Esempio n. 10
0
def _identity_metric_single(name, input_tensor):
  """A metric which takes on its last updated value.

  This keeps evaluation metrics in sync with one another, since update ops are
  run separately from their result Tensors. Simply returning (input_tensor,
  no_op) as a metric with a value but no update means that a metric will come
  from a different batch of data than metrics which cache values in a Variable
  (e.g. the default loss metric).

  Args:
    name: A name for the metric.
    input_tensor: Any Tensor.
  Returns:
    A tuple of (value, update_op).
  """
  metric_variable = variable_scope.variable(
      name="{}_identity_metric".format(name),
      initial_value=array_ops.zeros([], dtype=input_tensor.dtype),
      collections=[ops.GraphKeys.LOCAL_VARIABLES],
      validate_shape=False)
  update_op = state_ops.assign(
      metric_variable, input_tensor, validate_shape=False)
  # This shape will be correct once the first update runs (but may be
  # incomplete, so is not helpful for initializing the variable).
  metric_variable.set_shape(input_tensor.get_shape())
  return (metric_variable.value(), update_op)
Esempio n. 11
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    in_graph_mode = context.in_graph_mode()
    if in_graph_mode:
      graph = colocate_with.graph
    else:
      graph = None

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_checkpointable()
      with ops.colocate_with(colocate_with):
        if not in_graph_mode:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(initial_value, name=name, trainable=False)
        # Restore this variable by name if necessary, but don't add a
        # Checkpointable dependency. Optimizers return the current graph's
        # non-slot variables from _checkpoint_dependencies explicitly rather
        # than unconditionally adding dependencies (since there may be multiple
        # non-slot variables with the same name in different graphs, trying to
        # save all of them would result in errors).
        self._handle_deferred_dependencies(name=name, checkpointable=v)
      self._non_slot_dict[key] = v

    return v
 def model_fn():
   v_sum = variable_scope.variable(
       1.0,
       synchronization=variable_scope.VariableSynchronization.ON_READ,
       aggregation=variable_scope.VariableAggregation.SUM)
   self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
   return v_sum
  def __call__(self, getter, name, trainable, collections, *args, **kwargs):
    if trainable:
      with ops.device(self._worker_device):
        local_var = getter(
            name,
            trainable=True,
            collections=[ops.GraphKeys.LOCAL_VARIABLES],
            *args,
            **kwargs)

      global_variable = variable_scope.variable(
          name="%s/%s" % (GLOBAL_VARIABLE_NAME, name),
          initial_value=local_var.initialized_value(),
          trainable=False,
          collections=[ops.GraphKeys.GLOBAL_VARIABLES])

      self._local_2_global[local_var] = global_variable
      return local_var
    else:
      kwargs['trainable'] = trainable
      kwargs['collections'] = collections
      if ops.GraphKeys.LOCAL_VARIABLES in collections:
        with ops.device(self._worker_device):
          return getter(name, *args, **kwargs)
      else:
        return getter(name, *args, **kwargs)
Esempio n. 14
0
  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
Esempio n. 15
0
  def _create_factors(cls, rows, cols, num_shards, init, name):
    """Helper function to create row and column factors."""
    if callable(init):
      init = init()
    if isinstance(init, list):
      assert len(init) == num_shards
    elif isinstance(init, str) and init == "random":
      pass
    elif num_shards == 1:
      init = [init]
    sharded_matrix = []
    sizes = cls._shard_sizes(rows, num_shards)
    assert len(sizes) == num_shards

    def make_initializer(i, size):

      def initializer():
        if init == "random":
          return random_ops.random_normal([size, cols])
        else:
          return init[i]

      return initializer

    for i, size in enumerate(sizes):
      var_name = "%s_shard_%d" % (name, i)
      var_init = make_initializer(i, size)
      sharded_matrix.append(
          variable_scope.variable(
              var_init, dtype=dtypes.float32, name=var_name))

    return sharded_matrix
 def create_metric_variable(self, initial_value, name):
   return variable_scope.variable(
       initial_value,
       trainable=False,
       collections=[ops_lib.GraphKeys.METRIC_VARIABLES],
       validate_shape=True,
       name=name)
Esempio n. 17
0
 def _transient_var(name):
   """Helper function to create a Variable."""
   return variable_scope.variable(
       1.0,
       trainable=False,
       collections=[ops.GraphKeys.LOCAL_VARIABLES],
       validate_shape=False,
       name=name)
Esempio n. 18
0
def _local_variable(tensor, name=None):
  """Stores a tensor as a local Variable for faster read."""
  return variable_scope.variable(
      initial_value=tensor,
      trainable=False,
      collections=[ops.GraphKeys.LOCAL_VARIABLES],
      validate_shape=False,
      name=name)
Esempio n. 19
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   self.assertEqual("bar", variable_scope.variable(1.0, name="bar"))
Esempio n. 20
0
  def _create_weights(cls, wt_init, num_wts, num_shards, name):
    """Helper function to create sharded weight vector.

    Args:
      wt_init: init value for the weight. If None, weights are not created. This
        can be one of the None, a list of non-negative real numbers or a single
        non-negative real number (or equivalent iterables).
      num_wts: total size of all the weight shards
      num_shards: number of shards for the weights
      name: name for the new Variables.

    Returns:
      A list of weight shard Tensors.

    Raises:
      ValueError: If wt_init is not the right format.
    """

    if wt_init is None:
      return None

    init_mode = "list"
    if isinstance(wt_init, collections.Iterable):
      if num_shards == 1 and len(wt_init) == num_wts:
        wt_init = [wt_init]
      assert len(wt_init) == num_shards
    elif isinstance(wt_init, numbers.Real) and wt_init >= 0:
      init_mode = "scalar"
    else:
      raise ValueError(
          "Invalid weight initialization argument. Must be one of these: "
          "None, a real non-negative real number, or a list of lists of "
          "non-negative real numbers (or equivalent iterables) corresponding "
          "to sharded factors.")

    sizes = cls._shard_sizes(num_wts, num_shards)
    assert len(sizes) == num_shards

    with ops.name_scope(name):
      def make_wt_initializer(i, size):

        def initializer():
          if init_mode == "scalar":
            return wt_init * array_ops.ones([size])
          else:
            return wt_init[i]

        return initializer

      sharded_weight = []
      for i, size in enumerate(sizes):
        var_name = "%s_shard_%d" % (name, i)
        var_init = make_wt_initializer(i, size)
        sharded_weight.append(
            variable_scope.variable(
                var_init, dtype=dtypes.float32, name=var_name))

      return sharded_weight
Esempio n. 21
0
 def test_creating_var_with_numpy_arrays(self):
   with self.cached_session() as session:
     x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
     initial = np.zeros_like(x)
     var_x = variable_scope.variable(initial)
     numpy_dataset.init_var_from_numpy(var_x, x, session)
     val = self.evaluate(var_x.value())
     # Verify that the numpy value is copied to the variable.
     self.assertAllEqual(x, val)
    def model_fn():
      v0 = variable_scope.variable(1.0, name="var0", aggregation=None)
      with variable_scope.variable_scope("common"):
        v1 = variable_scope.variable(1.0, name="var1")
        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
        v2 = variable_scope.variable(
            1.0,
            name="var2",
            synchronization=variable_scope.VariableSynchronization.ON_READ,
            aggregation=variable_scope.VariableAggregation.SUM)
        v3 = variable_scope.variable(
            1.0,
            name="var3",
            synchronization=variable_scope.VariableSynchronization.ON_WRITE,
            aggregation=variable_scope.VariableAggregation.MEAN)

      return v0, v1, v2, v3
Esempio n. 23
0
  def _create_variables(self, data, initial_means=None):
    """Initializes GMM algorithm.

    Args:
      data: a list of Tensors with data, each row is a new example.
      initial_means: a Tensor with a matrix of means.
    """
    first_shard = data[0]
    # Initialize means: num_classes X 1 X dimensions.
    if initial_means is not None:
      self._means = variable_scope.variable(
          array_ops.expand_dims(initial_means, 1),
          name=self.CLUSTERS_VARIABLE,
          validate_shape=False,
          dtype=dtypes.float32)
    else:
      # Sample data randomly
      self._means = variable_scope.variable(
          array_ops.expand_dims(
              _init_clusters_random(data, self._num_classes, self._random_seed),
              1),
          name=self.CLUSTERS_VARIABLE,
          validate_shape=False)

    # Initialize covariances.
    if self._covariance_type == FULL_COVARIANCE:
      cov = _covariance(first_shard, False) + self._min_var
      # A matrix per class, num_classes X dimensions X dimensions
      covs = array_ops.tile(
          array_ops.expand_dims(cov, 0), [self._num_classes, 1, 1])
    elif self._covariance_type == DIAG_COVARIANCE:
      cov = _covariance(first_shard, True) + self._min_var
      # A diagonal per row, num_classes X dimensions.
      covs = array_ops.tile(
          array_ops.expand_dims(array_ops.diag_part(cov), 0),
          [self._num_classes, 1])
    self._covs = variable_scope.variable(
        covs, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
    # Mixture weights, representing the probability that a randomly
    # selected unobservable data (in EM terms) was generated by component k.
    self._alpha = variable_scope.variable(
        array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
        name=self.CLUSTERS_WEIGHT,
        validate_shape=False)
Esempio n. 24
0
 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribute.get_tower_context())
     self.assertIs(dist, distribute.get_cross_tower_context())
     self.assertTrue(distribute.has_distribution_strategy())
     self.assertIs(dist, distribute.get_distribution_strategy())
     self.assertEqual("baz", variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
Esempio n. 25
0
  def _create_slots(self, var_list):
    # Create the beta1 and beta2 accumulators on the same device as the first
    # variable. Sort the var_list to make sure this device is consistent across
    # workers (these need to go on the same PS, otherwise some updates are
    # silently ignored).
    first_var = min(var_list, key=lambda x: x.name)

    if (self._beta1_power is None or
        self._beta1_power.graph is not first_var.graph):
      with ops.colocate_with(first_var):
        self._beta1_power = variable_scope.variable(self._beta1,
                                                    name="beta1_power",
                                                    trainable=False)
        self._beta2_power = variable_scope.variable(self._beta2,
                                                    name="beta2_power",
                                                    trainable=False)
    # Create slots for the first and second moments.
    for v in var_list:
      self._zeros_slot(v, "m", self._name)
      self._zeros_slot(v, "v", self._name)
Esempio n. 26
0
    def create_axis_ops(sp_input, num_items, update_fn, axis_name):
      """Creates book-keeping and training ops for a given axis.

      Args:
        sp_input: A SparseTensor corresponding to the row or column batch.
        num_items: An integer, the total number of items of this axis.
        update_fn: A function that takes one argument (`sp_input`), and that
        returns a tuple of
          * new_factors: A flot Tensor of the factor values after update.
          * update_op: a TensorFlow op which updates the factors.
          * loss: A float Tensor, the unregularized loss.
          * reg_loss: A float Tensor, the regularization loss.
          * sum_weights: A float Tensor, the sum of factor weights.
        axis_name: A string that specifies the name of the axis.

      Returns:
        A tuple consisting of:
          * reset_processed_items_op: A TensorFlow op, to be run before the
            beginning of any sweep. It marks all items as not-processed.
          * axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
      """
      processed_items_init = array_ops.fill(dims=[num_items], value=False)
      with ops.colocate_with(processed_items_init):
        processed_items = variable_scope.variable(
            processed_items_init,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES],
            trainable=False,
            name="processed_" + axis_name)
      reset_processed_items_op = state_ops.assign(
          processed_items, processed_items_init,
          name="reset_processed_" + axis_name)
      _, update_op, loss, reg, sum_weights = update_fn(sp_input)
      input_indices = sp_input.indices[:, 0]
      with ops.control_dependencies([
          update_op,
          state_ops.assign(loss_var, loss + reg),
          state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
        with ops.colocate_with(processed_items):
          update_processed_items = state_ops.scatter_update(
              processed_items,
              input_indices,
              array_ops.ones_like(input_indices, dtype=dtypes.bool),
              name="update_processed_{}_indices".format(axis_name))
        with ops.control_dependencies([update_processed_items]):
          is_sweep_done = math_ops.reduce_all(processed_items)
          axis_train_op = control_flow_ops.group(
              global_step_incr_op,
              state_ops.assign(is_sweep_done_var, is_sweep_done),
              state_ops.assign_add(
                  completed_sweeps_var,
                  math_ops.cast(is_sweep_done, dtypes.int32)),
              name="{}_sweep_train_op".format(axis_name))
      return reset_processed_items_op, axis_train_op
Esempio n. 27
0
 def run_fn():
   tower_context = distribute.get_tower_context()
   self.assertTrue(tower_context is not None)
   self.assertIs(None, distribute.get_cross_tower_context())
   self.assertTrue(distribute.has_distribution_strategy())
   self.assertIs(dist, distribute.get_distribution_strategy())
   self.assertEqual("foo", tower_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
Esempio n. 28
0
def _local_variable(initial_value, name=None):
  """Stores a tensor as a local Variable for faster read."""
  result = variable_scope.variable(
      initial_value=initial_value,
      trainable=False,
      collections=[ops.GraphKeys.LOCAL_VARIABLES],
      validate_shape=False,
      name=name)
  if isinstance(initial_value, ops.Tensor):
    # Match the resulting variable's shape if the initial_value is a Tensor.
    result.set_shape(initial_value.shape)
  return result
    def model_fn(device_id):
      assert isinstance(device_id, int)
      def thread_creator_fn(next_creator, *args, **kwargs):
        return next_creator(*args, **kwargs) + ":thread_" + str(device_id)

      with variable_scope.variable_creator_scope(thread_creator_fn):
        # Create a variable in this scope.
        v = variable_scope.variable(1.0)

        # This will pause the current thread, and execute the other thread.
        distribute_lib.get_tower_context().merge_call(lambda _: _)
      return v
  def testSharedVariable(self):

    shared_variable_store = {}
    num_devices = 3
    creator_fns = []
    for i in range(num_devices):
      creator_fn = shared_variable_creator.make_fn(shared_variable_store, i)
      creator_fns.append(creator_fn)

    with variable_scope.variable_creator_scope(creator_fns[0]):
      v0 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[1]):
      v1 = variable_scope.variable(1.0, name="foo")

    with variable_scope.variable_creator_scope(creator_fns[2]):
      v2 = variable_scope.variable(1.0, name="foo")

    # v1 and v2 should be same as v0
    self.assertIs(v1, v0)
    self.assertIs(v2, v0)
Esempio n. 31
0
 def in_cross_tower(_):
     c = variable_scope.variable(1.0, name="c")
     return c
Esempio n. 32
0
    def create_batch(self):
        """Create queues to window and batch time series data.

    Returns:
      A dictionary of Tensors corresponding to the output of `self._reader`
      (from the `time_series_reader` constructor argument), each with shapes
      prefixed by [`batch_size`, `window_size`].
    """
        features = self._reader.read()
        if self._jitter:
            # TODO(agarwal, allenl): Figure out if more jitter is needed here.
            jitter = random_ops.random_uniform(shape=[],
                                               maxval=2,
                                               dtype=dtypes.int32)
        else:
            jitter = 0
        # To keep things efficient, we pass from the windowing batcher to the
        # batch-of-windows batcher in batches. This avoids the need for huge numbers
        # of threads, but does mean that jitter is only applied occasionally.
        # TODO(allenl): Experiment with different internal passing sizes.
        internal_passing_size = self._batch_size
        features_windowed = input_lib.batch(
            features,
            batch_size=self._window_size * internal_passing_size + jitter,
            enqueue_many=True,
            capacity=(self._queue_capacity_multiplier * internal_passing_size *
                      self._window_size),
            num_threads=self._num_threads)
        raw_features_windowed = features_windowed
        if self._jitter:
            features_windowed = {
                key: value[jitter:]
                for key, value in features_windowed.items()
            }
        features_windowed = {
            key: array_ops.reshape(
                value,
                array_ops.concat([[internal_passing_size, self._window_size],
                                  array_ops.shape(value)[1:]],
                                 axis=0))
            for key, value in features_windowed.items()
        }
        batch_and_window_shape = tensor_shape.TensorShape(
            [internal_passing_size, self._window_size])
        for key in features_windowed.keys():
            features_windowed[key].set_shape(
                batch_and_window_shape.concatenate(
                    raw_features_windowed[key].get_shape()[1:]))
        # When switching files, we may end up with windows where the time is not
        # decreasing, even if times within each file are sorted (and even if those
        # files are visited in order, when looping back around to the beginning of
        # the first file). This is hard for models to deal with, so we either
        # discard such examples, creating a bias where the beginning and end of the
        # series is under-sampled, or we sort the window, creating large gaps.
        times = features_windowed[feature_keys.TrainEvalFeatures.TIMES]
        if self._discard_out_of_order:
            non_decreasing = math_ops.reduce_all(times[:, 1:] >= times[:, :-1],
                                                 axis=1)
            # Ensure that no more than self._discard_limit complete batches are
            # discarded contiguously (resetting the count when we find a single clean
            # window). This prevents infinite looping when the dataset is smaller than
            # the window size.
            # TODO(allenl): Figure out a way to return informative errors from
            # count_up_to.
            discarded_windows_limiter = variable_scope.variable(
                initial_value=constant_op.constant(0, dtype=dtypes.int64),
                name="discarded_windows_limiter",
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES])

            def _initialized_limit_check():
                return control_flow_ops.cond(
                    math_ops.reduce_any(non_decreasing),
                    lambda: state_ops.assign(discarded_windows_limiter, 0),
                    lambda: discarded_windows_limiter.count_up_to(
                        self._discard_limit))

            discard_limit_op = control_flow_ops.cond(
                state_ops.is_variable_initialized(discarded_windows_limiter),
                _initialized_limit_check,
                lambda: constant_op.constant(0, dtype=dtypes.int64))
            with ops.control_dependencies([discard_limit_op]):
                non_decreasing = array_ops.identity(non_decreasing)
        else:
            _, indices_descending = nn.top_k(times,
                                             k=array_ops.shape(times)[-1],
                                             sorted=True)
            indices = array_ops.reverse(indices_descending, axis=[0])
            features_windowed = {
                key: array_ops.gather(params=value, indices=indices)
                for key, value in features_windowed.items()
            }
            non_decreasing = True
        features_batched = input_lib.maybe_shuffle_batch(
            features_windowed,
            num_threads=self._num_threads,
            seed=self._shuffle_seed,
            batch_size=self._batch_size,
            capacity=self._queue_capacity_multiplier * self._batch_size,
            min_after_dequeue=(self._shuffle_min_after_dequeue_multiplier *
                               self._batch_size),
            keep_input=non_decreasing,
            enqueue_many=True)
        return (features_batched, None)
 def __init__(self, loss_scale):
   self._loss_scale = variable_scope.variable(
       name="loss_scale",
       initial_value=loss_scale,
       dtype=dtypes.float32,
       trainable=False)
Esempio n. 34
0
    def _create_hook_ops(self, input_row_indices, input_col_indices,
                         train_ops):
        """Creates ops to update is_row_sweep_var, global_step and completed_sweeps.

    Creates two boolean tensors `processed_rows` and `processed_cols`, which
    keep track of which rows/cols have been processed during the current sweep.
    Returns ops that should be run after each row / col update.
      - When `self._is_row_sweep_var` is True, it sets
        processed_rows[input_row_indices] to True.
      - When `self._is_row_sweep_var` is False, it sets
        processed_cols[input_col_indices] to True.

    Args:
      input_row_indices: A Tensor. The indices of the input rows that are
        processed during the current sweep.
      input_col_indices: A Tensor. The indices of the input columns that
        are processed during the current sweep.
      train_ops: A list of ops. The ops created by this function have control
        dependencies on `train_ops`.

    Returns:
      A tuple consisting of:
        update_op: An op to be run jointly with training. It updates the state
          and increments counters (global step and completed sweeps).
        is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is
          done, i.e. all rows (during a row sweep) or all columns (during a
          column sweep) have been processed.
        switch_op: An op to be run in `self.before_run` when the sweep is done.
    """
        processed_rows_init = array_ops.fill(dims=[self._num_rows],
                                             value=False)
        with ops.colocate_with(processed_rows_init):
            processed_rows = variable_scope.variable(
                processed_rows_init,
                collections=[ops.GraphKeys.GLOBAL_VARIABLES],
                trainable=False,
                name="sweep_hook_processed_rows")
        processed_cols_init = array_ops.fill(dims=[self._num_cols],
                                             value=False)
        with ops.colocate_with(processed_cols_init):
            processed_cols = variable_scope.variable(
                processed_cols_init,
                collections=[ops.GraphKeys.GLOBAL_VARIABLES],
                trainable=False,
                name="sweep_hook_processed_cols")
        switch_ops = control_flow_ops.group(
            state_ops.assign(self._is_row_sweep_var,
                             math_ops.logical_not(self._is_row_sweep_var)),
            state_ops.assign(processed_rows, processed_rows_init),
            state_ops.assign(processed_cols, processed_cols_init))
        is_sweep_done_var = variable_scope.variable(
            False,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES],
            trainable=False,
            name="is_sweep_done")

        # After running the `train_ops`, updates `processed_rows` or
        # `processed_cols` tensors, depending on whether this is a row or col sweep.
        with ops.control_dependencies(train_ops):
            with ops.colocate_with(processed_rows):
                update_processed_rows = state_ops.scatter_update(
                    processed_rows, input_row_indices,
                    math_ops.logical_and(
                        self._is_row_sweep_var,
                        array_ops.ones_like(input_row_indices,
                                            dtype=dtypes.bool)))
            with ops.colocate_with(processed_cols):
                update_processed_cols = state_ops.scatter_update(
                    processed_cols, input_col_indices,
                    math_ops.logical_and(
                        math_ops.logical_not(self._is_row_sweep_var),
                        array_ops.ones_like(input_col_indices,
                                            dtype=dtypes.bool)))
            update_processed_op = control_flow_ops.group(
                update_processed_rows, update_processed_cols)

            with ops.control_dependencies([update_processed_op]):
                is_sweep_done = math_ops.logical_or(
                    math_ops.reduce_all(processed_rows),
                    math_ops.reduce_all(processed_cols))
                # Increments global step.
                global_step = framework_variables.get_global_step()
                if global_step is not None:
                    global_step_incr_op = state_ops.assign_add(
                        global_step, 1, name="global_step_incr").op
                else:
                    global_step_incr_op = control_flow_ops.no_op()
                # Increments completed sweeps.
                completed_sweeps_incr_op = state_ops.assign_add(
                    self._completed_sweeps_var,
                    math_ops.cast(is_sweep_done, dtypes.int32),
                    use_locking=True).op
                update_ops = control_flow_ops.group(
                    global_step_incr_op, completed_sweeps_incr_op,
                    state_ops.assign(is_sweep_done_var, is_sweep_done))

        return update_ops, is_sweep_done_var, switch_ops
 def in_cross_replica(_):
   c = variable_scope.variable(1.0, name="c")
   return c
 def model_fn():
   v_sum = variable_scope.variable(
       1.0,
       synchronization=variable_scope.VariableSynchronization.ON_READ,
       aggregation=variable_scope.VariableAggregation.MEAN)
   return v_sum
Esempio n. 37
0
 def model_fn():
     tower_context = distribute_lib.get_tower_context()
     with tower_context.tower_local_var_scope("sum"):
         v_sum = variable_scope.variable(1.0)
     self.assertTrue(isinstance(v_sum, values.TowerLocalVariable))
     return v_sum
Esempio n. 38
0
 def model_fn(device_id):
     v = variable_scope.variable(1.0, name="foo_" + str(device_id))
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     return v
Esempio n. 39
0
 def model_fn(name):
     v = variable_scope.variable(1.0, name=name)
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     return v
Esempio n. 40
0
 def model_fn():
     vs = []
     for i in range(5):
         vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     return vs
Esempio n. 41
0
 def model_fn():
     # This variable should be created only once across the threads because of
     # special variable_creator functions used by `dist.call_for_each_tower`.
     v = variable_scope.variable(1.0, name="foo")
     distribute_lib.get_tower_context().merge_call(lambda _: _)
     return v
 def model_fn(name):
     v = variable_scope.variable(1.0, name=name)
     ds_context.get_replica_context().merge_call(lambda _: _)
     return v
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """Apply gradients to variables.

    This contains most of the synchronization implementation and also wraps the
    apply_gradients() from the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      train_op: The op to dequeue a token so the replicas can exit this batch
      and start the next one. This is executed by each replica.

    Raises:
      ValueError: If the grads_and_vars is empty.
      ValueError: If global step is not provided, the staleness cannot be
        checked.
    """
        if not grads_and_vars:
            raise ValueError("Must supply at least one variable")

        if global_step is None:
            raise ValueError("Global step is required to check staleness")

        self._global_step = global_step
        train_ops = []
        aggregated_grad = []
        var_list = []

        # local_anchor op will be placed on this worker task by default.
        local_anchor = control_flow_ops.no_op()
        # Colocating local_step variable prevents it being placed on the PS.
        with ops.colocate_with(local_anchor):
            self._local_step = variable_scope.variable(
                initial_value=0,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                dtype=global_step.dtype.base_dtype,
                name="sync_rep_local_step")

        self.local_step_init_op = state_ops.assign(self._local_step,
                                                   global_step)
        chief_init_ops = [self.local_step_init_op]
        self.ready_for_local_init_op = variables.report_uninitialized_variables(
            variables.global_variables())

        with ops.name_scope(None, self._name):
            for grad, var in grads_and_vars:
                var_list.append(var)
                with ops.device(var.device):
                    # Dense gradients.
                    if grad is None:
                        aggregated_grad.append(None)  # pass-through.
                        continue
                    elif isinstance(grad, ops.Tensor):
                        grad_accum = data_flow_ops.ConditionalAccumulator(
                            grad.dtype,
                            shape=var.get_shape(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_grad(grad,
                                                  local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_grad(self._replicas_to_aggregate))
                    else:
                        if not isinstance(grad, ops.IndexedSlices):
                            raise ValueError("Unknown grad type!")
                        grad_accum = data_flow_ops.SparseConditionalAccumulator(
                            grad.dtype,
                            shape=(),
                            shared_name=var.name + "/grad_accum")
                        train_ops.append(
                            grad_accum.apply_indexed_slices_grad(
                                grad, local_step=self._local_step))
                        aggregated_grad.append(
                            grad_accum.take_indexed_slices_grad(
                                self._replicas_to_aggregate))

                    self._accumulator_list.append((grad_accum, var.device))

            aggregated_grads_and_vars = zip(aggregated_grad, var_list)

            # sync_op will be assigned to the same device as the global step.
            with ops.device(global_step.device), ops.name_scope(""):
                update_op = self._opt.apply_gradients(
                    aggregated_grads_and_vars, global_step)

            # Create token queue.
            with ops.device(global_step.device), ops.name_scope(""):
                sync_token_queue = (data_flow_ops.FIFOQueue(
                    -1,
                    global_step.dtype.base_dtype,
                    shapes=(),
                    name="sync_token_q",
                    shared_name="sync_token_q"))
                self._sync_token_queue = sync_token_queue

                # dummy_queue is passed to the queue runner. Don't use the real queues
                # because the queue runner doesn't automatically reopen it once it
                # closed queues in PS devices.
                dummy_queue = (data_flow_ops.FIFOQueue(
                    1,
                    types_pb2.DT_INT32,
                    shapes=(),
                    name="dummy_queue",
                    shared_name="dummy_queue"))

            with ops.device(global_step.device), ops.name_scope(""):
                # Replicas have to wait until they can get a token from the token queue.
                with ops.control_dependencies(train_ops):
                    token = sync_token_queue.dequeue()
                train_op = state_ops.assign(self._local_step, token)

                with ops.control_dependencies([update_op]):
                    # Sync_op needs to insert tokens to the token queue at the end of the
                    # step so the replicas can fetch them to start the next step.
                    tokens = array_ops.fill([self._tokens_per_step],
                                            global_step)
                    sync_op = sync_token_queue.enqueue_many((tokens, ))

                if self._variable_averages is not None:
                    with ops.control_dependencies([sync_op
                                                   ]), ops.name_scope(""):
                        sync_op = self._variable_averages.apply(
                            self._variables_to_average)

                self._chief_queue_runner = queue_runner.QueueRunner(
                    dummy_queue, [sync_op])
            for accum, dev in self._accumulator_list:
                with ops.device(dev):
                    chief_init_ops.append(
                        accum.set_global_step(global_step,
                                              name="SetGlobalStep"))
            self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
            self._gradients_applied = True
            return train_op
 def model_fn():
     vs = []
     for i in range(5):
         vs.append(variable_scope.variable(1.0, name="foo" + str(i)))
     ds_context.get_replica_context().merge_call(lambda _: _)
     return vs
Esempio n. 45
0
def _wals_factorization_model_function(features, labels, mode, params):
  """Model function for the WALSFactorization estimator.

  Args:
    features: Dictionary of features. See WALSMatrixFactorization.
    labels: Must be None.
    mode: A model_fn.ModeKeys object.
    params: Dictionary of parameters containing arguments passed to the
      WALSMatrixFactorization constructor.

  Returns:
    A ModelFnOps object.

  Raises:
    ValueError: If `mode` is not recognized.
  """
  assert labels is None
  use_factors_weights_cache = (params["use_factors_weights_cache_for_training"]
                               and mode == model_fn.ModeKeys.TRAIN)
  use_gramian_cache = (params["use_gramian_cache_for_training"] and
                       mode == model_fn.ModeKeys.TRAIN)
  max_sweeps = params["max_sweeps"]
  model = factorization_ops.WALSModel(
      params["num_rows"],
      params["num_cols"],
      params["embedding_dimension"],
      unobserved_weight=params["unobserved_weight"],
      regularization=params["regularization_coeff"],
      row_init=params["row_init"],
      col_init=params["col_init"],
      num_row_shards=params["num_row_shards"],
      num_col_shards=params["num_col_shards"],
      row_weights=params["row_weights"],
      col_weights=params["col_weights"],
      use_factors_weights_cache=use_factors_weights_cache,
      use_gramian_cache=use_gramian_cache)

  # Get input rows and cols. We either update rows or columns depending on
  # the value of row_sweep, which is maintained using a session hook.
  input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
  input_cols = features[WALSMatrixFactorization.INPUT_COLS]

  # TRAIN mode:
  if mode == model_fn.ModeKeys.TRAIN:
    # Training consists of the following ops (controlled using a SweepHook).
    # Before a row sweep:
    #   row_update_prep_gramian_op
    #   initialize_row_update_op
    # During a row sweep:
    #   update_row_factors_op
    # Before a col sweep:
    #   col_update_prep_gramian_op
    #   initialize_col_update_op
    # During a col sweep:
    #   update_col_factors_op

    is_row_sweep_var = variable_scope.variable(
        True,
        trainable=False,
        name="is_row_sweep",
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    is_sweep_done_var = variable_scope.variable(
        False,
        trainable=False,
        name="is_sweep_done",
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    completed_sweeps_var = variable_scope.variable(
        0,
        trainable=False,
        name=WALSMatrixFactorization.COMPLETED_SWEEPS,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    loss_var = variable_scope.variable(
        0.,
        trainable=False,
        name=WALSMatrixFactorization.LOSS,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    # The root weighted squared error =
    #   \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\)
    rwse_var = variable_scope.variable(
        0.,
        trainable=False,
        name=WALSMatrixFactorization.RWSE,
        collections=[ops.GraphKeys.GLOBAL_VARIABLES])

    summary.scalar("loss", loss_var)
    summary.scalar("root_weighted_squared_error", rwse_var)
    summary.scalar("completed_sweeps", completed_sweeps_var)

    def create_axis_ops(sp_input, num_items, update_fn, axis_name):
      """Creates book-keeping and training ops for a given axis.

      Args:
        sp_input: A SparseTensor corresponding to the row or column batch.
        num_items: An integer, the total number of items of this axis.
        update_fn: A function that takes one argument (`sp_input`), and that
        returns a tuple of
          * new_factors: A float Tensor of the factor values after update.
          * update_op: a TensorFlow op which updates the factors.
          * loss: A float Tensor, the unregularized loss.
          * reg_loss: A float Tensor, the regularization loss.
          * sum_weights: A float Tensor, the sum of factor weights.
        axis_name: A string that specifies the name of the axis.

      Returns:
        A tuple consisting of:
          * reset_processed_items_op: A TensorFlow op, to be run before the
            beginning of any sweep. It marks all items as not-processed.
          * axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
      """
      processed_items_init = array_ops.fill(dims=[num_items], value=False)
      with ops.colocate_with(processed_items_init):
        processed_items = variable_scope.variable(
            processed_items_init,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES],
            trainable=False,
            name="processed_" + axis_name)
      _, update_op, loss, reg, sum_weights = update_fn(sp_input)
      input_indices = sp_input.indices[:, 0]
      with ops.control_dependencies([
          update_op,
          state_ops.assign(loss_var, loss + reg),
          state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights))]):
        with ops.colocate_with(processed_items):
          update_processed_items = state_ops.scatter_update(
              processed_items,
              input_indices,
              array_ops.ones_like(input_indices, dtype=dtypes.bool),
              name="update_processed_{}_indices".format(axis_name))
        with ops.control_dependencies([update_processed_items]):
          is_sweep_done = math_ops.reduce_all(processed_items)
          axis_train_op = control_flow_ops.group(
              state_ops.assign(is_sweep_done_var, is_sweep_done),
              state_ops.assign_add(
                  completed_sweeps_var,
                  math_ops.cast(is_sweep_done, dtypes.int32)),
              name="{}_sweep_train_op".format(axis_name))
      return processed_items.initializer, axis_train_op

    reset_processed_rows_op, row_train_op = create_axis_ops(
        input_rows,
        params["num_rows"],
        lambda x: model.update_row_factors(sp_input=x, transpose_input=False),
        "rows")
    reset_processed_cols_op, col_train_op = create_axis_ops(
        input_cols,
        params["num_cols"],
        lambda x: model.update_col_factors(sp_input=x, transpose_input=True),
        "cols")
    switch_op = control_flow_ops.group(
        state_ops.assign(
            is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
        reset_processed_rows_op,
        reset_processed_cols_op,
        name="sweep_switch_op")
    row_prep_ops = [
        model.row_update_prep_gramian_op, model.initialize_row_update_op]
    col_prep_ops = [
        model.col_update_prep_gramian_op, model.initialize_col_update_op]
    init_op = model.worker_init
    sweep_hook = _SweepHook(
        is_row_sweep_var, is_sweep_done_var, init_op,
        row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op)
    global_step_hook = _IncrementGlobalStepHook()
    training_hooks = [sweep_hook, global_step_hook]
    if max_sweeps is not None:
      training_hooks.append(_StopAtSweepHook(max_sweeps))

    return model_fn.ModelFnOps(
        mode=model_fn.ModeKeys.TRAIN,
        predictions={},
        loss=loss_var,
        eval_metric_ops={},
        train_op=control_flow_ops.no_op(),
        training_hooks=training_hooks)

  # INFER mode
  elif mode == model_fn.ModeKeys.INFER:
    projection_weights = features.get(
        WALSMatrixFactorization.PROJECTION_WEIGHTS)

    def get_row_projection():
      return model.project_row_factors(
          sp_input=input_rows,
          projection_weights=projection_weights,
          transpose_input=False)

    def get_col_projection():
      return model.project_col_factors(
          sp_input=input_cols,
          projection_weights=projection_weights,
          transpose_input=True)

    predictions = {
        WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond(
            features[WALSMatrixFactorization.PROJECT_ROW],
            get_row_projection,
            get_col_projection)
    }

    return model_fn.ModelFnOps(
        mode=model_fn.ModeKeys.INFER,
        predictions=predictions,
        loss=None,
        eval_metric_ops={},
        train_op=control_flow_ops.no_op(),
        training_hooks=[])

  # EVAL mode
  elif mode == model_fn.ModeKeys.EVAL:
    def get_row_loss():
      _, _, loss, reg, _ = model.update_row_factors(
          sp_input=input_rows, transpose_input=False)
      return loss + reg
    def get_col_loss():
      _, _, loss, reg, _ = model.update_col_factors(
          sp_input=input_cols, transpose_input=True)
      return loss + reg
    loss = control_flow_ops.cond(
        features[WALSMatrixFactorization.PROJECT_ROW],
        get_row_loss,
        get_col_loss)
    return model_fn.ModelFnOps(
        mode=model_fn.ModeKeys.EVAL,
        predictions={},
        loss=loss,
        eval_metric_ops={},
        train_op=control_flow_ops.no_op(),
        training_hooks=[])

  else:
    raise ValueError("mode=%s is not recognized." % str(mode))
Esempio n. 46
0
 def model_fn():
     b = variable_scope.variable(1.0, name="b")
     with ops.name_scope("foo"):
         c = distribute_lib.get_tower_context().merge_call(
             in_cross_tower)
     return b, c
 def __init__(self, two_variables=False):
   self.variables = []
   self.variables.append(variable_scope.variable(1.25, name="dummy_var1"))
   if two_variables:
     self.variables.append(variable_scope.variable(2.0, name="dummy_var2"))
Esempio n. 48
0
 def var_fn():
     v = variable_scope.variable(
         1.0,
         name="foo",
         aggregation=variable_scope.VariableAggregation.SUM)
     return v
 def model_fn():
   b = variable_scope.variable(1.0, name="b")
   with ops.name_scope("foo"):
     c = ds_context.get_replica_context().merge_call(in_cross_replica)
   return b, c
Esempio n. 50
0
 def var_fn():
     return variable_scope.variable(1.0, name="foo")
Esempio n. 51
0
 def var_fn():
     return variable_scope.variable(
         1.0,
         name="foo",
         aggregation=variable_scope.VariableAggregation.MEAN)
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        if not grads_and_vars:
            raise ValueError("Must supply at least one variable")

        if global_step is None:
            raise ValueError("Global step is required to check staleness")

        self._global_step = global_step
        train_ops = []
        aggregated_grad = []

        # local_anchor op will be placed on this worker task by default.
        local_anchor = control_flow_ops.no_op()
        # Colocating local_step variable prevents it being placed on the PS.
        with ops.colocate_with(local_anchor):
            self._local_step = variable_scope.variable(
                initial_value=0,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                dtype=global_step.dtype.base_dtype,
                name="local_step")

        self.local_step_init_op = state_ops.assign(self._local_step,
                                                   global_step)
        chief_init_ops = [self.local_step_init_op]
        self.ready_for_local_init_op = variables.report_uninitialized_variables(
            variables.global_variables())

        var_list = [v for g, v in grads_and_vars]
        velocity_list = [self._var_2_velocity[v] for v in var_list]
        residual_list = [self._var_2_residual[v] for v in var_list]

        density = 0.01

        with ops.name_scope(None, self._name):
            for velocity, residual, grad, var in zip(velocity_list,
                                                     residual_list,
                                                     grads_and_vars):
                if grad is not None:
                    if self._use_nesterov:
                        update_velocity = self._momentum * (velocity + grad)
                        update_residual = residual + update_velocity + grad
                    else:
                        update_velocity = self._momentum * velocity + grad
                        update_residual = residual + update_velocity
                else:
                    update_velocity = velocity
                    update_residual = residual

                # select threshold according to abs(update_residual)
                top_k_values, top_k_indices = nn_ops.top_k(
                    math_ops.abs(update_residual),
                    math_ops.to_int32(
                        array_ops.shape(update_residual)[-1] * density))
                threshold = top_k_values[-1]
                mask = math_ops.abs(update_residual) > threshold
                mask = math_ops.cast(mask, dtype=dtypes.int32)
                mask_h = math_ops.abs(mask - 1)

                with ops.device(grad.device):
                    dense_grad = mask * update_residual
                    indices = array_ops.where(math_ops.not_equal(
                        dense_grad, 0))
                    values = array_ops.gather_nd(dense_grad, indices)
                    sparse_grad = ops.IndexedSlices(values, indices,
                                                    dense_grad.get_shape())
                    #grad_update = state_ops.assign(grad, mask * update_residual)

                #with ops.control_dependencies([grad_update]), ops.device(var.device):
                #grad_accum = data_flow_ops.ConditionalAccumulator(
                #grad.dtype, shape=var.get_shape(),
                #shared_name=var.name + "/grad_accum")
                #train_ops.append(grad_accum.apply_grad(grad, local_step=self._local_step))
                #aggregated_grad.append(grad_accum.take_grad(self._replicas_to_aggregate))

                with ops.device(var.device):
                    grad_accum = data_flow_ops.SparseConditionalAccumulator(
                        sparse_grad.dtype,
                        shape=(),
                        shared_name=var.name + "/grad_accum")
                    train_ops.append(
                        grad_accum.apply_indexed_slices_grad(
                            sparse_grad, local_step=self._local_step))
                    aggregated_grad.append(
                        grad_accum.take_indexed_slices_grad(
                            self._replicas_to_aggregate))

                    self._accumulator_list.append((grad_accum, var.device))

                with ops.device(residual.device):
                    train_ops.append(
                        state_ops.assign(residual, mask_h * update_residual))
                with ops.device(velocity.device):
                    train_ops.append(
                        state_ops.assign(velocity, mask_h * update_velocity))

            aggregated_grads_and_vars = zip(aggregated_grad, var_list)

            with ops.device(global_step.device), ops.name_scope(""):
                update_op = self._opt.apply_gradient(aggregated_grads_and_vars,
                                                     global_step)

            with ops.device(global_step.device), ops.name_scope(""):
                sync_token_queue = (data_flow_ops.FIFOQueue(
                    -1,
                    global_step.dtype.base_dtype,
                    shapes=(),
                    name="sync_token_q",
                    shared_name="sync_token_q"))
                self._sync_token_queue = sync_token_queue

                dummy_queue = (data_flow_ops.FIFOQueue(
                    1,
                    types_pb2.DT_INT32,
                    shapes=(),
                    name="dummy_queue",
                    shared_name="dummy_queue"))

                with ops.control_dependencies(train_ops):
                    token = sync_token_queue.dequeue()
                train_op = state_ops.assign(self._local_step, token)

                with ops.control_dependencies([update_op]):
                    tokens = array_ops.fill([self._tokens_per_step],
                                            global_step)
                    sync_op = sync_token_queue.enqueue_many((tokens, ))

                if self._variable_averages is not None:
                    with ops.control_dependencies([sync_op
                                                   ]), ops.name_scope(""):
                        sync_op = self._variable_averages.apply(
                            self._variables_to_average)

                self._chief_queue_runner = queue_runner.QueueRunner(
                    dummy_queue, [sync_op])

            for accum, dev in self._accumulator_list:
                with ops.device(dev):
                    chief_init_ops.append(
                        accum.set_global_step(global_step,
                                              name="SetGlobalStep"))
            self.chief_init_op = control_flow_ops.group(*(chief_init_ops))
            self._gradients_applied = True

            return train_op
Esempio n. 53
0
def _bt_model_fn(
        features,
        labels,
        mode,
        head,
        feature_columns,
        tree_hparams,
        n_batches_per_layer,
        config,
        closed_form_grad_and_hess_fn=None,
        example_id_column_name=None,
        # TODO(youngheek): replace this later using other options.
        train_in_memory=False,
        name='boosted_trees'):
    """Gradient Boosted Trees model_fn.

  Args:
    features: dict of `Tensor`.
    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of
      dtype `int32` or `int64` in the range `[0, n_classes)`.
    mode: Defines whether this is training, evaluation or prediction.
      See `ModeKeys`.
    head: A `head_lib._Head` instance.
    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.
    tree_hparams: TODO. collections.namedtuple for hyper parameters.
    n_batches_per_layer: A `Tensor` of `int64`. Each layer is built after at
      least n_batches_per_layer accumulations.
    config: `RunConfig` object to configure the runtime settings.
    closed_form_grad_and_hess_fn: a function that accepts logits and labels
      and returns gradients and hessians. By default, they are created by
      tf.gradients() from the loss.
    example_id_column_name: Name of the feature for a unique ID per example.
      Currently experimental -- not exposed to public API.
    train_in_memory: `bool`, when true, it assumes the dataset is in memory,
      i.e., input_fn should return the entire dataset as a single batch, and
      also n_batches_per_layer should be set as 1.
    name: Name to use for the model.

  Returns:
      An `EstimatorSpec` instance.

  Raises:
    ValueError: mode or params are invalid, or features has the wrong type.
  """
    is_single_machine = (config.num_worker_replicas <= 1)
    sorted_feature_columns = sorted(feature_columns, key=lambda tc: tc.name)
    center_bias = tree_hparams.center_bias
    if train_in_memory:
        assert n_batches_per_layer == 1, (
            'When train_in_memory is enabled, input_fn should return the entire '
            'dataset as a single batch, and n_batches_per_layer should be set as '
            '1.')
        if (not config.is_chief or config.num_worker_replicas > 1
                or config.num_ps_replicas > 0):
            raise ValueError('train_in_memory is supported only for '
                             'non-distributed training.')
    worker_device = control_flow_ops.no_op().device
    # maximum number of splits possible in the whole tree =2^(D-1)-1
    # TODO(youngheek): perhaps storage could be optimized by storing stats with
    # the dimension max_splits_per_layer, instead of max_splits (for the entire
    # tree).
    max_splits = (1 << tree_hparams.max_depth) - 1
    train_op = []
    with ops.name_scope(name) as name:
        # Prepare.
        global_step = training_util.get_or_create_global_step()
        bucket_size_list, feature_ids_list = _group_features_by_num_buckets(
            sorted_feature_columns)
        # Extract input features and set up cache for training.
        training_state_cache = None
        if mode == model_fn.ModeKeys.TRAIN and train_in_memory:
            # cache transformed features as well for in-memory training.
            batch_size = array_ops.shape(labels)[0]
            input_feature_list, input_cache_op = (_cache_transformed_features(
                features, sorted_feature_columns, batch_size))
            train_op.append(input_cache_op)
            training_state_cache = _CacheTrainingStatesUsingVariables(
                batch_size, head.logits_dimension)
        else:
            input_feature_list = _get_transformed_features(
                features, sorted_feature_columns)
            if mode == model_fn.ModeKeys.TRAIN and example_id_column_name:
                example_ids = features[example_id_column_name]
                training_state_cache = _CacheTrainingStatesUsingHashTable(
                    example_ids, head.logits_dimension)

        # Create Ensemble resources.
        tree_ensemble = boosted_trees_ops.TreeEnsemble(name=name)
        # Variable that determines whether bias centering is needed.
        center_bias_var = variable_scope.variable(initial_value=center_bias,
                                                  name='center_bias_needed',
                                                  trainable=False)
        # Create logits.
        if mode != model_fn.ModeKeys.TRAIN:
            logits = boosted_trees_ops.predict(
                # For non-TRAIN mode, ensemble doesn't change after initialization,
                # so no local copy is needed; using tree_ensemble directly.
                tree_ensemble_handle=tree_ensemble.resource_handle,
                bucketized_features=input_feature_list,
                logits_dimension=head.logits_dimension)
        else:
            if is_single_machine:
                local_tree_ensemble = tree_ensemble
                ensemble_reload = control_flow_ops.no_op()
            else:
                # Have a local copy of ensemble for the distributed setting.
                with ops.device(worker_device):
                    local_tree_ensemble = boosted_trees_ops.TreeEnsemble(
                        name=name + '_local', is_local=True)
                # TODO(soroush): Do partial updates if this becomes a bottleneck.
                ensemble_reload = local_tree_ensemble.deserialize(
                    *tree_ensemble.serialize())

            if training_state_cache:
                cached_tree_ids, cached_node_ids, cached_logits = (
                    training_state_cache.lookup())
            else:
                # Always start from the beginning when no cache is set up.
                batch_size = array_ops.shape(labels)[0]
                cached_tree_ids, cached_node_ids, cached_logits = (
                    array_ops.zeros([batch_size],
                                    dtype=dtypes.int32), _DUMMY_NODE_ID *
                    array_ops.ones([batch_size], dtype=dtypes.int32),
                    array_ops.zeros([batch_size, head.logits_dimension],
                                    dtype=dtypes.float32))

            with ops.control_dependencies([ensemble_reload]):
                (stamp_token, num_trees, num_finalized_trees,
                 num_attempted_layers,
                 last_layer_nodes_range) = local_tree_ensemble.get_states()
                summary.scalar('ensemble/num_trees', num_trees)
                summary.scalar('ensemble/num_finalized_trees',
                               num_finalized_trees)
                summary.scalar('ensemble/num_attempted_layers',
                               num_attempted_layers)

                partial_logits, tree_ids, node_ids = boosted_trees_ops.training_predict(
                    tree_ensemble_handle=local_tree_ensemble.resource_handle,
                    cached_tree_ids=cached_tree_ids,
                    cached_node_ids=cached_node_ids,
                    bucketized_features=input_feature_list,
                    logits_dimension=head.logits_dimension)

            logits = cached_logits + partial_logits

        # Create training graph.
        def _train_op_fn(loss):
            """Run one training iteration."""
            if training_state_cache:
                # Cache logits only after center_bias is complete, if it's in progress.
                train_op.append(
                    control_flow_ops.cond(
                        center_bias_var, control_flow_ops.no_op,
                        lambda: training_state_cache.insert(
                            tree_ids, node_ids, logits)))

            if closed_form_grad_and_hess_fn:
                gradients, hessians = closed_form_grad_and_hess_fn(
                    logits, labels)
            else:
                gradients = gradients_impl.gradients(loss,
                                                     logits,
                                                     name='Gradients')[0]
                hessians = gradients_impl.gradients(gradients,
                                                    logits,
                                                    name='Hessians')[0]

            stats_summaries_list = []
            for i, feature_ids in enumerate(feature_ids_list):
                num_buckets = bucket_size_list[i]
                summaries = [
                    array_ops.squeeze(boosted_trees_ops.make_stats_summary(
                        node_ids=node_ids,
                        gradients=gradients,
                        hessians=hessians,
                        bucketized_features_list=[input_feature_list[f]],
                        max_splits=max_splits,
                        num_buckets=num_buckets),
                                      axis=0) for f in feature_ids
                ]
                stats_summaries_list.append(summaries)

            # ========= Helper methods for both in and not in memory. ==============
            def grow_tree_from_stats_summaries(stats_summaries_list,
                                               feature_ids_list):
                """Updates ensemble based on the best gains from stats summaries."""
                node_ids_per_feature = []
                gains_list = []
                thresholds_list = []
                left_node_contribs_list = []
                right_node_contribs_list = []
                all_feature_ids = []

                assert len(stats_summaries_list) == len(feature_ids_list)

                for i, feature_ids in enumerate(feature_ids_list):
                    (numeric_node_ids_per_feature, numeric_gains_list,
                     numeric_thresholds_list, numeric_left_node_contribs_list,
                     numeric_right_node_contribs_list) = (
                         boosted_trees_ops.calculate_best_gains_per_feature(
                             node_id_range=last_layer_nodes_range,
                             stats_summary_list=stats_summaries_list[i],
                             l1=tree_hparams.l1,
                             l2=tree_hparams.l2,
                             tree_complexity=tree_hparams.tree_complexity,
                             min_node_weight=tree_hparams.min_node_weight,
                             max_splits=max_splits))

                    all_feature_ids += feature_ids
                    node_ids_per_feature += numeric_node_ids_per_feature
                    gains_list += numeric_gains_list
                    thresholds_list += numeric_thresholds_list
                    left_node_contribs_list += numeric_left_node_contribs_list
                    right_node_contribs_list += numeric_right_node_contribs_list

                grow_op = boosted_trees_ops.update_ensemble(
                    # Confirm if local_tree_ensemble or tree_ensemble should be used.
                    tree_ensemble.resource_handle,
                    feature_ids=all_feature_ids,
                    node_ids=node_ids_per_feature,
                    gains=gains_list,
                    thresholds=thresholds_list,
                    left_node_contribs=left_node_contribs_list,
                    right_node_contribs=right_node_contribs_list,
                    learning_rate=tree_hparams.learning_rate,
                    max_depth=tree_hparams.max_depth,
                    pruning_mode=boosted_trees_ops.PruningMode.NO_PRUNING)
                return grow_op

            def _center_bias_fn(mean_gradients, mean_hessians):
                """Updates the ensembles and cache (if needed) with logits prior."""
                continue_centering = boosted_trees_ops.center_bias(
                    tree_ensemble.resource_handle,
                    mean_gradients=mean_gradients,
                    mean_hessians=mean_hessians,
                    l1=tree_hparams.l1,
                    l2=tree_hparams.l2)
                return center_bias_var.assign(continue_centering)

            # ========= End of helper methods. ==============

            if train_in_memory and is_single_machine:
                train_op.append(distribute_lib.increment_var(global_step))

                mean_gradients = array_ops.expand_dims(
                    math_ops.reduce_mean(gradients, 0), 0)
                mean_heassians = array_ops.expand_dims(
                    math_ops.reduce_mean(hessians, 0), 0)

                train_op.append(
                    control_flow_ops.cond(
                        center_bias_var, lambda: _center_bias_fn(
                            mean_gradients, mean_heassians),
                        functools.partial(grow_tree_from_stats_summaries,
                                          stats_summaries_list,
                                          feature_ids_list)))
            else:

                def center_bias_not_in_mem():
                    """Accumulates the data and updates the logits bias, when ready."""
                    bias_dependencies = []

                    bias_accumulator = data_flow_ops.ConditionalAccumulator(
                        dtype=dtypes.float32,
                        # The stats consist of grads and hessians means only.
                        # TODO(nponomareva): this will change for a multiclass
                        shape=[2, 1],
                        shared_name='bias_accumulator')

                    grads_and_hess = array_ops.stack([gradients, hessians],
                                                     axis=0)
                    grads_and_hess = math_ops.reduce_mean(grads_and_hess,
                                                          axis=1)

                    apply_grad = bias_accumulator.apply_grad(
                        grads_and_hess, stamp_token)
                    bias_dependencies.append(apply_grad)

                    def center_bias_from_accumulator():
                        accumulated = array_ops.unstack(
                            bias_accumulator.take_grad(1), axis=0)
                        return _center_bias_fn(
                            array_ops.expand_dims(accumulated[0], 0),
                            array_ops.expand_dims(accumulated[1], 0))

                    with ops.control_dependencies(bias_dependencies):
                        if config.is_chief:
                            center_bias_op = control_flow_ops.cond(
                                math_ops.greater_equal(
                                    bias_accumulator.num_accumulated(),
                                    n_batches_per_layer),
                                center_bias_from_accumulator,
                                control_flow_ops.no_op,
                                name='wait_until_n_batches_for_bias_accumulated'
                            )

                            return center_bias_op
                        else:
                            return control_flow_ops.no_op()

                def grow_not_in_mem():
                    """Accumulates the data and grows a layer when ready."""

                    accumulators = []
                    dependencies = []
                    for i, feature_ids in enumerate(feature_ids_list):
                        stats_summaries = stats_summaries_list[i]
                        accumulator = data_flow_ops.ConditionalAccumulator(
                            dtype=dtypes.float32,
                            # The stats consist of grads and hessians (the last dimension).
                            shape=[
                                len(feature_ids), max_splits,
                                bucket_size_list[i], 2
                            ],
                            shared_name='numeric_stats_summary_accumulator_' +
                            str(i))
                        accumulators.append(accumulator)

                        apply_grad = accumulator.apply_grad(
                            array_ops.stack(stats_summaries, axis=0),
                            stamp_token)
                        dependencies.append(apply_grad)

                    def grow_tree_from_accumulated_summaries_fn():
                        """Updates tree with the best layer from accumulated summaries."""
                        # Take out the accumulated summaries from the accumulator and grow.
                        stats_summaries_list = []

                        stats_summaries_list = [
                            array_ops.unstack(accumulator.take_grad(1), axis=0)
                            for accumulator in accumulators
                        ]

                        grow_op = grow_tree_from_stats_summaries(
                            stats_summaries_list, feature_ids_list)
                        return grow_op

                    with ops.control_dependencies(dependencies):
                        if config.is_chief:
                            min_accumulated = math_ops.reduce_min(
                                array_ops.stack([
                                    acc.num_accumulated()
                                    for acc in accumulators
                                ]))

                            grow_model = control_flow_ops.cond(
                                math_ops.greater_equal(min_accumulated,
                                                       n_batches_per_layer),
                                grow_tree_from_accumulated_summaries_fn,
                                control_flow_ops.no_op,
                                name='wait_until_n_batches_accumulated')

                            return grow_model
                        else:
                            return control_flow_ops.no_op()

                update_model = control_flow_ops.cond(center_bias_var,
                                                     center_bias_not_in_mem,
                                                     grow_not_in_mem)
                train_op.append(update_model)
                with ops.control_dependencies([update_model]):
                    increment_global = distribute_lib.increment_var(
                        global_step)
                    train_op.append(increment_global)

            return control_flow_ops.group(train_op, name='train_op')

    estimator_spec = head.create_estimator_spec(features=features,
                                                mode=mode,
                                                labels=labels,
                                                train_op_fn=_train_op_fn,
                                                logits=logits)
    if mode == model_fn.ModeKeys.TRAIN:
        # Add an early stop hook.
        estimator_spec = estimator_spec._replace(
            training_hooks=estimator_spec.training_hooks +
            (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
                                 tree_hparams.n_trees, tree_hparams.max_depth),
             ))
    return estimator_spec
Esempio n. 54
0
    def _create_switch_ops(self, processed_row_indices, processed_col_indices,
                           train_op):
        """Creates ops to update is_row_sweep_var and to increment global_step.

    Creates two boolean tensors processed_rows and processed_cols, which keep
    track of which rows/cols have been processed during the current sweep.
    Returns ops that should be run after each row / col update.
      - When is_row_sweep_var is True, it sets
        processed_rows[processed_row_indices] to True.
      - When is_row_sweep_var is False, it sets
        processed_cols[processed_col_indices] to True .
    When all rows or all cols have been processed, negates is_row_sweep_var and
    resets processed_rows and processed_cols to False.
    All of the ops created by this function have control_dependencies on
    train_op.

    Args:
      processed_row_indices: A Tensor. The indices of the input rows that are
        processed during the current sweep.
      processed_col_indices: A Tensor. The indices of the input columns that
        are processed during the current sweep.
      train_op: An op. All the ops created by this function have
        control_dependencies on train_op.
    Returns:
      A list consisting of:
        is_sweep_done: A Boolean tensor, determines whether the sweep is done,
          i.e. all rows (during a row sweep) or all columns (during a column
          sweep) have been processed.
        switch_ops: An op that updates is_row_sweep_var when is_sweep_done is
          True. Has control_dependencies on train_op.
        global_step_incr_op: An op that increments the global_step counter. Has
          control_dependenciens on switch_ops.
    """
        processed_rows_init = array_ops.fill(dims=[self._num_rows],
                                             value=False)
        with ops.colocate_with(processed_rows_init):
            processed_rows = variable_scope.variable(
                processed_rows_init,
                collections=[ops.GraphKeys.GLOBAL_VARIABLES],
                trainable=False,
                name="sweep_hook_processed_rows")
        processed_cols_init = array_ops.fill(dims=[self._num_cols],
                                             value=False)
        with ops.colocate_with(processed_cols_init):
            processed_cols = variable_scope.variable(
                processed_cols_init,
                collections=[ops.GraphKeys.GLOBAL_VARIABLES],
                trainable=False,
                name="sweep_hook_processed_cols")
        # After running the train_op, update processed_rows or processed_cols
        # tensors, depending on whether we are currently doing a row or a col sweep
        with ops.control_dependencies([train_op]):

            def get_row_update_op():
                with ops.colocate_with(processed_rows):
                    return state_ops.scatter_update(
                        processed_rows, processed_row_indices,
                        array_ops.ones_like(processed_row_indices,
                                            dtype=dtypes.bool))

            def get_col_update_op():
                with ops.colocate_with(processed_cols):
                    return state_ops.scatter_update(
                        processed_cols, processed_col_indices,
                        array_ops.ones_like(processed_col_indices,
                                            dtype=dtypes.bool))

            update_processed_op = control_flow_ops.cond(
                self._is_row_sweep_var, get_row_update_op, get_col_update_op)

            # After update_processed_op, check whether we have completed a sweep.
            # If this is the case, flip the is_row_sweep_var and reset processed_rows
            # and processed_cols tensors.
            with ops.control_dependencies([update_processed_op]):

                def get_switch_op():
                    return state_ops.assign(
                        self._is_row_sweep_var,
                        gen_math_ops.logical_not(self._is_row_sweep_var)).op

                def get_reset_op():
                    return control_flow_ops.group(
                        state_ops.assign(processed_rows,
                                         processed_rows_init).op,
                        state_ops.assign(processed_cols,
                                         processed_cols_init).op)

                is_sweep_done = control_flow_ops.cond(
                    self._is_row_sweep_var,
                    lambda: math_ops.reduce_all(processed_rows),
                    lambda: math_ops.reduce_all(processed_cols),
                    name="sweep_hook_is_sweep_done")
                switch_op = control_flow_ops.cond(is_sweep_done,
                                                  get_switch_op,
                                                  control_flow_ops.no_op,
                                                  name="sweep_hook_switch_op")
                reset_op = control_flow_ops.cond(is_sweep_done,
                                                 get_reset_op,
                                                 control_flow_ops.no_op,
                                                 name="sweep_hook_reset_op")
                switch_ops = control_flow_ops.group(
                    switch_op, reset_op, name="sweep_hook_switch_ops")

                # Op to increment the global step
                global_step = framework_variables.get_global_step()
                with ops.control_dependencies([switch_ops]):
                    if global_step is not None:
                        global_step_incr_op = state_ops.assign_add(
                            global_step, 1, name="global_step_incr").op
                    else:
                        global_step_incr_op = control_flow_ops.no_op(
                            name="global_step_incr")

        return [is_sweep_done, switch_ops, global_step_incr_op]
 def testInvalidSynchronizationWithVariable(self, distribution):
   with distribution.scope():
     with self.assertRaisesRegex(
         ValueError, "Invalid variable synchronization mode: Invalid for "
         "variable: v"):
       variable_scope.variable(1.0, name="v", synchronization="Invalid")
Esempio n. 56
0
def _wals_factorization_model_function(features, labels, mode, params):
    """Model function for the WALSFactorization estimator.

  Args:
    features: Dictionary of features. See WALSMatrixFactorization.
    labels: Must be None.
    mode: A model_fn.ModeKeys object.
    params: Dictionary of parameters containing arguments passed to the
      WALSMatrixFactorization constructor.

  Returns:
    A ModelFnOps object.
  """
    assert labels is None
    use_factors_weights_cache = (
        params["use_factors_weights_cache_for_training"]
        and mode == model_fn.ModeKeys.TRAIN)
    use_gramian_cache = (params["use_gramian_cache_for_training"]
                         and mode == model_fn.ModeKeys.TRAIN)
    model = factorization_ops.WALSModel(
        params["num_rows"],
        params["num_cols"],
        params["embedding_dimension"],
        unobserved_weight=params["unobserved_weight"],
        regularization=params["regularization_coeff"],
        row_init=params["row_init"],
        col_init=params["col_init"],
        num_row_shards=params["num_row_shards"],
        num_col_shards=params["num_col_shards"],
        row_weights=params["row_weights"],
        col_weights=params["col_weights"],
        use_factors_weights_cache=use_factors_weights_cache,
        use_gramian_cache=use_gramian_cache)

    # Get input rows and cols. We either update rows or columns depending on
    # the value of row_sweep, which is maintained using a session hook
    input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
    input_cols = features[WALSMatrixFactorization.INPUT_COLS]
    input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0])
    input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0])

    # Train ops, controlled using the SweepHook
    # We need to run the following ops:
    # Before a row sweep:
    #   row_update_prep_gramian_op
    #   initialize_row_update_op
    # During a row sweep:
    #   update_row_factors_op
    # Before a col sweep:
    #   col_update_prep_gramian_op
    #   initialize_col_update_op
    # During a col sweep:
    #   update_col_factors_op

    is_row_sweep_var = variable_scope.variable(
        True, "is_row_sweep", collections=[ops.GraphKeys.GLOBAL_VARIABLES])
    # The row sweep is determined by is_row_sweep_var (controlled by the
    # sweep_hook) in TRAIN mode, and manually in EVAL mode.
    is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW]
                    if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var)

    def update_row_factors():
        return model.update_row_factors(sp_input=input_rows,
                                        transpose_input=False)

    def update_col_factors():
        return model.update_col_factors(sp_input=input_cols,
                                        transpose_input=True)

    _, train_op, loss = control_flow_ops.cond(is_row_sweep, update_row_factors,
                                              update_col_factors)

    row_prep_ops = [
        model.row_update_prep_gramian_op, model.initialize_row_update_op
    ]
    col_prep_ops = [
        model.col_update_prep_gramian_op, model.initialize_col_update_op
    ]
    cache_init_ops = [model.worker_init]

    sweep_hook = _SweepHook(
        is_row_sweep_var,
        train_op,
        params["num_rows"],
        params["num_cols"],
        input_row_indices,
        input_col_indices,
        row_prep_ops,
        col_prep_ops,
        cache_init_ops,
    )

    # Prediction ops (only return predictions in INFER mode)
    predictions = {}
    if mode == model_fn.ModeKeys.INFER:
        project_row = features[WALSMatrixFactorization.PROJECT_ROW]
        projection_weights = features.get(
            WALSMatrixFactorization.PROJECTION_WEIGHTS)

        def get_row_projection():
            return model.project_row_factors(
                sp_input=input_rows,
                projection_weights=projection_weights,
                transpose_input=False)

        def get_col_projection():
            return model.project_col_factors(
                sp_input=input_cols,
                projection_weights=projection_weights,
                transpose_input=True)

        predictions[WALSMatrixFactorization.PROJECTION_RESULT] = (
            control_flow_ops.cond(project_row, get_row_projection,
                                  get_col_projection))

    return model_fn.ModelFnOps(mode=mode,
                               predictions=predictions,
                               loss=loss,
                               eval_metric_ops={},
                               train_op=train_op,
                               training_hooks=[sweep_hook])
 def model_fn():
     replica_id = self.evaluate(_replica_id())
     v = variable_scope.variable(1.0, name="foo_" + str(replica_id))
     ds_context.get_replica_context().merge_call(lambda _: _)
     return v
 def _variable_getter(name, shape, dtype, initializer):
   del shape, dtype  # not used, but there for compatibility
   return variable_scope.variable(
       name=name, initial_value=initializer, trainable=False)