def __init__(self, main_channel, num_res_blocks, residual_hiddens,
                 the_stride):
        super(Encoder, self).__init__()

        if the_stride == 4:
            self.model = snt.Sequential([
                snt.Conv1D(output_channels=main_channel,
                           kernel_shape=4,
                           stride=2,
                           name="enc_0"), tf.nn.relu,
                snt.Conv1D(output_channels=main_channel,
                           kernel_shape=4,
                           stride=2,
                           name="enc_1"), tf.nn.relu,
                snt.Conv1D(output_channels=main_channel,
                           kernel_shape=3,
                           stride=1,
                           name="enc_2"), tf.nn.relu
            ])
        elif the_stride == 2:
            self.model = snt.Sequential([
                snt.Conv1D(output_channels=main_channel,
                           kernel_shape=4,
                           stride=2,
                           name="enc_0"), tf.nn.relu,
                snt.Conv1D(output_channels=main_channel,
                           kernel_shape=3,
                           stride=1,
                           name="enc_1"), tf.nn.relu
            ])

        for _ in range(num_res_blocks):
            self.model = snt.Sequential(
                [self.model,
                 residual_block(main_channel, residual_hiddens)])
Ejemplo n.º 2
0
    def variational_layer(self):
        if self._learn_prior == LEARN_PRIOR_CONV:
            self.en_shifted = shift_right(self.en)

            # POSTERIOR
            self.en_posterior = snt.Conv1D(output_channels=self._latent_channels,
                                           kernel_shape=1,
                                           stride=1,
                                           padding=snt.CAUSAL,
                                           name='conv_en_posterior')(self.en_shifted)

            self.input_conv = snt.Conv1D(output_channels=self._latent_channels,
                                         kernel_shape=self._hop_length,
                                         stride=self._hop_length,
                                         padding=snt.CAUSAL,
                                         name='conv_en_input')(self.x_quant_scaled)

            self.en_posterior = tf.concat([self.en_posterior, self.input_conv], axis=-1)
            self._gaussian_model_latent = self.get_gaussian(self.en_posterior)

            # PRIOR
            self.en_prior = snt.Conv1D(output_channels=self._latent_channels,
                                       kernel_shape=1,
                                       stride=1,
                                       padding=snt.CAUSAL,
                                       name='conv_en_prior_from_input')(shift_right(self.x_quant_scaled))

            self.en_prior = pool1d(self.en_prior, self._hop_length, name='ae_pool_prior', mode='avg')

            self._prior = self.get_gaussian(self.en_prior)

        else:
            self._gaussian_model_latent = self.get_gaussian(self.en)
            self._create_prior()
def residual_block(main_channel, residual_hiddens):
    output = snt.Sequential([
        snt.Conv1D(output_channels=residual_hiddens, kernel_shape=3, stride=1),
        tf.nn.relu,
        snt.Conv1D(output_channels=main_channel, kernel_shape=1, stride=1),
        tf.nn.relu,
    ])
    return output
Ejemplo n.º 4
0
        def pointwise(x):
            hidden = snt.Conv1D(output_channels=output_size, kernel_shape=1)(x)
            if self.dropout_rate > 0.0:
                hidden = tf.layers.dropout(hidden, self.dropout_rate)
            hidden = tf.nn.relu(hidden)

            outputs = snt.Conv1D(output_channels=output_size,
                                 kernel_shape=1)(hidden)
            if self.dropout_rate > 0.0:
                outputs = tf.layers.dropout(outputs, self.dropout_rate)
            return tf.nn.relu(outputs)
