Exemplo n.º 1
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            repetitions,
                            n_out=None,
                            **kwargs):
     """
 :param str name:
 :param list[LayerBase] sources:
 :param int repetitions:
 :param int|None|returnn.util.basic.NotSpecified n_out:
 :rtype: Data
 """
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[input_data.shape[1], input_data.shape[2] * repetitions],
         dtype=input_data.dtype,
         sparse=False,
         size_placeholder={
             0:
             input_data.size_placeholder[
                 input_data.time_dim_axis_excluding_batch]
         },
         batch_dim_axis=0,
         time_dim_axis=1)
Exemplo n.º 2
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            nr_of_channels,
                            n_out=None,
                            **kwargs):
     """
 :param str name:
 :param list[LayerBase] sources:
 :param int nr_of_channels:
 :param int|None|returnn.util.basic.NotSpecified n_out:
 :rtype: Data
 """
     input_data = get_concat_sources_data_template(
         sources).copy_as_batch_major()
     assert not input_data.sparse
     return Data(name="%s_output" % name,
                 shape=[
                     input_data.batch_shape[1],
                     input_data.batch_shape[2] // nr_of_channels
                 ],
                 dtype=input_data.dtype,
                 sparse=False,
                 batch_dim_axis=0,
                 time_dim_axis=1)
Exemplo n.º 3
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            pool_size,
                            n_out=None,
                            **kwargs):
     """
 :param str name:
 :param list[LayerBase] sources:
 :param int pool_size:
 :param int|None|returnn.util.basic.NotSpecified n_out:
 :rtype: Data
 """
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value
         ],
         dtype=input_data.dtype,
         size_placeholder={
             0:
             tf.strided_slice(
                 input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch], [0],
                 tf.shape(input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch]),
                 [pool_size])
         },
         sparse=False,
         batch_dim_axis=0,
         time_dim_axis=1)
Exemplo n.º 4
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            pool_size,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value
         ],
         dtype=input_data.dtype,
         size_placeholder={
             0:
             tf.strided_slice(
                 input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch], [0],
                 tf.shape(input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch]),
                 [pool_size])
         },
         sparse=False,
         batch_dim_axis=0,
         time_dim_axis=1)
Exemplo n.º 5
0
    def create_state_var(self, name, initial_value=None, data_shape=None):
        """
    A state var is a variable where the initial value is given by the encoder, or a constant,
    and the final value is determined by one step of this rec layer (usually called the decoder).

    :param str name:
    :param tf.Tensor|None initial_value: assumes batch-major, if data_shape is not given
    :param Data|None data_shape:
    :rtype: tf.Tensor
    """
        assert name not in self.state_vars
        assert data_shape or initial_value is not None
        if data_shape:
            assert isinstance(data_shape, Data)
        elif initial_value.shape.ndims == 0:
            data_shape = Data(name=name,
                              batch_dim_axis=None,
                              shape=(),
                              dtype=initial_value.dtype.name)
        else:
            assert initial_value.shape.dims[
                0].value is None  # first is batch dim
            data_shape = Data(name=name,
                              batch_dim_axis=0,
                              shape=initial_value.shape.as_list()[1:],
                              dtype=initial_value.dtype.name)
        if initial_value is not None:
            # initial_value might have dim 1 in variable dimensions (which are not the batch-dim-axis),
            # see get_rec_initial_output, which should be fine for broadcasting.
            initial_value.set_shape(data_shape.batch_shape)
        var = self.StateVar(parent=self,
                            name=name,
                            initial_value=initial_value,
                            data_shape=data_shape)
        self.state_vars[name] = var
        return var.read()
Exemplo n.º 6
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            nr_of_channels,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(
         sources).copy_as_batch_major()
     assert not input_data.sparse
     return Data(name="%s_output" % name,
                 shape=[
                     input_data.batch_shape[1],
                     input_data.batch_shape[2] // nr_of_channels
                 ],
                 dtype=input_data.dtype,
                 sparse=False,
                 batch_dim_axis=0,
                 time_dim_axis=1)
Exemplo n.º 7
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            repetitions,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value *
             repetitions
         ],
         dtype=input_data.dtype,
         sparse=False,
         size_placeholder={
             0:
             input_data.size_placeholder[
                 input_data.time_dim_axis_excluding_batch]
         },
         batch_dim_axis=0,
         time_dim_axis=1)
Exemplo n.º 8
0
 def get_out_data_from_opts(cls, name, **kwargs):
     from returnn.tf.util.basic import Data
     return Data(name="%s_output" % name,
                 batch_dim_axis=None,
                 shape=(),
                 dtype="float32")  # scalar