예제 #1
0
def universal_transformer_all_steps_so_far(layer_inputs, step, hparams,
                                           ffn_unit, attention_unit):
    """universal_transformer.
  It uses an attention mechanism-flipped vertically-
  over all the states from previous steps to generate the new_state.
  Args:
    layer_inputs:
      - state: state
      - memory: contains states from all the previous steps.
    step: indicating number of steps take so far
    hparams: model hyper-parameters.
    ffn_unit: feed-forward unit
    attention_unit: multi-head attention unit
  Returns:
    layer_output:
        new_state: new state
        memory: contains states from all the previous steps.
  """
    _, inputs, memory = layer_inputs
    all_states = memory
    # get the states up to the current step (non-zero part of the memory)
    states_so_far = all_states[:step, :, :, :]

    states_so_far_weights = tf.nn.softmax(common_layers.dense(
        states_so_far, (hparams.hidden_size if hparams.dwa_elements else 1),
        activation=None,
        use_bias=True),
                                          axis=-1)

    # # get summary of the step weights
    # step_weightes = tf.unstack(states_so_far_weights, axis=0, name="step_weightes")
    # for step_i, step_w  in enumerate(step_weightes):
    #   tf.contrib.summary.scalar("step_%d_weight:"%step_i,
    #                             tf.reduce_mean(step_w))

    # prepare the state as the summary of
    state_to_be_transformed = tf.reduce_sum(
        (states_so_far * states_so_far_weights), axis=0)

    state_to_be_transformed = universal_transformer_util.step_preprocess(
        state_to_be_transformed, step, hparams)

    new_state = ffn_unit(attention_unit(state_to_be_transformed))

    # add the new state to the memory
    memory = universal_transformer_util.fill_memory_slot(
        memory, new_state, step + 1)

    return new_state, inputs, memory
예제 #2
0
def invertible_universal_transformer_basic(layer_inputs, step, hparams,
                                           ffn_unit, attention_unit):
    """Basic Universal Transformer.
  This model is pretty similar to the vanilla transformer in which weights are
  shared between layers. For some tasks, this simple idea brings a
  generalization that is not achievable by playing with the size of the model
  or drop_out parameters in the vanilla transformer.
  Args:
    layer_inputs:
        - state: state
    step: indicates number of steps taken so far
    hparams: model hyper-parameters
    ffn_unit: feed-forward unit
    attention_unit: multi-head attention unit
  Returns:
    layer_output:
         new_state: new state
  """
    state, inputs, memory = tf.unstack(layer_inputs,
                                       num=None,
                                       axis=0,
                                       name="unstack")
    new_state = step_preprocess(state, step, hparams)

    for i in range(hparams.num_inrecurrence_layers):
        with tf.variable_scope("rec_layer_%d" % i):
            if isinstance(ffn_unit, list) and isinstance(attention_unit, list):
                for index, sub_ffn_unit, sub_attention_unit in zip(
                        range(len(ffn_unit)), ffn_unit, attention_unit):
                    if hparams.invertible_share_layer_weights:
                        new_state = sub_ffn_unit(sub_attention_unit(new_state))
                    else:
                        with tf.variable_scope("sublayer_%d" % index):
                            new_state = sub_ffn_unit(
                                sub_attention_unit(new_state))
            else:
                new_state = ffn_unit(attention_unit(new_state))

    return new_state, inputs, memory
예제 #3
0
def universal_transformer_basic_plus_lstm(layer_inputs,
                                          step,
                                          hparams,
                                          ffn_unit,
                                          attention_unit,
                                          pad_remover=None):
    """The UT layer which uses a lstm as transition function.

  Args:
    layer_inputs:
      - state: state
      - inputs: the original embedded inputs (= inputs to the first step)
      - memory: memory used in lstm.
    step: indicating number of steps take so far
    hparams: model hyper-parameters.
    ffn_unit: feed-forward unit
    attention_unit: multi-head attention unit

  Returns:
    layer_output:
        new_state: new state
        inputs: the original embedded inputs (= inputs to the first step)
        memory: contains information of state from all the previous steps.
  """

    state, unused_inputs, memory = tf.unstack(layer_inputs,
                                              num=None,
                                              axis=0,
                                              name="unstack")
    # NOTE:
    # state (ut_state): output of the lstm in the previous step
    # inputs (ut_input): original input --> we don't use it here
    # memory: lstm memory

    # Multi_head_attention:
    assert hparams.add_step_timing_signal == False  # Let lstm count for us!
    mh_attention_input = universal_transformer_util.step_preprocess(
        state, step, hparams)
    transition_function_input = ffn_unit(attention_unit(mh_attention_input))

    # Transition Function:
    transition_function_input = common_layers.layer_preprocess(
        transition_function_input, hparams)
    with tf.variable_scope("lstm"):
        # lstm input gate: i_t = sigmoid(W_i.x_t + U_i.h_{t-1})
        transition_function_input_gate = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="input",
            bias_initializer=tf.zeros_initializer(),
            activation=tf.sigmoid,
            pad_remover=pad_remover)

        tf.contrib.summary.scalar(
            "lstm_input_gate", tf.reduce_mean(transition_function_input_gate))

        # lstm forget gate: f_t = sigmoid(W_f.x_t + U_f.h_{t-1})
        transition_function_forget_gate = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="forget",
            bias_initializer=tf.zeros_initializer(),
            activation=None,
            pad_remover=pad_remover)
        forget_bias_tensor = tf.constant(hparams.lstm_forget_bias)
        transition_function_forget_gate = tf.sigmoid(
            transition_function_forget_gate + forget_bias_tensor)

        tf.contrib.summary.scalar(
            "lstm_forget_gate",
            tf.reduce_mean(transition_function_forget_gate))

        # lstm ouptut gate: o_t = sigmoid(W_o.x_t + U_o.h_{t-1})
        transition_function_output_gate = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="output",
            bias_initializer=tf.zeros_initializer(),
            activation=tf.sigmoid,
            pad_remover=pad_remover)

        tf.contrib.summary.scalar(
            "lstm_output_gate",
            tf.reduce_mean(transition_function_output_gate))

        # lstm input modulation
        transition_function_input_modulation = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="input_modulation",
            bias_initializer=tf.zeros_initializer(),
            activation=tf.tanh,
            pad_remover=pad_remover)

        transition_function_memory = (
            memory * transition_function_forget_gate +
            transition_function_input_gate *
            transition_function_input_modulation)

        transition_function_output = (tf.tanh(transition_function_memory) *
                                      transition_function_output_gate)

    transition_function_output = common_layers.layer_preprocess(
        transition_function_output, hparams)

    return transition_function_output, unused_inputs, transition_function_memory