Ejemplo n.º 5
0
    def __init__(self, num_filters=32, filter_size=5, act='', name="adaptor"):
        super(Adaptor, self).__init__(name=name)

        self._bf = snt.BatchFlatten()
        self._pool = Downsample1D(2)
        self._act = Activation(act, verbose=True)

        with self._enter_variable_scope():

            self._l1_conv = snt.Conv1D(num_filters, filter_size + 2)
            self._l2_conv = snt.Conv1D(num_filters << 1, filter_size)
            self._l3_conv = snt.Conv1D(num_filters << 2, filter_size - 2)
 def __init__(
     self,
     in_channel=1,
     main_channel=32,
     num_res_blocks=2,
     residual_hiddens=8,
     embed_dim=16,
     n_embed=256,
     decay=0.99,
     commitment_cost=0.25,
 ):
     super(VAEModel, self).__init__()
     self.enc_b = Encoder(main_channel,
                          num_res_blocks,
                          residual_hiddens,
                          the_stride=4)
     self.enc_t = Encoder(main_channel,
                          num_res_blocks,
                          residual_hiddens,
                          the_stride=2)
     self.quantize_conv_t = snt.Conv1D(output_channels=embed_dim,
                                       kernel_shape=1,
                                       stride=1,
                                       name="enc_1")
     self.vq_t = snt.nets.VectorQuantizerEMA(
         embedding_dim=embed_dim,
         num_embeddings=n_embed,
         commitment_cost=commitment_cost,
         decay=decay)
     self.dec_t = Decoder(embed_dim,
                          main_channel,
                          num_res_blocks,
                          residual_hiddens,
                          the_stride=2)
     self.quantize_conv_b = snt.Conv1D(output_channels=embed_dim,
                                       kernel_shape=1,
                                       stride=1,
                                       name="enc_1")
     self.vq_b = snt.nets.VectorQuantizerEMA(
         embedding_dim=embed_dim,
         num_embeddings=n_embed,
         commitment_cost=commitment_cost,
         decay=decay)
     self.upsample_t = snt.Conv1DTranspose(output_channels=embed_dim,
                                           output_shape=None,
                                           kernel_shape=4,
                                           stride=2,
                                           name="up_1")
     self.dec = Decoder(in_channel,
                        main_channel,
                        num_res_blocks,
                        residual_hiddens,
                        the_stride=4)
    def compute_top_delta(self, z):
        """ parameterization of topD. This converts the top level activation
    to an error signal.
    Args:
      z: tf.Tensor
        batch of final layer post activations
    Returns
      delta: tf.Tensor
        the error signal
    """
        s_idx = 0
        with tf.variable_scope('compute_top_delta'), tf.device(
                self.remote_device):
            # typically this takes [BS, length, input_channels],
            # We are applying this such that we convolve over the batch dimension.
            act = tf.expand_dims(tf.transpose(z, [1, 0]),
                                 2)  # [channels, BS, 1]

            mod = snt.Conv1D(output_channels=self.top_delta_size,
                             kernel_shape=[5])
            act = mod(act)

            act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
            act = tf.nn.relu(act)

            bs = act.shape.as_list()[0]
            act = tf.transpose(act, [2, 1, 0])
            act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
            act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
            act = tf.nn.relu(act)
            act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act)
            act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
            act = tf.nn.relu(act)
            act = tf.transpose(act, [2, 1, 0])

            prev_act = act
            for i in range(self.top_delta_layers):
                mod = snt.Conv1D(output_channels=self.top_delta_size,
                                 kernel_shape=[3])
                act = mod(act)

                act = snt.BatchNorm(axis=[0, 1])(act, is_training=False)
                act = tf.nn.relu(act)

                prev_act = act

            mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3])
            act = mod(act)

            # [bs, feature_channels, delta_channels]
            act = tf.transpose(act, [1, 0, 2])
            return act
Ejemplo n.º 8
0
    def _build(self, inputs):
        """Build the graph for this configuration.

       Args:
         inputs: input node (already preprocessed)

       Returns:
         A dict of outputs that includes the 'predictions', 'loss', the 'encoding',
         the 'quantized_input', and whatever metrics we want to track for eval.
       """

        # https://deepmind.github.io/sonnet/_modules/sonnet/python/modules/conv.html#Conv1D

        ###
        # The Non-Causal Temporal Encoder.
        ###
        ae_startconv = snt.Conv1D(output_channels=self._e_hidden_channels,
                                  kernel_shape=self._filter_length,
                                  name='ae_startconv')

        en = ae_startconv(inputs)

        for num_layer in range(self._num_layers):
            dilation = 2**(num_layer % self._num_layers_per_stage)
            d = tf.nn.relu(en)

            ae_dilatedconv = snt.Conv1D(
                output_channels=self._e_hidden_channels,
                kernel_shape=self._filter_length,
                rate=dilation,
                name='ae_dilatedconv_%d' % (num_layer + 1))

            d = ae_dilatedconv(d)

            d = tf.nn.relu(d)

            ae_res = snt.Conv1D(output_channels=self._e_hidden_channels,
                                kernel_shape=1,
                                padding=snt.CAUSAL,
                                name='ae_res_%d' % (num_layer + 1))

            en += ae_res(d)

        ae_bottleneck = snt.Conv1D(output_channels=self._latent_channels,
                                   kernel_shape=1,
                                   padding=snt.CAUSAL,
                                   name='ae_bottleneck')

        en = ae_bottleneck(en)

        return en
