Пример #1
0
  def test_asset_loading(self):
    first_path = self._v1_asset_saved_model()
    imported = load.load(first_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = imported.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
    second_path = os.path.join(self.get_temp_dir(), "saved_model",
                               str(ops.uid()))
    save.save(imported, second_path, signatures=imported.signatures)
    shutil.rmtree(first_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    second_import = load.load(second_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = second_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))

    third_path = os.path.join(self.get_temp_dir(), "saved_model",
                              str(ops.uid()))
    save.save(second_import, third_path, signatures=second_import.signatures)
    shutil.rmtree(second_path)
    del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:]
    third_import = load.load(third_path)
    self.evaluate(lookup_ops.tables_initializer())
    fn = third_import.signatures["serving_default"]
    self.assertAllClose({"output": [2, 0]},
                        fn(start=constant_op.constant(["gamma", "alpha"])))
Пример #2
0
  def _call_func(self, args, kwargs):
    try:
      vars_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES))
      trainable_at_start = len(
          ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES))
      if self._variables_created:
        result = self._func(*args, **kwargs)
      else:
        # The first time we run, restore variables if necessary (via
        # Checkpointable).
        with checkpointable_util.capture_dependencies(template=self):
          result = self._func(*args, **kwargs)

      if self._variables_created:
        # Variables were previously created, implying this is not the first
        # time the template has been called. Check to make sure that no new
        # trainable variables were created this time around.
        trainable_variables = ops.get_collection_ref(
            ops.GraphKeys.TRAINABLE_VARIABLES)
        # If a variable that we intend to train is created as a side effect
        # of creating a template, then that is almost certainly an error.
        if trainable_at_start != len(trainable_variables):
          raise ValueError("Trainable variable created when calling a template "
                           "after the first time, perhaps you used tf.Variable "
                           "when you meant tf.get_variable: %s" %
                           (trainable_variables[trainable_at_start:],))

        # Non-trainable tracking variables are a legitimate reason why a new
        # variable would be created, but it is a relatively advanced use-case,
        # so log it.
        variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)
        if vars_at_start != len(variables):
          logging.info("New variables created when calling a template after "
                       "the first time, perhaps you used tf.Variable when you "
                       "meant tf.get_variable: %s",
                       variables[vars_at_start:])
      else:
        self._variables_created = True
      return result
    except Exception as exc:
      # Reraise the exception, but append the original definition to the
      # trace.
      args = exc.args
      if not args:
        arg0 = ""
      else:
        arg0 = args[0]
      trace = "".join(_skip_common_stack_elements(self._stacktrace,
                                                  traceback.format_stack()))
      arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace)
      new_args = [arg0]
      new_args.extend(args[1:])
      exc.args = tuple(new_args)
      raise
Пример #3
0
def _clear_saved_model_collections():
  """Clear collections that are expected empty when exporting a SavedModel.

  The SavedModel builder uses these collections to track ops necessary to
  restore the graph state. These collections are expected to be empty before
  MetaGraphs are added to the builder.
  """
  del ops.get_collection_ref(constants.ASSETS_KEY)[:]
  del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:]
  del ops.get_collection_ref(constants.MAIN_OP_KEY)[:]
  del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
Пример #4
0
def never_record_summaries():
    """Sets the should_record_summaries Tensor to always false."""
    collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
    old = collection_ref[:]
    collection_ref[:] = [False]
    yield
    collection_ref[:] = old
