def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True):
  """The slicenet model, main step used for training."""
  with tf.variable_scope("slicenet"):
    # Project to hidden size if necessary
    if inputs.get_shape().as_list()[-1] != hparams.model_d:
      inputs = common_layers.conv_block(
          inputs,
          hparams.model_d, [((1, 1), (3, 3))],
          first_relu=False,
          padding="SAME",
          force2d=True)

    # Flatten inputs and encode.
    inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
    inputs_mask = 1.0 - embedding_to_padding(inputs)
    inputs = common_layers.add_timing_signal(inputs)  # Add position info.
    target_space_emb = embed_target_space(target_space, hparams.model_d)
    extra_layers = int(hparams.num_hidden_layers * 1.5)
    inputs_encoded = multi_conv_res(
        inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask)
    if not run_decoder:
      return inputs_encoded
    # Do the middle part.
    decoder_start, similarity_loss = slicenet_middle(
        inputs_encoded, targets, target_space_emb, inputs_mask, hparams)
    # Decode.
    decoder_final = multi_conv_res(
        decoder_start,
        "LEFT",
        "decoder",
        hparams.num_hidden_layers,
        hparams,
        mask=inputs_mask,
        source=inputs_encoded)
    return decoder_final, tf.reduce_mean(similarity_loss)
Beispiel #2
0
def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True):
  """The slicenet model, main step used for training."""
  with tf.variable_scope("slicenet"):
    # Project to hidden size if necessary
    if inputs.get_shape().as_list()[-1] != hparams.hidden_size:
      inputs = common_layers.conv_block(
          inputs,
          hparams.hidden_size, [((1, 1), (3, 3))],
          first_relu=False,
          padding="SAME",
          force2d=True)

    # Flatten inputs and encode.
    inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
    inputs_mask = 1.0 - embedding_to_padding(inputs)
    inputs = common_layers.add_timing_signal(inputs)  # Add position info.
    target_space_emb = embed_target_space(target_space, hparams.hidden_size)
    extra_layers = int(hparams.num_hidden_layers * 1.5)
    inputs_encoded = multi_conv_res(
        inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask)
    if not run_decoder:
      return inputs_encoded
    # Do the middle part.
    decoder_start, similarity_loss = slicenet_middle(
        inputs_encoded, targets, target_space_emb, inputs_mask, hparams)
    # Decode.
    decoder_final = multi_conv_res(
        decoder_start,
        "LEFT",
        "decoder",
        hparams.num_hidden_layers,
        hparams,
        mask=inputs_mask,
        source=inputs_encoded)
    return decoder_final, tf.reduce_mean(similarity_loss)
def slicenet_internal(inputs, targets, target_space, problem_idx, hparams):
    """The slicenet model, main step used for training."""
    with tf.variable_scope("slicenet"):
        # Flatten inputs and encode.
        inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
        inputs_mask = 1.0 - embedding_to_padding(inputs)
        inputs = common_layers.add_timing_signal(inputs)  # Add position info.
        target_space_emb = embed_target_space(target_space,
                                              hparams.hidden_size)
        extra_layers = int(hparams.num_hidden_layers * 1.5)
        inputs_encoded = multi_conv_res(inputs,
                                        "SAME",
                                        "encoder",
                                        extra_layers,
                                        hparams,
                                        mask=inputs_mask)
        target_modality_name = hparams.problems[
            problem_idx].target_modality.name
        if "class_label_modality" in target_modality_name:
            # If we're just predicing a class, there is no use for a decoder.
            return inputs_encoded
        # Do the middle part.
        decoder_start, similarity_loss = slicenet_middle(
            inputs_encoded, targets, target_space_emb, inputs_mask, hparams)
        # Decode.
        decoder_final = multi_conv_res(decoder_start,
                                       "LEFT",
                                       "decoder",
                                       hparams.num_hidden_layers,
                                       hparams,
                                       mask=inputs_mask,
                                       source=inputs_encoded)
        return decoder_final, tf.reduce_mean(similarity_loss)
 def testAddTimingSignal(self):
   batch = 5
   length = 7
   height = 3
   depth = 35
   x = np.random.rand(batch, length, height, depth)
   a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32))
   res = self.evaluate(a)
   self.assertEqual(res.shape, (batch, length, height, depth))
