Пример #1
0
 def get_out_data_from_opts(cls,
                            name,
                            window_size,
                            axis="T",
                            sources=(),
                            **kwargs):
     """
 :param str name:
 :param list[LayerBase] sources:
 :param int window_size:
 :param str axis:
 :rtype: Data
 """
     data = get_concat_sources_data_template(sources)
     data = data.copy_template(name="%s_output" % name)
     data = data.copy_as_batch_major()
     if axis == "T" and data.time_dim_axis is None:
         # Assume inside RecLayer.
         axis = 0
     else:
         axis = data.get_axis_from_description(axis)
     data = data.copy_add_spatial_dim(
         spatial_dim_axis=axis + 1,
         dim=window_size)  # add new axis right after
     return data
Пример #2
0
 def get_out_data_from_opts(cls, name, sources=(), **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name)
   out.batch_dim_axis = 0
   out.time_dim_axis = 1
   out.dim = None
   out.size_placeholder = {}
   return out
Пример #3
0
 def get_out_data_from_opts(cls, name, sources=(), **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name)
   out.batch_dim_axis = 0
   out.time_dim_axis = 1
   out.dim = None
   out.size_placeholder = {}
   return out
Пример #4
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            nr_of_channels,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value //
             nr_of_channels
         ],
         dtype=input_data.dtype,
         size_placeholder={
             0:
             tf.reshape(
                 tf.tile(
                     tf.reshape(
                         input_data.size_placeholder[
                             input_data.time_dim_axis_excluding_batch],
                         [-1, 1]), [1, nr_of_channels]), [-1])
         },
         sparse=False,
         batch_dim_axis=0,
         time_dim_axis=1)
Пример #5
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            pool_size,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value
         ],
         dtype=input_data.dtype,
         size_placeholder={
             0:
             tf.strided_slice(
                 input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch], [0],
                 tf.shape(input_data.size_placeholder[
                     input_data.time_dim_axis_excluding_batch]),
                 [pool_size])
         },
         sparse=False,
         batch_dim_axis=0,
         time_dim_axis=1)
Пример #6
0
 def get_out_data_from_opts(cls, name, sources, **kwargs):
     out = get_concat_sources_data_template(sources,
                                            name="%s_output" %
                                            name).copy_as_batch_major()
     out.dim = None
     out.shape = (1, None)
     out.size_placeholder = {}
     return out
Пример #7
0
 def get_out_data_from_opts(cls, n_out, **kwargs):
     data = get_concat_sources_data_template(kwargs["sources"], name="%s_output" % kwargs["name"])
     data = data.copy_as_time_major()  # type: Data
     data.shape = (None, n_out)
     data.time_dim_axis = 0
     data.batch_dim_axis = 1
     data.dim = n_out
     return data
Пример #8
0
 def get_out_data_from_opts(cls, name, sources, window, broadcast_axis='time', **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name).copy_as_batch_major()
   out.size_placeholder = {}
   out.size_placeholder[0] = None
   out.shape = (1, window) if broadcast_axis == 'time' else (window, 1)
   out.dim = window if broadcast_axis == 'time' else 1
   out.sparse = False
   out.dtype = 'float32'
   return out
Пример #9
0
 def get_out_data_from_opts(cls, name, sources, window, broadcast_axis='time', **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name).copy_as_batch_major()
   out.size_placeholder = {}
   out.size_placeholder[0] = None
   out.shape = (1, window) if broadcast_axis == 'time' else (window, 1)
   out.dim = window if broadcast_axis == 'time' else 1
   out.sparse = False
   out.dtype = 'float32'
   return out
Пример #10
0
 def get_out_data_from_opts(cls, name, sources, num_classes, window, **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name).copy_as_batch_major()
   out.size_placeholder = {}
   out.size_placeholder[0] = None
   out.shape = (window, num_classes)
   out.dim = num_classes
   out.sparse = False
   out.dtype = 'float32'
   return out
