示例#1
0
def modelInitVAEUndirectionalEncoder(cells,
                                     encoder_inputs,
                                     inputs_lengths,
                                     outputs_type='concat',
                                     states_type='last'):
    encoder_cell = ExtendedMultiRNNCell(cells)
    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
        cell=encoder_cell,
        inputs=encoder_inputs,
        sequence_length=inputs_lengths,
        time_major=True,
        dtype=tf.float32)
    o_m, o_v = tf.split(encoder_outputs, 2, 2)
    o_v = o_v + 1
    encoder_outputs = tf.random_normal(tf.shape(o_m)) * o_v + o_m
    vae_loss = -0.5 * tf.reduce_mean(
        tf.reduce_sum(1.0 + tf.log(o_v * o_v) - o_m * o_m - o_v * o_v, -1))
    if states_type == 'last':
        encoder_state_c = (encoder_fw_states[-1].c +
                           encoder_bw_states[-1].c) / 2.0
        encoder_state_h = (encoder_fw_states[-1].h +
                           encoder_bw_states[-1].h) / 2.0
        encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c,
                                                       h=encoder_state_h)
    else:
        raise Exception('Unknown states type.')

    return encoder_outputs, encoder_states, vae_loss
示例#2
0
def modelInitBidirectionalEncoder(cells,
                                  encoder_inputs,
                                  inputs_lengths,
                                  encoder_type='dynamic',
                                  outputs_type='concat',
                                  states_type='last'):
    encoder_fw_cell = copy.deepcopy(cells)
    encoder_bw_cell = copy.deepcopy(cells)
    encoder_outputs, encoder_fw_states, encoder_bw_states = None, None, None
    if encoder_type == 'dynamic':
        encoder_fw_cell = ExtendedMultiRNNCell(encoder_fw_cell)
        encoder_bw_cell = ExtendedMultiRNNCell(encoder_bw_cell)
        ((encoder_fw_outputs, encoder_bw_outputs),
         (encoder_fw_states,
          encoder_bw_states)) = tf.nn.bidirectional_dynamic_rnn(
              cell_fw=encoder_fw_cell,
              cell_bw=encoder_bw_cell,
              inputs=encoder_inputs,
              sequence_length=inputs_lengths,
              time_major=True,
              dtype=tf.float32)
        encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs),
                                    2)
    elif encoder_type == 'stack':
        (encoder_outputs, encoder_fw_states,
         encoder_bw_states) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
             cells_fw=encoder_fw_cell,
             cells_bw=encoder_bw_cell,
             inputs=tf.transpose(encoder_inputs, [1, 0, 2]),
             sequence_length=inputs_lengths,
             dtype=tf.float32)
        encoder_outputs = tf.transpose(encoder_outputs, [1, 0, 2])
    else:
        raise Exception('Unknown encoder type.')

    encoder_states = None
    if states_type == 'last':
        encoder_state_c = (encoder_fw_states[-1].c +
                           encoder_bw_states[-1].c) / 2.0
        encoder_state_h = (encoder_fw_states[-1].h +
                           encoder_bw_states[-1].h) / 2.0
        encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c,
                                                       h=encoder_state_h)
    else:
        raise Exception('Unknown states type.')

    return encoder_outputs, encoder_states
示例#3
0
def modelInitAttentionDecoderCell(cells,
                                  hidden_size,
                                  encoder_outputs,
                                  encoder_outputs_lengths,
                                  advance=True,
                                  att_type='LUONG',
                                  wrapper_type='whole'):
    op_cell = None
    if wrapper_type == 'whole':
        op_cell = ExtendedMultiRNNCell(cells)
    elif wrapper_type == 'gnmt':
        op_cell = cells.pop(0)
    else:
        raise Exception('Unknown wrapper type.')

    attention_mechanism = None
    if att_type == 'LUONG':
        attention_mechanism = tf.contrib.seq2seq.LuongAttention(
            hidden_size,
            tf.transpose(encoder_outputs, perm=[1, 0, 2]),
            memory_sequence_length=encoder_outputs_lengths,
            scale=advance)
    elif att_type == 'BAHDANAU':
        attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
            hidden_size,
            tf.transpose(encoder_outputs, perm=[1, 0, 2]),
            memory_sequence_length=encoder_outputs_lengths,
            normalize=advance)
    else:
        raise Exception('Unknown attention type.')

    op_cell = tf.contrib.seq2seq.AttentionWrapper(
        op_cell,
        attention_mechanism,
        attention_layer_size=hidden_size,
        output_attention=True)

    if wrapper_type == 'gnmt':
        op_cell = GNMTAttentionMultiCell(op_cell, cells)

    return op_cell