예제 #4
0
def universal_transformer_basic_plus_gru(layer_inputs,
                                         step,
                                         hparams,
                                         ffn_unit,
                                         attention_unit,
                                         pad_remover=None):
    """The UT layer which uses a gru as transition function.

  Args:
    layer_inputs:
      - state: state
      - inputs: the original embedded inputs (= inputs to the first step)
      - memory: memory used in lstm.
    step: indicating number of steps take so far
    hparams: model hyper-parameters.
    ffn_unit: feed-forward unit
    attention_unit: multi-head attention unit

  Returns:
    layer_output:
        new_state: new state
        inputs: not uesed
        memory: not used
  """

    state, unused_inputs, unused_memory = tf.unstack(layer_inputs,
                                                     num=None,
                                                     axis=0,
                                                     name="unstack")
    # NOTE:
    # state (ut_state): output of the gru in the previous step
    # inputs (ut_inputs): original input --> we don't use it here
    # memory: we don't use it here

    # Multi_head_attention:
    assert hparams.add_step_timing_signal == False  # Let gru count for us!
    mh_attention_input = universal_transformer_util.step_preprocess(
        state, step, hparams)
    transition_function_input = ffn_unit(attention_unit(mh_attention_input))

    # Transition Function:
    transition_function_input = common_layers.layer_preprocess(
        transition_function_input, hparams)
    with tf.variable_scope("gru"):
        # gru update gate: z_t = sigmoid(W_z.x_t + U_z.h_{t-1})
        transition_function_update_gate = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="update",
            bias_initializer=tf.constant_initializer(1.0),
            activation=tf.sigmoid,
            pad_remover=pad_remover)

        tf.contrib.summary.scalar(
            "gru_update_gate", tf.reduce_mean(transition_function_update_gate))

        # gru reset gate: r_t = sigmoid(W_r.x_t + U_r.h_{t-1})
        transition_function_reset_gate = _ffn_layer_multi_inputs(
            [transition_function_input, state],
            hparams,
            name="reset",
            bias_initializer=tf.constant_initializer(1.0),
            activation=tf.sigmoid,
            pad_remover=pad_remover)

        tf.contrib.summary.scalar(
            "gru_reset_gate", tf.reduce_mean(transition_function_reset_gate))
        reset_state = transition_function_reset_gate * state

        # gru_candidate_activation: h' = tanh(W_{x_t} + U (r_t h_{t-1})
        transition_function_candidate = _ffn_layer_multi_inputs(
            [transition_function_input, reset_state],
            hparams,
            name="candidate",
            bias_initializer=tf.zeros_initializer(),
            activation=tf.tanh,
            pad_remover=pad_remover)

        transition_function_output = (
            (1 - transition_function_update_gate) * transition_function_input +
            transition_function_update_gate * transition_function_candidate)

    transition_function_output = common_layers.layer_preprocess(
        transition_function_output, hparams)

    return transition_function_output, unused_inputs, unused_memory
