def testUsingInfeedQueueWithRegularizer(self): """Test that Layer regularizers can reference data created in loops.""" with ops.Graph().as_default(): def make_regularizer(scale): def regularizer(inputs): return scale * math_ops.reduce_sum(math_ops.square(inputs)) return regularizer def training_step(inputs, scale): outputs = convolutional.conv2d( inputs, filters=16, kernel_size=(3, 3), data_format="channels_first", kernel_regularizer=make_regularizer(scale)) loss = math_ops.reduce_mean(math_ops.square(outputs)) return loss.op inputs = array_ops.zeros(shape=(128, 32, 32, 16)) scale = array_ops.ones(shape=()) infeed = tpu_feed.InfeedQueue( tuple_types=[dtypes.float32, dtypes.float32], tuple_shapes=[inputs.shape, scale.shape]) def loop(): return training_loop.repeat(5, training_step, infeed_queue=infeed) # This should not throw an error. tpu.rewrite(loop)
def tpu_subgraph_predict(): if use_bfloat16: with bfloat16_scope(): return tpu.rewrite(tpu_subgraph_predict_fn, [preprocessed_inputs, true_image_shapes]) else: return tpu.rewrite(tpu_subgraph_predict_fn, [preprocessed_inputs, true_image_shapes])
def tpu_fn(*args, **kwargs): # tpu.rewrite only accepts list of tensors as input. We need to flatten # keyword arguments to meet this requirement. concrete = tf_func.get_concrete_function(*(list(args) + list(kwargs.values()))) return tpu.rewrite(concrete.__call__, list(args) + list(kwargs.values()))
def run(self, fetches, feed_dict=None, options=None, run_metadata=None): from tensorflow.python.tpu import tpu # pylint: disable=g-import-not-at-top if self.topology is None: self.topology = super().run(tpu.initialize_system()) assert self.topology is not None fetch_mapper = session._FetchMapper.for_fetch(fetches) new_fetches = [] for fetch in fetch_mapper.unique_fetches(): if isinstance(fetch, ops.Operation): fetch = tpu.rewrite(lambda fetch=fetch: fetch) new_fetches.append(fetch) rewritten_fetches = fetch_mapper.build_results(new_fetches) return super().run(rewritten_fetches, feed_dict, options, run_metadata)
def test_tpu_rewrite_uses_xla_einsum(self): with ops.Graph().as_default() as g: tpu.rewrite(do_einsum) self.assertTrue(find_einsum(g) or find_xla_einsum(g))
def predict_tpu(): return tpu.rewrite(predict_tpu_subgraph, [preprocessed_inputs, true_image_shapes])