def make_base_state_vars(self, name, output): """ make state vars for input from outside the loop (automatic tile_batch) :param str name: layer name :param output: Data :rtype: Data """ rec_layer = self.parent_rec_layer from TFUtil import tile_transposed output.placeholder = tile_transposed( rec_layer.create_state_var(name="base_value_%s" % name, initial_value=output.placeholder, data_shape=output), axis=output.batch_dim_axis, multiples=self.parent_tile_multiples_var.read_value()) from TFUtil import DimensionTag for i, size in list(output.size_placeholder.items()): dim_tag = DimensionTag.get_tag_from_size_tensor(size) if not dim_tag: print("Warning, no defined dim tag on %r, axis %i" % (name, output.get_batch_axis(i)), file=log.v2) dim_tag = output.get_dim_tag(output.get_batch_axis(i)) dim_tag.set_tag_on_size_tensor(size) new_size = rec_layer.create_state_var(name="base_size%i_%s" % (i, name), initial_value=size) new_size = tile_transposed(new_size, axis=0, multiples=self.parent_tile_multiples_var.read_value()) dim_tag.set_tag_on_size_tensor(new_size) output.size_placeholder[i] = new_size return output
def tile_batch_op(self, repetitions): """ :param tf.Tensor repetitions: :return: op which assigns the tiled value of the previous var value :rtype: tf.Operation """ if self.var_data_shape.batch_dim_axis is None: return tf.no_op(name="tile_batch_state_var_no_op_%s" % self.name) # See also Data.copy_extend_with_beam. from TFUtil import tile_transposed tiled_value = tile_transposed( self.var.read_value(), axis=self.var_data_shape.batch_dim_axis, multiples=repetitions) return tf.assign(self.var, tiled_value, name="tile_batch_state_var_%s" % self.name).op