Ejemplo n.º 9
0
    def _build(self, inputs, is_training=True):
        if EncodeProcessDecode_v1.convnet_tanh:
            activation = tf.nn.tanh
        else:
            activation = tf.nn.relu
        # input shape is (batch_size, feature_length) but CNN operates on depth channels --> (batch_size, feature_length, 1)
        inputs = tf.expand_dims(inputs, axis=2)
        ''' layer 1'''
        outputs = snt.Conv1D(output_channels=12,
                             kernel_shape=10, stride=2)(inputs)

        outputs = snt.BatchNorm()(outputs, is_training=is_training)
        if EncodeProcessDecode_v1.convnet_pooling:
            outputs = tf.layers.max_pooling1d(outputs, 2, 2)
        outputs = activation(outputs)
        #print(outputs.get_shape())

        ''' layer 2'''
        outputs = snt.Conv1D(output_channels=12,
                             kernel_shape=10, stride=2)(outputs)
        outputs = snt.BatchNorm()(outputs, is_training=is_training)
        if EncodeProcessDecode_v1.convnet_pooling:
            outputs = tf.layers.max_pooling1d(outputs, 2, 2)
        outputs = activation(outputs)
        #print(outputs.get_shape())

        ''' layer 3'''
        outputs = snt.Conv1D(output_channels=12,
                             kernel_shape=10, stride=2)(outputs)
        outputs = snt.BatchNorm()(outputs, is_training=is_training)
        if EncodeProcessDecode_v1.convnet_pooling:
            outputs = tf.layers.max_pooling1d(outputs, 2, 2)
        outputs = activation(outputs)
        #print(outputs.get_shape())

        ''' layer 4'''
        outputs = snt.Conv1D(output_channels=12,
                             kernel_shape=10, stride=2)(outputs)
        outputs = snt.BatchNorm()(outputs, is_training=is_training)  # todo: deal with train/test time
        if EncodeProcessDecode_v1.convnet_pooling:
            outputs = tf.layers.max_pooling1d(outputs, 2, 2)
        outputs = activation(outputs)
        #print(outputs.get_shape())

        ''' layer 5'''
        outputs = snt.BatchFlatten()(outputs)
        #outputs = tf.nn.dropout(outputs, keep_prob=tf.constant(1.0)) # todo: deal with train/test time
        outputs = snt.Linear(output_size=EncodeProcessDecode_v1.dimensions_latent_repr)(outputs)
        #print(outputs.get_shape())
        return outputs
Ejemplo n.º 10
0
 def __init__(
     self,
     enc_conf,
     create_scale=False,
     create_offset=False,
     name="ScalarConv1D",
 ):
     super(ScalarConv1D, self).__init__(name=name)
     self._pools = []
     self._convs = []
     for conv_conf, pool_conf in enc_conf[:-1]:
         self._pools.append(
             partial(
                 tf.nn.max_pool1d,
                 ksize=pool_conf[0],
                 strides=pool_conf[1],
                 padding=pool_conf[2],
             ))
         self._convs.append(
             snt.Conv1D(
                 output_channels=conv_conf[0],
                 kernel_shape=conv_conf[1],
                 stride=conv_conf[2],
                 padding=conv_conf[3],
             ))
     self._mlp = make_norm_mlp_model(enc_conf[-1], create_scale,
                                     create_offset)
Ejemplo n.º 11
0
    def __init__(self,
                 filter_width,
                 num_hidden,
                 pool_type='fo',
                 zone_out=0.0,
                 name="quasi_rnn"):
        """Constructs a quasi_rnn.

        Args:
            filter_width: Filter width of the convolutional component.
            num_hidden: Number of hidden units in each RNN layer.
            pool_type: Types of pool component. (f-pooling, fo-pooling, ifo-pooling)
            zoon_out: The dropout rate.
            name: Name of the module.
        """

        super(QuasiRNN, self).__init__(name=name)

        self._filter_width = filter_width
        self._num_hidden = num_hidden
        assert pool_type in ['f', 'fo', 'ifo']
        self._pool_type = pool_type
        self._zone_out = zone_out

        with self._enter_variable_scope():
            self._cnn_component = snt.Conv1D(output_channels=self._num_hidden *
                                             (len(self._pool_type) + 1),
                                             kernel_shape=self._filter_width,
                                             padding="VALID",
                                             name="cnn_component")

            self._pooling_component = RecurrentPooling(self._num_hidden,
                                                       self._pool_type)
 def testConv1dSymbolicBounds(self):
     m = snt.Conv1D(output_channels=1,
                    kernel_shape=(2),
                    padding='VALID',
                    stride=1,
                    use_bias=True,
                    initializers={
                        'w': tf.constant_initializer(1.),
                        'b': tf.constant_initializer(3.),
                    })
     z = tf.constant([3, 4], dtype=tf.float32)
     z = tf.reshape(z, [1, 2, 1])
     m(z)  # Connect to create weights.
     m = ibp.LinearConv1dWrapper(m)
     input_bounds = ibp.IntervalBounds(z - 1., z + 1.)
     input_bounds = ibp.SymbolicBounds.convert(input_bounds)
     output_bounds = m.propagate_bounds(input_bounds)
     output_bounds = ibp.IntervalBounds.convert(output_bounds)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         l, u = sess.run([output_bounds.lower, output_bounds.upper])
         l = l.item()
         u = u.item()
         self.assertAlmostEqual(8., l)
         self.assertAlmostEqual(12., u)
