def modelInitVAEUndirectionalEncoder(cells, encoder_inputs, inputs_lengths, outputs_type='concat', states_type='last'): encoder_cell = ExtendedMultiRNNCell(cells) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell=encoder_cell, inputs=encoder_inputs, sequence_length=inputs_lengths, time_major=True, dtype=tf.float32) o_m, o_v = tf.split(encoder_outputs, 2, 2) o_v = o_v + 1 encoder_outputs = tf.random_normal(tf.shape(o_m)) * o_v + o_m vae_loss = -0.5 * tf.reduce_mean( tf.reduce_sum(1.0 + tf.log(o_v * o_v) - o_m * o_m - o_v * o_v, -1)) if states_type == 'last': encoder_state_c = (encoder_fw_states[-1].c + encoder_bw_states[-1].c) / 2.0 encoder_state_h = (encoder_fw_states[-1].h + encoder_bw_states[-1].h) / 2.0 encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) else: raise Exception('Unknown states type.') return encoder_outputs, encoder_states, vae_loss
def modelInitBidirectionalEncoder(cells, encoder_inputs, inputs_lengths, encoder_type='dynamic', outputs_type='concat', states_type='last'): encoder_fw_cell = copy.deepcopy(cells) encoder_bw_cell = copy.deepcopy(cells) encoder_outputs, encoder_fw_states, encoder_bw_states = None, None, None if encoder_type == 'dynamic': encoder_fw_cell = ExtendedMultiRNNCell(encoder_fw_cell) encoder_bw_cell = ExtendedMultiRNNCell(encoder_bw_cell) ((encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_states, encoder_bw_states)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=encoder_fw_cell, cell_bw=encoder_bw_cell, inputs=encoder_inputs, sequence_length=inputs_lengths, time_major=True, dtype=tf.float32) encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs), 2) elif encoder_type == 'stack': (encoder_outputs, encoder_fw_states, encoder_bw_states) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw=encoder_fw_cell, cells_bw=encoder_bw_cell, inputs=tf.transpose(encoder_inputs, [1, 0, 2]), sequence_length=inputs_lengths, dtype=tf.float32) encoder_outputs = tf.transpose(encoder_outputs, [1, 0, 2]) else: raise Exception('Unknown encoder type.') encoder_states = None if states_type == 'last': encoder_state_c = (encoder_fw_states[-1].c + encoder_bw_states[-1].c) / 2.0 encoder_state_h = (encoder_fw_states[-1].h + encoder_bw_states[-1].h) / 2.0 encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) else: raise Exception('Unknown states type.') return encoder_outputs, encoder_states
def modelInitAttentionDecoderCell(cells, hidden_size, encoder_outputs, encoder_outputs_lengths, advance=True, att_type='LUONG', wrapper_type='whole'): op_cell = None if wrapper_type == 'whole': op_cell = ExtendedMultiRNNCell(cells) elif wrapper_type == 'gnmt': op_cell = cells.pop(0) else: raise Exception('Unknown wrapper type.') attention_mechanism = None if att_type == 'LUONG': attention_mechanism = tf.contrib.seq2seq.LuongAttention( hidden_size, tf.transpose(encoder_outputs, perm=[1, 0, 2]), memory_sequence_length=encoder_outputs_lengths, scale=advance) elif att_type == 'BAHDANAU': attention_mechanism = tf.contrib.seq2seq.BahdanauAttention( hidden_size, tf.transpose(encoder_outputs, perm=[1, 0, 2]), memory_sequence_length=encoder_outputs_lengths, normalize=advance) else: raise Exception('Unknown attention type.') op_cell = tf.contrib.seq2seq.AttentionWrapper( op_cell, attention_mechanism, attention_layer_size=hidden_size, output_attention=True) if wrapper_type == 'gnmt': op_cell = GNMTAttentionMultiCell(op_cell, cells) return op_cell
def modelInitUndirectionalEncoder(cells, encoder_inputs, inputs_lengths, outputs_type='concat', states_type='last'): encoder_cell = ExtendedMultiRNNCell(cells) encoder_outputs, encoder_state = tf.nn.dynamic_rnn( cell=encoder_cell, inputs=encoder_inputs, sequence_length=inputs_lengths, time_major=True, dtype=tf.float32) if states_type == 'last': encoder_state_c = (encoder_fw_states[-1].c + encoder_bw_states[-1].c) / 2.0 encoder_state_h = (encoder_fw_states[-1].h + encoder_bw_states[-1].h) / 2.0 encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) else: raise Exception('Unknown states type.') return encoder_outputs, encoder_states
def modelInitRNNDecoderCell(cells): return ExtendedMultiRNNCell(cells)
def modelInitVAEBidirectionalEncoder(cells, encoder_inputs, inputs_lengths, encoder_type='dynamic', outputs_type='concat', states_type='last'): encoder_fw_cell = copy.deepcopy(cells) encoder_bw_cell = copy.deepcopy(cells) encoder_outputs, encoder_fw_states, encoder_bw_states = None, None, None vae_loss = 0 if encoder_type == 'dynamic': encoder_fw_cell = ExtendedMultiRNNCell(encoder_fw_cell) encoder_bw_cell = ExtendedMultiRNNCell(encoder_bw_cell) ((encoder_fw_outputs, encoder_bw_outputs), (encoder_fw_states, encoder_bw_states)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=encoder_fw_cell, cell_bw=encoder_bw_cell, inputs=encoder_inputs, sequence_length=inputs_lengths, time_major=True, dtype=tf.float32) fwo_m, fwo_v = tf.split(encoder_fw_outputs, 2, 2) bwo_m, bwo_v = tf.split(encoder_bw_outputs, 2, 2) eo_m = tf.concat([fwo_m, bwo_m], -1) eo_v = tf.concat([fwo_v, bwo_v], -1) + 1 encoder_outputs = tf.random_normal(tf.shape(eo_m)) * eo_v + eo_m vae_loss = -0.5 * tf.reduce_mean( tf.reduce_sum( 1.0 + tf.log(eo_v * eo_v) - eo_m * eo_m - eo_v * eo_v, -1)) elif encoder_type == 'stack': (encoder_outputs, encoder_fw_states, encoder_bw_states) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( cells_fw=encoder_fw_cell, cells_bw=encoder_bw_cell, inputs=tf.transpose(encoder_inputs, [1, 0, 2]), sequence_length=inputs_lengths, dtype=tf.float32) encoder_outputs = tf.transpose(encoder_outputs, [1, 0, 2]) encoder_fw_outputs, encoder_bw_outputs = tf.split( encoder_outputs, 2, 2) fwo_m, fwo_v = tf.split(encoder_fw_outputs, 2, 2) bwo_m, bwo_v = tf.split(encoder_bw_outputs, 2, 2) eo_m = tf.concat([fwo_m, bwo_m], -1) eo_v = tf.concat([fwo_v, bwo_v], -1) + 1 encoder_outputs = tf.random_normal(tf.shape(eo_m)) * eo_v + eo_m vae_loss = -0.5 * tf.reduce_mean( tf.reduce_sum( 1.0 + tf.log(eo_v * eo_v) - eo_m * eo_m - eo_v * eo_v, -1)) else: raise Exception('Unknown encoder type.') encoder_states = None if states_type == 'last': encoder_state_c = (encoder_fw_states[-1].c + encoder_bw_states[-1].c) / 2.0 encoder_state_h = (encoder_fw_states[-1].h + encoder_bw_states[-1].h) / 2.0 encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c, h=encoder_state_h) else: raise Exception('Unknown states type.') return encoder_outputs, encoder_states, vae_loss