Beispiel #5
0
 def testAddTimingSignal(self):
     batch = 5
     length = 7
     height = 3
     depth = 35
     x = np.random.rand(batch, length, height, depth)
     a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32))
     res = self.evaluate(a)
     self.assertEqual(res.shape, (batch, length, height, depth))
 def testAddTimingSignal(self):
   batch = 5
   length = 7
   height = 3
   depth = 35
   x = np.random.rand(batch, length, height, depth)
   with self.test_session() as session:
     a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32))
     session.run(tf.global_variables_initializer())
     res = session.run(a)
   self.assertEqual(res.shape, (batch, length, height, depth))
Beispiel #7
0
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
    """Middle part of slicenet, connecting encoder and decoder."""
    def norm_fn(x, name):
        with tf.variable_scope(name, default_name="norm"):
            return common_layers.apply_norm(x, hparams.norm_type,
                                            hparams.hidden_size,
                                            hparams.norm_epsilon)

    # Flatten targets and embed target_space_id.
    targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
    target_space_emb = tf.tile(target_space_emb,
                               [tf.shape(targets_flat)[0], 1, 1, 1])

    # Calculate similarity loss (but don't run if not needed).
    if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001:
        targets_timed = common_layers.add_timing_signal(targets_flat)
        extra_layers = int(hparams.num_hidden_layers * 1.5)
        with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder",
                                             extra_layers, hparams)
        with tf.variable_scope("similarity_loss"):
            similarity_loss = similarity_cost(inputs_encoded, targets_encoded)
            similarity_loss *= hparams.sim_loss_mult
    else:
        similarity_loss = 0.0

    # Use attention from each target to look at input and retrieve.
    targets_shifted = common_layers.shift_right(targets_flat,
                                                pad_value=target_space_emb)
    if hparams.attention_type == "none":
        targets_with_attention = tf.zeros_like(targets_shifted)
    else:
        inputs_padding_bias = (1.0 -
                               mask) * -1e9  # Bias to not attend to padding.
        targets_with_attention = attention(targets_shifted,
                                           inputs_encoded,
                                           norm_fn,
                                           hparams,
                                           bias=inputs_padding_bias)

    # Positional targets: merge attention and raw.
    kernel = (hparams.kernel_height, hparams.kernel_width)
    targets_merged = common_layers.subseparable_conv_block(
        tf.concat([targets_with_attention, targets_shifted], axis=3),
        hparams.hidden_size, [((1, 1), kernel)],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separability=4,
        name="targets_merge")

    return targets_merged, similarity_loss
Beispiel #8
0
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
  """Middle part of slicenet, connecting encoder and decoder."""

  def norm_fn(x, name):
    with tf.variable_scope(name, default_name="norm"):
      return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size,
                                      hparams.norm_epsilon)

  # Flatten targets and embed target_space_id.
  targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
  target_space_emb = tf.tile(target_space_emb,
                             [tf.shape(targets_flat)[0], 1, 1, 1])

  # Calculate similarity loss (but don't run if not needed).
  if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001:
    targets_timed = common_layers.add_timing_signal(targets_flat)
    extra_layers = int(hparams.num_hidden_layers * 1.5)
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder",
                                       extra_layers, hparams)
    with tf.variable_scope("similarity_loss"):
      similarity_loss = similarity_cost(inputs_encoded, targets_encoded)
      similarity_loss *= hparams.sim_loss_mult
  else:
    similarity_loss = 0.0

  # Use attention from each target to look at input and retrieve.
  targets_shifted = common_layers.shift_right(
      targets_flat, pad_value=target_space_emb)
  if hparams.attention_type == "none":
    targets_with_attention = tf.zeros_like(targets_shifted)
  else:
    inputs_padding_bias = (1.0 - mask) * -1e9  # Bias to not attend to padding.
    targets_with_attention = attention(
        targets_shifted,
        inputs_encoded,
        norm_fn,
        hparams,
        bias=inputs_padding_bias)

  # Positional targets: merge attention and raw.
  kernel = (hparams.kernel_height, hparams.kernel_width)
  targets_merged = common_layers.subseparable_conv_block(
      tf.concat([targets_with_attention, targets_shifted], axis=3),
      hparams.hidden_size, [((1, 1), kernel)],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separability=4,
      name="targets_merge")

  return targets_merged, similarity_loss
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
  """Complete attention layer with preprocessing."""
  separabilities = [hparams.separability, hparams.separability]
  if hparams.separability < 0:
    separabilities = [hparams.separability - 1, hparams.separability]
  targets_timed = common_layers.subseparable_conv_block(
      common_layers.add_timing_signal(targets_shifted),
      hparams.model_d, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separabilities=separabilities,
      name="targets_time")
  if hparams.attention_type == "transformer":
    targets_timed = tf.squeeze(targets_timed, 2)
    target_shape = tf.shape(targets_timed)
    targets_segment = tf.zeros([target_shape[0], target_shape[1]])
    target_attention_bias = common_attention.attention_bias_lower_triangle(
        target_shape[1])
    inputs_encoded = common_layers.flatten4d3d(inputs_encoded)
    # TODO(jbaccash): use input bias parameter. This code seems to assume fixed
    # size inputs.
    inputs_attention_bias = tf.zeros([
        tf.shape(inputs_encoded)[0], hparams.num_heads,
        tf.shape(targets_segment)[1],
        tf.shape(inputs_encoded)[1]
    ])

    qv = common_attention.multihead_attention(
        targets_timed,
        None,
        target_attention_bias,
        hparams.model_d,
        hparams.model_d,
        hparams.model_d,
        hparams.num_heads,
        hparams.attention_dropout,
        name="self_attention")
    qv = common_attention.multihead_attention(
        qv,
        inputs_encoded,
        inputs_attention_bias,
        hparams.model_d,
        hparams.model_d,
        hparams.model_d,
        hparams.num_heads,
        hparams.attention_dropout,
        name="encdec_attention")
    return tf.expand_dims(qv, 2)
  else:
    raise ValueError("Unsupported attention_type: %s" % hparams.attention_type)
