예제 #1
0
def universal_transformer_basic_plus_lstm():
    hparams = universal_transformer.universal_transformer_base()
    hparams.recurrence_type = "basic_plus_lstm"
    # hparams.transformer_ffn_type = "fc"
    hparams.batch_size = 2048
    hparams.add_step_timing_signal = False  # Let lstm count in time for us!
    return hparams
예제 #2
0
 def testTransformer(self):
   model, features = self.get_model(
       universal_transformer.universal_transformer_base())
   logits, _ = model(features)
   with self.test_session() as session:
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
 def testTransformer(self):
   model, features = self.get_model(
       universal_transformer.universal_transformer_base())
   logits, _ = model(features)
   with self.test_session() as session:
     session.run(tf.global_variables_initializer())
     res = session.run(logits)
   self.assertEqual(res.shape, (BATCH_SIZE, TARGET_LENGTH, 1, 1, VOCAB_SIZE))
def vqa_recurrent_self_attention_base():
    """VQA attention baseline hparams."""
    hparams = universal_transformer.universal_transformer_base()
    hparams.batch_size = 1024
    hparams.use_fixed_batch_size = True
    hparams.weight_decay = 0.
    hparams.clip_grad_norm = 0.
    # use default initializer
    # hparams.initializer = "xavier"
    hparams.learning_rate_schedule = (
        "constant*linear_warmup*rsqrt_normalized_decay")
    hparams.learning_rate_warmup_steps = 8000
    hparams.learning_rate_constant = 7e-4
    hparams.learning_rate_decay_rate = 0.5
    hparams.learning_rate_decay_steps = 50000
    # hparams.dropout = 0.5
    hparams.summarize_grads = True
    hparams.summarize_vars = True

    # not used hparams
    hparams.label_smoothing = 0.1
    hparams.multiply_embedding_mode = "sqrt_depth"

    # add new hparams
    # use raw image as input
    hparams.add_hparam("image_input_type", "feature")
    hparams.add_hparam("image_model_fn", "resnet_v1_152")
    hparams.add_hparam("resize_side", 512)
    hparams.add_hparam("height", 448)
    hparams.add_hparam("width", 448)
    hparams.add_hparam("distort", True)
    hparams.add_hparam("train_resnet", False)

    # question hidden size
    # hparams.hidden_size = 512
    # hparams.filter_size = 1024
    # hparams.num_hidden_layers = 4

    # self attention parts
    # hparams.norm_type = "layer"
    # hparams.layer_preprocess_sequence = "n"
    # hparams.layer_postprocess_sequence = "da"
    # hparams.layer_prepostprocess_dropout = 0.1
    # hparams.attention_dropout = 0.1
    # hparams.relu_dropout = 0.1
    # hparams.add_hparam("pos", "timing")
    # hparams.add_hparam("num_encoder_layers", 0)
    # hparams.add_hparam("num_decoder_layers", 0)
    # hparams.add_hparam("num_heads", 8)
    # hparams.add_hparam("attention_key_channels", 0)
    # hparams.add_hparam("attention_value_channels", 0)
    # hparams.add_hparam("self_attention_type", "dot_product")

    # iterative part
    hparams.transformer_ffn_type = "fc"

    return hparams
def vqa_recurrent_self_attention_base():
  """VQA attention baseline hparams."""
  hparams = universal_transformer.universal_transformer_base()
  hparams.batch_size = 1024
  hparams.use_fixed_batch_size = True
  hparams.weight_decay = 0.
  hparams.clip_grad_norm = 0.
  # use default initializer
  # hparams.initializer = "xavier"
  hparams.learning_rate_schedule = (
      "constant*linear_warmup*rsqrt_normalized_decay")
  hparams.learning_rate_warmup_steps = 8000
  hparams.learning_rate_constant = 7e-4
  hparams.learning_rate_decay_rate = 0.5
  hparams.learning_rate_decay_steps = 50000
  # hparams.dropout = 0.5
  hparams.summarize_grads = True
  hparams.summarize_vars = True

  # not used hparams
  hparams.label_smoothing = 0.1
  hparams.multiply_embedding_mode = "sqrt_depth"

  # add new hparams
  # use raw image as input
  hparams.add_hparam("image_input_type", "feature")
  hparams.add_hparam("image_model_fn", "resnet_v1_152")
  hparams.add_hparam("resize_side", 512)
  hparams.add_hparam("height", 448)
  hparams.add_hparam("width", 448)
  hparams.add_hparam("distort", True)
  hparams.add_hparam("train_resnet", False)

  # question hidden size
  # hparams.hidden_size = 512
  # hparams.filter_size = 1024
  # hparams.num_hidden_layers = 4

  # self attention parts
  # hparams.norm_type = "layer"
  # hparams.layer_preprocess_sequence = "n"
  # hparams.layer_postprocess_sequence = "da"
  # hparams.layer_prepostprocess_dropout = 0.1
  # hparams.attention_dropout = 0.1
  # hparams.relu_dropout = 0.1
  # hparams.add_hparam("pos", "timing")
  # hparams.add_hparam("num_encoder_layers", 0)
  # hparams.add_hparam("num_decoder_layers", 0)
  # hparams.add_hparam("num_heads", 8)
  # hparams.add_hparam("attention_key_channels", 0)
  # hparams.add_hparam("attention_value_channels", 0)
  # hparams.add_hparam("self_attention_type", "dot_product")

  # iterative part
  hparams.transformer_ffn_type = "fc"

  return hparams
예제 #6
0
def universal_transformer_all_steps_so_far():
    hparams = universal_transformer.universal_transformer_base()
    hparams.recurrence_type = "all_steps_so_far"
    return hparams
예제 #7
0
def universal_transformer_with_lstm_as_transition_function():
    hparams = universal_transformer.universal_transformer_base()
    hparams.recurrence_type = "lstm"
    hparams.add_step_timing_signal = False  # Let lstm count in time for us!
    return hparams