Пример #5
0
def record_summaries_every_n_global_steps(n):
    """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
    collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
    old = collection_ref[:]
    collection_ref[:] = [training_util.get_global_step() % n == 0]
    yield
    collection_ref[:] = old
Пример #6
0
def apply_mask(x, scope=''):
    """Apply mask to a given weight tensor.

  Args:
    x: Input weight tensor
    scope: The current variable scope. Defaults to "".
  Returns:
    Tensor representing masked_weights
  """

    mask = pruning_utils.weight_mask_variable(x, scope)
    threshold = pruning_utils.weight_threshold_variable(x, scope)
    # Add masked_weights in the weights namescope so as to make it easier
    # for the quantization library to add quant ops.
    masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)

    # Make sure the mask for a given variable are not added multiple times to the
    # collection. This is particularly important when applying mask to RNN's
    # weight variables
    if mask not in ops.get_collection_ref(_MASK_COLLECTION):
        ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
        ops.add_to_collection(_MASK_COLLECTION, mask)
        ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
        ops.add_to_collection(_WEIGHT_COLLECTION, x)
    return masked_weights
def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [training_util.get_global_step() % n == 0]
  yield
  collection_ref[:] = old
Пример #8
0
  def testFromStringHandle(self):
    test_cases = [{
        'shape': tensor_shape.TensorShape([])
    }, {
        'shape': tensor_shape.TensorShape([3])
    }, {
        'shape': tensor_shape.TensorShape([1, 2])
    }, {
        'shape': tensor_shape.TensorShape([1, 2, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
        handle = iterator.string_handle()
        iterator = iterator_ops.Iterator.from_string_handle(
            handle, dtypes.int64, output_shapes=test_case['shape'])
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
Пример #9
0
  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = variables.Variable(10, name='a')
    b = variables.Variable(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0))

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
    graph_size = len(mg.graph_def.node)

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), graph_size + 2)
    self.assertTrue(
        set([node.name for node in graph.node]) > set(
            ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b/read', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b/read', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
Пример #10
0
  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = constant_op.constant(10, name='a')
    b = constant_op.constant(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op.node_def.attr['_swap_to_host'].i = 0

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), 6)
    self.assertItemsEqual([node.name for node in graph.node], [
        'a',
        'b',
        'c',
        'd',
        'swap_in_d_0',
        'swap_out_d_0',
    ])
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
Пример #11
0
  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = variables.Variable(10, name='a')
    b = variables.Variable(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op.node_def.attr['_swap_to_host'].i = 0

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
    graph_size = len(mg.graph_def.node)

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), graph_size + 2)
    self.assertTrue(
        set([node.name for node in graph.node]) > set(
            ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b/read', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b/read', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
Пример #12
0
    def testSimpleSwap(self):
        """Check that the swap annotations are followed."""
        a = constant_op.constant(10, name='a')
        b = constant_op.constant(20, name='b')
        c = math_ops.add_n([a, b], name='c')
        d = math_ops.add_n([b, c], name='d')
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(d)

        d.op.node_def.attr['_swap_to_host'].i = 0

        mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

        rewriter_config = rewriter_config_pb2.RewriterConfig(
            memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
        graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

        self.assertEqual(len(graph.node), 6)
        self.assertItemsEqual([node.name for node in graph.node], [
            'a',
            'b',
            'c',
            'd',
            'swap_in_d_0',
            'swap_out_d_0',
        ])
        for node in graph.node:
            if node.name == 'swap_in_d_0':
                self.assertEqual('swap_out_d_0', node.input[0])
                self.assertEqual('^b', node.input[1])
            elif node.name == 'swap_out_d_0':
                self.assertEqual('b', node.input[0])
            elif node.name == 'd':
                self.assertEqual('swap_in_d_0', node.input[0])
                self.assertEqual('c', node.input[1])
Пример #13
0
  def build(self, inputs_shape):
    # Call the build method of the parent class.
    super(MaskedBasicLSTMCell, self).build(inputs_shape)

    self.built = False

    input_depth = inputs_shape[1].value
    h_depth = self._num_units
    self._mask = self.add_variable(
        name="mask",
        shape=[input_depth + h_depth, 4 * h_depth],
        initializer=init_ops.ones_initializer(),
        trainable=False,
        dtype=self.dtype)
    self._threshold = self.add_variable(
        name="threshold",
        shape=[],
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        dtype=self.dtype)
    # Add masked_weights in the weights namescope so as to make it easier
    # for the quantization library to add quant ops.
    self._masked_kernel = math_ops.multiply(self._mask, self._kernel,
                                            core_layers.MASKED_WEIGHT_NAME)
    if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION):
      ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask)
      ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION,
                            self._masked_kernel)
      ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold)
      ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel)

    self.built = True
Пример #14
0
def never_record_summaries():
  """Sets the should_record_summaries Tensor to always false."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [False]
  yield
  collection_ref[:] = old
Пример #15
0
def apply_mask(x, scope=''):
  """Apply mask to a given weight tensor.

  Args:
    x: Input weight tensor
    scope: The current variable scope. Defaults to ""
  Returns:
    Tensor representing masked_weights
  """

  mask = _weight_mask_variable(x, scope)
  threshold = _weight_threshold_variable(x, scope)
  # Add masked_weights in the weights namescope so as to make it easier
  # for the quantization library to add quant ops.
  masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)

  # Make sure the mask for a given variable are not added multiple times to the
  # collection. This is particularly important when applying mask to RNN's
  # weight variables
  if mask not in ops.get_collection_ref(_MASK_COLLECTION):
    ops.add_to_collection(_THRESHOLD_COLLECTION, threshold)
    ops.add_to_collection(_MASK_COLLECTION, mask)
    ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights)
    ops.add_to_collection(_WEIGHT_COLLECTION, x)
  return masked_weights