Пример #11
0
 def get_out_data_from_opts(cls, name, sources, num_classes, window, **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name).copy_as_batch_major()
   out.size_placeholder = {}
   out.size_placeholder[0] = None
   out.shape = (window, num_classes)
   out.dim = num_classes
   out.sparse = False
   out.dtype = 'float32'
   return out
Пример #12
0
 def get_out_data_from_opts(cls, name, sources, nr_of_channels, n_out=None, **kwargs):
   input_data = get_concat_sources_data_template(sources).copy_as_batch_major()
   assert not input_data.sparse
   return Data(
     name="%s_output" % name,
     shape=[input_data.batch_shape[1], input_data.batch_shape[2] // nr_of_channels],
     dtype=input_data.dtype,
     sparse=False,
     batch_dim_axis=0,
     time_dim_axis=1)
Пример #13
0
 def get_out_data_from_opts(cls, name, sources, repetitions, n_out=None, **kwargs):
   input_data = get_concat_sources_data_template(sources)
   assert not input_data.sparse
   return Data(
     name="%s_output" % name,
     shape=[input_data.get_placeholder_as_batch_major().shape[1].value, input_data.get_placeholder_as_batch_major().shape[2].value * repetitions],
     dtype=input_data.dtype,
     sparse=False,
     size_placeholder={0: input_data.size_placeholder[input_data.time_dim_axis_excluding_batch]},
     batch_dim_axis=0,
     time_dim_axis=1)
Пример #14
0
 def get_out_data_from_opts(cls, name, sources, pool_size, n_out=None, **kwargs):
   input_data = get_concat_sources_data_template(sources)
   assert not input_data.sparse
   return Data(
     name="%s_output" % name,
     shape=[input_data.get_placeholder_as_batch_major().shape[1].value, input_data.get_placeholder_as_batch_major().shape[2].value],
     dtype=input_data.dtype,
     size_placeholder={0: tf.strided_slice(input_data.size_placeholder[input_data.time_dim_axis_excluding_batch], [0], tf.shape(input_data.size_placeholder[input_data.time_dim_axis_excluding_batch]), [pool_size])},
     sparse=False,
     batch_dim_axis=0,
     time_dim_axis=1)
Пример #15
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            nr_of_channels,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(
         sources).copy_as_batch_major()
     assert not input_data.sparse
     return Data(name="%s_output" % name,
                 shape=[
                     input_data.batch_shape[1],
                     input_data.batch_shape[2] // nr_of_channels
                 ],
                 dtype=input_data.dtype,
                 sparse=False,
                 batch_dim_axis=0,
                 time_dim_axis=1)
Пример #16
0
 def get_out_data_from_opts(cls,
                            name,
                            sources,
                            repetitions,
                            n_out=None,
                            **kwargs):
     input_data = get_concat_sources_data_template(sources)
     assert not input_data.sparse
     return Data(
         name="%s_output" % name,
         shape=[
             input_data.get_placeholder_as_batch_major().shape[1].value,
             input_data.get_placeholder_as_batch_major().shape[2].value *
             repetitions
         ],
         dtype=input_data.dtype,
         sparse=False,
         size_placeholder={
             0:
             input_data.size_placeholder[
                 input_data.time_dim_axis_excluding_batch]
         },
         batch_dim_axis=0,
         time_dim_axis=1)
Пример #17
0
 def get_rec_initial_extra_outputs(cls,
                                   batch_dim,
                                   rec_layer,
                                   window_size,
                                   axis="T",
                                   sources=(),
                                   **kwargs):
     """
 :param tf.Tensor batch_dim:
 :param TFNetworkRecLayer.RecLayer|LayerBase rec_layer:
 :param int window_size:
 :param str axis:
 :param list[LayerBase] sources:
 :rtype: dict[str,tf.Tensor]
 """
     data = get_concat_sources_data_template(sources)
     data = data.copy_as_batch_major()
     if axis == "T" and data.time_dim_axis is None:
         # Assume inside RecLayer.
         shape = list(data.batch_shape)
         shape[0] = batch_dim
         shape.insert(1, window_size)
         return {"state": tf.zeros(shape, dtype=data.dtype)}
     return {}
Пример #18
0
 def get_out_data_from_opts(cls, name, sources=(), **kwargs):
   return get_concat_sources_data_template(sources, name="%s_output" % name)
Пример #19
0
 def get_out_data_from_opts(cls, name, sources, **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name).copy_as_batch_major()
   out.dim = None
   out.shape = (1, None)
   out.size_placeholder = {}
   return out
Пример #20
0
    def get_out_data_from_opts(cls,
                               name,
                               sources,
                               switch_axes=None,
                               size_base=None,
                               set_axes=None,
                               enforce_batch_major=False,
                               enforce_time_major=False,
                               set_sparse=None,
                               set_sparse_dim=NotSpecified,
                               increase_sparse_dim=None,
                               **kwargs):
        """
    :param str name:
    :param list[LayerBase] sources:
    :param str|list[str] switch_axes: e.g. "bt" to switch batch and time axes
    :param LayerBase|None size_base: similar as size_target
    :param dict[str,int] set_axes:
    :param bool enforce_batch_major:
    :param bool enforce_time_major:
    :param bool|None set_sparse: if bool, set sparse value to this
    :param int|None|NotSpecified set_sparse_dim: set sparse dim to this. assumes that it is sparse
    :param int|None increase_sparse_dim: add this to the dim. assumes that it is sparse
    """
        out = get_concat_sources_data_template(sources,
                                               name="%s_output" % name)
        assert not (enforce_batch_major and enforce_time_major)
        if enforce_batch_major:
            out = out.copy_as_batch_major()
        if enforce_time_major:
            out = out.copy_as_time_major()

        def map_axis_name(s):
            """
      :param str s:
      :rtype: str
      """
            if s.upper() == "B":
                return "batch_dim_axis"
            if s.upper() == "T":
                return "time_dim_axis"
            if s.upper() == "F":
                return "feature_dim_axis"
            assert s in ["batch_dim_axis", "time_dim_axis", "feature_dim_axis"]
            return s

        if switch_axes:
            assert len(switch_axes) == 2
            axes_s = list(map(map_axis_name, switch_axes))
            axes = [getattr(out, s) for s in axes_s]
            for i in range(len(axes)):
                setattr(out, axes_s[i], axes[(i + 1) % len(axes)])
        if set_axes:
            for s, i in sorted(set_axes.items()):
                s = map_axis_name(s)
                if isinstance(i, int):
                    assert enforce_batch_major or enforce_time_major, "%r: explicit set_axes %r" % (
                        name, set_axes)
                i = out.get_axis_from_description(i)
                setattr(out, s, i)
                if s == "feature_dim_axis":
                    out.dim = out.batch_shape[out.feature_dim_axis]
        if size_base:
            out.size_placeholder = size_base.output.size_placeholder.copy()
        if set_sparse is not None:
            assert isinstance(set_sparse, bool)
            out.sparse = set_sparse
        if set_sparse_dim is not NotSpecified:
            assert set_sparse_dim is None or isinstance(set_sparse_dim, int)
            out.dim = set_sparse_dim
        if increase_sparse_dim:
            assert out.sparse
            out.dim += increase_sparse_dim
        return out
Пример #21
0
 def get_out_data_from_opts(cls, name, sources, window, **kwargs):
   out = get_concat_sources_data_template(sources, name="%s_output" % name)
   out.size_placeholder = {}
   out.size_placeholder[0] = None
   return out
Пример #22
0
 def get_out_data_from_opts(cls, name, sources, window, **kwargs):
     out = get_concat_sources_data_template(sources,
                                            name="%s_output" % name)
     out.size_placeholder = {}
     out.size_placeholder[0] = None
     return out
Пример #23
0
 def get_out_data_from_opts(cls, name, sources=(), **kwargs):
     return get_concat_sources_data_template(sources,
                                             name="%s_output" % name)