Esempio n. 1
0
def multi_conv_res(x, padding, name, layers, hparams, mask=None, source=None):
    """A stack of separable convolution blocks with residual connections."""
    with tf.variable_scope(name):
        padding_bias = None
        if mask is not None:
            padding_bias = (1.0 -
                            mask) * -1e9  # Bias to not attend to padding.
            if padding == "LEFT":  # Do not mask anything when left-padding.
                mask = None
        if (hparams.kernel_scheme in _KERNEL_SCHEMES
                and hparams.dilation_scheme in _DILATION_SCHEMES):
            kernels = _KERNEL_SCHEMES[hparams.kernel_scheme]
            dilations = _DILATION_SCHEMES[hparams.dilation_scheme]
            dilations_and_kernels = list(zip(dilations, kernels))
            dilations_and_kernels1 = dilations_and_kernels[:2]
            dilations_and_kernels2 = dilations_and_kernels[2:]
        else:
            k = (hparams.kernel_height, hparams.kernel_width)
            k2 = (hparams.large_kernel_size, 1)
            dilations_and_kernels1 = [((1, 1), k), ((1, 1), k)]
            dilations_and_kernels2 = [((1, 1), k2), ((4, 4), k2)]
        separabilities1 = [hparams.separability, hparams.separability]
        separabilities2 = [hparams.separability] * len(dilations_and_kernels2)
        if hparams.separability < 0:
            separabilities1 = [hparams.separability - 1, hparams.separability]
            separabilities2 = [
                hparams.separability - i
                for i in reversed(range(len(dilations_and_kernels2)))
            ]
        norm_fn = common_layers.get_norm(hparams.norm_type)
        for layer in xrange(layers):
            with tf.variable_scope("layer_%d" % layer):
                y = common_layers.subseparable_conv_block(
                    x,
                    hparams.hidden_size,
                    dilations_and_kernels1,
                    normalizer_fn=norm_fn,
                    padding=padding,
                    mask=mask,
                    separabilities=separabilities1,
                    name="residual1")
                x += common_layers.subseparable_conv_block(
                    x + y,
                    hparams.hidden_size,
                    dilations_and_kernels2,
                    normalizer_fn=norm_fn,
                    padding=padding,
                    mask=mask,
                    separabilities=separabilities2,
                    name="residual2") + y
                if source is not None and hparams.attention_type != "none":
                    x += attention(x,
                                   source,
                                   norm_fn,
                                   hparams,
                                   bias=padding_bias)
                if mask is not None:
                    x *= mask
        return tf.nn.dropout(x, 1.0 - hparams.dropout)
 def testGetNormBatchFn(self):
     norm_type = "batch"
     with self.test_session() as session:
         a = common_layers.get_norm(norm_type)
         x1 = np.random.rand(5, 2, 1, 11)
         x2 = a(tf.constant(x1, dtype=tf.float32), name="batch")
         session.run(tf.global_variables_initializer())
         actual = session.run(x2)
     self.assertEqual(actual.shape, (5, 2, 1, 11))
Esempio n. 3
0
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
    """Middle part of slicenet, connecting encoder and decoder."""
    norm_fn = common_layers.get_norm(hparams.norm_type)

    # Flatten targets and embed target_space_id.
    targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
    target_space_emb = tf.tile(target_space_emb,
                               [tf.shape(targets_flat)[0], 1, 1, 1])

    # Calculate similarity loss (but don't run if not needed).
    if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001:
        targets_timed = common_layers.add_timing_signal(targets_flat)
        extra_layers = int(hparams.num_hidden_layers * 1.5)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder",
                                             extra_layers, hparams)
        with tf.variable_scope("similarity_loss"):
            similarity_loss = similarity_cost(inputs_encoded, targets_encoded)
            similarity_loss *= hparams.sim_loss_mult
    else:
        similarity_loss = 0.0

    # Use attention from each target to look at input and retrieve.
    targets_shifted = common_layers.shift_left(targets_flat,
                                               pad_value=target_space_emb)
    if hparams.attention_type == "none":
        targets_with_attention = tf.zeros_like(targets_shifted)
    else:
        inputs_padding_bias = (1.0 -
                               mask) * -1e9  # Bias to not attend to padding.
        targets_with_attention = attention(targets_shifted,
                                           inputs_encoded,
                                           norm_fn,
                                           hparams,
                                           bias=inputs_padding_bias)

    # Positional targets: merge attention and raw.
    kernel = (hparams.kernel_height, hparams.kernel_width)
    targets_merged = common_layers.subseparable_conv_block(
        tf.concat([targets_with_attention, targets_shifted], axis=3),
        hparams.hidden_size, [((1, 1), kernel)],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separability=4,
        name="targets_merge")

    return targets_merged, similarity_loss