Ejemplo n.º 13
0
    def __init__(self,
                 sampling_rate,
                 filter_size=3,
                 num_filters=32,
                 pooling_stride=2,
                 pool='avg',
                 act='elu',
                 name="classifier"):
        super(Classifier, self).__init__(name=name)

        num_classes = 2

        self._act = Activation(act, verbose=True)
        self._pool = Downsample1D(2)
        self._bf = snt.BatchFlatten()

        regularizers = {
            "w": tf.contrib.layers.l2_regularizer(scale=0.1),
            "b": tf.contrib.layers.l2_regularizer(scale=0.1)
        }

        with self._enter_variable_scope():
            self._l1_conv = snt.Conv1D(num_filters, filter_size + 2)
            self._l2_sepconv = snt.SeparableConv1D(num_filters << 1, 1,
                                                   filter_size)
            self._lin1 = snt.Linear(256, regularizers=regularizers)
            self._lin2 = snt.Linear(num_classes, regularizers=regularizers)
 def _build(self, x):
     # x is [units, bs, 1]
     net = tf.transpose(x, [1, 0, 2])  # now [bs x units x 1]
     channels = x.shape.as_list()[2]
     mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
     net = mod(net)
     net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
     net = tf.nn.relu(net)
     mod = snt.Conv1D(output_channels=channels, kernel_shape=[3])
     net = mod(net)
     net = snt.BatchNorm(axis=[0, 1])(net, is_training=False)
     net = tf.nn.relu(net)
     to_concat = tf.transpose(net, [1, 0, 2])
     if self.add:
         return x + to_concat
     else:
         return tf.concat([x, to_concat], 2)
Ejemplo n.º 15
0
 def conv_layer(self, h, num_units, kernel_size, stride, name, padding=snt.SAME):
     h_i = snt.Conv1D(
     output_channels = num_units,
     kernel_shape = kernel_size,
     stride = stride,
     padding = padding,
     data_format='NCW',
     name = name)(h)
     return h_i
Ejemplo n.º 16
0
    def prep_vq(self, embedding_dim):

        pre_vq_conv1 = snt.Conv1D(
            output_channels=self.config.vqvae_embedding_dim,
            kernel_shape=1,
            stride=1,
            name="linear_to_vq_dim")

        return pre_vq_conv1
Ejemplo n.º 17
0
    def conv_layer(self, h, num_units, kernel_shape, stride, name):

        h_i = snt.Conv1D(
            output_channels=num_units,
            kernel_shape=kernel_shape,  #
            stride=stride,
            data_format='NWC',
            name=name)(h)

        return h_i
Ejemplo n.º 18
0
 def conv_block(filters,
                width=1,
                w_init=None,
                name='conv_block',
                **kwargs):
     return Sequential(lambda: [
         snt.BatchNorm(create_scale=True,
                       create_offset=True,
                       decay_rate=0.9,
                       scale_init=snt.initializers.Ones()), gelu,
         snt.Conv1D(filters, width, w_init=w_init, **kwargs)
     ],
                       name=name)
Ejemplo n.º 19
0
    def __init__(self,
                 params,
                 train_initial_state=True,
                 residual=False,
                 dense=False,
                 name=None):
        super(CausalConv1D, self).__init__(name=name)
        self._params = Struct.make(params)
        self._train_initial_state = train_initial_state
        self._input_channels = None

        if residual and dense:
            raise ValueError('Cannot have residual and dense connections!')
        if residual not in {False, 'up', 'down'}:
            raise ValueError('residual should one of [False, "up", "down"]')
        self._residual, self._dense = residual, dense

        with self._enter_variable_scope():
            self._cores = Struct(
                conv_f=snt.Conv1D(name='conv_xf', padding=snt.VALID, **params),
                conv_g=snt.Conv1D(name='conv_xg', padding=snt.VALID, **params),
                lin_f=snt.Linear(self._params.output_channels, name='lin_zf'),
                lin_g=snt.Linear(self._params.output_channels, name='lin_zg'),
            )
Ejemplo n.º 20
0
    def conv_with_residual(self, h, num_units, kernel_shape, stride, name):

        h_i = snt.Conv1D(
            output_channels=num_units,
            kernel_shape=kernel_shape,  #
            stride=stride,
            data_format='NWC',
            name=name)(h)

        h_i = tf.nn.relu(h_i)
        # print("this conv_with_residual is h_i {}".format(h_i))
        # print("this conv_with_residual is residual {}".format(h))
        h += h_i

        return h
