def __init__(self, main_channel, num_res_blocks, residual_hiddens, the_stride): super(Encoder, self).__init__() if the_stride == 4: self.model = snt.Sequential([ snt.Conv1D(output_channels=main_channel, kernel_shape=4, stride=2, name="enc_0"), tf.nn.relu, snt.Conv1D(output_channels=main_channel, kernel_shape=4, stride=2, name="enc_1"), tf.nn.relu, snt.Conv1D(output_channels=main_channel, kernel_shape=3, stride=1, name="enc_2"), tf.nn.relu ]) elif the_stride == 2: self.model = snt.Sequential([ snt.Conv1D(output_channels=main_channel, kernel_shape=4, stride=2, name="enc_0"), tf.nn.relu, snt.Conv1D(output_channels=main_channel, kernel_shape=3, stride=1, name="enc_1"), tf.nn.relu ]) for _ in range(num_res_blocks): self.model = snt.Sequential( [self.model, residual_block(main_channel, residual_hiddens)])
def variational_layer(self): if self._learn_prior == LEARN_PRIOR_CONV: self.en_shifted = shift_right(self.en) # POSTERIOR self.en_posterior = snt.Conv1D(output_channels=self._latent_channels, kernel_shape=1, stride=1, padding=snt.CAUSAL, name='conv_en_posterior')(self.en_shifted) self.input_conv = snt.Conv1D(output_channels=self._latent_channels, kernel_shape=self._hop_length, stride=self._hop_length, padding=snt.CAUSAL, name='conv_en_input')(self.x_quant_scaled) self.en_posterior = tf.concat([self.en_posterior, self.input_conv], axis=-1) self._gaussian_model_latent = self.get_gaussian(self.en_posterior) # PRIOR self.en_prior = snt.Conv1D(output_channels=self._latent_channels, kernel_shape=1, stride=1, padding=snt.CAUSAL, name='conv_en_prior_from_input')(shift_right(self.x_quant_scaled)) self.en_prior = pool1d(self.en_prior, self._hop_length, name='ae_pool_prior', mode='avg') self._prior = self.get_gaussian(self.en_prior) else: self._gaussian_model_latent = self.get_gaussian(self.en) self._create_prior()
def residual_block(main_channel, residual_hiddens): output = snt.Sequential([ snt.Conv1D(output_channels=residual_hiddens, kernel_shape=3, stride=1), tf.nn.relu, snt.Conv1D(output_channels=main_channel, kernel_shape=1, stride=1), tf.nn.relu, ]) return output
def pointwise(x): hidden = snt.Conv1D(output_channels=output_size, kernel_shape=1)(x) if self.dropout_rate > 0.0: hidden = tf.layers.dropout(hidden, self.dropout_rate) hidden = tf.nn.relu(hidden) outputs = snt.Conv1D(output_channels=output_size, kernel_shape=1)(hidden) if self.dropout_rate > 0.0: outputs = tf.layers.dropout(outputs, self.dropout_rate) return tf.nn.relu(outputs)
def __init__(self, num_filters=32, filter_size=5, act='', name="adaptor"): super(Adaptor, self).__init__(name=name) self._bf = snt.BatchFlatten() self._pool = Downsample1D(2) self._act = Activation(act, verbose=True) with self._enter_variable_scope(): self._l1_conv = snt.Conv1D(num_filters, filter_size + 2) self._l2_conv = snt.Conv1D(num_filters << 1, filter_size) self._l3_conv = snt.Conv1D(num_filters << 2, filter_size - 2)
def __init__( self, in_channel=1, main_channel=32, num_res_blocks=2, residual_hiddens=8, embed_dim=16, n_embed=256, decay=0.99, commitment_cost=0.25, ): super(VAEModel, self).__init__() self.enc_b = Encoder(main_channel, num_res_blocks, residual_hiddens, the_stride=4) self.enc_t = Encoder(main_channel, num_res_blocks, residual_hiddens, the_stride=2) self.quantize_conv_t = snt.Conv1D(output_channels=embed_dim, kernel_shape=1, stride=1, name="enc_1") self.vq_t = snt.nets.VectorQuantizerEMA( embedding_dim=embed_dim, num_embeddings=n_embed, commitment_cost=commitment_cost, decay=decay) self.dec_t = Decoder(embed_dim, main_channel, num_res_blocks, residual_hiddens, the_stride=2) self.quantize_conv_b = snt.Conv1D(output_channels=embed_dim, kernel_shape=1, stride=1, name="enc_1") self.vq_b = snt.nets.VectorQuantizerEMA( embedding_dim=embed_dim, num_embeddings=n_embed, commitment_cost=commitment_cost, decay=decay) self.upsample_t = snt.Conv1DTranspose(output_channels=embed_dim, output_shape=None, kernel_shape=4, stride=2, name="up_1") self.dec = Decoder(in_channel, main_channel, num_res_blocks, residual_hiddens, the_stride=4)
def compute_top_delta(self, z): """ parameterization of topD. This converts the top level activation to an error signal. Args: z: tf.Tensor batch of final layer post activations Returns delta: tf.Tensor the error signal """ s_idx = 0 with tf.variable_scope('compute_top_delta'), tf.device( self.remote_device): # typically this takes [BS, length, input_channels], # We are applying this such that we convolve over the batch dimension. act = tf.expand_dims(tf.transpose(z, [1, 0]), 2) # [channels, BS, 1] mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[5]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) bs = act.shape.as_list()[0] act = tf.transpose(act, [2, 1, 0]) act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) act = snt.Conv1D(output_channels=bs, kernel_shape=[3])(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) act = tf.transpose(act, [2, 1, 0]) prev_act = act for i in range(self.top_delta_layers): mod = snt.Conv1D(output_channels=self.top_delta_size, kernel_shape=[3]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=False) act = tf.nn.relu(act) prev_act = act mod = snt.Conv1D(output_channels=self.delta_dim, kernel_shape=[3]) act = mod(act) # [bs, feature_channels, delta_channels] act = tf.transpose(act, [1, 0, 2]) return act
def _build(self, inputs): """Build the graph for this configuration. Args: inputs: input node (already preprocessed) Returns: A dict of outputs that includes the 'predictions', 'loss', the 'encoding', the 'quantized_input', and whatever metrics we want to track for eval. """ # https://deepmind.github.io/sonnet/_modules/sonnet/python/modules/conv.html#Conv1D ### # The Non-Causal Temporal Encoder. ### ae_startconv = snt.Conv1D(output_channels=self._e_hidden_channels, kernel_shape=self._filter_length, name='ae_startconv') en = ae_startconv(inputs) for num_layer in range(self._num_layers): dilation = 2**(num_layer % self._num_layers_per_stage) d = tf.nn.relu(en) ae_dilatedconv = snt.Conv1D( output_channels=self._e_hidden_channels, kernel_shape=self._filter_length, rate=dilation, name='ae_dilatedconv_%d' % (num_layer + 1)) d = ae_dilatedconv(d) d = tf.nn.relu(d) ae_res = snt.Conv1D(output_channels=self._e_hidden_channels, kernel_shape=1, padding=snt.CAUSAL, name='ae_res_%d' % (num_layer + 1)) en += ae_res(d) ae_bottleneck = snt.Conv1D(output_channels=self._latent_channels, kernel_shape=1, padding=snt.CAUSAL, name='ae_bottleneck') en = ae_bottleneck(en) return en
def _build(self, inputs, is_training=True): if EncodeProcessDecode_v1.convnet_tanh: activation = tf.nn.tanh else: activation = tf.nn.relu # input shape is (batch_size, feature_length) but CNN operates on depth channels --> (batch_size, feature_length, 1) inputs = tf.expand_dims(inputs, axis=2) ''' layer 1''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(inputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 2''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 3''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 4''' outputs = snt.Conv1D(output_channels=12, kernel_shape=10, stride=2)(outputs) outputs = snt.BatchNorm()(outputs, is_training=is_training) # todo: deal with train/test time if EncodeProcessDecode_v1.convnet_pooling: outputs = tf.layers.max_pooling1d(outputs, 2, 2) outputs = activation(outputs) #print(outputs.get_shape()) ''' layer 5''' outputs = snt.BatchFlatten()(outputs) #outputs = tf.nn.dropout(outputs, keep_prob=tf.constant(1.0)) # todo: deal with train/test time outputs = snt.Linear(output_size=EncodeProcessDecode_v1.dimensions_latent_repr)(outputs) #print(outputs.get_shape()) return outputs
def __init__( self, enc_conf, create_scale=False, create_offset=False, name="ScalarConv1D", ): super(ScalarConv1D, self).__init__(name=name) self._pools = [] self._convs = [] for conv_conf, pool_conf in enc_conf[:-1]: self._pools.append( partial( tf.nn.max_pool1d, ksize=pool_conf[0], strides=pool_conf[1], padding=pool_conf[2], )) self._convs.append( snt.Conv1D( output_channels=conv_conf[0], kernel_shape=conv_conf[1], stride=conv_conf[2], padding=conv_conf[3], )) self._mlp = make_norm_mlp_model(enc_conf[-1], create_scale, create_offset)
def __init__(self, filter_width, num_hidden, pool_type='fo', zone_out=0.0, name="quasi_rnn"): """Constructs a quasi_rnn. Args: filter_width: Filter width of the convolutional component. num_hidden: Number of hidden units in each RNN layer. pool_type: Types of pool component. (f-pooling, fo-pooling, ifo-pooling) zoon_out: The dropout rate. name: Name of the module. """ super(QuasiRNN, self).__init__(name=name) self._filter_width = filter_width self._num_hidden = num_hidden assert pool_type in ['f', 'fo', 'ifo'] self._pool_type = pool_type self._zone_out = zone_out with self._enter_variable_scope(): self._cnn_component = snt.Conv1D(output_channels=self._num_hidden * (len(self._pool_type) + 1), kernel_shape=self._filter_width, padding="VALID", name="cnn_component") self._pooling_component = RecurrentPooling(self._num_hidden, self._pool_type)
def testConv1dSymbolicBounds(self): m = snt.Conv1D(output_channels=1, kernel_shape=(2), padding='VALID', stride=1, use_bias=True, initializers={ 'w': tf.constant_initializer(1.), 'b': tf.constant_initializer(3.), }) z = tf.constant([3, 4], dtype=tf.float32) z = tf.reshape(z, [1, 2, 1]) m(z) # Connect to create weights. m = ibp.LinearConv1dWrapper(m) input_bounds = ibp.IntervalBounds(z - 1., z + 1.) input_bounds = ibp.SymbolicBounds.convert(input_bounds) output_bounds = m.propagate_bounds(input_bounds) output_bounds = ibp.IntervalBounds.convert(output_bounds) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) l, u = sess.run([output_bounds.lower, output_bounds.upper]) l = l.item() u = u.item() self.assertAlmostEqual(8., l) self.assertAlmostEqual(12., u)
def __init__(self, sampling_rate, filter_size=3, num_filters=32, pooling_stride=2, pool='avg', act='elu', name="classifier"): super(Classifier, self).__init__(name=name) num_classes = 2 self._act = Activation(act, verbose=True) self._pool = Downsample1D(2) self._bf = snt.BatchFlatten() regularizers = { "w": tf.contrib.layers.l2_regularizer(scale=0.1), "b": tf.contrib.layers.l2_regularizer(scale=0.1) } with self._enter_variable_scope(): self._l1_conv = snt.Conv1D(num_filters, filter_size + 2) self._l2_sepconv = snt.SeparableConv1D(num_filters << 1, 1, filter_size) self._lin1 = snt.Linear(256, regularizers=regularizers) self._lin2 = snt.Linear(num_classes, regularizers=regularizers)
def _build(self, x): # x is [units, bs, 1] net = tf.transpose(x, [1, 0, 2]) # now [bs x units x 1] channels = x.shape.as_list()[2] mod = snt.Conv1D(output_channels=channels, kernel_shape=[3]) net = mod(net) net = snt.BatchNorm(axis=[0, 1])(net, is_training=False) net = tf.nn.relu(net) mod = snt.Conv1D(output_channels=channels, kernel_shape=[3]) net = mod(net) net = snt.BatchNorm(axis=[0, 1])(net, is_training=False) net = tf.nn.relu(net) to_concat = tf.transpose(net, [1, 0, 2]) if self.add: return x + to_concat else: return tf.concat([x, to_concat], 2)
def conv_layer(self, h, num_units, kernel_size, stride, name, padding=snt.SAME): h_i = snt.Conv1D( output_channels = num_units, kernel_shape = kernel_size, stride = stride, padding = padding, data_format='NCW', name = name)(h) return h_i
def prep_vq(self, embedding_dim): pre_vq_conv1 = snt.Conv1D( output_channels=self.config.vqvae_embedding_dim, kernel_shape=1, stride=1, name="linear_to_vq_dim") return pre_vq_conv1
def conv_layer(self, h, num_units, kernel_shape, stride, name): h_i = snt.Conv1D( output_channels=num_units, kernel_shape=kernel_shape, # stride=stride, data_format='NWC', name=name)(h) return h_i
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs): return Sequential(lambda: [ snt.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, scale_init=snt.initializers.Ones()), gelu, snt.Conv1D(filters, width, w_init=w_init, **kwargs) ], name=name)
def __init__(self, params, train_initial_state=True, residual=False, dense=False, name=None): super(CausalConv1D, self).__init__(name=name) self._params = Struct.make(params) self._train_initial_state = train_initial_state self._input_channels = None if residual and dense: raise ValueError('Cannot have residual and dense connections!') if residual not in {False, 'up', 'down'}: raise ValueError('residual should one of [False, "up", "down"]') self._residual, self._dense = residual, dense with self._enter_variable_scope(): self._cores = Struct( conv_f=snt.Conv1D(name='conv_xf', padding=snt.VALID, **params), conv_g=snt.Conv1D(name='conv_xg', padding=snt.VALID, **params), lin_f=snt.Linear(self._params.output_channels, name='lin_zf'), lin_g=snt.Linear(self._params.output_channels, name='lin_zg'), )
def conv_with_residual(self, h, num_units, kernel_shape, stride, name): h_i = snt.Conv1D( output_channels=num_units, kernel_shape=kernel_shape, # stride=stride, data_format='NWC', name=name)(h) h_i = tf.nn.relu(h_i) # print("this conv_with_residual is h_i {}".format(h_i)) # print("this conv_with_residual is residual {}".format(h)) h += h_i return h
def func(name, data_format, custom_getter=None): conv = snt.Conv1D( name=name, output_channels=self.OUT_CHANNELS, kernel_shape=self.KERNEL_SHAPE, use_bias=use_bias, initializers=create_initializers(use_bias), data_format=data_format, custom_getter=custom_getter) if data_format == "NWC": batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None) else: # data_format = "NCW" batch_norm = snt.BatchNorm(scale=True, update_ops_collection=None, axis=(0, 2)) return snt.Sequential([conv, functools.partial(batch_norm, is_training=True)])
def reduce_dimension_layer(self, inputs): ''' reduces the inputs based on the self._dim_reduction constant Args: inputs (tf.Tensor): inputs of shape (batch_size, signal_length, latent_channels) Returns: tf.Tensor: reduced inputs based on self._hop_length with shape (batch_size, signal_length // self._hop_length, latent_channels) ''' if self._dim_reduction == DIM_REDUCTION_MAX_POOL: reduced = pool1d(inputs, self._hop_length, name='ae_pool', mode='max') elif self._dim_reduction == DIM_REDUCTION_AVG_POOL: reduced = pool1d(inputs, self._hop_length, name='ae_pool', mode='avg') elif self._dim_reduction == DIM_REDUCTION_CONV: # reduced = tf.layers.Conv1D(filters=self._latent_channels, # kernel_size=self._hop_length, # strides=self._hop_length, # padding='same')(inputs) reduced = snt.Conv1D(output_channels=self._latent_channels, kernel_shape=self._hop_length, stride=self._hop_length, padding=snt.CAUSAL, name='dim_reduction_conv')(inputs) elif self._dim_reduction == DIM_REDUCTION_LINEAR: # because Dense can only be applied on last dimension we have to change the order of dimensions reduced = tf.transpose(inputs, perm=[0, 2, 1]) # only transposing dimensions won;t help because None shape dimensions so we have to reshape # (bs, channels, signal//hop_len, hop_length) and apply linear layer of 1 unit to get # -> (bs, channels, signal//hop_len, 1) kudos @Csongor reduced_reshaped = tf.reshape(reduced, [tf.shape(inputs)[0], inputs.shape[-1], -1, self._hop_length]) reduced = tf.layers.Dense(units=1)(reduced_reshaped) reduced = tf.squeeze(reduced, axis=-1) reduced = tf.transpose(reduced, perm=[0, 2, 1]) else: raise ValueError('Can\'t recognize type of dimensionality reduction method: \'{}\' ' '(change network architecture -> dim_reduction: <max_pool | avg_pool | conv | linear>' .format(self._dim_reduction) ) return reduced
def __init__(self, out_channel, dec_channel, num_res_blocks, residual_hiddens, the_stride, name='decoder'): super(Decoder, self).__init__() self.model = snt.Sequential([ snt.Conv1D(output_channels=dec_channel, kernel_shape=3, stride=1, name="dec_0"), tf.nn.relu, ]) for _ in range(num_res_blocks): self.model = snt.Sequential( [self.model, residual_block(dec_channel, residual_hiddens)]) if the_stride == 4: self.model = snt.Sequential([ self.model, snt.Conv1DTranspose(output_channels=dec_channel // 2, output_shape=None, kernel_shape=4, stride=2, name="dec_1"), tf.nn.relu, snt.Conv1DTranspose(output_channels=out_channel, output_shape=None, kernel_shape=4, stride=2, name="dec_2"), ]) elif the_stride == 2: self.model = snt.Sequential([ self.model, snt.Conv1DTranspose(output_channels=out_channel, output_shape=None, kernel_shape=4, stride=2, name="dec_1"), ])
def _build(self, padded_word_embeddings, length): x = padded_word_embeddings for layer in self._config['conv_architecture']: if isinstance(layer, tuple) or isinstance(layer, list): filters, kernel_size, pooling_size = layer conv = snt.Conv1D( output_channels=filters, kernel_shape=kernel_size) x = conv(x) if pooling_size and pooling_size > 1: x = _max_pool_1d(x, pooling_size) elif layer == 'relu': x = tf.nn.relu(x) if self._keep_prob < 1: x = tf.nn.dropout(x, keep_prob=self._keep_prob) else: raise RuntimeError('Bad layer type {} in conv'.format(layer)) # Final layer pools over the remaining sequence length to get a # fixed sized vector. if self._pooling == 'max': x = tf.reduce_max(x, axis=1) elif self._pooling == 'average': x = tf.reduce_sum(x, axis=1) lengths = tf.expand_dims(tf.cast(length, tf.float32), axis=1) x = x / lengths if self._config['conv_fc1']: fc1_layer = snt.Linear(output_size=self._config['conv_fc1']) x = tf.nn.relu(fc1_layer(x)) if self._keep_prob < 1: x = tf.nn.dropout(x, keep_prob=self._keep_prob) if self._config['conv_fc2']: fc2_layer = snt.Linear(output_size=self._config['conv_fc2']) x = tf.nn.relu(fc2_layer(x)) if self._keep_prob < 1: x = tf.nn.dropout(x, keep_prob=self._keep_prob) return x
def __init__(self, output_channels: int, kernel_shape: int, stride: int = 1, name: str = 'conv_1d_periodic'): """Constructs Conv1dPeriodic moduel. Args: output_channels: Number of channels in convolution. kernel_shape: Convolution kernel sizes. stride: Convolution stride. name: Name of the module. """ super(Conv1dPeriodic, self).__init__(name=name) self._output_channels = output_channels self._kernel_shape = kernel_shape self._stride = stride with self._enter_variable_scope(): self._conv_1d_module = snt.Conv1D( output_channels=self._output_channels, kernel_shape=self._kernel_shape, stride=self._stride, padding=snt.VALID)
def _build(self, x_shifted, en, conditioning=None): self._nodes_list = [] startconv = snt.Conv1D(output_channels=self._d_hidden_channels, kernel_shape=self._filter_length, padding=snt.CAUSAL, name='startconv') skip_start = snt.Conv1D(output_channels=self._skip_channels, kernel_shape=1, padding=snt.CAUSAL, name='skip_start') if self._prob_dropout_decoder_tf > 0: l = tf.nn.dropout(x_shifted, rate=self._prob_dropout_decoder_tf) l = startconv(x_shifted) else: l = startconv(x_shifted) self._nodes_list.append(l) # Set up skip connections. s = skip_start(l) self._nodes_list.append(s) # Residual blocks with skip connections. for i in range(self._num_layers): dilation = 2**(i % self._num_layers_per_stage) causal_convolution = snt.Conv1D(output_channels=2 * self._d_hidden_channels, kernel_shape=[self._filter_length], rate=dilation, padding=snt.CAUSAL, name='dilatedconv_%d' % (i + 1)) dil = causal_convolution(l) self._nodes_list.append(dil) encoding_convolution = snt.Conv1D(output_channels=2 * self._d_hidden_channels, kernel_shape=1, padding=snt.CAUSAL, name='cond_map_%d' % (i + 1)) enc = encoding_convolution(en) dil = condition(dil, enc) self._nodes_list.append(dil) assert dil.get_shape().as_list()[2] % 2 == 0 m = dil.get_shape().as_list()[2] // 2 d_sigmoid = tf.sigmoid(dil[:, :, :m]) d_tanh = tf.tanh(dil[:, :, m:]) dil = d_sigmoid * d_tanh res_convolution = snt.Conv1D( output_channels=self._d_hidden_channels, kernel_shape=1, padding=snt.CAUSAL, name='res_%d' % (i + 1)) l += res_convolution(dil) self._nodes_list.append(l) skip_conv = snt.Conv1D(output_channels=self._skip_channels, kernel_shape=1, padding=snt.CAUSAL, name='skip_%d' % (i + 1)) s += skip_conv(dil) self._nodes_list.append(s) s = tf.nn.relu(s) out = snt.Conv1D(output_channels=self._skip_channels, kernel_shape=1, padding=snt.CAUSAL, name='out1') s = out(s) self._nodes_list.append(s) cond_map_out = snt.Conv1D(output_channels=self._skip_channels, kernel_shape=1, padding=snt.CAUSAL, name='cond_map_out1') s = condition(s, cond_map_out(en)) s = tf.nn.relu(s) self._nodes_list.append(s) logits_conv = snt.Conv1D(output_channels=256, kernel_shape=1, padding=snt.CAUSAL, name='logits') self.logits = logits_conv(s) # self.logits = tf.reshape(self.logits, [-1, 256]) # self._probs = tf.nn.softmax(self.logits, name='softmax') self.dec_distr = tfd.Categorical(logits=self.logits) # dimension of rec_sample should be: [bs, length, 1] self.reconstruction = inv_mu_law( tf.expand_dims(self.dec_distr.mode(), axis=-1) - 128) # self.reconstruction = inv_mu_law(tf.expand_dims(self.dec_distr.mean(), axis=-1) - 128) return self.dec_distr, self.reconstruction
def __init__(self, channels: int = 1536, num_transformer_layers: int = 11, num_heads: int = 8, pooling_type: str = 'attention', name: str = 'enformer'): """Enformer model. Args: channels: Number of convolutional filters and the overall 'width' of the model. num_transformer_layers: Number of transformer layers. num_heads: Number of attention heads. pooling_type: Which pooling function to use. Options: 'attention' or max'. name: Name of sonnet module. """ super().__init__(name=name) # pylint: disable=g-complex-comprehension,g-long-lambda,cell-var-from-loop heads_channels = {'human': 5313, 'mouse': 1643} dropout_rate = 0.4 assert channels % num_heads == 0, ('channels needs to be divisible ' f'by {num_heads}') whole_attention_kwargs = { 'attention_dropout_rate': 0.05, 'initializer': None, 'key_size': 64, 'num_heads': num_heads, 'num_relative_position_features': channels // num_heads, 'positional_dropout_rate': 0.01, 'relative_position_functions': [ 'positional_features_exponential', 'positional_features_central_mask', 'positional_features_gamma' ], 'relative_positions': True, 'scaling': True, 'value_size': channels // num_heads, 'zero_initialize': True } trunk_name_scope = tf.name_scope('trunk') trunk_name_scope.__enter__() # lambda is used in Sequential to construct the module under tf.name_scope. def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs): return Sequential(lambda: [ snt.BatchNorm(create_scale=True, create_offset=True, decay_rate=0.9, scale_init=snt.initializers.Ones()), gelu, snt.Conv1D(filters, width, w_init=w_init, **kwargs) ], name=name) stem = Sequential(lambda: [ snt.Conv1D(channels // 2, 15), Residual(conv_block(channels // 2, 1, name='pointwise_conv_block') ), pooling_module(pooling_type, pool_size=2), ], name='stem') filter_list = exponential_linspace_int(start=channels // 2, end=channels, num=6, divisible_by=128) conv_tower = Sequential(lambda: [ Sequential(lambda: [ conv_block(num_filters, 5), Residual( conv_block(num_filters, 1, name='pointwise_conv_block')), pooling_module(pooling_type, pool_size=2), ], name=f'conv_tower_block_{i}') for i, num_filters in enumerate(filter_list) ], name='conv_tower') # Transformer. def transformer_mlp(): return Sequential(lambda: [ snt.LayerNorm(axis=-1, create_scale=True, create_offset=True), snt.Linear(channels * 2), snt.Dropout(dropout_rate), tf.nn.relu, snt.Linear(channels), snt.Dropout(dropout_rate) ], name='mlp') transformer = Sequential(lambda: [ Sequential(lambda: [ Residual( Sequential(lambda: [ snt.LayerNorm(axis=-1, create_scale=True, create_offset=True, scale_init=snt.initializers.Ones()), attention_module.MultiheadAttention( **whole_attention_kwargs, name=f'attention_{i}'), snt.Dropout(dropout_rate) ], name='mha')), Residual(transformer_mlp()) ], name=f'transformer_block_{i}') for i in range(num_transformer_layers) ], name='transformer') crop_final = TargetLengthCrop1D(TARGET_LENGTH, name='target_input') final_pointwise = Sequential( lambda: [conv_block(channels * 2, 1), snt.Dropout(dropout_rate / 8), gelu], name='final_pointwise') self._trunk = Sequential( [stem, conv_tower, transformer, crop_final, final_pointwise], name='trunk') trunk_name_scope.__exit__(None, None, None) with tf.name_scope('heads'): self._heads = { head: Sequential(lambda: [snt.Linear(num_channels), tf.nn.softplus], name=f'head_{head}') for head, num_channels in heads_channels.items() }
def compute_h(self, x, z, d, bias, W_bot, W_top, compute_perc=1.0, compute_units=None): """z = [BS, n_units] a = [BS, n_units] b = [BS, n_units] d = [BS, n_units, delta_channels] """ s_idx = 0 if compute_perc != 1.0: assert compute_units is None with tf.device(self.remote_device): inp_feat = [x, z] inp_feat = [tf.transpose(f, [1, 0]) for f in inp_feat] units = x.shape.as_list()[1] bs = x.shape.as_list()[0] # add unit ID, to help the network differentiate units id_theta = tf.linspace(0., (4) * np.pi, units) assert bs is not None id_theta_bs = tf.reshape(id_theta, [-1, 1]) * tf.ones([1, bs]) inp_feat += [tf.sin(id_theta_bs), tf.cos(id_theta_bs)] # list of [units, BS, 1] inp_feat = [tf.expand_dims(f, 2) for f in inp_feat] d_trans = tf.transpose(d, [1, 0, 2]) if compute_perc != 1.0: compute_units = int(compute_perc * inp_feat.shape.as_list()[0]) # add weight matrix statistics, both from above and below w_stats_bot = get_weight_stats(W_bot, 0) w_stats_top = get_weight_stats(W_top, 1) w_stats = w_stats_bot + w_stats_top if W_bot is None or W_top is None: # if it's an edge layer (top or bottom), just duplicate the stats for # the weight matrix that does exist w_stats = w_stats + w_stats w_stats = [tf.ones([1, x.shape[0], 1]) * ww for ww in w_stats] # w_stats is a list, with entries with shape UNITS x 1 x channels if compute_units is None: inp_feat_in = inp_feat d_trans_in = d_trans w_stats_in = w_stats bias_in = tf.transpose(bias) else: # only run on a subset of the activations. mask = tf.random_uniform( minval=0, maxval=1, dtype=tf.float32, shape=inp_feat[0].shape.as_list()[0:1]) _, ind = tf.nn.top_k(mask, k=compute_units) ind = tf.reshape(ind, [-1, 1]) inp_feat_in = [tf.gather_nd(xx, ind) for xx in inp_feat] w_stats_in = [tf.gather_nd(xx, ind) for xx in w_stats] d_trans_in = tf.gather_nd(d_trans, ind) bias_in = tf.gather_nd(tf.transpose(bias), ind) w_stats_in = tf.concat(w_stats_in, 2) w_stats_in_norm = w_stats_in * tf.rsqrt( tf.reduce_mean(w_stats_in**2) + 1e-6) act = tf.concat(inp_feat_in + [d_trans_in], 2) act = snt.BatchNorm(axis=[0, 1])(act, is_training=True) bias_dense = tf.reshape(bias_in, [-1, 1, 1]) * tf.ones([1, bs, 1]) act = tf.concat([w_stats_in_norm, bias_dense, act], 2) mod = snt.Conv1D(output_channels=self.compute_h_size, kernel_shape=[3]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=True) act = tf.nn.relu(act) act2 = ConcatUnitConv()(act) act = act2 prev_act = act for i in range(self.compute_h_layers): mod = snt.Conv1D(output_channels=self.compute_h_size, kernel_shape=[3]) act = mod(act) act = snt.BatchNorm(axis=[0, 1])(act, is_training=True) act = tf.nn.relu(act) act = ConcatUnitConv()(act) prev_act = act h = act if compute_units is not None: shape = inp_feat[0].shape.as_list()[:1] + h.shape.as_list()[1:] h = tf.scatter_nd(ind, h, shape=shape) h = tf.transpose(h, [1, 0, 2]) # [bs, units, channels] return h