Пример #16
0
  def testUpdates(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    initial_tf_item = grappler_item.tf_item
    no_change_tf_item = grappler_item.tf_item
    self.assertEqual(initial_tf_item, no_change_tf_item)

    # Modify the placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    new_tf_item = grappler_item.tf_item
    self.assertNotEqual(initial_tf_item, new_tf_item)

    # Assign the same placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    newest_tf_item = grappler_item.tf_item
    self.assertEqual(new_tf_item, newest_tf_item)
Пример #17
0
  def testPaddedBatch(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([None])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([None, 4])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([None, 2, 4])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
        dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:])
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        inferred_shape = self.as_tensor_shape(
            op_properties['IteratorGetNext'][0].shape)
        self.assertTrue(test_case['shape'].dims[0].is_compatible_with(
            inferred_shape[0]))
        self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
Пример #18
0
  def testBasicMemory(self):
    """Make sure arguments can be passed correctly."""
    with test_util.device(use_gpu=False):
      a = constant_op.constant(10, name="a")
      b = constant_op.constant(20, name="b")
      c = math_ops.add_n([a, b], name="c")
      d = math_ops.add_n([b, c], name="d")
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(d)
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    report = cost_analyzer.GenerateMemoryReport(mg)

    # Print the report to make it easier to debug
    print("{}".format(report))

    # Check the report
    self.assertTrue(
        "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
        "16 bytes"
        in report)
    self.assertTrue("  a:0 uses 4 bytes" in report)
    self.assertTrue("  b:0 uses 4 bytes" in report)
    self.assertTrue("  c:0 uses 4 bytes" in report)
    self.assertTrue("  d:0 uses 4 bytes" in report)