Ejemplo n.º 21
0
 def func(name, data_format, custom_getter=None):
   conv = snt.Conv1D(
       name=name,
       output_channels=self.OUT_CHANNELS,
       kernel_shape=self.KERNEL_SHAPE,
       use_bias=use_bias,
       initializers=create_initializers(use_bias),
       data_format=data_format,
       custom_getter=custom_getter)
   if data_format == "NWC":
     batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None)
   else:  # data_format = "NCW"
     batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None,
                                axis=(0, 2))
   return snt.Sequential([conv,
                          functools.partial(batch_norm, is_training=True)])
Ejemplo n.º 22
0
    def reduce_dimension_layer(self, inputs):
        '''
        reduces the inputs based on the self._dim_reduction constant
        Args:
            inputs (tf.Tensor): inputs of shape (batch_size, signal_length, latent_channels)

        Returns:
            tf.Tensor: reduced inputs based on self._hop_length
                       with shape (batch_size, signal_length // self._hop_length, latent_channels)
        '''
        if self._dim_reduction == DIM_REDUCTION_MAX_POOL:
            reduced = pool1d(inputs, self._hop_length, name='ae_pool', mode='max')

        elif self._dim_reduction == DIM_REDUCTION_AVG_POOL:
            reduced = pool1d(inputs, self._hop_length, name='ae_pool', mode='avg')

        elif self._dim_reduction == DIM_REDUCTION_CONV:
            # reduced = tf.layers.Conv1D(filters=self._latent_channels,
            #                            kernel_size=self._hop_length,
            #                            strides=self._hop_length,
            #                            padding='same')(inputs)

            reduced = snt.Conv1D(output_channels=self._latent_channels,
                                kernel_shape=self._hop_length,
                                stride=self._hop_length,
                                padding=snt.CAUSAL,
                                name='dim_reduction_conv')(inputs)

        elif self._dim_reduction == DIM_REDUCTION_LINEAR:
            # because Dense can only be applied on last dimension we have to change the order of dimensions
            reduced = tf.transpose(inputs, perm=[0, 2, 1])
            # only transposing dimensions won;t help because None shape dimensions so we have to reshape
            # (bs, channels, signal//hop_len, hop_length) and apply linear layer of 1 unit to get
            # -> (bs, channels, signal//hop_len, 1) kudos @Csongor
            reduced_reshaped = tf.reshape(reduced, [tf.shape(inputs)[0], inputs.shape[-1], -1, self._hop_length])
            reduced = tf.layers.Dense(units=1)(reduced_reshaped)
            reduced = tf.squeeze(reduced, axis=-1)
            reduced = tf.transpose(reduced, perm=[0, 2, 1])

        else:
            raise ValueError('Can\'t recognize type of dimensionality reduction method: \'{}\' '
                             '(change network architecture -> dim_reduction: <max_pool | avg_pool | conv | linear>'
                             .format(self._dim_reduction)
                             )

        return reduced
 def __init__(self,
              out_channel,
              dec_channel,
              num_res_blocks,
              residual_hiddens,
              the_stride,
              name='decoder'):
     super(Decoder, self).__init__()
     self.model = snt.Sequential([
         snt.Conv1D(output_channels=dec_channel,
                    kernel_shape=3,
                    stride=1,
                    name="dec_0"),
         tf.nn.relu,
     ])
     for _ in range(num_res_blocks):
         self.model = snt.Sequential(
             [self.model,
              residual_block(dec_channel, residual_hiddens)])
     if the_stride == 4:
         self.model = snt.Sequential([
             self.model,
             snt.Conv1DTranspose(output_channels=dec_channel // 2,
                                 output_shape=None,
                                 kernel_shape=4,
                                 stride=2,
                                 name="dec_1"),
             tf.nn.relu,
             snt.Conv1DTranspose(output_channels=out_channel,
                                 output_shape=None,
                                 kernel_shape=4,
                                 stride=2,
                                 name="dec_2"),
         ])
     elif the_stride == 2:
         self.model = snt.Sequential([
             self.model,
             snt.Conv1DTranspose(output_channels=out_channel,
                                 output_shape=None,
                                 kernel_shape=4,
                                 stride=2,
                                 name="dec_1"),
         ])
