def test_local_variable(self):
   with self.test_session() as sess:
     self.assertEquals([], variables_lib.local_variables())
     value0 = 42
     variables_lib2.local_variable(value0)
     value1 = 43
     variables_lib2.local_variable(value1)
     variables = variables_lib.local_variables()
     self.assertEquals(2, len(variables))
     self.assertRaises(errors_impl.OpError, sess.run, variables)
     variables_lib.initialize_variables(variables).run()
     self.assertAllEqual(set([value0, value1]), set(sess.run(variables)))
 def testVars(self):
   classification.f1_score(
       predictions=array_ops.ones((10, 1)),
       labels=array_ops.ones((10, 1)),
       num_thresholds=3)
   expected = {'f1/true_positives:0', 'f1/false_positives:0',
               'f1/false_negatives:0'}
   self.assertEquals(
       expected, set(v.name for v in variables.local_variables()))
   self.assertEquals(
       set(expected), set(v.name for v in variables.local_variables()))
   self.assertEquals(
       set(expected),
       set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
示例#3
0
  def testVariableDevicePlacement(self):
    classes = np.random.randint(5, size=(20000,))  # Uniformly sampled
    target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
    with ops.device(
        device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")):
      dataset = (dataset_ops.Dataset.from_tensor_slices(classes)
                 .shuffle(200, seed=21)
                 .map(lambda c: (c, string_ops.as_string(c))))
      dataset = dataset_ops.rejection_resample(
          dataset, target_dist=target_dist, initial_dist=None,
          class_func=lambda c, _: c, seed=27)

      self.assertEqual(1, len(variables.local_variables()))
      self.assertEqual(b"",
                       compat.as_bytes(variables.local_variables()[0].device))
示例#4
0
 def testNotInLocalVariables(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.model_variable('a', [5])
       self.assertTrue(a in variables_lib.global_variables())
       self.assertTrue(a in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in variables_lib.local_variables())
示例#5
0
def get_epoch_variable():
  """Returns the epoch variable, or [0] if not defined."""
  # Grab epoch variable defined in
  # //third_party/tensorflow/python/training/input.py::limit_epochs
  for v in tf_variables.local_variables():
    if 'limit_epochs/epoch' in v.op.name:
      return array_ops.reshape(v, [1])
  # TODO(thomaswc): Access epoch from the data feeder.
  return [0]
示例#6
0
 def testCreateVariable(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.variable('a', [5])
       self.assertEquals(a.op.name, 'A/a')
       self.assertListEqual(a.get_shape().as_list(), [5])
       self.assertTrue(a in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
       self.assertFalse(a in ops.get_collection(ops.GraphKeys.MODEL_VARIABLES))
       self.assertFalse(a in variables_lib.local_variables())
示例#7
0
def _get_variable_for(v):
  """Returns the ResourceVariable responsible for v, or v if not necessary."""
  if v.op.type == "ResourceGather":
    for var in variables.global_variables() + variables.local_variables():
      if (isinstance(var, resource_variable_ops.ResourceVariable)
          and var.handle is v.op.inputs[0]):
        return var
    raise ValueError("Got embedding lookup %s but"
                     " could not locate source variable." % (str(v)))
  return v
示例#8
0
def _get_variable_for(v):
    """Returns the ResourceVariable responsible for v, or v if not necessary."""
    if v.op.type == "ResourceGather":
        for var in variables.global_variables() + variables.local_variables():
            if (isinstance(var, resource_variable_ops.ResourceVariable)
                    and var.handle is v.op.inputs[0]):
                return var
        raise ValueError("Got embedding lookup %s but"
                         " could not locate source variable." % (str(v)))
    return v
  def run(self,
          num_batches=None,
          graph=None,
          session=None,
          start_queues=True,
          initialize_variables=True,
          **kwargs):
    """Builds and runs the columns of the `DataFrame` and yields batches.

    This is a generator that yields a dictionary mapping column names to
    evaluated columns.

    Args:
      num_batches: the maximum number of batches to produce. If none specified,
        the returned value will iterate through infinite batches.
      graph: the `Graph` in which the `DataFrame` should be built.
      session: the `Session` in which to run the columns of the `DataFrame`.
      start_queues: if true, queues will be started before running and halted
        after producting `n` batches.
      initialize_variables: if true, variables will be initialized.
      **kwargs: Additional keyword arguments e.g. `num_epochs`.

    Yields:
      A dictionary, mapping column names to the values resulting from running
      each column for a single batch.
    """
    if graph is None:
      graph = ops.get_default_graph()
    with graph.as_default():
      if session is None:
        session = sess.Session()
      self_built = self.build(**kwargs)
      keys = list(self_built.keys())
      cols = list(self_built.values())
      if initialize_variables:
        if variables.local_variables():
          session.run(variables.initialize_local_variables())
        if variables.all_variables():
          session.run(variables.initialize_all_variables())
      if start_queues:
        coord = coordinator.Coordinator()
        threads = qr.start_queue_runners(sess=session, coord=coord)
      i = 0
      while num_batches is None or i < num_batches:
        i += 1
        try:
          values = session.run(cols)
          yield collections.OrderedDict(zip(keys, values))
        except errors.OutOfRangeError:
          break
      if start_queues:
        coord.request_stop()
        coord.join(threads)
示例#10
0
 def testLocalVariableNotInVariablesToRestore(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.local_variable(0)
       self.assertFalse(a in variables_lib2.get_variables_to_restore())
       self.assertTrue(a in variables_lib.local_variables())
示例#11
0
 def testLocalVariableNotInAllVariables(self):
   with self.test_session():
     with variable_scope.variable_scope('A'):
       a = variables_lib2.local_variable(0)
       self.assertFalse(a in variables_lib.global_variables())
       self.assertTrue(a in variables_lib.local_variables())