示例#4
0
def modelInitUndirectionalEncoder(cells,
                                  encoder_inputs,
                                  inputs_lengths,
                                  outputs_type='concat',
                                  states_type='last'):
    encoder_cell = ExtendedMultiRNNCell(cells)
    encoder_outputs, encoder_state = tf.nn.dynamic_rnn(
        cell=encoder_cell,
        inputs=encoder_inputs,
        sequence_length=inputs_lengths,
        time_major=True,
        dtype=tf.float32)
    if states_type == 'last':
        encoder_state_c = (encoder_fw_states[-1].c +
                           encoder_bw_states[-1].c) / 2.0
        encoder_state_h = (encoder_fw_states[-1].h +
                           encoder_bw_states[-1].h) / 2.0
        encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c,
                                                       h=encoder_state_h)
    else:
        raise Exception('Unknown states type.')

    return encoder_outputs, encoder_states
示例#5
0
def modelInitRNNDecoderCell(cells):
    return ExtendedMultiRNNCell(cells)
示例#6
0
def modelInitVAEBidirectionalEncoder(cells,
                                     encoder_inputs,
                                     inputs_lengths,
                                     encoder_type='dynamic',
                                     outputs_type='concat',
                                     states_type='last'):
    encoder_fw_cell = copy.deepcopy(cells)
    encoder_bw_cell = copy.deepcopy(cells)
    encoder_outputs, encoder_fw_states, encoder_bw_states = None, None, None
    vae_loss = 0
    if encoder_type == 'dynamic':
        encoder_fw_cell = ExtendedMultiRNNCell(encoder_fw_cell)
        encoder_bw_cell = ExtendedMultiRNNCell(encoder_bw_cell)
        ((encoder_fw_outputs, encoder_bw_outputs),
         (encoder_fw_states,
          encoder_bw_states)) = tf.nn.bidirectional_dynamic_rnn(
              cell_fw=encoder_fw_cell,
              cell_bw=encoder_bw_cell,
              inputs=encoder_inputs,
              sequence_length=inputs_lengths,
              time_major=True,
              dtype=tf.float32)
        fwo_m, fwo_v = tf.split(encoder_fw_outputs, 2, 2)
        bwo_m, bwo_v = tf.split(encoder_bw_outputs, 2, 2)
        eo_m = tf.concat([fwo_m, bwo_m], -1)
        eo_v = tf.concat([fwo_v, bwo_v], -1) + 1
        encoder_outputs = tf.random_normal(tf.shape(eo_m)) * eo_v + eo_m
        vae_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(
                1.0 + tf.log(eo_v * eo_v) - eo_m * eo_m - eo_v * eo_v, -1))
    elif encoder_type == 'stack':
        (encoder_outputs, encoder_fw_states,
         encoder_bw_states) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
             cells_fw=encoder_fw_cell,
             cells_bw=encoder_bw_cell,
             inputs=tf.transpose(encoder_inputs, [1, 0, 2]),
             sequence_length=inputs_lengths,
             dtype=tf.float32)
        encoder_outputs = tf.transpose(encoder_outputs, [1, 0, 2])
        encoder_fw_outputs, encoder_bw_outputs = tf.split(
            encoder_outputs, 2, 2)
        fwo_m, fwo_v = tf.split(encoder_fw_outputs, 2, 2)
        bwo_m, bwo_v = tf.split(encoder_bw_outputs, 2, 2)
        eo_m = tf.concat([fwo_m, bwo_m], -1)
        eo_v = tf.concat([fwo_v, bwo_v], -1) + 1
        encoder_outputs = tf.random_normal(tf.shape(eo_m)) * eo_v + eo_m
        vae_loss = -0.5 * tf.reduce_mean(
            tf.reduce_sum(
                1.0 + tf.log(eo_v * eo_v) - eo_m * eo_m - eo_v * eo_v, -1))
    else:
        raise Exception('Unknown encoder type.')

    encoder_states = None
    if states_type == 'last':
        encoder_state_c = (encoder_fw_states[-1].c +
                           encoder_bw_states[-1].c) / 2.0
        encoder_state_h = (encoder_fw_states[-1].h +
                           encoder_bw_states[-1].h) / 2.0
        encoder_states = tf.contrib.rnn.LSTMStateTuple(c=encoder_state_c,
                                                       h=encoder_state_h)
    else:
        raise Exception('Unknown states type.')

    return encoder_outputs, encoder_states, vae_loss