def universal_transformer_all_steps_so_far(layer_inputs, step, hparams, ffn_unit, attention_unit): """universal_transformer. It uses an attention mechanism-flipped vertically- over all the states from previous steps to generate the new_state. Args: layer_inputs: - state: state - memory: contains states from all the previous steps. step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit Returns: layer_output: new_state: new state memory: contains states from all the previous steps. """ _, inputs, memory = layer_inputs all_states = memory # get the states up to the current step (non-zero part of the memory) states_so_far = all_states[:step, :, :, :] states_so_far_weights = tf.nn.softmax(common_layers.dense( states_so_far, (hparams.hidden_size if hparams.dwa_elements else 1), activation=None, use_bias=True), axis=-1) # # get summary of the step weights # step_weightes = tf.unstack(states_so_far_weights, axis=0, name="step_weightes") # for step_i, step_w in enumerate(step_weightes): # tf.contrib.summary.scalar("step_%d_weight:"%step_i, # tf.reduce_mean(step_w)) # prepare the state as the summary of state_to_be_transformed = tf.reduce_sum( (states_so_far * states_so_far_weights), axis=0) state_to_be_transformed = universal_transformer_util.step_preprocess( state_to_be_transformed, step, hparams) new_state = ffn_unit(attention_unit(state_to_be_transformed)) # add the new state to the memory memory = universal_transformer_util.fill_memory_slot( memory, new_state, step + 1) return new_state, inputs, memory
def invertible_universal_transformer_basic(layer_inputs, step, hparams, ffn_unit, attention_unit): """Basic Universal Transformer. This model is pretty similar to the vanilla transformer in which weights are shared between layers. For some tasks, this simple idea brings a generalization that is not achievable by playing with the size of the model or drop_out parameters in the vanilla transformer. Args: layer_inputs: - state: state step: indicates number of steps taken so far hparams: model hyper-parameters ffn_unit: feed-forward unit attention_unit: multi-head attention unit Returns: layer_output: new_state: new state """ state, inputs, memory = tf.unstack(layer_inputs, num=None, axis=0, name="unstack") new_state = step_preprocess(state, step, hparams) for i in range(hparams.num_inrecurrence_layers): with tf.variable_scope("rec_layer_%d" % i): if isinstance(ffn_unit, list) and isinstance(attention_unit, list): for index, sub_ffn_unit, sub_attention_unit in zip( range(len(ffn_unit)), ffn_unit, attention_unit): if hparams.invertible_share_layer_weights: new_state = sub_ffn_unit(sub_attention_unit(new_state)) else: with tf.variable_scope("sublayer_%d" % index): new_state = sub_ffn_unit( sub_attention_unit(new_state)) else: new_state = ffn_unit(attention_unit(new_state)) return new_state, inputs, memory
def universal_transformer_basic_plus_lstm(layer_inputs, step, hparams, ffn_unit, attention_unit, pad_remover=None): """The UT layer which uses a lstm as transition function. Args: layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - memory: memory used in lstm. step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit Returns: layer_output: new_state: new state inputs: the original embedded inputs (= inputs to the first step) memory: contains information of state from all the previous steps. """ state, unused_inputs, memory = tf.unstack(layer_inputs, num=None, axis=0, name="unstack") # NOTE: # state (ut_state): output of the lstm in the previous step # inputs (ut_input): original input --> we don't use it here # memory: lstm memory # Multi_head_attention: assert hparams.add_step_timing_signal == False # Let lstm count for us! mh_attention_input = universal_transformer_util.step_preprocess( state, step, hparams) transition_function_input = ffn_unit(attention_unit(mh_attention_input)) # Transition Function: transition_function_input = common_layers.layer_preprocess( transition_function_input, hparams) with tf.variable_scope("lstm"): # lstm input gate: i_t = sigmoid(W_i.x_t + U_i.h_{t-1}) transition_function_input_gate = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="input", bias_initializer=tf.zeros_initializer(), activation=tf.sigmoid, pad_remover=pad_remover) tf.contrib.summary.scalar( "lstm_input_gate", tf.reduce_mean(transition_function_input_gate)) # lstm forget gate: f_t = sigmoid(W_f.x_t + U_f.h_{t-1}) transition_function_forget_gate = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="forget", bias_initializer=tf.zeros_initializer(), activation=None, pad_remover=pad_remover) forget_bias_tensor = tf.constant(hparams.lstm_forget_bias) transition_function_forget_gate = tf.sigmoid( transition_function_forget_gate + forget_bias_tensor) tf.contrib.summary.scalar( "lstm_forget_gate", tf.reduce_mean(transition_function_forget_gate)) # lstm ouptut gate: o_t = sigmoid(W_o.x_t + U_o.h_{t-1}) transition_function_output_gate = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="output", bias_initializer=tf.zeros_initializer(), activation=tf.sigmoid, pad_remover=pad_remover) tf.contrib.summary.scalar( "lstm_output_gate", tf.reduce_mean(transition_function_output_gate)) # lstm input modulation transition_function_input_modulation = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="input_modulation", bias_initializer=tf.zeros_initializer(), activation=tf.tanh, pad_remover=pad_remover) transition_function_memory = ( memory * transition_function_forget_gate + transition_function_input_gate * transition_function_input_modulation) transition_function_output = (tf.tanh(transition_function_memory) * transition_function_output_gate) transition_function_output = common_layers.layer_preprocess( transition_function_output, hparams) return transition_function_output, unused_inputs, transition_function_memory
def universal_transformer_basic_plus_gru(layer_inputs, step, hparams, ffn_unit, attention_unit, pad_remover=None): """The UT layer which uses a gru as transition function. Args: layer_inputs: - state: state - inputs: the original embedded inputs (= inputs to the first step) - memory: memory used in lstm. step: indicating number of steps take so far hparams: model hyper-parameters. ffn_unit: feed-forward unit attention_unit: multi-head attention unit Returns: layer_output: new_state: new state inputs: not uesed memory: not used """ state, unused_inputs, unused_memory = tf.unstack(layer_inputs, num=None, axis=0, name="unstack") # NOTE: # state (ut_state): output of the gru in the previous step # inputs (ut_inputs): original input --> we don't use it here # memory: we don't use it here # Multi_head_attention: assert hparams.add_step_timing_signal == False # Let gru count for us! mh_attention_input = universal_transformer_util.step_preprocess( state, step, hparams) transition_function_input = ffn_unit(attention_unit(mh_attention_input)) # Transition Function: transition_function_input = common_layers.layer_preprocess( transition_function_input, hparams) with tf.variable_scope("gru"): # gru update gate: z_t = sigmoid(W_z.x_t + U_z.h_{t-1}) transition_function_update_gate = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="update", bias_initializer=tf.constant_initializer(1.0), activation=tf.sigmoid, pad_remover=pad_remover) tf.contrib.summary.scalar( "gru_update_gate", tf.reduce_mean(transition_function_update_gate)) # gru reset gate: r_t = sigmoid(W_r.x_t + U_r.h_{t-1}) transition_function_reset_gate = _ffn_layer_multi_inputs( [transition_function_input, state], hparams, name="reset", bias_initializer=tf.constant_initializer(1.0), activation=tf.sigmoid, pad_remover=pad_remover) tf.contrib.summary.scalar( "gru_reset_gate", tf.reduce_mean(transition_function_reset_gate)) reset_state = transition_function_reset_gate * state # gru_candidate_activation: h' = tanh(W_{x_t} + U (r_t h_{t-1}) transition_function_candidate = _ffn_layer_multi_inputs( [transition_function_input, reset_state], hparams, name="candidate", bias_initializer=tf.zeros_initializer(), activation=tf.tanh, pad_remover=pad_remover) transition_function_output = ( (1 - transition_function_update_gate) * transition_function_input + transition_function_update_gate * transition_function_candidate) transition_function_output = common_layers.layer_preprocess( transition_function_output, hparams) return transition_function_output, unused_inputs, unused_memory
def main(): FLAGS = Args() # Enable TF Eager execution tfe = tf.contrib.eager tfe.enable_eager_execution() # sample sentence input_str = 'Twas brillig, and the slithy toves Did gyre and gimble in the wade; All mimsy were the borogoves, And the mome raths outgrabe.' # convert sentence into index in vocab wmt_problem = problems.problem(FLAGS.problem) encoders = wmt_problem.feature_encoders(FLAGS.data_dir) inputs = encoders["inputs"].encode(input_str) + [1] # add EOS id batch_inputs = tf.reshape(inputs, [1, -1, 1]) # Make it 3D. features = {"inputs": batch_inputs} # initialize translation model hparams_set = FLAGS.hparams_set Modes = tf.estimator.ModeKeys hparams = trainer_lib.create_hparams(hparams_set, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL) # recover parameters and conduct recurrent conduction ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir) with tfe.restore_variables_on_create(ckpt_dir): with variable_scope.EagerVariableStore().as_default(): with tf.variable_scope('universal_transformer'): # Convert word index to word embedding features = translate_model.bottom(features) with tf.variable_scope('universal_transformer/body'): input_tensor = tf.convert_to_tensor(features['inputs']) input_tensor = common_layers.flatten4d3d(input_tensor) encoder_input, self_attention_bias, _ = ( transformer.transformer_prepare_encoder( input_tensor, tf.convert_to_tensor([0]), translate_model.hparams, features=None)) with tf.variable_scope('universal_transformer/body/encoder'): ffn_unit = functools.partial( universal_transformer_util.transformer_encoder_ffn_unit, hparams=translate_model.hparams) attention_unit = functools.partial( universal_transformer_util.transformer_encoder_attention_unit, hparams=translate_model.hparams, encoder_self_attention_bias=None, attention_dropout_broadcast_dims=[], save_weights_to={}, make_image_summary=True) storing_list = [] transformed_state = encoder_input for step_index in range(1024): storing_list.append(transformed_state.numpy()) with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}'.format(FLAGS.ut_type)): transformed_state = universal_transformer_util.step_preprocess( transformed_state, tf.convert_to_tensor(step_index % FLAGS.step_num), translate_model.hparams ) with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0'.format(FLAGS.ut_type)): transformed_new_state = ffn_unit(attention_unit(transformed_state)) with tf.variable_scope('universal_transformer/body/encoder'): if (step_index + 1) % FLAGS.step_num == 0: transformed_new_state = common_layers.layer_preprocess(transformed_new_state, translate_model.hparams) if step_index == 5: print(transformed_new_state) transformed_state = transformed_new_state storing_list = np.asarray(storing_list) np.save(FLAGS.save_dir, storing_list)
def main(): FLAGS = Args() # Enable TF Eager execution tfe = tf.contrib.eager tfe.enable_eager_execution() batch_inputs = input_generator() # initialize translation model hparams_set = FLAGS.hparams_set Modes = tf.estimator.ModeKeys hparams = trainer_lib.create_hparams(hparams_set, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem) translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL) # recover parameters and conduct recurrent conduction ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir) with tfe.restore_variables_on_create(ckpt_dir): with variable_scope.EagerVariableStore().as_default(): features = {'inputs': batch_inputs} with tf.variable_scope('universal_transformer/body'): input_tensor = tf.convert_to_tensor(features['inputs']) input_tensor = common_layers.flatten4d3d(input_tensor) encoder_input, self_attention_bias, _ = ( transformer.transformer_prepare_encoder( input_tensor, tf.convert_to_tensor([0]), translate_model.hparams, features=None)) with tf.variable_scope('universal_transformer/body/encoder'): ffn_unit = functools.partial( universal_transformer_util.transformer_encoder_ffn_unit, hparams=translate_model.hparams) attention_unit = functools.partial( universal_transformer_util. transformer_encoder_attention_unit, hparams=translate_model.hparams, encoder_self_attention_bias=None, attention_dropout_broadcast_dims=[], save_weights_to={}, make_image_summary=True) storing_list = [] transformed_state = encoder_input for step_index in range(1024): storing_list.append(transformed_state.numpy()) with tf.variable_scope( 'universal_transformer/body/encoder/universal_transformer_{}' .format(FLAGS.ut_type)): transformed_state = universal_transformer_util.step_preprocess( transformed_state, tf.convert_to_tensor(step_index % FLAGS.step_num), translate_model.hparams) with tf.variable_scope( 'universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0' .format(FLAGS.ut_type)): transformed_new_state = ffn_unit( attention_unit(transformed_state)) with tf.variable_scope('universal_transformer/body/encoder'): if (step_index + 1) % FLAGS.step_num == 0: transformed_new_state = common_layers.layer_preprocess( transformed_new_state, translate_model.hparams) if step_index == 5: print(transformed_new_state) transformed_state = transformed_new_state storing_list = np.asarray(storing_list) np.save(FLAGS.save_dir, storing_list)