Exemplo n.º 1
0
def sc_lstm_decoder(decoder_input, nclasses, sample_out_size, lstm_size,
                    text_idx, text_one_hot, dialogue_act, inputs, step):
    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(text_one_hot)
    previous_char_slice = Lambda(remove_last_column,
                                 output_shape=(sample_out_size,
                                               nclasses))(padding)

    temperature = 1 / step

    lstm = SC_LSTM(lstm_size,
                   nclasses,
                   softmax_temperature=temperature,
                   generation_only=True,
                   condition_on_ptm1=True,
                   semantic_condition=True,
                   return_da=False,
                   return_state=False,
                   use_bias=True,
                   return_sequences=True,
                   implementation=2,
                   dropout=0.2,
                   recurrent_dropout=0.2,
                   sc_dropout=0.2)

    recurrent_component = lstm([previous_char_slice, dialogue_act])

    decoder_train = Model(inputs=[decoder_input, text_idx] + inputs,
                          outputs=recurrent_component,
                          name='decoder_{}'.format('train'))
    #decoder_test = Model(inputs=[decoder_input, text_idx] + inputs, outputs=recurrent_component, name='decoder_{}'.format('test'))
    # decoder_train.summary()
    return decoder_train, decoder_train
Exemplo n.º 2
0
def vae_model(config_data, vocab, step):
    sample_out_size = config_data['max_output_length']
    nclasses = len(vocab) + 3
    #last available index is reserved as start character
    lstm_size = config_data['lstm_size']
    max_idx = max(vocab.values())
    dummy_word_idx = max_idx + 1
    dropout_word_idx = max_idx + 1
    top_paths = 10

    l2_regularizer = None
    # == == == == == =
    # Define Encoder
    # == == == == == =
    name_idx = Input(batch_shape=(None, 2), dtype='float32', name='name_idx')
    eat_type_idx = Input(batch_shape=(None, 4),
                         dtype='float32',
                         name='eat_type_idx')
    price_range_idx = Input(batch_shape=(None, 7),
                            dtype='float32',
                            name='price_range_idx')
    customer_feedback_idx = Input(batch_shape=(None, 7),
                                  dtype='float32',
                                  name='customer_feedback_idx')
    near_idx = Input(batch_shape=(None, 2), dtype='float32', name='near_idx')
    food_idx = Input(batch_shape=(None, 2), dtype='float32', name='food_idx')
    area_idx = Input(batch_shape=(None, 3), dtype='float32', name='area_idx')
    family_idx = Input(batch_shape=(None, 3),
                       dtype='float32',
                       name='family_idx')
    fw_idx = Input(batch_shape=(None, 40), dtype='float32', name='fw_idx')
    output_idx = Input(batch_shape=(None, sample_out_size),
                       dtype='int32',
                       name='character_output')

    inputs = [
        name_idx, eat_type_idx, price_range_idx, customer_feedback_idx,
        near_idx, food_idx, area_idx, family_idx, fw_idx
    ]
    word_dropout = WordDropout(rate=1.0,
                               dummy_word=dropout_word_idx,
                               anneal_step=step)(output_idx)

    one_hot_weights = np.identity(nclasses)

    one_hot_out_embeddings = Embedding(input_length=sample_out_size,
                                       input_dim=nclasses,
                                       output_dim=nclasses,
                                       weights=[one_hot_weights],
                                       trainable=False,
                                       name='one_hot_out_embeddings')

    output_one_hot_embeddings = one_hot_out_embeddings(word_dropout)

    dialogue_act = concatenate(inputs=inputs)

    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(output_one_hot_embeddings)
    previous_char_slice = Lambda(remove_last_column,
                                 output_shape=(sample_out_size,
                                               nclasses))(padding)

    #combined_input = concatenate(inputs=[softmax_auxiliary, previous_char_slice], axis=2)
    #MUST BE IMPLEMENTATION 1 or 2
    lstm = SC_LSTM(lstm_size,
                   nclasses,
                   generation_only=True,
                   condition_on_ptm1=True,
                   return_da=True,
                   return_state=False,
                   use_bias=True,
                   semantic_condition=True,
                   return_sequences=True,
                   implementation=2,
                   dropout=0.2,
                   recurrent_dropout=0.2,
                   sc_dropout=0.2)
    recurrent_component, last_da, da_array = lstm(
        [previous_char_slice, dialogue_act])

    lstm.inference_phase()

    output_gen_layer, _, _ = lstm([previous_char_slice, dialogue_act])

    def vae_cross_ent_loss(args):
        x_truth, x_decoded_final = args
        x_truth_flatten = K.reshape(x_truth, shape=(-1, K.shape(x_truth)[-1]))
        x_decoded_flat = K.reshape(x_decoded_final,
                                   shape=(-1, K.shape(x_decoded_final)[-1]))
        cross_ent = K.categorical_crossentropy(x_decoded_flat, x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return sum_over_sentences

    def da_loss_fun(args):
        da = args[0]
        sq_da_t = K.square(da)
        sum_sq_da_T = K.sum(sq_da_t, axis=1)
        return sum_sq_da_T

    def da_history_loss_fun(args):
        da_t = args[0]
        zeta = 10e-4
        n = 100
        #shape: batch_size, sample_size
        norm_of_differnece = K.sum(K.square(da_t), axis=2)
        n1 = zeta**norm_of_differnece
        n2 = n * n1
        return K.sum(n2, axis=1)

    def identity_loss(y_true, y_pred):
        return y_pred

    def argmax_fun(softmax_output):
        return K.argmax(softmax_output, axis=2)

    argmax = Lambda(argmax_fun,
                    output_shape=(sample_out_size, ))(output_gen_layer)
    #beams = CTC_Decoding_layer(sample_out_size, False, top_paths, 100, dummy_word_idx)(output_gen_layer)

    main_loss = Lambda(vae_cross_ent_loss, output_shape=(1, ), name='main')(
        [output_one_hot_embeddings, recurrent_component])
    da_loss = Lambda(da_loss_fun, output_shape=(1, ),
                     name='dialogue_act')([last_da])
    da_history_loss = Lambda(da_history_loss_fun,
                             output_shape=(1, ),
                             name='dialogue_history')([da_array])

    train_model = Model(inputs=inputs + [output_idx],
                        outputs=[main_loss, da_loss, da_history_loss])
    test_model = Model(inputs=inputs + [output_idx], outputs=argmax)

    return train_model, test_model
Exemplo n.º 3
0
def get_decoder(decoder_input, nclasses, nfilter, sample_out_size, out_size,
                intermediate_dim, lstm_size, text_idx, text_one_hot,
                dialogue_act, inputs, step):
    decoder_input_layer = Dense(intermediate_dim, name='intermediate_decoding')
    hidden_intermediate_dec = decoder_input_layer(decoder_input)
    decoder_upsample = Dense(int(2 * nfilter * sample_out_size /
                                 4))(hidden_intermediate_dec)
    relu_int = PReLU()(decoder_upsample)
    if K.image_data_format() == 'channels_first':
        output_shape = (2 * nfilter, int(sample_out_size / 4), 1)
    else:
        output_shape = (int(sample_out_size / 4), 1, 2 * nfilter)
    reshape = Reshape(output_shape)(relu_int)
    # shape = (batch_size, filters)
    deconv1 = Conv2DTranspose(filters=nfilter,
                              kernel_size=(3, 1),
                              strides=(2, 1),
                              padding='same')(reshape)
    bn3 = BatchNormalization(scale=False)(deconv1)
    relu3 = PReLU()(bn3)
    deconv2 = Conv2DTranspose(filters=out_size,
                              kernel_size=(3, 1),
                              strides=(2, 1),
                              padding='same')(relu3)
    bn4 = BatchNormalization(scale=False)(deconv2)
    relu4 = PReLU()(bn4)
    reshape = Reshape((sample_out_size, out_size))(relu4)
    logits = Dense(nclasses,
                   activation='softmax',
                   name='auxiliary_softmax_layer')(reshape)
    temperature = 1 / step

    def temperature_log(logits):
        return logits / temperature

    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(text_one_hot)
    previous_char_slice = Lambda(remove_last_column,
                                 output_shape=(sample_out_size,
                                               nclasses))(padding)

    temp_layer = Lambda(temperature_log,
                        output_shape=(sample_out_size, nclasses))(logits)
    softmax_auxiliary = Activation('softmax')(temp_layer)

    lstm = SC_LSTM(lstm_size,
                   nclasses,
                   softmax_temperature=temperature,
                   generation_only=False,
                   condition_on_ptm1=True,
                   semantic_condition=True,
                   return_da=False,
                   return_state=False,
                   use_bias=True,
                   return_sequences=True,
                   implementation=2,
                   dropout=0.2,
                   recurrent_dropout=0.2,
                   sc_dropout=0.2)

    recurrent_component = lstm(
        [softmax_auxiliary, previous_char_slice, dialogue_act])
    #output_gen_layer = lstm([softmax_auxiliary, softmax_auxiliary])  # for testing

    decoder_train = Model(inputs=[decoder_input, text_idx] + inputs,
                          outputs=[recurrent_component, softmax_auxiliary],
                          name='decoder_{}'.format('train'))
    decoder_test = Model(inputs=[decoder_input, text_idx] + inputs,
                         outputs=recurrent_component,
                         name='decoder_{}'.format('test'))
    #decoder_train.summary()
    return decoder_train, decoder_test
def vae_model(config_data, vocab, step):
    z_size = config_data['z_size']
    sample_in_size = config_data['max_input_length']
    sample_out_size = config_data['max_output_length']
    nclasses = len(vocab) + 2
    #last available index is reserved as start character
    start_word_idx = nclasses - 1
    lstm_size = config_data['lstm_size']
    alpha = config_data['alpha']
    intermediate_dim = config_data['intermediate_dim']
    batch_size = config_data['batch_size']
    nfilter = 64
    out_size = 200
    eps = 0.001
    anneal_start = 0
    anneal_end = anneal_start + 10000.0

    l2_regularizer = None
    # == == == == == =
    # Define Encoder
    # == == == == == =
    name_idx = Input(batch_shape=(None, 2), dtype='float32', name='name_idx')
    eat_type_idx = Input(batch_shape=(None, 4), dtype='float32', name='eat_type_idx')
    price_range_idx = Input(batch_shape=(None, 7), dtype='float32', name='price_range_idx')
    customer_feedback_idx = Input(batch_shape=(None, 7), dtype='float32', name='customer_feedback_idx')
    near_idx = Input(batch_shape=(None, 2), dtype='float32', name='near_idx')
    food_idx = Input(batch_shape=(None, 8), dtype='float32', name='food_idx')
    area_idx = Input(batch_shape=(None, 3), dtype='float32', name='area_idx')
    family_idx = Input(batch_shape=(None, 3), dtype='float32', name='family_idx')
    output_idx = Input(batch_shape=(None, sample_out_size), dtype='int32', name='character_output')

    inputs = [name_idx, eat_type_idx, price_range_idx, customer_feedback_idx, near_idx, food_idx, area_idx, family_idx, output_idx]

    one_hot_weights = np.identity(nclasses)
    #oshape = (batch_size, sample_size, nclasses)
    one_hot_embeddings = Embedding(
        input_length=sample_in_size,
        input_dim=nclasses,
        output_dim=nclasses,
        weights=[one_hot_weights],
        trainable=False,
        name='one_hot_embeddings'
    )

    one_hot_out_embeddings = Embedding(
        input_length=sample_out_size,
        input_dim=nclasses,
        output_dim=nclasses,
        weights=[one_hot_weights],
        trainable=False,
        name='one_hot_out_embeddings'
    )

    name_one_hot_embeddings = one_hot_embeddings(name_idx)
    near_one_hot_embeddings = one_hot_embeddings(near_idx)

    output_one_hot_embeddings = one_hot_out_embeddings(output_idx)

    decoder_input = Input(shape=(z_size,), name='decoder_input')
    encoder, _ , dialogue_act = get_encoder(inputs, name_one_hot_embeddings, near_one_hot_embeddings, nfilter, z_size, intermediate_dim)
    decoder = get_decoder(decoder_input, intermediate_dim, nfilter, sample_out_size, out_size, nclasses)

    x_sampled, x_mean, x_los_sigma = encoder(inputs[:-1])
    softmax_auxiliary = decoder(x_sampled)

    def argmax_fun(softmax_output):
        return K.argmax(softmax_output, axis=2)

    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(output_one_hot_embeddings)
    previous_char_slice = Lambda(remove_last_column, output_shape=(sample_out_size, nclasses))(padding)

    #combined_input = concatenate(inputs=[softmax_auxiliary, previous_char_slice], axis=2)
    #MUST BE IMPLEMENTATION 1 or 2
    lstm = SC_LSTM(
        lstm_size,
        nclasses,
        return_da=False,
        return_state=False,
        use_bias=True,
        semantic_condition=True,
        return_sequences=True,
        implementation=2,
        dropout=0.2,
        recurrent_dropout=0.2,
        sc_dropout=0.2
    )
    recurrent_component = lstm([softmax_auxiliary, previous_char_slice, dialogue_act])

    lstm.inference_phase()
    output_gen_layer = lstm([softmax_auxiliary, softmax_auxiliary, dialogue_act])

    def vae_cross_ent_loss(args):
        x_truth, x_decoded_final = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded_final, shape=(-1, K.shape(x_decoded_final)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat, x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return sum_over_sentences

    def vae_kld_loss(args):
        mu, log_sigma = args

        kl_loss = - 0.5 * K.sum(1 + log_sigma - K.square(mu) - K.exp(log_sigma), axis=-1)
        kld_weight = K.clip((step - anneal_start) / (anneal_end - anneal_start), 0, 1 - eps) + eps
        return kl_loss*kld_weight

    def vae_aux_loss(args):
        x_truth, x_decoded = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded, shape=(-1, K.shape(x_decoded)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat, x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return alpha*sum_over_sentences

    def da_loss_fun(args):
        da = args[0]
        sq_da_t = K.square(da)
        sum_sq_da_T = K.sum(sq_da_t, axis=1)
        return sum_sq_da_T

    def identity_loss(y_true, y_pred):
        return y_pred

    def argmax_fun(softmax_output):
        return K.argmax(softmax_output, axis=2)

    argmax = Lambda(argmax_fun, output_shape=(sample_out_size,))(output_gen_layer)

    main_loss = Lambda(vae_cross_ent_loss, output_shape=(1,), name='main_loss')([output_idx, recurrent_component])
    kld_loss = Lambda(vae_kld_loss, output_shape=(1,), name='kld_loss')([x_mean, x_los_sigma])
    aux_loss = Lambda(vae_aux_loss, output_shape=(1,), name='auxiliary_loss')([output_idx, softmax_auxiliary])
    #da_loss = Lambda(da_loss_fun, output_shape=(1,), name='dialogue_act_loss')([state])

    #output_gen_layer = LSTMStep(lstm, final_softmax_layer, sample_out_size, nclasses)(softmax_auxiliary)

    train_model = Model(inputs=inputs, outputs=[main_loss, kld_loss, aux_loss])
    test_model = Model(inputs=inputs, outputs=[argmax])

    return train_model, test_model
Exemplo n.º 5
0
def get_vae_gan_model(config_data, vocab_char, step):
    z_size = config_data['z_size']
    sample_in_size = config_data['max_input_length']
    sample_out_size = config_data['max_output_length']
    nclasses = len(vocab_char) + 2
    # last available index is reserved as start character
    max_idx = max(vocab_char.values())
    dummy_word_idx = max_idx + 1
    dropout_word_idx = max_idx + 2
    word_dropout_rate = config_data['word_dropout_rate']
    lstm_size = config_data['lstm_size']
    alpha = config_data['alpha']
    intermediate_dim = config_data['intermediate_dim']
    nfilter = 128
    out_size = 200
    eps = 0.001

    anneal_start = config_data['anneal_start']
    anneal_end = anneal_start + config_data['anneal_duration']

    # == == == == == =
    # Define Char Input
    # == == == == == =
    input_idx = Input(batch_shape=(None, sample_in_size),
                      dtype='int32',
                      name='character_input')
    output_idx = Input(batch_shape=(None, sample_out_size),
                       dtype='int32',
                       name='character_output')

    one_hot_weights = np.identity(nclasses)
    #oshape = (batch_size, sample_size, nclasses)
    one_hot_embeddings = Embedding(input_length=sample_in_size,
                                   input_dim=nclasses,
                                   output_dim=nclasses,
                                   weights=[one_hot_weights],
                                   trainable=False,
                                   name='one_hot_embeddings')
    input_one_hot_embeddings = one_hot_embeddings(input_idx)

    dropped_output_idx = WordDropout(rate=word_dropout_rate,
                                     dummy_word=dropout_word_idx)(output_idx)

    one_hot_weights = np.identity(nclasses)
    one_hot_out_embeddings = Embedding(input_length=sample_out_size,
                                       input_dim=nclasses,
                                       output_dim=nclasses,
                                       weights=[one_hot_weights],
                                       trainable=False,
                                       name='one_hot_out_embeddings')
    output_one_hot_embeddings = one_hot_out_embeddings(dropped_output_idx)

    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(output_one_hot_embeddings)
    orig_output = Lambda(remove_last_column,
                         output_shape=(sample_out_size, nclasses))(padding)

    # == == == == == =
    # Define Encoder
    # == == == == == =
    encoder = get_encoder(input_idx, input_one_hot_embeddings, nfilter, z_size,
                          intermediate_dim)

    # == == == == == =
    # Define Decoder
    # == == == == == =
    decoder_input = Input(shape=(z_size, ), name='decoder_input')
    decoder_train = get_decoder(decoder_input, nclasses, nfilter,
                                sample_out_size, out_size, intermediate_dim)

    # == == == == == == == =
    # Define Discriminators
    # == == == == == == == =
    dis_input = Input(shape=(sample_in_size, nclasses))
    dis_output = Input(shape=(sample_out_size, nclasses))

    discriminator = get_descriminator(dis_input, dis_output, nfilter,
                                      intermediate_dim)

    def vae_cross_ent_loss(args):
        x_truth, x_decoded_final = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded_final,
                                   shape=(-1, K.shape(x_decoded_final)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat,
                                                    x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return sum_over_sentences

    def vae_kld_loss(args):
        mu, log_sigma = args
        kl_loss = -0.5 * K.sum(1 + log_sigma - K.square(mu) - K.exp(log_sigma),
                               axis=-1)
        kld_weight = K.clip((step - anneal_start) /
                            (anneal_end - anneal_start), 0, 1 - eps) + eps
        return kl_loss * kld_weight

    def vae_aux_loss(args):
        x_truth, x_decoded = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded,
                                   shape=(-1, K.shape(x_decoded)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat,
                                                    x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return alpha * sum_over_sentences

    def gan_classification_loss(args):
        discr_x, dirscr_xp = args

        return -0.5 * K.log(K.clip(discr_x, eps, 1 - eps)) - 0.5 * K.log(
            1 - K.clip(dirscr_xp, eps, 1 - eps))

    def generator_loss(args):
        x_fake, = args
        return -K.log(K.clip(x_fake, eps, 1 - eps))

    def wasserstein(y_true, y_pred):
        return K.mean(y_true * y_pred)

    def argmax_fun(softmax_output):
        return K.argmax(softmax_output, axis=2)

    z_prior, z_mean, z_sigmoid = encoder(input_idx)
    x_auxiliary = decoder_train(z_prior)

    #put sc-lst outside of decoder.. some strange problem with disconnected gradients
    lstm = SC_LSTM(lstm_size,
                   nclasses,
                   generation_only=False,
                   condition_on_ptm1=True,
                   semantic_condition=False,
                   return_da=False,
                   return_state=False,
                   use_bias=True,
                   return_sequences=True,
                   implementation=2,
                   dropout=0.2,
                   recurrent_dropout=0.2,
                   sc_dropout=0.2)

    recurrent_component = lstm([x_auxiliary, orig_output])
    lstm.inference_phase()
    output_gen_layer = lstm([x_auxiliary, x_auxiliary])  #for testing

    #vae_loss
    main_loss = Lambda(vae_cross_ent_loss,
                       output_shape=(1, ),
                       name='main_loss')([output_idx, recurrent_component])
    kld_loss = Lambda(vae_kld_loss, output_shape=(1, ),
                      name='kld_loss')([z_mean, z_sigmoid])
    aux_loss = Lambda(vae_aux_loss, output_shape=(1, ),
                      name='auxiliary_loss')([output_idx, x_auxiliary])
    argmax = Lambda(argmax_fun,
                    output_shape=(sample_out_size, ))(output_gen_layer)
    vae_model_train = Model(inputs=[input_idx, output_idx],
                            outputs=[main_loss, kld_loss, aux_loss])
    vae_model_test = Model(inputs=input_idx, outputs=argmax)

    #decoder training
    noise_input = Input(batch_shape=(None, z_size),
                        dtype='float32',
                        name='noise_input')
    noise_on_input = GaussianNoise(stddev=1.0)(noise_input)
    noise_model = Model(inputs=[noise_input],
                        outputs=[noise_on_input],
                        name='noise_model')
    noise = noise_model(noise_input)

    x_aux_prior = decoder_train(noise)
    lstm.train_phase = True
    output_gen_layer = lstm([x_aux_prior,
                             x_aux_prior])  #recurrent using auxiliary prior
    discr_sigmoid = discriminator([input_one_hot_embeddings, output_gen_layer])
    decoder_discr_model = Model(inputs=[input_idx, noise_input],
                                outputs=discr_sigmoid)

    #decoder test
    x_aux_prior = decoder_train(noise)
    lstm.train_phase = False
    output_gen_layer = lstm([x_aux_prior, x_aux_prior])
    argmax = Lambda(argmax_fun,
                    output_shape=(sample_out_size, ))(output_gen_layer)
    decoder_test_model = Model(inputs=noise_input, outputs=argmax)

    #discriminator_training
    discr_input = Input(batch_shape=(None, sample_out_size),
                        dtype='int32',
                        name='discr_output')
    discr_emb = one_hot_out_embeddings(discr_input)

    discr_sigmoid = discriminator([input_one_hot_embeddings, discr_emb])
    discriminator_model = Model(inputs=[input_idx, discr_input],
                                outputs=discr_sigmoid)

    #compile the training models
    optimizer_rms = RMSprop(lr=1e-3, decay=0.0001, clipnorm=10)
    optimizer_ada = Adadelta(lr=1.0,
                             epsilon=1e-8,
                             rho=0.95,
                             decay=0.0001,
                             clipnorm=10)
    optimizer_nadam = Nadam(lr=0.002,
                            beta_1=0.9,
                            beta_2=0.999,
                            epsilon=1e-08,
                            schedule_decay=0.001)

    vae_model_train.compile(optimizer=optimizer_ada,
                            loss=lambda y_true, y_pred: y_pred)
    decoder_discr_model.compile(optimizer=optimizer_rms, loss=wasserstein)
    discriminator_model.compile(optimizer=optimizer_rms, loss=wasserstein)

    return vae_model_train, vae_model_test, decoder_discr_model, decoder_test_model, discriminator_model, discriminator
Exemplo n.º 6
0
def vae_model(config_data, vocab, step, pretrained_model=None):
    z_size = config_data['z_size']
    sample_in_size = config_data['max_input_length']
    sample_out_size = config_data['max_output_length']
    nclasses = len(vocab) + 2
    #last available index is reserved as start character
    max_idx = max(vocab.values())
    dummy_word_idx = max_idx + 1
    dropout_word_idx = max_idx + 2
    lstm_size = config_data['lstm_size']
    alpha = config_data['alpha']
    intermediate_dim = config_data['intermediate_dim']
    batch_size = config_data['batch_size']
    nfilter = 128
    out_size = 200
    eps = 0.001
    anneal_start = config_data['anneal_start']
    anneal_end = anneal_start + config_data['anneal_duration']

    l2_regularizer = None
    # == == == == == =
    # Define Encoder
    # == == == == == =
    input_idx = Input(batch_shape=(None, sample_in_size),
                      dtype='int32',
                      name='character_input')
    output_idx = Input(batch_shape=(None, sample_out_size),
                       dtype='int32',
                       name='character_output')

    dropped_output_idx = WordDropout(rate=config_data['word_dropout_rate'],
                                     dummy_word=dropout_word_idx)(output_idx)

    one_hot_weights = np.identity(nclasses)
    #oshape = (batch_size, sample_size, nclasses)
    one_hot_embeddings = Embedding(input_length=sample_in_size,
                                   input_dim=nclasses,
                                   output_dim=nclasses,
                                   weights=[one_hot_weights],
                                   trainable=False,
                                   name='one_hot_embeddings')

    one_hot_out_embeddings = Embedding(input_length=sample_out_size,
                                       input_dim=nclasses,
                                       output_dim=nclasses,
                                       weights=[one_hot_weights],
                                       trainable=False,
                                       name='one_hot_out_embeddings')

    input_one_hot_embeddings = one_hot_embeddings(input_idx)
    output_one_hot_embeddings = one_hot_out_embeddings(dropped_output_idx)

    decoder_input = Input(shape=(z_size, ), name='decoder_input')
    encoder, _ = get_encoder(input_idx, input_one_hot_embeddings, nfilter,
                             z_size, intermediate_dim)
    decoder = get_decoder(decoder_input, intermediate_dim, nfilter,
                          sample_out_size, out_size, nclasses)

    x_sampled, x_mean, x_los_sigma = encoder(input_idx)
    softmax_auxiliary = decoder(x_sampled)
    #softmax_aux_mean = decoder(x_mean)

    encoder.summary()
    decoder.summary()

    def remove_last_column(x):
        return x[:, :-1, :]

    padding = ZeroPadding1D(padding=(1, 0))(output_one_hot_embeddings)
    previous_char_slice = Lambda(remove_last_column,
                                 output_shape=(sample_out_size,
                                               nclasses))(padding)

    #combined_input = concatenate(inputs=[softmax_auxiliary, previous_char_slice], axis=2)

    lstm = SC_LSTM(lstm_size,
                   nclasses,
                   generation_only=False,
                   condition_on_ptm1=True,
                   semantic_condition=False,
                   return_da=False,
                   return_state=False,
                   use_bias=True,
                   return_sequences=True,
                   implementation=2,
                   dropout=0.2,
                   recurrent_dropout=0.2,
                   sc_dropout=0.2)

    recurrent_component = lstm([softmax_auxiliary, previous_char_slice])
    lstm.inference_phase()
    output_gen_layer = lstm([softmax_auxiliary, softmax_auxiliary])

    def vae_cross_ent_loss(args):
        x_truth, x_decoded_final = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded_final,
                                   shape=(-1, K.shape(x_decoded_final)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat,
                                                    x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return sum_over_sentences

    def vae_kld_loss(args):
        mu, log_sigma = args

        kl_loss = -0.5 * K.sum(1 + log_sigma - K.square(mu) - K.exp(log_sigma),
                               axis=-1)
        kld_weight = K.clip((step - anneal_start) /
                            (anneal_end - anneal_start), 0, 1 - eps) + eps
        return kl_loss * kld_weight

    def vae_aux_loss(args):
        x_truth, x_decoded = args
        x_truth_flatten = K.flatten(x_truth)
        x_decoded_flat = K.reshape(x_decoded,
                                   shape=(-1, K.shape(x_decoded)[-1]))
        cross_ent = T.nnet.categorical_crossentropy(x_decoded_flat,
                                                    x_truth_flatten)
        cross_ent = K.reshape(cross_ent, shape=(-1, K.shape(x_truth)[1]))
        sum_over_sentences = K.sum(cross_ent, axis=1)
        return alpha * sum_over_sentences

    def identity_loss(y_true, y_pred):
        return y_pred

    def argmax_fun(softmax_output):
        return K.argmax(softmax_output, axis=2)

    argmax = Lambda(argmax_fun,
                    output_shape=(sample_out_size, ))(output_gen_layer)

    main_loss = Lambda(vae_cross_ent_loss,
                       output_shape=(1, ),
                       name='main_loss')([output_idx, recurrent_component])
    kld_loss = Lambda(vae_kld_loss, output_shape=(1, ),
                      name='kld_loss')([x_mean, x_los_sigma])
    aux_loss = Lambda(vae_aux_loss, output_shape=(1, ),
                      name='auxiliary_loss')([output_idx, softmax_auxiliary])
    train_model = Model(inputs=[input_idx, output_idx],
                        outputs=[main_loss, kld_loss, aux_loss])

    test_model = Model(inputs=[input_idx], outputs=[argmax, x_mean])

    return train_model, test_model