def step(self, inputs, states=None): def dimensions(x): dim = tf.ones([tf.size(tf.shape(x))], dtype=tf.int32) return tf.concat([[self.size], dim], axis=0) inputs_0 = map_fn(inputs, [inputs], lambda i: i[0]) # Run the model without state to get the dtype op_value, next_state, init_state = self.operator( inputs_0, streamable=False) if states is None: extended_init_state = map_fn(init_state, [init_state], lambda x: tf.tile(tf.expand_dims(x, axis=0), dimensions(x))) states = (extended_init_state,) prev_state = states[0] state_dtype = map_fn(prev_state, [prev_state], lambda x: x.dtype) output_dtype = map_fn(op_value, [op_value], lambda o: o.dtype) def apply_op(inputs_states): outputs, new_state, _ = self.operator( inputs=inputs_states[0], state=inputs_states[1], streamable=False ) return outputs, new_state outputs, next_state = tf.map_fn( apply_op, (inputs, prev_state), dtype=(output_dtype, state_dtype) ) return outputs, (next_state,), (prev_state,)
def __call__(self, inputs, state=None, streamable=True): # Preprocessing to convert python literals into tensors. inputs = map_fn(inputs, [inputs], lambda x: tf.convert_to_tensor(x)) if not state is None: state = map_fn(state, [state], lambda x: tf.convert_to_tensor(x)) if streamable: return self.call_streamed(inputs, state) else: return self.forward_step(inputs, state)
def loop(i, loop_inputs, loop_state, loop_outputs): inputs_i = map_fn(loop_inputs, [loop_inputs], lambda x: x[i]) outputs_i, next_state, _ = self.forward_step(inputs_i, loop_state) Streamable.check_types_equality( Streamable.to_type(loop_state), Streamable.to_type(next_state), "Provided state and model state has different types (%s != %s)." ) new_outputs = map_fn(outputs_i, [outputs_i, loop_outputs], lambda x, y: y.write(i, x)) map_fn( loop_state, [loop_state, next_state], lambda x, y: tf.convert_to_tensor(y).set_shape( tf.convert_to_tensor(x).get_shape())) return (i + 1, loop_inputs, next_state, new_outputs)
def step(self, inputs, states=None): if states is None: states = (0, ) iteration = states[0] inp = map_fn( inputs, [inputs], lambda x: tf.cond(tf.less(iteration, self.period), lambda: tf.fill( tf.shape(x), self.value), lambda: x)) return inp, (iteration + 1, ), (iteration, )
def any_value(inputs, value): """ Given tensors in parameters, returns true if one of them contains the given value. """ has_value_inputs = flatten( map_fn(inputs, [inputs], lambda i: tensor_has_value(i, value))) return functools.reduce(lambda acc, x: tf.logical_or(acc, x), has_value_inputs, tf.constant(False))
def any_nan(inputs): """ Given tensors in parameters, return true if one of them has a nan value. """ has_nan_inputs = flatten( map_fn(inputs, [inputs], lambda i: tensor_has_nan(i))) return functools.reduce(lambda acc, x: tf.logical_or(acc, x), has_nan_inputs, tf.constant(False))
def test_map_fn_complex_type(self): x = {'a': [{'x': 4, 'y': 3}, 3], 'b': {'z': [1, 3]}} expected = {'a': [{'x': 8, 'y': 6}, 6], 'b': {'z': [2, 6]}} self.assertEqual(map_fn(x, [x], times_2), expected)
def test_map_fn_dict(self): x = {'a': 4, 'b': 10} self.assertEqual(map_fn(x, [x], times_2), {'a': 8, 'b': 20})
def test_map_fn_list(self): x = [2, 3, 4, 5] self.assertEqual(map_fn(x, [x], times_2), [4, 6, 8, 10])
def test_map_fn_single_value(self): x = 4 self.assertEqual(map_fn(x, [x], times_2), 8)
def call_streamed(self, inputs_tensors, provided_state): """ Stream inputs_tensors into the operator. The variable provided_state can be None if initial state is not provided by the user or it can be a compatible initial state provided by the user. """ inputs_sizes = tuple( map(lambda x: tf.shape(x)[0], flatten(inputs_tensors))) size = inputs_sizes[0] same_size = functools.reduce( lambda acc, x: tf.logical_and(acc, tf.equal(x, size)), inputs_sizes, tf.constant(True)) assert_same_size = tf.assert_equal( same_size, tf.constant(True), data=inputs_sizes, message="inputs have different sizes.") inputs_0 = map_fn(inputs_tensors, [inputs_tensors], lambda x: x[0]) outputs_0, state, initial_state = self.forward_step( inputs_0, provided_state) # if state is provided, we take it as the initial value. if not provided_state is None: initial_state = provided_state def cond(i, loop_inputs, loop_state, outputs): return i < size def loop(i, loop_inputs, loop_state, loop_outputs): inputs_i = map_fn(loop_inputs, [loop_inputs], lambda x: x[i]) outputs_i, next_state, _ = self.forward_step(inputs_i, loop_state) Streamable.check_types_equality( Streamable.to_type(loop_state), Streamable.to_type(next_state), "Provided state and model state has different types (%s != %s)." ) new_outputs = map_fn(outputs_i, [outputs_i, loop_outputs], lambda x, y: y.write(i, x)) map_fn( loop_state, [loop_state, next_state], lambda x, y: tf.convert_to_tensor(y).set_shape( tf.convert_to_tensor(x).get_shape())) return (i + 1, loop_inputs, next_state, new_outputs) i0 = tf.constant(0) outputs = map_fn(outputs_0, [outputs_0], lambda x: tf.TensorArray(dtype=x.dtype, size=size)) with tf.control_dependencies([assert_same_size]): i_f, inputs_f, state_f, outputs_f = tf.while_loop( cond, loop, loop_vars=[i0, inputs_tensors, initial_state, outputs], name="stream_loop") outputs = map_fn(outputs_f, [outputs_f], lambda o: o.stack()) return (outputs, state_f, initial_state)
def to_type(values): return map_fn( values, [values], lambda val: val.dtype if isinstance(val, tf.Tensor) else tf.convert_to_tensor(val).dtype)