def test_asset_loading(self): first_path = self._v1_asset_saved_model() imported = load.load(first_path) self.evaluate(lookup_ops.tables_initializer()) fn = imported.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"]))) second_path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) save.save(imported, second_path, signatures=imported.signatures) shutil.rmtree(first_path) del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:] second_import = load.load(second_path) self.evaluate(lookup_ops.tables_initializer()) fn = second_import.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"]))) third_path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid())) save.save(second_import, third_path, signatures=second_import.signatures) shutil.rmtree(second_path) del ops.get_collection_ref(ops.GraphKeys.TABLE_INITIALIZERS)[:] third_import = load.load(third_path) self.evaluate(lookup_ops.tables_initializer()) fn = third_import.signatures["serving_default"] self.assertAllClose({"output": [2, 0]}, fn(start=constant_op.constant(["gamma", "alpha"])))
def _call_func(self, args, kwargs): try: vars_at_start = len( ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES)) trainable_at_start = len( ops.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)) if self._variables_created: result = self._func(*args, **kwargs) else: # The first time we run, restore variables if necessary (via # Checkpointable). with checkpointable_util.capture_dependencies(template=self): result = self._func(*args, **kwargs) if self._variables_created: # Variables were previously created, implying this is not the first # time the template has been called. Check to make sure that no new # trainable variables were created this time around. trainable_variables = ops.get_collection_ref( ops.GraphKeys.TRAINABLE_VARIABLES) # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if trainable_at_start != len(trainable_variables): raise ValueError("Trainable variable created when calling a template " "after the first time, perhaps you used tf.Variable " "when you meant tf.get_variable: %s" % (trainable_variables[trainable_at_start:],)) # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. variables = ops.get_collection_ref(ops.GraphKeys.GLOBAL_VARIABLES) if vars_at_start != len(variables): logging.info("New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " "meant tf.get_variable: %s", variables[vars_at_start:]) else: self._variables_created = True return result except Exception as exc: # Reraise the exception, but append the original definition to the # trace. args = exc.args if not args: arg0 = "" else: arg0 = args[0] trace = "".join(_skip_common_stack_elements(self._stacktrace, traceback.format_stack())) arg0 = "%s\n\noriginally defined at:\n%s" % (arg0, trace) new_args = [arg0] new_args.extend(args[1:]) exc.args = tuple(new_args) raise
def _clear_saved_model_collections(): """Clear collections that are expected empty when exporting a SavedModel. The SavedModel builder uses these collections to track ops necessary to restore the graph state. These collections are expected to be empty before MetaGraphs are added to the builder. """ del ops.get_collection_ref(constants.ASSETS_KEY)[:] del ops.get_collection_ref(constants.LEGACY_INIT_OP_KEY)[:] del ops.get_collection_ref(constants.MAIN_OP_KEY)[:] del ops.get_collection_ref(constants.TRAIN_OP_KEY)[:]
def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] collection_ref[:] = [False] yield collection_ref[:] = old
def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] collection_ref[:] = [training_util.get_global_step() % n == 0] yield collection_ref[:] = old
def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "". Returns: Tensor representing masked_weights """ mask = pruning_utils.weight_mask_variable(x, scope) threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in ops.get_collection_ref(_MASK_COLLECTION): ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) ops.add_to_collection(_MASK_COLLECTION, mask) ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) ops.add_to_collection(_WEIGHT_COLLECTION, x) return masked_weights
def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] collection_ref[:] = [training_util.get_global_step() % n == 0] yield collection_ref[:] = old
def testFromStringHandle(self): test_cases = [{ 'shape': tensor_shape.TensorShape([]) }, { 'shape': tensor_shape.TensorShape([3]) }, { 'shape': tensor_shape.TensorShape([1, 2]) }, { 'shape': tensor_shape.TensorShape([1, 2, 3]) }] for test_case in test_cases: with ops.Graph().as_default() as g: iterator = iterator_ops.Iterator.from_structure(dtypes.int64) handle = iterator.string_handle() iterator = iterator_ops.Iterator.from_string_handle( handle, dtypes.int64, output_shapes=test_case['shape']) get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(test_case['shape'], op_properties['IteratorGetNext'][0].shape)
def testSimpleSwap(self): """Check that the swap annotations are followed.""" a = variables.Variable(10, name='a') b = variables.Variable(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) d.op._set_attr('_swap_to_host', attr_value_pb2.AttrValue(i=0)) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) graph_size = len(mg.graph_def.node) rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size + 2) self.assertTrue( set([node.name for node in graph.node]) > set( ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0'])) for node in graph.node: if node.name == 'swap_in_d_0': self.assertEqual('swap_out_d_0', node.input[0]) self.assertEqual('^b/read', node.input[1]) elif node.name == 'swap_out_d_0': self.assertEqual('b/read', node.input[0]) elif node.name == 'd': self.assertEqual('swap_in_d_0', node.input[0]) self.assertEqual('c', node.input[1])
def testSimpleSwap(self): """Check that the swap annotations are followed.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) d.op.node_def.attr['_swap_to_host'].i = 0 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), 6) self.assertItemsEqual([node.name for node in graph.node], [ 'a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0', ]) for node in graph.node: if node.name == 'swap_in_d_0': self.assertEqual('swap_out_d_0', node.input[0]) self.assertEqual('^b', node.input[1]) elif node.name == 'swap_out_d_0': self.assertEqual('b', node.input[0]) elif node.name == 'd': self.assertEqual('swap_in_d_0', node.input[0]) self.assertEqual('c', node.input[1])
def testSimpleSwap(self): """Check that the swap annotations are followed.""" a = variables.Variable(10, name='a') b = variables.Variable(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) d.op.node_def.attr['_swap_to_host'].i = 0 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) graph_size = len(mg.graph_def.node) rewriter_config = rewriter_config_pb2.RewriterConfig( disable_model_pruning=True, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), graph_size + 2) self.assertTrue( set([node.name for node in graph.node]) > set( ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0'])) for node in graph.node: if node.name == 'swap_in_d_0': self.assertEqual('swap_out_d_0', node.input[0]) self.assertEqual('^b/read', node.input[1]) elif node.name == 'swap_out_d_0': self.assertEqual('b/read', node.input[0]) elif node.name == 'd': self.assertEqual('swap_in_d_0', node.input[0]) self.assertEqual('c', node.input[1])
def testSimpleSwap(self): """Check that the swap annotations are followed.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) d.op.node_def.attr['_swap_to_host'].i = 0 mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), 6) self.assertItemsEqual([node.name for node in graph.node], [ 'a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0', ]) for node in graph.node: if node.name == 'swap_in_d_0': self.assertEqual('swap_out_d_0', node.input[0]) self.assertEqual('^b', node.input[1]) elif node.name == 'swap_out_d_0': self.assertEqual('b', node.input[0]) elif node.name == 'd': self.assertEqual('swap_in_d_0', node.input[0]) self.assertEqual('c', node.input[1])
def build(self, inputs_shape): # Call the build method of the parent class. super(MaskedBasicLSTMCell, self).build(inputs_shape) self.built = False input_depth = inputs_shape[1].value h_depth = self._num_units self._mask = self.add_variable( name="mask", shape=[input_depth + h_depth, 4 * h_depth], initializer=init_ops.ones_initializer(), trainable=False, dtype=self.dtype) self._threshold = self.add_variable( name="threshold", shape=[], initializer=init_ops.zeros_initializer(), trainable=False, dtype=self.dtype) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. self._masked_kernel = math_ops.multiply(self._mask, self._kernel, core_layers.MASKED_WEIGHT_NAME) if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, self._masked_kernel) ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) self.built = True
def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] collection_ref[:] = [False] yield collection_ref[:] = old
def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor scope: The current variable scope. Defaults to "" Returns: Tensor representing masked_weights """ mask = _weight_mask_variable(x, scope) threshold = _weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) # Make sure the mask for a given variable are not added multiple times to the # collection. This is particularly important when applying mask to RNN's # weight variables if mask not in ops.get_collection_ref(_MASK_COLLECTION): ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) ops.add_to_collection(_MASK_COLLECTION, mask) ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) ops.add_to_collection(_WEIGHT_COLLECTION, x) return masked_weights
def testUpdates(self): with ops.Graph().as_default() as g: a = constant_op.constant(10) b = constant_op.constant(20) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) initial_tf_item = grappler_item.tf_item no_change_tf_item = grappler_item.tf_item self.assertEqual(initial_tf_item, no_change_tf_item) # Modify the placement. for node in grappler_item.metagraph.graph_def.node: node.device = '/cpu:0' new_tf_item = grappler_item.tf_item self.assertNotEqual(initial_tf_item, new_tf_item) # Assign the same placement. for node in grappler_item.metagraph.graph_def.node: node.device = '/cpu:0' newest_tf_item = grappler_item.tf_item self.assertEqual(new_tf_item, newest_tf_item)
def testPaddedBatch(self): test_cases = [{ 'tensor': 0, 'shape': tensor_shape.TensorShape([None]) }, { 'tensor': np.array([1, 2, 3]), 'shape': tensor_shape.TensorShape([None, 4]) }, { 'tensor': np.array([[1, 2, 3]]), 'shape': tensor_shape.TensorShape([None, 2, 4]) }] for test_case in test_cases: with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:]) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() inferred_shape = self.as_tensor_shape( op_properties['IteratorGetNext'][0].shape) self.assertTrue(test_case['shape'].dims[0].is_compatible_with( inferred_shape[0])) self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
def testBasicMemory(self): """Make sure arguments can be passed correctly.""" with test_util.device(use_gpu=False): a = constant_op.constant(10, name="a") b = constant_op.constant(20, name="b") c = math_ops.add_n([a, b], name="c") d = math_ops.add_n([b, c], name="d") train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) report = cost_analyzer.GenerateMemoryReport(mg) # Print the report to make it easier to debug print("{}".format(report)) # Check the report self.assertTrue( "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: " "16 bytes" in report) self.assertTrue(" a:0 uses 4 bytes" in report) self.assertTrue(" b:0 uses 4 bytes" in report) self.assertTrue(" c:0 uses 4 bytes" in report) self.assertTrue(" d:0 uses 4 bytes" in report)
def testFromGenerator(self): test_cases = [{ 'tensor': 0, 'shape': tensor_shape.TensorShape([]) }, { 'tensor': np.array([1, 2, 3]), 'shape': tensor_shape.TensorShape([3]) }, { 'tensor': np.array([[1, 2, 3]]), 'shape': tensor_shape.TensorShape([1, 3]) }] for test_case in test_cases: def make_generator(tensor): def generator(): yield tensor return generator with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.from_generator( make_generator(test_case['tensor']), dtypes.int64, output_shapes=test_case['shape']) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(test_case['shape'], op_properties['IteratorGetNext'][0].shape)
def testInterleave(self): test_cases = [{ 'tensor': 0, 'shape': tensor_shape.TensorShape([]) }, { 'tensor': np.array([1, 2, 3]), 'shape': tensor_shape.TensorShape([3]) }, { 'tensor': np.array([[1, 2, 3]]), 'shape': tensor_shape.TensorShape([1, 3]) }] for test_case in test_cases: with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(42) def make_dataset(tensor): def dataset_fn(n): return dataset_ops.Dataset.from_tensors(tensor).repeat(n) return dataset_fn dataset = dataset.interleave( make_dataset(test_case['tensor']), cycle_length=42) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(test_case['shape'], op_properties['IteratorGetNext'][0].shape)
def testMap(self): test_cases = [{ 'tensor': 0, 'shape': tensor_shape.TensorShape([]) }, { 'tensor': np.array([1, 2, 3]), 'shape': tensor_shape.TensorShape([3]) }, { 'tensor': np.array([[1, 2, 3]]), 'shape': tensor_shape.TensorShape([3, 1]) }, { 'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]), 'shape': tensor_shape.TensorShape([3, 2, 1]) }] for test_case in test_cases: with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.from_tensors(test_case['tensor']) dataset = dataset.map(array_ops.transpose) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(test_case['shape'], op_properties['IteratorGetNext'][0].shape)
def testPruning(self): x = constant_op.constant(1) tensor_list = list_ops.empty_tensor_list( element_dtype=x.dtype, element_shape=x.shape) def Cond(x, tl): del tl # Unused for Cond. return x < 5 def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x) outputs = while_loop_v1(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) def GetOptimizedGraph(): mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) return tf_optimizer.OptimizeGraph(rewriter_config, mg) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
def testVirtualCluster(self): with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=()) b = random_ops.random_uniform(shape=()) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) device_properties = device_properties_pb2.DeviceProperties( type='GPU', frequency=1000, num_cores=60, environment={ 'architecture': '7' }) named_device = device_properties_pb2.NamedDevice( properties=device_properties, name='/GPU:0') grappler_cluster = cluster.Cluster(devices=[named_device]) op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item) self.assertGreater(run_time, 0) self.assertEqual(len(op_perfs), 15) estimated_perf = grappler_cluster.EstimatePerformance(named_device) self.assertEqual(7680.0, estimated_perf)
def capture(self): collection_ref = ops.get_collection_ref( ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access original_ops = collection_ref[:] collection_ref[:] = [] yield self.collected_ops = collection_ref[:] collection_ref[:] = original_ops
def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] with ops.device("cpu:0"): collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] yield collection_ref[:] = old
def _lift_unlifted_variables(graph, variable_holder): """Finds resource variables and lifts them into the outer context. When we import a GraphDef inside a wrap_function, no Python graph building code runs. This means we get VarHandleOps which create variable resources, but no corresponding Python objects. Leaving them like this works but gives the user no way to interact with or modify the variables outside the graph. This method searches for variables and lifts them out as regular variable objects when possible, indicating to the FuncGraph that they are captures. Args: graph: The FuncGraph to lift variables from. variable_holder: A VariableHolder to record the lifted variables in. """ with graph.as_default(): collection_variables = ( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) existing_captures = set(graph.internal_captures) lifted_variables = {} for old_variable in collection_variables: if (old_variable._in_graph_mode # pylint: disable=protected-access and isinstance(old_variable, resource_variable_ops.ResourceVariable)): if old_variable.handle in existing_captures: continue new_variable = resource_variable_ops.UninitializedVariable( shape=old_variable.shape, dtype=old_variable.dtype, name=old_variable.op.name, trainable=old_variable.trainable, extra_handle_data=old_variable.handle) new_variable._initializer_op = old_variable._initializer_op # pylint: disable=protected-access graph.inputs.append(old_variable.handle) graph.captures[new_variable.handle] = old_variable.handle # Now that we've added the new variable to graph.captures, # graph.capture will use that cached value and do some post-processing # on the capture like recording it on the tape. graph.capture(new_variable.handle) existing_captures.add(old_variable.handle) lifted_variables[old_variable] = new_variable # pylint: disable=protected-access variable_name = new_variable.name.split(":")[0] variable_holder._variables_by_name[ variable_name] = new_variable graph._weak_variables.append(weakref.ref(new_variable)) # pylint: enable=protected-access graph.watch_variable(new_variable) # Update the graph's collections, partly for the user and partly so this # function is idempotent when it runs again in prune() calls. for collection_name in [ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES ]: mutable_collection = ops.get_collection_ref(collection_name) for index, current in enumerate(mutable_collection): mutable_collection[index] = lifted_variables.get( current, current)
def _VRClassRewardGrad(op, *grad): collection = ops.get_collection_ref(REWARD) for _ in range(len(collection)): collection.pop() ops.add_to_collection(REWARD, op.outputs[2]) # learn the baseline reward return array_ops.zeros_like(op.inputs[0]), grad[1], None
def record_summaries_if(bool_value): """Sets the should_record_summaries Tensor to the given boolean value.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] try: collection_ref[:] = [bool_value] yield finally: collection_ref[:] = old
def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] try: collection_ref[:] = [True] yield finally: collection_ref[:] = old
def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] try: collection_ref[:] = [True] yield finally: collection_ref[:] = old
def _apply_averages(): # pylint: disable=missing-docstring # Collect local and global vars local_vars = [v for g, v in grads_and_vars if g is not None] global_vars = ops.get_collection_ref("global_model") # sync queue, place it in the ps with ops.colocate_with(self._global_step): sync_queue = data_flow_ops.FIFOQueue( -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue") train_ops = [] aggregated_vars = [] with ops.name_scope(None, self._name + "/global"): for var, gvar in zip(local_vars, global_vars): # pylint: disable=protected-access # Get reference to the tensor, this works with Variable and ResourceVariable var = ops.convert_to_tensor(var) # Place the accumulator in the same ps as the corresponding global_var with ops.device(gvar.device): var_accum = data_flow_ops.ConditionalAccumulator( var.dtype, shape=var.get_shape(), shared_name=gvar.name + "/var_accum") # Add op to push local_var to accumulator train_ops.append( var_accum.apply_grad(var, local_step=global_step)) # Op to average the vars in the accumulator aggregated_vars.append(var_accum.take_grad(self._replicas_to_aggregate)) # Remember accumulator and corresponding device self._accumulator_list.append((var_accum, gvar.device)) # chief worker updates global vars and enqueues tokens to the sync queue if self._is_chief: update_ops = [] # Make sure train_ops are run with ops.control_dependencies(train_ops): # Update global_vars with average values for avg_var, gvar in zip(aggregated_vars, global_vars): with ops.device(gvar.device): update_ops.append(state_ops.assign(gvar, avg_var)) # Update shared global_step with ops.device(global_step.device): update_ops.append(state_ops.assign_add(self._global_step, 1)) # After averaging, push tokens to the queue with ops.control_dependencies(update_ops), ops.device( global_step.device): tokens = array_ops.fill([self._tokens_per_step], constant_op.constant(False)) sync_op = sync_queue.enqueue_many(tokens) # non chief workers deque a token, they will block here until chief is done else: # Make sure train_ops are run with ops.control_dependencies(train_ops), ops.device( global_step.device): sync_op = sync_queue.dequeue() # All workers pull averaged values with ops.control_dependencies([sync_op]): local_update_op = self._assign_vars(local_vars, global_vars) return local_update_op
def _add_elements_to_collection(elements, collections): elements = _to_list(elements) collections = _to_list(collections) for name in collections: collection = ops.get_collection_ref(name) collection_set = set(collection) for element in elements: if element not in collection_set: collection.append(element)
def _add_elements_to_collection(elements, collection_list): elements = _to_list(elements) collection_list = _to_list(collection_list) for name in collection_list: collection = ops.get_collection_ref(name) collection_set = set(collection) for element in elements: if element not in collection_set: collection.append(element)
def record_summaries_every_n_global_steps(n, global_step=None): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" if global_step is None: global_step = training_util.get_or_create_global_step() collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] with ops.device("cpu:0"): collection_ref[:] = [math_ops.equal(global_step % n, 0)] yield collection_ref[:] = old
def testPruningNested(self): assert control_flow_util_v2._EXPERIMENTAL_OUTPUT_ALL_INTERMEDIATES_OVERRIDE x = constant_op.constant(0) tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype, element_shape=x.shape) def Cond(x, tl): del tl # Unused for Cond. return x < 25 def Body(x, tl): def InnerCond(inner_x, unused_outer_x, unused_tl): return inner_x < 5 def InnerBody(inner_x, outer_x, tl): return inner_x + 1, outer_x + 1, list_ops.tensor_list_push_back( tl, x) inner_x = constant_op.constant(0) return control_flow_ops.while_loop(InnerCond, InnerBody, [inner_x, x, tl])[1:] outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) g = GetOptimizedGraph() # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned # away, causing an extra Enter node. # enter_count = 4 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) # Test that the TensorList is pruned out. self.assertEmpty([ n for n in g.node if n.op == "Enter" and n.attr["T"].type == dtypes.variant.as_datatype_enum ]) self.assertEmpty([n for n in g.node if n.op == "TensorListPushBack"]) self.assertEmpty([n for n in g.node if n.op == "_While"]) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned # away, causing an extra Enter node. # enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 # self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) # Test that the TensorList is not pruned out. self.assertNotEmpty([ n for n in g.node if n.op == "Enter" and n.attr["T"].type == dtypes.variant.as_datatype_enum ]) self.assertNotEmpty( [n for n in g.node if n.op == "TensorListPushBack"])
def testImportantOps(self): with ops.Graph().as_default() as g: a = constant_op.constant(10) b = constant_op.constant(20) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_list = grappler_item.IdentifyImportantOps() self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list)
def layer_with_recompute(inputs, is_recomputing=False): kwarg_values.append(is_recomputing) out = core_layers.dense(inputs, 2) out = normalization_layers.batch_normalization(out, training=True) if is_recomputing: # Ensure that the updates are not duplicated by popping off the latest # 2 additions. update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) update_ops.pop() update_ops.pop() return out
def layer_with_recompute(inputs, is_recomputing=False): kwarg_values.append(is_recomputing) out = core_layers.dense(inputs, 2) out = normalization_layers.batch_normalization(out, training=True) if is_recomputing: # Ensure that the updates are not duplicated by popping off the latest # 2 additions. update_ops = ops.get_collection_ref(ops.GraphKeys.UPDATE_OPS) update_ops.pop() update_ops.pop() return out
def testImportantOps(self): with ops.Graph().as_default() as g: a = constant_op.constant(10) b = constant_op.constant(20) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_list = grappler_item.IdentifyImportantOps() self.assertEqual([b'Const', b'Const_1', b'add'], op_list)
def record_summaries_every_n_global_steps(n, global_step=None): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" if global_step is None: global_step = training_util.get_or_create_global_step() collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) old = collection_ref[:] try: with ops.device("cpu:0"): collection_ref[:] = [math_ops.equal(global_step % n, 0)] yield finally: collection_ref[:] = old
def testRange(self): with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(42) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(tensor_shape.scalar(), op_properties['IteratorGetNext'][0].shape)
def testRange(self): with ops.Graph().as_default() as g: dataset = dataset_ops.Dataset.range(42) iterator = dataset.make_one_shot_iterator() get_next = iterator.get_next() train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(get_next) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() self.assertEqual(tensor_shape.scalar(), op_properties['IteratorGetNext'][0].shape)
def _lift_unlifted_variables(self): """Finds resource variables and lifts them into the outer context. When we import a GraphDef inside a wrap_function, no Python graph building code runs. This means we get VarHandleOps which create variable resources, but no corresponding Python objects. Leaving them like this works but gives the user no way to interact with or modify the variables outside the graph. This method searches for variables and lifts them out as regular variable objects when possible, indicating to the FuncGraph that they are captures. """ with self.graph.as_default(): collection_variables = ( ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)) existing_captures = set(self.graph.internal_captures) lifted_variables = {} for old_variable in collection_variables: if (old_variable._in_graph_mode # pylint: disable=protected-access and isinstance(old_variable, resource_variable_ops.ResourceVariable)): if old_variable.handle in existing_captures: continue new_variable = def_function.UnliftedInitializerVariable( array_ops.placeholder( name="unused_{}_initializer".format( old_variable.op.name), shape=old_variable.shape, dtype=old_variable.dtype), name=old_variable.op.name, trainable=old_variable.trainable) self.graph.captures[ new_variable.handle] = old_variable.handle existing_captures.add(old_variable.handle) lifted_variables[old_variable] = new_variable # pylint: disable=protected-access variable_name = new_variable.name.split(":")[0] self._variable_holder._variables_by_name[ variable_name] = new_variable self.graph._weak_variables.append( weakref.ref(new_variable)) # pylint: enable=protected-access # Update the graph's collections, partly for the user and partly so this # function is idempotent when it runs again in prune() calls. for collection_name in [ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES ]: mutable_collection = ops.get_collection_ref(collection_name) for index, current in enumerate(mutable_collection): mutable_collection[index] = lifted_variables.get( current, current)
def _testPruning(self): x = constant_op.constant(1) tensor_list = list_ops.empty_tensor_list(element_dtype=x.dtype, element_shape=x.shape) def Cond(x, tl): del tl # Unused for Cond. return x < 5 def Body(x, tl): return x + 1, list_ops.tensor_list_push_back(tl, x) outputs = control_flow_ops.while_loop(Cond, Body, [x, tensor_list]) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(outputs[0]) def GetOptimizedGraph(): mg = meta_graph.create_meta_graph_def( graph=ops.get_default_graph()) config = config_pb2.ConfigProto() config.graph_options.rewrite_options.CopyFrom( rewriter_config_pb2.RewriterConfig( constant_folding=rewriter_config_pb2.RewriterConfig.OFF, memory_optimization=rewriter_config_pb2.RewriterConfig. MANUAL)) return tf_optimizer.OptimizeGraph(config, mg) g = GetOptimizedGraph() # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned # away, causing an extra Enter node. enter_count = 2 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 1 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) # Test that the TensorList is pruned out. self.assertEmpty([ n for n in g.node if n.op == "Enter" and n.attr["T"].type == dtypes.variant.as_datatype_enum ]) stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype) train_op.append(stack) g = GetOptimizedGraph() # TODO(b/136034023): while_v2 adds an extra loop_counter which is not pruned # away, causing an extra Enter node. enter_count = 3 if control_flow_util.ENABLE_CONTROL_FLOW_V2 else 2 self.assertLen([n for n in g.node if n.op == "Enter"], enter_count) # Test that the TensorList is not pruned out. self.assertNotEmpty([ n for n in g.node if n.op == "Enter" and n.attr["T"].type == dtypes.variant.as_datatype_enum ])
def _add_elements_to_collection(elements, collection_list): if context.executing_eagerly(): raise RuntimeError('Using collections from Layers not supported in Eager ' 'mode. Tried to add %s to %s' % (elements, collection_list)) elements = nest.flatten(elements) collection_list = nest.flatten(collection_list) for name in collection_list: collection = ops.get_collection_ref(name) collection_set = set(collection) for element in elements: if element not in collection_set: collection.append(element)
def _add_elements_to_collection(elements, collection_list): if context.executing_eagerly(): raise RuntimeError('Using collections from Layers not supported in Eager ' 'mode. Tried to add %s to %s' % (elements, collection_list)) elements = nest.flatten(elements) collection_list = nest.flatten(collection_list) for name in collection_list: collection = ops.get_collection_ref(name) collection_set = set(collection) for element in elements: if element not in collection_set: collection.append(element)
def _apply_averages(): # pylint: disable=missing-docstring local_vars = [v for g, v in grads_and_vars if g is not None] global_vars = ops.get_collection_ref("global_model") # sync queue with ops.colocate_with(self._global_step): sync_queue = data_flow_ops.FIFOQueue(-1, [dtypes.bool], shapes=[[]], shared_name="sync_queue") train_ops = [] aggregated_vars = [] with ops.name_scope(None, self._name + "/global"): for var, gvar in zip(local_vars, global_vars): # pylint: disable=protected-access with ops.device(gvar.device): if isinstance(var._ref(), ops.Tensor): var_accum = data_flow_ops.ConditionalAccumulator( var.dtype, shape=var.get_shape(), shared_name=gvar.name + "/var_accum") train_ops.append( var_accum.apply_grad(var._ref(), local_step=global_step)) aggregated_vars.append( var_accum.take_grad( self._replicas_to_aggregate)) else: raise ValueError("Unknown local variable type!") self._accumulator_list.append((var_accum, gvar.device)) # chief worker updates global vars and enqueues tokens to the sync queue if self._is_chief: update_ops = [] with ops.control_dependencies(train_ops): for avg_var, gvar in zip(aggregated_vars, global_vars): with ops.device(gvar.device): update_ops.append(state_ops.assign(gvar, avg_var)) with ops.device(global_step.device): update_ops.append( state_ops.assign_add(self._global_step, 1)) with ops.control_dependencies(update_ops), ops.device( global_step.device): tokens = array_ops.fill([self._tokens_per_step], constant_op.constant(False)) sync_op = sync_queue.enqueue_many(tokens) else: with ops.control_dependencies(train_ops), ops.device( global_step.device): sync_op = sync_queue.dequeue() with ops.control_dependencies([sync_op]): local_update_op = self._assign_vars(local_vars, global_vars) return local_update_op
def apply_gradients(self, grads_and_vars, global_step=None, name=None): # Do error checking and create slots. Error checking is copied from base apply_gradients grads_and_vars = tuple( grads_and_vars) # Make sure repeat iteration works. if not grads_and_vars: raise ValueError("No variables provided.") converted_grads_and_vars = [] for g, v in grads_and_vars: if g is not None: try: # Convert the grad to Tensor or IndexedSlices if necessary. g = ops.convert_to_tensor_or_indexed_slices(g) except TypeError: raise TypeError("Gradient must be convertible to a Tensor" " or IndexedSlices, or None: %s" % g) if not isinstance(g, (ops.Tensor, ops.IndexedSlices)): raise TypeError( "Gradient must be a Tensor, IndexedSlices, or None: %s" % g) converted_grads_and_vars.append((g, v)) converted_grads_and_vars = tuple(converted_grads_and_vars) var_list = [v for g, v in converted_grads_and_vars if g is not None] if not var_list: raise ValueError("No gradients provided for any variable: %s." % ([str(v) for _, _, v in converted_grads_and_vars], )) with ops.control_dependencies(None): self._create_slots([_get_variable_for(v) for v in var_list]) ##### end copypasta code ##### with ops.name_scope(name, self._name) as name: aggregate_op = self._prepare_aggregates(converted_grads_and_vars) non_aggregate_updates = super(OptimizerWithAggregates, self).apply_gradients( grads_and_vars, global_step=global_step, name=name) apply_updates = control_flow_ops.group(aggregate_op, non_aggregate_updates) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) if aggregate_op not in train_op: train_op.append(aggregate_op) return apply_updates
def testColocationContraints(self): with ops.Graph().as_default() as g: c = constant_op.constant([10]) v = variables.VariableV1([3], dtype=dtypes.int32) i = gen_array_ops.ref_identity(v) a = state_ops.assign(i, c) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(a) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) groups = grappler_item.GetColocationGroups() self.assertEqual(len(groups), 1) self.assertItemsEqual( groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
def testColocationContraints(self): with ops.Graph().as_default() as g: c = constant_op.constant([10]) v = variables.Variable([3], dtype=dtypes.int32) i = gen_array_ops._ref_identity(v) a = state_ops.assign(i, c) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(a) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) groups = grappler_item.GetColocationGroups() self.assertEqual(len(groups), 1) self.assertItemsEqual( groups[0], ['Assign', 'RefIdentity', 'Variable', 'Variable/Assign'])
def testContext(self): with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=()) b = random_ops.random_uniform(shape=()) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) with cluster.Provision( disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) self.assertEqual(len(op_perfs), 4) self.assertTrue(step_stats.dev_stats)
def testContext(self): with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=()) b = random_ops.random_uniform(shape=()) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) with cluster.Provision( disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) self.assertEqual(len(op_perfs), 7) self.assertTrue(step_stats.dev_stats)
def testNoDetailedStats(self): with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=()) b = random_ops.random_uniform(shape=()) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) grappler_cluster = cluster.Cluster(disable_detailed_stats=True) op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) self.assertEqual(len(op_perfs), 0) self.assertEqual(len(step_stats.dev_stats), 0)
def _set_placeholder(cls, placeholder): """Sets a `tf.placeholder` to be fed by the first SummarySaverHook. If a placeholder is provided, the first instance of SummarySaverHook in use will feed it a boolean indicating whether summaries should be written, according to the `save_steps` and `save_secs` parameters of that hook. This makes the placeholder usable with `tf.contrib.summary.record_summaries_if` to control `tf.contrib.summary` summary writing using the same schedule as the `tf.summary` summary writing (which the hook controls directly). Args: placeholder: `tf.placeholder` for the first SummarySaverHook to feed """ collection = ops.get_collection_ref( cls._SUMMARY_PLACEHOLDER_COLLECTION) collection[:] = [placeholder]
def testNoDetailedStats(self): with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=()) b = random_ops.random_uniform(shape=()) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) grappler_cluster = cluster.Cluster(disable_detailed_stats=True) op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) self.assertEqual(len(op_perfs), 0) self.assertEqual(len(step_stats.dev_stats), 0)
def testSupportDevices(self): gpu_type = test_util.gpu_device_type() gpu_name = test_util.gpu_device_name() with ops.Graph().as_default() as g: a = random_ops.random_uniform(shape=(2, 3)) b = random_ops.random_uniform(shape=(2, 3)) c = a + b dims = math_ops.range(0, array_ops.rank(c), 1) d = math_ops.reduce_sum(a, axis=dims) train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) device_properties = device_properties_pb2.DeviceProperties( type=gpu_type, frequency=1000, num_cores=60) named_gpu = device_properties_pb2.NamedDevice( properties=device_properties, name=gpu_name) device_properties = device_properties_pb2.DeviceProperties( type='CPU', frequency=3000, num_cores=6) named_cpu = device_properties_pb2.NamedDevice( properties=device_properties, name='/CPU:0') virtual_cluster = cluster.Cluster(devices=[named_cpu, named_gpu]) supported_dev = virtual_cluster.GetSupportedDevices(grappler_item) self.assertEqual(supported_dev['add'], ['/CPU:0', gpu_name]) self.assertEqual(supported_dev['Sum'], ['/CPU:0', gpu_name]) self.assertEqual(supported_dev['range'], ['/CPU:0', gpu_name]) real_cluster = cluster.Cluster() supported_dev = real_cluster.GetSupportedDevices(grappler_item) if test.is_gpu_available(): self.assertEqual(supported_dev['add'], [ '/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0' + gpu_name ]) self.assertEqual(supported_dev['Sum'], [ '/job:localhost/replica:0/task:0/device:CPU:0', '/job:localhost/replica:0/task:0' + gpu_name ]) # The axis tensor must reside on the host self.assertEqual( supported_dev['range'], ['/job:localhost/replica:0/task:0/device:CPU:0']) else: self.assertEqual( supported_dev['add'], ['/job:localhost/replica:0/task:0/device:CPU:0'])
def testNoSwapping(self): """Make sure the graph is preserved when there is nothing to swap.""" a = constant_op.constant(10, name='a') b = constant_op.constant(20, name='b') c = math_ops.add_n([a, b], name='c') d = math_ops.add_n([b, c], name='d') train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) rewriter_config = rewriter_config_pb2.RewriterConfig( memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL) graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) self.assertEqual(len(graph.node), 4) self.assertItemsEqual([node.name for node in graph.node], ['a', 'b', 'c', 'd'])
def testDebugMode(self): """Make sure arguments can be passed correctly.""" a = constant_op.constant([10, 11], name="a") b = constant_op.constant([10], name="b") c = math_ops.add(a, b, name="c") train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) report = model_analyzer.GenerateModelReport(mg, debug=True) # Check the report headers self.assertIn(b"input 0 (int32) has known value", report) self.assertIn(b"input 1 (int32) has known value", report) # Also print the report to make it easier to debug print("{}".format(report))
def testOpProperties(self): with ops.Graph().as_default() as g: a = constant_op.constant(10) b = constant_op.constant(20) c = a + b train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) train_op.append(c) mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_properties = grappler_item.GetOpProperties() # All the nodes in this model have one scalar output for node in grappler_item.metagraph.graph_def.node: node_prop = op_properties[node.name] self.assertEqual(1, len(node_prop)) self.assertEqual(dtypes.int32, node_prop[0].dtype) self.assertEqual(tensor_shape.scalar(), node_prop[0].shape)