def multi_conv_res(x, padding, name, layers, hparams, mask=None, source=None): """A stack of separable convolution blocks with residual connections.""" with tf.variable_scope(name): padding_bias = None if mask is not None: padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. if padding == "LEFT": # Do not mask anything when left-padding. mask = None if (hparams.kernel_scheme in _KERNEL_SCHEMES and hparams.dilation_scheme in _DILATION_SCHEMES): kernels = _KERNEL_SCHEMES[hparams.kernel_scheme] dilations = _DILATION_SCHEMES[hparams.dilation_scheme] dilations_and_kernels = list(zip(dilations, kernels)) dilations_and_kernels1 = dilations_and_kernels[:2] dilations_and_kernels2 = dilations_and_kernels[2:] else: k = (hparams.kernel_height, hparams.kernel_width) k2 = (hparams.large_kernel_size, 1) dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] separabilities1 = [hparams.separability, hparams.separability] separabilities2 = [hparams.separability] * len(dilations_and_kernels2) if hparams.separability < 0: separabilities1 = [hparams.separability - 1, hparams.separability] separabilities2 = [ hparams.separability - i for i in reversed(range(len(dilations_and_kernels2))) ] def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm( x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) for layer in xrange(layers): with tf.variable_scope("layer_%d" % layer): y = common_layers.subseparable_conv_block( x, hparams.hidden_size, dilations_and_kernels1, normalizer_fn=norm_fn, padding=padding, mask=mask, separabilities=separabilities1, name="residual1") x += common_layers.subseparable_conv_block( x + y, hparams.hidden_size, dilations_and_kernels2, normalizer_fn=norm_fn, padding=padding, mask=mask, separabilities=separabilities2, name="residual2") + y if source is not None and hparams.attention_type != "none": x += attention(x, source, norm_fn, hparams, bias=padding_bias) if mask is not None: x *= mask return tf.nn.dropout(x, 1.0 - hparams.dropout)
def multi_conv_res(x, padding, name, layers, hparams, mask=None, source=None): """A stack of separable convolution blocks with residual connections.""" with tf.variable_scope(name): padding_bias = None if mask is not None: padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. if padding == "LEFT": # Do not mask anything when left-padding. mask = None if (hparams.kernel_scheme in _KERNEL_SCHEMES and hparams.dilation_scheme in _DILATION_SCHEMES): kernels = _KERNEL_SCHEMES[hparams.kernel_scheme] dilations = _DILATION_SCHEMES[hparams.dilation_scheme] dilations_and_kernels = list(zip(dilations, kernels)) dilations_and_kernels1 = dilations_and_kernels[:2] dilations_and_kernels2 = dilations_and_kernels[2:] else: k = (hparams.kernel_height, hparams.kernel_width) k2 = (hparams.large_kernel_size, 1) dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] separabilities1 = [hparams.separability, hparams.separability] separabilities2 = [hparams.separability] * len(dilations_and_kernels2) if hparams.separability < 0: separabilities1 = [hparams.separability - 1, hparams.separability] separabilities2 = [ hparams.separability - i for i in reversed(range(len(dilations_and_kernels2))) ] def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm( x, hparams.norm_type, hparams.model_d, hparams.norm_epsilon) for layer in range(layers): with tf.variable_scope("layer_%d" % layer): y = common_layers.subseparable_conv_block( x, hparams.model_d, dilations_and_kernels1, normalizer_fn=norm_fn, padding=padding, mask=mask, separabilities=separabilities1, name="residual1") x += common_layers.subseparable_conv_block( x + y, hparams.model_d, dilations_and_kernels2, normalizer_fn=norm_fn, padding=padding, mask=mask, separabilities=separabilities2, name="residual2") + y if source is not None and hparams.attention_type != "none": x += attention(x, source, norm_fn, hparams, bias=padding_bias) if mask is not None: x *= mask return tf.nn.dropout(x, 1.0 - hparams.dropout)
def testSubSeparableConvBlock(self): for sep in [0, 1, 2, 4]: x = np.random.rand(5, 7, 1, 12) with tf.variable_scope("sep_%d" % sep): y = common_layers.subseparable_conv_block( tf.constant(x, dtype=tf.float32), 16, [(1, (3, 3)), (1, (3, 3))], padding="SAME", separability=sep) self.evaluate(tf.global_variables_initializer()) res = self.evaluate(y) self.assertEqual(res.shape, (5, 7, 1, 16))
def residual_block(x, hparams): """A stack of convolution blocks with residual connection.""" k = (hparams.kernel_height, hparams.kernel_width) dilations_and_kernels = [((1, 1), k) for _ in xrange(3)] y = common_layers.subseparable_conv_block(x, hparams.hidden_size, dilations_and_kernels, padding="SAME", separability=0, name="residual_block") x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm") return tf.nn.dropout(x, 1.0 - hparams.dropout)
def testSubSeparableConvBlock(self): for sep in [0, 1, 2, 4]: x = np.random.rand(5, 7, 1, 12) with tf.variable_scope("sep_%d" % sep): y = common_layers.subseparable_conv_block( tf.constant(x, dtype=tf.float32), 16, [(1, (3, 3)), (1, (3, 3))], padding="SAME", separability=sep) self.evaluate(tf.global_variables_initializer()) res = self.evaluate(y) self.assertEqual(res.shape, (5, 7, 1, 16))
def residual_block(x, hparams): """A stack of convolution blocks with residual connection.""" k = (hparams.kernel_height, hparams.kernel_width) dilations_and_kernels = [((1, 1), k) for _ in xrange(3)] y = common_layers.subseparable_conv_block( x, hparams.hidden_size, dilations_and_kernels, padding="SAME", separability=0, name="residual_block") x = common_layers.layer_norm(x + y, hparams.hidden_size, name="lnorm") return tf.nn.dropout(x, 1.0 - hparams.dropout)
def conv_res_step(x, hparams, padding, mask): """One step of convolutions and mid-residual.""" k = (hparams.kernel_height, hparams.kernel_width) k2 = (hparams.large_kernel_size, 1) dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] with tf.variable_scope("conv_res_step"): y = common_layers.subseparable_conv_block(x, hparams.filter_size, dilations_and_kernels1, padding=padding, mask=mask, separabilities=0, name="residual1") y = tf.nn.dropout(y, 1.0 - hparams.dropout) return common_layers.subseparable_conv_block(y, hparams.hidden_size, dilations_and_kernels2, padding=padding, mask=mask, separabilities=0, name="residual2")
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right(targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention( targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.model_d, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias_lower_triangle( target_shape[1]) inputs_encoded = common_layers.flatten4d3d(inputs_encoded) # TODO(jbaccash): use input bias parameter. This code seems to assume fixed # size inputs. inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) else: raise ValueError("Unsupported attention_type: %s" % hparams.attention_type)
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention(targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention(qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention(targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def conv_res_step(x, hparams, padding, mask): """One step of convolutions and mid-residual.""" k = (hparams.kernel_height, hparams.kernel_width) k2 = (hparams.large_kernel_size, 1) dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)] dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)] with tf.variable_scope("conv_res_step"): y = common_layers.subseparable_conv_block( x, hparams.filter_size, dilations_and_kernels1, padding=padding, mask=mask, separabilities=0, name="residual1") y = tf.nn.dropout(y, 1.0 - hparams.dropout) return common_layers.subseparable_conv_block( y, hparams.hidden_size, dilations_and_kernels2, padding=padding, mask=mask, separabilities=0, name="residual2")
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.model_d, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention( targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.model_d, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, 0.0