Пример #19
0
  def testFromGenerator(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:

      def make_generator(tensor):

        def generator():
          yield tensor

        return generator

      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_generator(
            make_generator(test_case['tensor']),
            dtypes.int64,
            output_shapes=test_case['shape'])
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
Пример #20
0
  def testInterleave(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.range(42)

        def make_dataset(tensor):

          def dataset_fn(n):
            return dataset_ops.Dataset.from_tensors(tensor).repeat(n)

          return dataset_fn

        dataset = dataset.interleave(
            make_dataset(test_case['tensor']), cycle_length=42)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
Пример #21
0
  def testMap(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([3, 1])
    }, {
        'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
        'shape': tensor_shape.TensorShape([3, 2, 1])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
        dataset = dataset.map(array_ops.transpose)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
Пример #22
0
  def testPruning(self):
    x = constant_op.constant(1)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=x.dtype, element_shape=x.shape)

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5

    def Body(x, tl):
      return x + 1, list_ops.tensor_list_push_back(tl, x)

    outputs = while_loop_v1(Cond, Body, [x, tensor_list])

    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(outputs[0])

    def GetOptimizedGraph():
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
      rewriter_config = rewriter_config_pb2.RewriterConfig(
          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
      return tf_optimizer.OptimizeGraph(rewriter_config, mg)

    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
    train_op.append(stack)
    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
Пример #23
0
  def testVirtualCluster(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      device_properties = device_properties_pb2.DeviceProperties(
          type='GPU',
          frequency=1000,
          num_cores=60,
          environment={
              'architecture': '7'
          })
      named_device = device_properties_pb2.NamedDevice(
          properties=device_properties, name='/GPU:0')
      grappler_cluster = cluster.Cluster(devices=[named_device])
      op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
      self.assertGreater(run_time, 0)
      self.assertEqual(len(op_perfs), 15)

      estimated_perf = grappler_cluster.EstimatePerformance(named_device)
      self.assertEqual(7680.0, estimated_perf)
Пример #24
0
 def capture(self):
     collection_ref = ops.get_collection_ref(
         ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
     original_ops = collection_ref[:]
     collection_ref[:] = []
     yield
     self.collected_ops = collection_ref[:]
     collection_ref[:] = original_ops
Пример #25
0
def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  with ops.device("cpu:0"):
    collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
  yield
  collection_ref[:] = old
Пример #26
0
def _lift_unlifted_variables(graph, variable_holder):
    """Finds resource variables and lifts them into the outer context.

  When we import a GraphDef inside a wrap_function, no Python graph building
  code runs. This means we get VarHandleOps which create variable resources,
  but no corresponding Python objects. Leaving them like this works but gives
  the user no way to interact with or modify the variables outside the graph.

  This method searches for variables and lifts them out as regular variable
  objects when possible, indicating to the FuncGraph that they are captures.

  Args:
    graph: The FuncGraph to lift variables from.
    variable_holder: A VariableHolder to record the lifted variables in.
  """
    with graph.as_default():
        collection_variables = (
            ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
            ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
        existing_captures = set(graph.internal_captures)
        lifted_variables = {}
        for old_variable in collection_variables:
            if (old_variable._in_graph_mode  # pylint: disable=protected-access
                    and isinstance(old_variable,
                                   resource_variable_ops.ResourceVariable)):
                if old_variable.handle in existing_captures:
                    continue
                new_variable = resource_variable_ops.UninitializedVariable(
                    shape=old_variable.shape,
                    dtype=old_variable.dtype,
                    name=old_variable.op.name,
                    trainable=old_variable.trainable,
                    extra_handle_data=old_variable.handle)
                new_variable._initializer_op = old_variable._initializer_op  # pylint: disable=protected-access
                graph.inputs.append(old_variable.handle)
                graph.captures[new_variable.handle] = old_variable.handle
                # Now that we've added the new variable to graph.captures,
                # graph.capture will use that cached value and do some post-processing
                # on the capture like recording it on the tape.
                graph.capture(new_variable.handle)
                existing_captures.add(old_variable.handle)
                lifted_variables[old_variable] = new_variable
                # pylint: disable=protected-access
                variable_name = new_variable.name.split(":")[0]
                variable_holder._variables_by_name[
                    variable_name] = new_variable
                graph._weak_variables.append(weakref.ref(new_variable))
                # pylint: enable=protected-access
                graph.watch_variable(new_variable)
        # Update the graph's collections, partly for the user and partly so this
        # function is idempotent when it runs again in prune() calls.
        for collection_name in [
                ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES
        ]:
            mutable_collection = ops.get_collection_ref(collection_name)
            for index, current in enumerate(mutable_collection):
                mutable_collection[index] = lifted_variables.get(
                    current, current)
def _VRClassRewardGrad(op, *grad):
    collection = ops.get_collection_ref(REWARD)
    for _ in range(len(collection)):
        collection.pop()

    ops.add_to_collection(REWARD, op.outputs[2])

    # learn the baseline reward
    return array_ops.zeros_like(op.inputs[0]), grad[1], None
Пример #28
0
def record_summaries_if(bool_value):
    """Sets the should_record_summaries Tensor to the given boolean value."""
    collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
    old = collection_ref[:]
    try:
        collection_ref[:] = [bool_value]
        yield
    finally:
        collection_ref[:] = old
Пример #29
0
def always_record_summaries():
  """Sets the should_record_summaries Tensor to always true."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  try:
    collection_ref[:] = [True]
    yield
  finally:
    collection_ref[:] = old
Пример #30
0
def always_record_summaries():
  """Sets the should_record_summaries Tensor to always true."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  try:
    collection_ref[:] = [True]
    yield
  finally:
    collection_ref[:] = old
    def _apply_averages():  # pylint: disable=missing-docstring
      # Collect local and global vars
      local_vars = [v for g, v in grads_and_vars if g is not None]
      global_vars = ops.get_collection_ref("global_model")
      # sync queue, place it in the ps
      with ops.colocate_with(self._global_step):
        sync_queue = data_flow_ops.FIFOQueue(
            -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue")
      train_ops = []
      aggregated_vars = []
      with ops.name_scope(None, self._name + "/global"):
        for var, gvar in zip(local_vars, global_vars):
          # pylint: disable=protected-access
          # Get reference to the tensor, this works with Variable and ResourceVariable
          var = ops.convert_to_tensor(var)
          # Place the accumulator in the same ps as the corresponding global_var
          with ops.device(gvar.device):
            var_accum = data_flow_ops.ConditionalAccumulator(
                var.dtype,
                shape=var.get_shape(),
                shared_name=gvar.name + "/var_accum")
            # Add op to push local_var to accumulator
            train_ops.append(
                var_accum.apply_grad(var, local_step=global_step))
            # Op to average the vars in the accumulator
            aggregated_vars.append(var_accum.take_grad(self._replicas_to_aggregate))
            # Remember accumulator and corresponding device
            self._accumulator_list.append((var_accum, gvar.device))
      # chief worker updates global vars and enqueues tokens to the sync queue
      if self._is_chief:
        update_ops = []
        # Make sure train_ops are run
        with ops.control_dependencies(train_ops):
          # Update global_vars with average values
          for avg_var, gvar in zip(aggregated_vars, global_vars):
            with ops.device(gvar.device):
              update_ops.append(state_ops.assign(gvar, avg_var))
          # Update shared global_step
          with ops.device(global_step.device):
            update_ops.append(state_ops.assign_add(self._global_step, 1))
        # After averaging, push tokens to the queue
        with ops.control_dependencies(update_ops), ops.device(
            global_step.device):
          tokens = array_ops.fill([self._tokens_per_step],
                                  constant_op.constant(False))
          sync_op = sync_queue.enqueue_many(tokens)
      # non chief workers deque a token, they will block here until chief is done
      else:
        # Make sure train_ops are run
        with ops.control_dependencies(train_ops), ops.device(
            global_step.device):
          sync_op = sync_queue.dequeue()

      # All workers pull averaged values
      with ops.control_dependencies([sync_op]):
        local_update_op = self._assign_vars(local_vars, global_vars)
      return local_update_op
Пример #32
0
def _add_elements_to_collection(elements, collections):
  elements = _to_list(elements)
  collections = _to_list(collections)
  for name in collections:
    collection = ops.get_collection_ref(name)
    collection_set = set(collection)
    for element in elements:
      if element not in collection_set:
        collection.append(element)
Пример #33
0
def _add_elements_to_collection(elements, collection_list):
    elements = _to_list(elements)
    collection_list = _to_list(collection_list)
    for name in collection_list:
        collection = ops.get_collection_ref(name)
        collection_set = set(collection)
        for element in elements:
            if element not in collection_set:
                collection.append(element)
Пример #34
0
def record_summaries_every_n_global_steps(n, global_step=None):
    """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
    if global_step is None:
        global_step = training_util.get_or_create_global_step()
    collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
    old = collection_ref[:]
    with ops.device("cpu:0"):
        collection_ref[:] = [math_ops.equal(global_step % n, 0)]
    yield
    collection_ref[:] = old
Пример #35
0
    def testPruningNested(self):
        assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE
        x = constant_op.constant(0)

        tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype,
                                                 element_shape=x.shape)

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 25

        def Body(x, tl):
            def InnerCond(inner_x, unused_outer_x, unused_tl):
                return inner_x < 5

            def InnerBody(inner_x, outer_x, tl):
                return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back(
                    tl, x)

            inner_x = constant_op.constant(0)
            return control_flow_ops.while_loop(InnerCond, InnerBody,
                                               [inner_x, x, tl])[1:]

        outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(outputs[0])

        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
        # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is pruned out.
        self.assertEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])
        self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"])
        self.assertEmpty([n for n in g.node if n.op == "_While"])

        stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
        train_op.append(stack)
        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
        # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is not pruned out.
        self.assertNotEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])
        self.assertNotEmpty(
            [n for n in g.node if n.op == "TensorListPushBack"])
Пример #36
0
 def testImportantOps(self):
     with ops.Graph().as_default() as g:
         a = constant_op.constant(10)
         b = constant_op.constant(20)
         c = a + b
         train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
         train_op.append(c)
         mg = meta_graph.create_meta_graph_def(graph=g)
         grappler_item = item.Item(mg)
         op_list = grappler_item.IdentifyImportantOps()
         self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
Пример #37
0
 def layer_with_recompute(inputs, is_recomputing=False):
   kwarg_values.append(is_recomputing)
   out = core_layers.dense(inputs, 2)
   out = normalization_layers.batch_normalization(out, training=True)
   if is_recomputing:
     # Ensure that the updates are not duplicated by popping off the latest
     # 2 additions.
     update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS)
     update_ops.pop()
     update_ops.pop()
   return out
Пример #38
0
 def layer_with_recompute(inputs, is_recomputing=False):
     kwarg_values.append(is_recomputing)
     out = core_layers.dense(inputs, 2)
     out = normalization_layers.batch_normalization(out, training=True)
     if is_recomputing:
         # Ensure that the updates are not duplicated by popping off the latest
         # 2 additions.
         update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS)
         update_ops.pop()
         update_ops.pop()
     return out
Пример #39
0
 def testImportantOps(self):
   with ops.Graph().as_default() as g:
     a = constant_op.constant(10)
     b = constant_op.constant(20)
     c = a + b
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(c)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     op_list = grappler_item.IdentifyImportantOps()
     self.assertEqual([b'Const', b'Const_1', b'add'], op_list)
Пример #40
0
def record_summaries_every_n_global_steps(n, global_step=None):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  if global_step is None:
    global_step = training_util.get_or_create_global_step()
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  try:
    with ops.device("cpu:0"):
      collection_ref[:] = [math_ops.equal(global_step % n, 0)]
    yield
  finally:
    collection_ref[:] = old
Пример #41
0
 def testRange(self):
     with ops.Graph().as_default() as g:
         dataset = dataset_ops.Dataset.range(42)
         iterator = dataset.make_one_shot_iterator()
         get_next = iterator.get_next()
         train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
         train_op.append(get_next)
         mg = meta_graph.create_meta_graph_def(graph=g)
         grappler_item = item.Item(mg)
         op_properties = grappler_item.GetOpProperties()
         self.assertEqual(tensor_shape.scalar(),
                          op_properties['IteratorGetNext'][0].shape)
Пример #42
0
 def testRange(self):
   with ops.Graph().as_default() as g:
     dataset = dataset_ops.Dataset.range(42)
     iterator = dataset.make_one_shot_iterator()
     get_next = iterator.get_next()
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(get_next)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     op_properties = grappler_item.GetOpProperties()
     self.assertEqual(tensor_shape.scalar(),
                      op_properties['IteratorGetNext'][0].shape)
Пример #43
0
    def _lift_unlifted_variables(self):
        """Finds resource variables and lifts them into the outer context.

    When we import a GraphDef inside a wrap_function, no Python graph building
    code runs. This means we get VarHandleOps which create variable resources,
    but no corresponding Python objects. Leaving them like this works but gives
    the user no way to interact with or modify the variables outside the graph.

    This method searches for variables and lifts them out as regular variable
    objects when possible, indicating to the FuncGraph that they are captures.
    """
        with self.graph.as_default():
            collection_variables = (
                ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) +
                ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
            existing_captures = set(self.graph.internal_captures)
            lifted_variables = {}
            for old_variable in collection_variables:
                if (old_variable._in_graph_mode  # pylint: disable=protected-access
                        and
                        isinstance(old_variable,
                                   resource_variable_ops.ResourceVariable)):
                    if old_variable.handle in existing_captures:
                        continue
                    new_variable = def_function.UnliftedInitializerVariable(
                        array_ops.placeholder(
                            name="unused_{}_initializer".format(
                                old_variable.op.name),
                            shape=old_variable.shape,
                            dtype=old_variable.dtype),
                        name=old_variable.op.name,
                        trainable=old_variable.trainable)
                    self.graph.captures[
                        new_variable.handle] = old_variable.handle
                    existing_captures.add(old_variable.handle)
                    lifted_variables[old_variable] = new_variable
                    # pylint: disable=protected-access
                    variable_name = new_variable.name.split(":")[0]
                    self._variable_holder._variables_by_name[
                        variable_name] = new_variable
                    self.graph._weak_variables.append(
                        weakref.ref(new_variable))
                    # pylint: enable=protected-access
            # Update the graph's collections, partly for the user and partly so this
            # function is idempotent when it runs again in prune() calls.
            for collection_name in [
                    ops.GraphKeys.GLOBAL_VARIABLES,
                    ops.GraphKeys.LOCAL_VARIABLES
            ]:
                mutable_collection = ops.get_collection_ref(collection_name)
                for index, current in enumerate(mutable_collection):
                    mutable_collection[index] = lifted_variables.get(
                        current, current)
Пример #44
0
    def _testPruning(self):
        x = constant_op.constant(1)

        tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype,
                                                 element_shape=x.shape)

        def Cond(x, tl):
            del tl  # Unused for Cond.
            return x < 5

        def Body(x, tl):
            return x + 1, list_ops.tensor_list_push_back(tl, x)

        outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list])

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(outputs[0])

        def GetOptimizedGraph():
            mg = meta_graph.create_meta_graph_def(
                graph=ops.get_default_graph())
            config = config_pb2.ConfigProto()
            config.graph_options.rewrite_options.CopyFrom(
                rewriter_config_pb2.RewriterConfig(
                    constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
                    memory_optimization=rewriter_config_pb2.RewriterConfig.
                    MANUAL))
            return tf_optimizer.OptimizeGraph(config, mg)

        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1
        self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is pruned out.
        self.assertEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])

        stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
        train_op.append(stack)
        g = GetOptimizedGraph()
        # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned
        # away, causing an extra Enter node.
        enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2
        self.assertLen([n for n in g.node if n.op == "Enter"], enter_count)
        # Test that the TensorList is not pruned out.
        self.assertNotEmpty([
            n for n in g.node if n.op == "Enter"
            and n.attr["T"].type == dtypes.variant.as_datatype_enum
        ])
Пример #45
0
def _add_elements_to_collection(elements, collection_list):
  if context.executing_eagerly():
    raise RuntimeError('Using collections from Layers not supported in Eager '
                       'mode. Tried to add %s to %s' % (elements,
                                                        collection_list))
  elements = nest.flatten(elements)
  collection_list = nest.flatten(collection_list)
  for name in collection_list:
    collection = ops.get_collection_ref(name)
    collection_set = set(collection)
    for element in elements:
      if element not in collection_set:
        collection.append(element)
Пример #46
0
def _add_elements_to_collection(elements, collection_list):
  if context.executing_eagerly():
    raise RuntimeError('Using collections from Layers not supported in Eager '
                       'mode. Tried to add %s to %s' % (elements,
                                                        collection_list))
  elements = nest.flatten(elements)
  collection_list = nest.flatten(collection_list)
  for name in collection_list:
    collection = ops.get_collection_ref(name)
    collection_set = set(collection)
    for element in elements:
      if element not in collection_set:
        collection.append(element)
        def _apply_averages():  # pylint: disable=missing-docstring
            local_vars = [v for g, v in grads_and_vars if g is not None]
            global_vars = ops.get_collection_ref("global_model")
            # sync queue
            with ops.colocate_with(self._global_step):
                sync_queue = data_flow_ops.FIFOQueue(-1, [dtypes.bool],
                                                     shapes=[[]],
                                                     shared_name="sync_queue")
            train_ops = []
            aggregated_vars = []
            with ops.name_scope(None, self._name + "/global"):
                for var, gvar in zip(local_vars, global_vars):
                    # pylint: disable=protected-access
                    with ops.device(gvar.device):
                        if isinstance(var._ref(), ops.Tensor):
                            var_accum = data_flow_ops.ConditionalAccumulator(
                                var.dtype,
                                shape=var.get_shape(),
                                shared_name=gvar.name + "/var_accum")
                            train_ops.append(
                                var_accum.apply_grad(var._ref(),
                                                     local_step=global_step))
                            aggregated_vars.append(
                                var_accum.take_grad(
                                    self._replicas_to_aggregate))
                        else:
                            raise ValueError("Unknown local variable type!")
                        self._accumulator_list.append((var_accum, gvar.device))
            # chief worker updates global vars and enqueues tokens to the sync queue
            if self._is_chief:
                update_ops = []
                with ops.control_dependencies(train_ops):
                    for avg_var, gvar in zip(aggregated_vars, global_vars):
                        with ops.device(gvar.device):
                            update_ops.append(state_ops.assign(gvar, avg_var))
                    with ops.device(global_step.device):
                        update_ops.append(
                            state_ops.assign_add(self._global_step, 1))
                with ops.control_dependencies(update_ops), ops.device(
                        global_step.device):
                    tokens = array_ops.fill([self._tokens_per_step],
                                            constant_op.constant(False))
                    sync_op = sync_queue.enqueue_many(tokens)
            else:
                with ops.control_dependencies(train_ops), ops.device(
                        global_step.device):
                    sync_op = sync_queue.dequeue()

            with ops.control_dependencies([sync_op]):
                local_update_op = self._assign_vars(local_vars, global_vars)
            return local_update_op
Пример #48
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):

        # Do error checking and create slots. Error checking is copied from base apply_gradients
        grads_and_vars = tuple(
            grads_and_vars)  # Make sure repeat iteration works.
        if not grads_and_vars:
            raise ValueError("No variables provided.")
        converted_grads_and_vars = []
        for g, v in grads_and_vars:
            if g is not None:
                try:
                    # Convert the grad to Tensor or IndexedSlices if necessary.
                    g = ops.convert_to_tensor_or_indexed_slices(g)
                except TypeError:
                    raise TypeError("Gradient must be convertible to a Tensor"
                                    " or IndexedSlices, or None: %s" % g)
                if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
                    raise TypeError(
                        "Gradient must be a Tensor, IndexedSlices, or None: %s"
                        % g)
            converted_grads_and_vars.append((g, v))

        converted_grads_and_vars = tuple(converted_grads_and_vars)
        var_list = [v for g, v in converted_grads_and_vars if g is not None]
        if not var_list:
            raise ValueError("No gradients provided for any variable: %s." %
                             ([str(v)
                               for _, _, v in converted_grads_and_vars], ))
        with ops.control_dependencies(None):
            self._create_slots([_get_variable_for(v) for v in var_list])

        ##### end copypasta code #####

        with ops.name_scope(name, self._name) as name:
            aggregate_op = self._prepare_aggregates(converted_grads_and_vars)

        non_aggregate_updates = super(OptimizerWithAggregates,
                                      self).apply_gradients(
                                          grads_and_vars,
                                          global_step=global_step,
                                          name=name)

        apply_updates = control_flow_ops.group(aggregate_op,
                                               non_aggregate_updates)

        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if aggregate_op not in train_op:
            train_op.append(aggregate_op)

        return apply_updates
 def testColocationContraints(self):
   with ops.Graph().as_default() as g:
     c = constant_op.constant([10])
     v = variables.VariableV1([3], dtype=dtypes.int32)
     i = gen_array_ops.ref_identity(v)
     a = state_ops.assign(i, c)
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(a)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     groups = grappler_item.GetColocationGroups()
     self.assertEqual(len(groups), 1)
     self.assertItemsEqual(
         groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
Пример #50
0
 def testColocationContraints(self):
   with ops.Graph().as_default() as g:
     c = constant_op.constant([10])
     v = variables.Variable([3], dtype=dtypes.int32)
     i = gen_array_ops._ref_identity(v)
     a = state_ops.assign(i, c)
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(a)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     groups = grappler_item.GetColocationGroups()
     self.assertEqual(len(groups), 1)
     self.assertItemsEqual(
         groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
  def testContext(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    with cluster.Provision(
        disable_detailed_stats=False, disable_timeline=False) as gcluster:
      op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
      self.assertTrue(run_time > 0)
      self.assertEqual(len(op_perfs), 4)
      self.assertTrue(step_stats.dev_stats)
Пример #52
0
  def testContext(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    with cluster.Provision(
        disable_detailed_stats=False, disable_timeline=False) as gcluster:
      op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item)
      self.assertTrue(run_time > 0)
      self.assertEqual(len(op_perfs), 7)
      self.assertTrue(step_stats.dev_stats)
Пример #53
0
  def testNoDetailedStats(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      grappler_cluster = cluster.Cluster(disable_detailed_stats=True)

      op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
          grappler_item)
      self.assertTrue(run_time > 0)
      self.assertEqual(len(op_perfs), 0)
      self.assertEqual(len(step_stats.dev_stats), 0)
    def _set_placeholder(cls, placeholder):
        """Sets a `tf.placeholder` to be fed by the first SummarySaverHook.

    If a placeholder is provided, the first instance of SummarySaverHook in use
    will feed it a boolean indicating whether summaries should be written,
    according to the `save_steps` and `save_secs` parameters of that hook. This
    makes the placeholder usable with `tf.contrib.summary.record_summaries_if`
    to control `tf.contrib.summary` summary writing using the same schedule as
    the `tf.summary` summary writing (which the hook controls directly).

    Args:
      placeholder: `tf.placeholder` for the first SummarySaverHook to feed
    """
        collection = ops.get_collection_ref(
            cls._SUMMARY_PLACEHOLDER_COLLECTION)
        collection[:] = [placeholder]
Пример #55
0
    def testNoDetailedStats(self):
        with ops.Graph().as_default() as g:
            a = random_ops.random_uniform(shape=())
            b = random_ops.random_uniform(shape=())
            c = a + b
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(c)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)
            grappler_cluster = cluster.Cluster(disable_detailed_stats=True)

            op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts(
                grappler_item)
            self.assertTrue(run_time > 0)
            self.assertEqual(len(op_perfs), 0)
            self.assertEqual(len(step_stats.dev_stats), 0)
Пример #56
0
    def testSupportDevices(self):
        gpu_type = test_util.gpu_device_type()
        gpu_name = test_util.gpu_device_name()
        with ops.Graph().as_default() as g:
            a = random_ops.random_uniform(shape=(2, 3))
            b = random_ops.random_uniform(shape=(2, 3))
            c = a + b
            dims = math_ops.range(0, array_ops.rank(c), 1)
            d = math_ops.reduce_sum(a, axis=dims)
            train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
            train_op.append(d)
            mg = meta_graph.create_meta_graph_def(graph=g)
            grappler_item = item.Item(mg)

            device_properties = device_properties_pb2.DeviceProperties(
                type=gpu_type, frequency=1000, num_cores=60)
            named_gpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name=gpu_name)
            device_properties = device_properties_pb2.DeviceProperties(
                type='CPU', frequency=3000, num_cores=6)
            named_cpu = device_properties_pb2.NamedDevice(
                properties=device_properties, name='/CPU:0')
            virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu])
            supported_dev = virtual_cluster.GetSupportedDevices(grappler_item)
            self.assertEqual(supported_dev['add'], ['/CPU:0', gpu_name])
            self.assertEqual(supported_dev['Sum'], ['/CPU:0', gpu_name])
            self.assertEqual(supported_dev['range'], ['/CPU:0', gpu_name])

            real_cluster = cluster.Cluster()
            supported_dev = real_cluster.GetSupportedDevices(grappler_item)
            if test.is_gpu_available():
                self.assertEqual(supported_dev['add'], [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0' + gpu_name
                ])
                self.assertEqual(supported_dev['Sum'], [
                    '/job:localhost/replica:0/task:0/device:CPU:0',
                    '/job:localhost/replica:0/task:0' + gpu_name
                ])
                # The axis tensor must reside on the host
                self.assertEqual(
                    supported_dev['range'],
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
            else:
                self.assertEqual(
                    supported_dev['add'],
                    ['/job:localhost/replica:0/task:0/device:CPU:0'])
Пример #57
0
  def testNoSwapping(self):
    """Make sure the graph is preserved when there is nothing to swap."""
    a = constant_op.constant(10, name='a')
    b = constant_op.constant(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)
    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), 4)
    self.assertItemsEqual([node.name
                           for node in graph.node], ['a', 'b', 'c', 'd'])
    def testDebugMode(self):
        """Make sure arguments can be passed correctly."""
        a = constant_op.constant([10, 11], name="a")
        b = constant_op.constant([10], name="b")
        c = math_ops.add(a, b, name="c")
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(c)
        mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

        report = model_analyzer.GenerateModelReport(mg, debug=True)

        # Check the report headers
        self.assertIn(b"input 0 (int32) has known value", report)
        self.assertIn(b"input 1 (int32) has known value", report)

        # Also print the report to make it easier to debug
        print("{}".format(report))
Пример #59
0
  def testOpProperties(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      op_properties = grappler_item.GetOpProperties()

      # All the nodes in this model have one scalar output
      for node in grappler_item.metagraph.graph_def.node:
        node_prop = op_properties[node.name]

        self.assertEqual(1, len(node_prop))
        self.assertEqual(dtypes.int32, node_prop[0].dtype)
        self.assertEqual(tensor_shape.scalar(), node_prop[0].shape)