Ejemplo n.º 24
0
  def _build(self, padded_word_embeddings, length):
    x = padded_word_embeddings
    for layer in self._config['conv_architecture']:
      if isinstance(layer, tuple) or isinstance(layer, list):
        filters, kernel_size, pooling_size = layer
        conv = snt.Conv1D(
            output_channels=filters,
            kernel_shape=kernel_size)
        x = conv(x)
        if pooling_size and pooling_size > 1:
          x = _max_pool_1d(x, pooling_size)
      elif layer == 'relu':
        x = tf.nn.relu(x)
        if self._keep_prob < 1:
          x = tf.nn.dropout(x, keep_prob=self._keep_prob)
      else:
        raise RuntimeError('Bad layer type {} in conv'.format(layer))
    # Final layer pools over the remaining sequence length to get a
    # fixed sized vector.
    if self._pooling == 'max':
      x = tf.reduce_max(x, axis=1)
    elif self._pooling == 'average':
      x = tf.reduce_sum(x, axis=1)
      lengths = tf.expand_dims(tf.cast(length, tf.float32), axis=1)
      x = x / lengths

    if self._config['conv_fc1']:
      fc1_layer = snt.Linear(output_size=self._config['conv_fc1'])
      x = tf.nn.relu(fc1_layer(x))
      if self._keep_prob < 1:
        x = tf.nn.dropout(x, keep_prob=self._keep_prob)
    if self._config['conv_fc2']:
      fc2_layer = snt.Linear(output_size=self._config['conv_fc2'])
      x = tf.nn.relu(fc2_layer(x))
      if self._keep_prob < 1:
        x = tf.nn.dropout(x, keep_prob=self._keep_prob)

    return x
Ejemplo n.º 25
0
    def __init__(self,
                 output_channels: int,
                 kernel_shape: int,
                 stride: int = 1,
                 name: str = 'conv_1d_periodic'):
        """Constructs Conv1dPeriodic moduel.

      Args:
      output_channels: Number of channels in convolution.
      kernel_shape: Convolution kernel sizes.
      stride: Convolution stride.
      name: Name of the module.
    """
        super(Conv1dPeriodic, self).__init__(name=name)
        self._output_channels = output_channels
        self._kernel_shape = kernel_shape
        self._stride = stride
        with self._enter_variable_scope():
            self._conv_1d_module = snt.Conv1D(
                output_channels=self._output_channels,
                kernel_shape=self._kernel_shape,
                stride=self._stride,
                padding=snt.VALID)
    def _build(self, x_shifted, en, conditioning=None):

        self._nodes_list = []

        startconv = snt.Conv1D(output_channels=self._d_hidden_channels,
                               kernel_shape=self._filter_length,
                               padding=snt.CAUSAL,
                               name='startconv')

        skip_start = snt.Conv1D(output_channels=self._skip_channels,
                                kernel_shape=1,
                                padding=snt.CAUSAL,
                                name='skip_start')

        if self._prob_dropout_decoder_tf > 0:
            l = tf.nn.dropout(x_shifted, rate=self._prob_dropout_decoder_tf)
            l = startconv(x_shifted)
        else:
            l = startconv(x_shifted)

        self._nodes_list.append(l)

        # Set up skip connections.
        s = skip_start(l)
        self._nodes_list.append(s)

        # Residual blocks with skip connections.
        for i in range(self._num_layers):
            dilation = 2**(i % self._num_layers_per_stage)

            causal_convolution = snt.Conv1D(output_channels=2 *
                                            self._d_hidden_channels,
                                            kernel_shape=[self._filter_length],
                                            rate=dilation,
                                            padding=snt.CAUSAL,
                                            name='dilatedconv_%d' % (i + 1))

            dil = causal_convolution(l)
            self._nodes_list.append(dil)

            encoding_convolution = snt.Conv1D(output_channels=2 *
                                              self._d_hidden_channels,
                                              kernel_shape=1,
                                              padding=snt.CAUSAL,
                                              name='cond_map_%d' % (i + 1))

            enc = encoding_convolution(en)
            dil = condition(dil, enc)
            self._nodes_list.append(dil)

            assert dil.get_shape().as_list()[2] % 2 == 0
            m = dil.get_shape().as_list()[2] // 2
            d_sigmoid = tf.sigmoid(dil[:, :, :m])
            d_tanh = tf.tanh(dil[:, :, m:])
            dil = d_sigmoid * d_tanh

            res_convolution = snt.Conv1D(
                output_channels=self._d_hidden_channels,
                kernel_shape=1,
                padding=snt.CAUSAL,
                name='res_%d' % (i + 1))

            l += res_convolution(dil)
            self._nodes_list.append(l)

            skip_conv = snt.Conv1D(output_channels=self._skip_channels,
                                   kernel_shape=1,
                                   padding=snt.CAUSAL,
                                   name='skip_%d' % (i + 1))

            s += skip_conv(dil)
            self._nodes_list.append(s)

        s = tf.nn.relu(s)

        out = snt.Conv1D(output_channels=self._skip_channels,
                         kernel_shape=1,
                         padding=snt.CAUSAL,
                         name='out1')

        s = out(s)
        self._nodes_list.append(s)

        cond_map_out = snt.Conv1D(output_channels=self._skip_channels,
                                  kernel_shape=1,
                                  padding=snt.CAUSAL,
                                  name='cond_map_out1')

        s = condition(s, cond_map_out(en))

        s = tf.nn.relu(s)
        self._nodes_list.append(s)

        logits_conv = snt.Conv1D(output_channels=256,
                                 kernel_shape=1,
                                 padding=snt.CAUSAL,
                                 name='logits')

        self.logits = logits_conv(s)

        # self.logits = tf.reshape(self.logits, [-1, 256])
        # self._probs = tf.nn.softmax(self.logits, name='softmax')
        self.dec_distr = tfd.Categorical(logits=self.logits)
        # dimension of rec_sample should be: [bs, length, 1]

        self.reconstruction = inv_mu_law(
            tf.expand_dims(self.dec_distr.mode(), axis=-1) - 128)
        # self.reconstruction = inv_mu_law(tf.expand_dims(self.dec_distr.mean(), axis=-1) - 128)

        return self.dec_distr, self.reconstruction