Beispiel #10
0
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
    """Complete attention layer with preprocessing."""
    separabilities = [hparams.separability, hparams.separability]
    if hparams.separability < 0:
        separabilities = [hparams.separability - 1, hparams.separability]
    targets_timed = common_layers.subseparable_conv_block(
        common_layers.add_timing_signal(targets_shifted),
        hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
        normalizer_fn=norm_fn,
        padding="LEFT",
        separabilities=separabilities,
        name="targets_time")
    if hparams.attention_type == "transformer":
        targets_timed = tf.squeeze(targets_timed, 2)
        target_shape = tf.shape(targets_timed)
        targets_segment = tf.zeros([target_shape[0], target_shape[1]])
        target_attention_bias = common_attention.attention_bias(
            targets_segment, targets_segment, lower_triangular=True)
        inputs_attention_bias = tf.zeros([
            tf.shape(inputs_encoded)[0], hparams.num_heads,
            tf.shape(targets_segment)[1],
            tf.shape(inputs_encoded)[1]
        ])

        qv = common_attention.multihead_attention(targets_timed,
                                                  None,
                                                  target_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  hparams.attention_dropout,
                                                  name="self_attention")
        qv = common_attention.multihead_attention(qv,
                                                  inputs_encoded,
                                                  inputs_attention_bias,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.hidden_size,
                                                  hparams.num_heads,
                                                  hparams.attention_dropout,
                                                  name="encdec_attention")
        return tf.expand_dims(qv, 2)
    elif hparams.attention_type == "simple":
        targets_with_attention = common_layers.simple_attention(targets_timed,
                                                                inputs_encoded,
                                                                bias=bias)
        return norm_fn(targets_shifted + targets_with_attention,
                       name="attn_norm")
Beispiel #11
0
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
  """Complete attention layer with preprocessing."""
  separabilities = [hparams.separability, hparams.separability]
  if hparams.separability < 0:
    separabilities = [hparams.separability - 1, hparams.separability]
  targets_timed = common_layers.subseparable_conv_block(
      common_layers.add_timing_signal(targets_shifted),
      hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separabilities=separabilities,
      name="targets_time")
  if hparams.attention_type == "transformer":
    targets_timed = tf.squeeze(targets_timed, 2)
    target_shape = tf.shape(targets_timed)
    targets_segment = tf.zeros([target_shape[0], target_shape[1]])
    target_attention_bias = common_attention.attention_bias(
        targets_segment, targets_segment, lower_triangular=True)
    inputs_attention_bias = tf.zeros([
        tf.shape(inputs_encoded)[0], hparams.num_heads,
        tf.shape(targets_segment)[1],
        tf.shape(inputs_encoded)[1]
    ])

    qv = common_attention.multihead_attention(
        targets_timed,
        None,
        target_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        name="self_attention")
    qv = common_attention.multihead_attention(
        qv,
        inputs_encoded,
        inputs_attention_bias,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.hidden_size,
        hparams.num_heads,
        hparams.attention_dropout,
        name="encdec_attention")
    return tf.expand_dims(qv, 2)
  elif hparams.attention_type == "simple":
    targets_with_attention = common_layers.simple_attention(
        targets_timed, inputs_encoded, bias=bias)
    return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")