Example #1
0
 def call(self, input, **kwargs):
     """Call batch_normalization function."""
     bn = tf.keras.layers.BatchNormalization(momentum=self.momentum,
                                             axis=1 if self.data_format == 'channels_first' else 3,
                                             epsilon=self.eps,
                                             center=True, scale=True, fused=True,
                                             name=self.name, trainable=self._trainable)
     if self._is_load_pretrained:
         self.training = True
     out = bn(inputs=input, training=self.training)
     # update  moving average
     if self._trainable:
         for item in bn.updates:
             tf.add_to_collections(tf.GraphKeys.UPDATE_OPS, item)
     return out
Example #2
0
def load_variables_to_tf_graph(g  # type: base_graph.BaseGraph
                               ):
    """
  Convenience function to load all variables present in a `gde.Graph` into
  the current default TensorFlow graph, without generating a MetaGraphDef.
  Also adds those variables to the appropriate TensorFlow collections.

  Args:
    g: `gde.Graph` object from which all variables and variable collections
      should be loaded
  """
    for var_name in g.variable_names:
        var = g.get_variable_by_name(var_name)
        tf_var = tf.Variable.from_proto(var.to_proto())
        tf.add_to_collections(var.collection_names, tf_var)
Example #3
0
def collect_named_outputs(collections, alias, outputs):
    """Add `Tensor` outputs tagged with alias to collections.
  It is useful to collect end-points or tags for summaries. Example of usage:
  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
  assert 'inception_v3/logits' in logits.aliases
  Args:
    collections: A collection or list of collections. If None skip collection.
    alias: String to append to the list of aliases of outputs, for example,
           'inception_v3/conv1'.
    outputs: Tensor, an output tensor to collect
  Returns:
    The outputs Tensor to allow inline call.
  """
    if collections:
        append_tensor_alias(outputs, alias)
        tf.add_to_collections(collections, outputs)
    return outputs
Example #4
0
    def _build(self, inputs, prev_state, is_training=True):
        """Builds graph.

    Args:
      inputs: 3D tensor of shape [batch_size, chunk_size, input_dim] or
        2D tensor of shape [batch_size, input_dim].
      prev_state: list of length `num_layers` containing `CompressedMemoryState`
        tuples.
      is_training: applies dropout if true.

    Returns:
      output: tensor equal in rank to `inputs` with final dimension equal to
        `embedding_size` = `key_size` * `num_heads`.
      next_state: list of length `num_layers` containing `CompressedMemoryState`
        tuples.
    """
        input_shape = inputs.get_shape().as_list()
        if len(input_shape) == 2:
            inputs = tf.expand_dims(inputs, 1)

        _, chunk_size, _ = inputs.get_shape().as_list()
        num_layers_t = tf.constant(self._num_layers, dtype=inputs.dtype)

        inputs = default_mlp([self._embedding_size], activate_final=True)(
            inputs,
            is_training=is_training,
            dropout_keep_prob=1 - self._dropout_rate)
        transformer = TransformerTower(**self._core_config)
        state_for_transformer = (None if self._episodic_memory_size == 0 else
                                 prev_state)
        output, attention_state = transformer(inputs,
                                              state=state_for_transformer,
                                              is_training=is_training)

        min_num_to_compress = (self._compression_rate +
                               self._compression_config.get('kernel_size', 0))
        num_to_compress = min(max(min_num_to_compress, chunk_size),
                              chunk_size + self._episodic_memory_size - 1)

        def apply_compression_generic(attn_state, attn_module, mem_to_compress,
                                      prev_compressed_memory):
            """Instantiates compression module and returns fn to build graph."""
            compress_module = self._compression_ctor(
                **self._compression_config)

            def _inner_fn():
                """Returns (updated compressed memory, compression loss)."""
                next_compressed_memory, compression_loss = compress_module(
                    mem_to_compress,
                    attention_state=attn_state,
                    attention_module=attn_module,
                    is_training=is_training,
                    dropout_keep_prob=1 - self._dropout_rate,
                )
                compressed_memory, _ = _concat_and_slice(
                    prev_compressed_memory, next_compressed_memory)
                return compressed_memory, compression_loss

            return _inner_fn

        def dont_apply_compression_generic(prev_compressed_memory):
            """Instantiates fn to build dummy graph that skips any compression."""
            def _inner_fn():
                return (prev_compressed_memory,
                        tf.zeros([], dtype=prev_compressed_memory.dtype))

            return _inner_fn

        next_state = []
        compression_loss = tf.zeros([], dtype=inputs.dtype)
        global_attention_weights = []
        stats_export_dict = {}
        for i, state_i in enumerate(prev_state):
            # Append new elements to memory.
            attn_state_i = attention_state[i]
            memory, concat_memory = _concat_and_slice(state_i.episodic_memory,
                                                      attn_state_i.embeddings)

            sequence_index = state_i.index[0]
            # We special-case chunk_size=1, which is useful for sampling. In the
            # single time-step setting we only compress the memory every
            # 'compression_rate' steps. Otherwise we assume chunk_size is a multiple
            # of `compression_rate`, and thus multiple compressions can be performed
            # in parallel.
            to_compress = tf.logical_or(
                chunk_size > 1,
                tf.equal(sequence_index % self._compression_rate,
                         self._compression_rate - 1))[0]

            apply_compression_fn = apply_compression_generic(
                attn_state=attn_state_i,
                attn_module=transformer.attention_module(i),
                mem_to_compress=concat_memory[:, :num_to_compress],
                prev_compressed_memory=state_i.compressed_memory,
            )
            dont_apply_compression_fn = dont_apply_compression_generic(
                prev_compressed_memory=state_i.compressed_memory)

            compression_output = tf.cond(to_compress, apply_compression_fn,
                                         dont_apply_compression_fn)
            compressed_memory, compression_loss_i = compression_output
            compression_loss += compression_loss_i

            # Log useful stats, compression loss per layer.
            stats_export_dict['compression_loss_l%02d' %
                              i] = compression_loss_i
            # Attention weights per layer.
            attn_names, attn_weights = _compute_avg_attention(
                attn_state_i, self._compressed_memory_size,
                self._episodic_memory_size, chunk_size)
            attn_names_i = [name + '_l%02d' % i for name in attn_names]
            stats_export_dict.update(dict(zip(attn_names_i, attn_weights)))

            # Avg global attention weights.
            if i == 0:
                global_attention_weights = [
                    y / num_layers_t for y in attn_weights
                ]
            else:
                global_attention_weights = [
                    (x + y / num_layers_t)
                    for x, y in zip(global_attention_weights, attn_weights)
                ]

            next_state.append(
                CompressedMemoryState(index=state_i.index + 1,
                                      episodic_memory=memory,
                                      compressed_memory=compressed_memory))

        next_state = tuple(next_state)
        compression_loss /= num_layers_t
        stats_export_dict.update(
            dict(zip(attn_names, global_attention_weights)))
        if is_training:
            tf.add_to_collections('auxiliary_losses', compression_loss)
        if self._export_stats:
            tf.add_to_collections('stats_export', stats_export_dict)

        if self._chunk_size == 0:  # For the use-case as a single-step RNN.
            output = tf.squeeze(output, 1)

        return output, next_state