def main():

  FLAGS = Args()

  # Enable TF Eager execution
  tfe = tf.contrib.eager
  tfe.enable_eager_execution()

  # sample sentence
  input_str = 'Twas brillig, and the slithy toves Did gyre and gimble in the wade; All mimsy were the borogoves, And the mome raths outgrabe.'

  # convert sentence into index in vocab
  wmt_problem = problems.problem(FLAGS.problem)
  encoders = wmt_problem.feature_encoders(FLAGS.data_dir)
  inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
  batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
  features = {"inputs": batch_inputs}

  # initialize translation model
  hparams_set = FLAGS.hparams_set
  Modes = tf.estimator.ModeKeys
  hparams = trainer_lib.create_hparams(hparams_set, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem)
  translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL)

  # recover parameters and conduct recurrent conduction
  ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir)

  with tfe.restore_variables_on_create(ckpt_dir):
    with variable_scope.EagerVariableStore().as_default():
      with tf.variable_scope('universal_transformer'):
        # Convert word index to word embedding
        features = translate_model.bottom(features)

      with tf.variable_scope('universal_transformer/body'):
        input_tensor = tf.convert_to_tensor(features['inputs'])
        input_tensor = common_layers.flatten4d3d(input_tensor)
        encoder_input, self_attention_bias, _ = (
          transformer.transformer_prepare_encoder(
            input_tensor, tf.convert_to_tensor([0]), translate_model.hparams, features=None))

      with tf.variable_scope('universal_transformer/body/encoder'):

        ffn_unit = functools.partial(
          universal_transformer_util.transformer_encoder_ffn_unit,
          hparams=translate_model.hparams)

        attention_unit = functools.partial(
          universal_transformer_util.transformer_encoder_attention_unit,
          hparams=translate_model.hparams,
          encoder_self_attention_bias=None,
          attention_dropout_broadcast_dims=[],
          save_weights_to={},
          make_image_summary=True)

      storing_list = []
      transformed_state = encoder_input
      for step_index in range(1024):
        storing_list.append(transformed_state.numpy())

        with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}'.format(FLAGS.ut_type)):
          transformed_state = universal_transformer_util.step_preprocess(
            transformed_state,
            tf.convert_to_tensor(step_index % FLAGS.step_num),
            translate_model.hparams
          )
        with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0'.format(FLAGS.ut_type)):
          transformed_new_state = ffn_unit(attention_unit(transformed_state))
        with tf.variable_scope('universal_transformer/body/encoder'):
          if (step_index + 1) % FLAGS.step_num == 0:
            transformed_new_state = common_layers.layer_preprocess(transformed_new_state, translate_model.hparams)

            if step_index == 5:
              print(transformed_new_state)

        transformed_state = transformed_new_state
      storing_list = np.asarray(storing_list)
      np.save(FLAGS.save_dir, storing_list)
예제 #6
0
def main():

    FLAGS = Args()

    # Enable TF Eager execution
    tfe = tf.contrib.eager
    tfe.enable_eager_execution()

    batch_inputs = input_generator()

    # initialize translation model
    hparams_set = FLAGS.hparams_set
    Modes = tf.estimator.ModeKeys
    hparams = trainer_lib.create_hparams(hparams_set,
                                         data_dir=FLAGS.data_dir,
                                         problem_name=FLAGS.problem)
    translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL)

    # recover parameters and conduct recurrent conduction
    ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir)

    with tfe.restore_variables_on_create(ckpt_dir):
        with variable_scope.EagerVariableStore().as_default():
            features = {'inputs': batch_inputs}
            with tf.variable_scope('universal_transformer/body'):
                input_tensor = tf.convert_to_tensor(features['inputs'])
                input_tensor = common_layers.flatten4d3d(input_tensor)
                encoder_input, self_attention_bias, _ = (
                    transformer.transformer_prepare_encoder(
                        input_tensor,
                        tf.convert_to_tensor([0]),
                        translate_model.hparams,
                        features=None))

            with tf.variable_scope('universal_transformer/body/encoder'):

                ffn_unit = functools.partial(
                    universal_transformer_util.transformer_encoder_ffn_unit,
                    hparams=translate_model.hparams)

                attention_unit = functools.partial(
                    universal_transformer_util.
                    transformer_encoder_attention_unit,
                    hparams=translate_model.hparams,
                    encoder_self_attention_bias=None,
                    attention_dropout_broadcast_dims=[],
                    save_weights_to={},
                    make_image_summary=True)

            storing_list = []
            transformed_state = encoder_input
            for step_index in range(1024):
                storing_list.append(transformed_state.numpy())

                with tf.variable_scope(
                        'universal_transformer/body/encoder/universal_transformer_{}'
                        .format(FLAGS.ut_type)):
                    transformed_state = universal_transformer_util.step_preprocess(
                        transformed_state,
                        tf.convert_to_tensor(step_index % FLAGS.step_num),
                        translate_model.hparams)
                with tf.variable_scope(
                        'universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0'
                        .format(FLAGS.ut_type)):
                    transformed_new_state = ffn_unit(
                        attention_unit(transformed_state))
                with tf.variable_scope('universal_transformer/body/encoder'):
                    if (step_index + 1) % FLAGS.step_num == 0:
                        transformed_new_state = common_layers.layer_preprocess(
                            transformed_new_state, translate_model.hparams)

                        if step_index == 5:
                            print(transformed_new_state)

                transformed_state = transformed_new_state
            storing_list = np.asarray(storing_list)
            np.save(FLAGS.save_dir, storing_list)