Ejemplo n.º 27
0
    def __init__(self,
                 channels: int = 1536,
                 num_transformer_layers: int = 11,
                 num_heads: int = 8,
                 pooling_type: str = 'attention',
                 name: str = 'enformer'):
        """Enformer model.

    Args:
      channels: Number of convolutional filters and the overall 'width' of the
        model.
      num_transformer_layers: Number of transformer layers.
      num_heads: Number of attention heads.
      pooling_type: Which pooling function to use. Options: 'attention' or max'.
      name: Name of sonnet module.
    """
        super().__init__(name=name)
        # pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop
        heads_channels = {'human': 5313, 'mouse': 1643}
        dropout_rate = 0.4
        assert channels % num_heads == 0, ('channels needs to be divisible '
                                           f'by {num_heads}')
        whole_attention_kwargs = {
            'attention_dropout_rate':
            0.05,
            'initializer':
            None,
            'key_size':
            64,
            'num_heads':
            num_heads,
            'num_relative_position_features':
            channels // num_heads,
            'positional_dropout_rate':
            0.01,
            'relative_position_functions': [
                'positional_features_exponential',
                'positional_features_central_mask', 'positional_features_gamma'
            ],
            'relative_positions':
            True,
            'scaling':
            True,
            'value_size':
            channels // num_heads,
            'zero_initialize':
            True
        }

        trunk_name_scope = tf.name_scope('trunk')
        trunk_name_scope.__enter__()

        # lambda is used in Sequential to construct the module under tf.name_scope.
        def conv_block(filters,
                       width=1,
                       w_init=None,
                       name='conv_block',
                       **kwargs):
            return Sequential(lambda: [
                snt.BatchNorm(create_scale=True,
                              create_offset=True,
                              decay_rate=0.9,
                              scale_init=snt.initializers.Ones()), gelu,
                snt.Conv1D(filters, width, w_init=w_init, **kwargs)
            ],
                              name=name)

        stem = Sequential(lambda: [
            snt.Conv1D(channels // 2, 15),
            Residual(conv_block(channels // 2, 1, name='pointwise_conv_block')
                     ),
            pooling_module(pooling_type, pool_size=2),
        ],
                          name='stem')

        filter_list = exponential_linspace_int(start=channels // 2,
                                               end=channels,
                                               num=6,
                                               divisible_by=128)
        conv_tower = Sequential(lambda: [
            Sequential(lambda: [
                conv_block(num_filters, 5),
                Residual(
                    conv_block(num_filters, 1, name='pointwise_conv_block')),
                pooling_module(pooling_type, pool_size=2),
            ],
                       name=f'conv_tower_block_{i}')
            for i, num_filters in enumerate(filter_list)
        ],
                                name='conv_tower')

        # Transformer.
        def transformer_mlp():
            return Sequential(lambda: [
                snt.LayerNorm(axis=-1, create_scale=True, create_offset=True),
                snt.Linear(channels * 2),
                snt.Dropout(dropout_rate), tf.nn.relu,
                snt.Linear(channels),
                snt.Dropout(dropout_rate)
            ],
                              name='mlp')

        transformer = Sequential(lambda: [
            Sequential(lambda: [
                Residual(
                    Sequential(lambda: [
                        snt.LayerNorm(axis=-1,
                                      create_scale=True,
                                      create_offset=True,
                                      scale_init=snt.initializers.Ones()),
                        attention_module.MultiheadAttention(
                            **whole_attention_kwargs, name=f'attention_{i}'),
                        snt.Dropout(dropout_rate)
                    ],
                               name='mha')),
                Residual(transformer_mlp())
            ],
                       name=f'transformer_block_{i}')
            for i in range(num_transformer_layers)
        ],
                                 name='transformer')

        crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input')

        final_pointwise = Sequential(
            lambda:
            [conv_block(channels * 2, 1),
             snt.Dropout(dropout_rate / 8), gelu],
            name='final_pointwise')

        self._trunk = Sequential(
            [stem, conv_tower, transformer, crop_final, final_pointwise],
            name='trunk')
        trunk_name_scope.__exit__(None, None, None)

        with tf.name_scope('heads'):
            self._heads = {
                head:
                Sequential(lambda: [snt.Linear(num_channels), tf.nn.softplus],
                           name=f'head_{head}')
                for head, num_channels in heads_channels.items()
            }
    def compute_h(self,
                  x,
                  z,
                  d,
                  bias,
                  W_bot,
                  W_top,
                  compute_perc=1.0,
                  compute_units=None):
        """z = [BS, n_units] a = [BS, n_units] b = [BS, n_units] d = [BS, n_units, delta_channels]

    """

        s_idx = 0
        if compute_perc != 1.0:
            assert compute_units is None

        with tf.device(self.remote_device):
            inp_feat = [x, z]
            inp_feat = [tf.transpose(f, [1, 0]) for f in inp_feat]

            units = x.shape.as_list()[1]
            bs = x.shape.as_list()[0]

            # add unit ID, to help the network differentiate units
            id_theta = tf.linspace(0., (4) * np.pi, units)
            assert bs is not None
            id_theta_bs = tf.reshape(id_theta, [-1, 1]) * tf.ones([1, bs])
            inp_feat += [tf.sin(id_theta_bs), tf.cos(id_theta_bs)]

            # list of [units, BS, 1]
            inp_feat = [tf.expand_dims(f, 2) for f in inp_feat]

            d_trans = tf.transpose(d, [1, 0, 2])

            if compute_perc != 1.0:
                compute_units = int(compute_perc * inp_feat.shape.as_list()[0])

            # add weight matrix statistics, both from above and below
            w_stats_bot = get_weight_stats(W_bot, 0)
            w_stats_top = get_weight_stats(W_top, 1)
            w_stats = w_stats_bot + w_stats_top
            if W_bot is None or W_top is None:
                # if it's an edge layer (top or bottom), just duplicate the stats for
                # the weight matrix that does exist
                w_stats = w_stats + w_stats
            w_stats = [tf.ones([1, x.shape[0], 1]) * ww for ww in w_stats]
            # w_stats is a list, with entries with shape UNITS x 1 x channels

            if compute_units is None:
                inp_feat_in = inp_feat
                d_trans_in = d_trans
                w_stats_in = w_stats
                bias_in = tf.transpose(bias)
            else:
                # only run on a subset of the activations.
                mask = tf.random_uniform(
                    minval=0,
                    maxval=1,
                    dtype=tf.float32,
                    shape=inp_feat[0].shape.as_list()[0:1])
                _, ind = tf.nn.top_k(mask, k=compute_units)
                ind = tf.reshape(ind, [-1, 1])

                inp_feat_in = [tf.gather_nd(xx, ind) for xx in inp_feat]
                w_stats_in = [tf.gather_nd(xx, ind) for xx in w_stats]
                d_trans_in = tf.gather_nd(d_trans, ind)
                bias_in = tf.gather_nd(tf.transpose(bias), ind)

            w_stats_in = tf.concat(w_stats_in, 2)
            w_stats_in_norm = w_stats_in * tf.rsqrt(
                tf.reduce_mean(w_stats_in**2) + 1e-6)

            act = tf.concat(inp_feat_in + [d_trans_in], 2)
            act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)

            bias_dense = tf.reshape(bias_in, [-1, 1, 1]) * tf.ones([1, bs, 1])
            act = tf.concat([w_stats_in_norm, bias_dense, act], 2)

            mod = snt.Conv1D(output_channels=self.compute_h_size,
                             kernel_shape=[3])
            act = mod(act)

            act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)
            act = tf.nn.relu(act)

            act2 = ConcatUnitConv()(act)
            act = act2

            prev_act = act
            for i in range(self.compute_h_layers):
                mod = snt.Conv1D(output_channels=self.compute_h_size,
                                 kernel_shape=[3])
                act = mod(act)

                act = snt.BatchNorm(axis=[0, 1])(act, is_training=True)
                act = tf.nn.relu(act)

                act = ConcatUnitConv()(act)

                prev_act = act

            h = act
            if compute_units is not None:
                shape = inp_feat[0].shape.as_list()[:1] + h.shape.as_list()[1:]
                h = tf.scatter_nd(ind, h, shape=shape)

            h = tf.transpose(h, [1, 0, 2])  # [bs, units, channels]

            return h