Пример #1
0
def set_tf_update_function(input_emb_param,
                           generator_rnn_model,
                           generator_output_model,
                           generator_optimizer,
                           generator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                     dtype=floatX)

    # embedding sequence
    input_emb_sequence  = tensor.dot(input_sequence, input_emb_param)
    target_emb_sequence = tensor.dot(target_sequence, input_emb_param)

    # set generator input data list
    generator_input_data_list = [input_emb_sequence,]

    # get generator output data
    generator_output = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)
    generator_hidden = generator_output[0]
    generator_cell   = generator_output[1]

    generator_emb_sequence = get_tensor_output(generator_hidden, generator_output_model, is_training=True)
    generator_sequence     = tensor.dot(generator_emb_sequence, tensor.transpose(input_emb_param))

    # get square error
    square_error = tensor.sqr(target_sequence-generator_sequence).sum(axis=2)

    # set generator update
    tf_updates_cost = square_error.mean()
    tf_updates_dict = get_model_and_params_updates(layers=generator_rnn_model+generator_output_model,
                                                   params=[input_emb_param,],
                                                   cost=tf_updates_cost,
                                                   optimizer=generator_optimizer)

    generator_gradient_dict  = get_model_and_params_gradients(layers=generator_rnn_model+generator_output_model,
                                                              params=[input_emb_param,],
                                                              cost=tf_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set tf update inputs
    tf_updates_inputs  = [input_sequence,
                          target_sequence]

    # set tf update outputs
    tf_updates_outputs = [square_error,
                          generator_gradient_norm,]

    # set tf update function
    tf_updates_function = theano.function(inputs=tf_updates_inputs,
                                          outputs=tf_updates_outputs,
                                          updates=tf_updates_dict,
                                          on_unused_input='ignore')

    return tf_updates_function
Пример #2
0
def set_gan_update_function(input_emb_param,
                            generator_model,
                            discriminator_feature_model,
                            discriminator_output_model,
                            generator_optimizer,
                            discriminator_optimizer,
                            generator_grad_clipping,
                            discriminator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                      dtype=floatX)

    # embedding sequence
    input_emb_sequence  = tensor.dot(input_sequence, input_emb_param)
    target_emb_sequence = tensor.dot(target_sequence, input_emb_param)

    # set generator input data list
    generator_input_data_list = [input_emb_sequence, 1]

    # get generator output data
    generator_output = generator_model[0].forward(generator_input_data_list, is_training=True)
    output_emb_sequence = generator_output[0]
    output_sequence     = tensor.dot(output_emb_sequence, tensor.transpose(input_emb_param))
    data_hidden         = generator_output[1]
    data_cell           = generator_output[2]
    model_hidden        = generator_output[3]
    model_cell          = generator_output[4]

    condition_hid    = data_hidden[:-1]
    condition_hid    = theano.gradient.disconnected_grad(condition_hid)
    condition_feature = get_tensor_output(condition_hid,
                                          discriminator_feature_model,
                                          is_training=True)

    positive_hid     = data_hidden[1:]
    positive_feature = get_tensor_output(positive_hid,
                                         discriminator_feature_model,
                                         is_training=True)
    negative_hid     = model_hidden[1:]
    negative_feature = get_tensor_output(negative_hid,
                                         discriminator_feature_model,
                                         is_training=True)

    positive_pair = tensor.concatenate([condition_feature, positive_feature], axis=2)
    negative_pair = tensor.concatenate([condition_feature, negative_feature], axis=2)

    positive_score = get_tensor_output(positive_pair,
                                       discriminator_output_model,
                                       is_training=True)
    negative_score = get_tensor_output(negative_pair,
                                       discriminator_output_model,
                                       is_training=True)

    generator_gan_cost = tensor.nnet.binary_crossentropy(output=negative_score,
                                                         target=tensor.ones_like(negative_score))

    discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=positive_score,
                                                              target=tensor.ones_like(positive_score)) +
                              tensor.nnet.binary_crossentropy(output=negative_score,
                                                              target=tensor.zeros_like(negative_score)))

    # set generator update
    generator_updates_cost = generator_gan_cost.mean()
    generator_updates_dict = get_model_and_params_updates(layers=generator_model,
                                                          params=[input_emb_param,],
                                                          cost=generator_updates_cost,
                                                          optimizer=generator_optimizer,
                                                          use_grad_clip=generator_grad_clipping)

    generator_gradient_dict  = get_model_and_params_gradients(layers=generator_model,
                                                              params=[input_emb_param,],
                                                              cost=generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set discriminator update
    discriminator_updates_cost = discriminator_gan_cost.mean()
    discriminator_updates_dict = get_model_updates(layers=discriminator_feature_model+discriminator_output_model,
                                                   cost=discriminator_updates_cost,
                                                   optimizer=discriminator_optimizer,
                                                   use_grad_clip=discriminator_grad_clipping)

    discriminator_gradient_dict  = get_model_gradients(discriminator_feature_model+discriminator_output_model,
                                                       discriminator_updates_cost)
    discriminator_gradient_norm  = 0.
    for grad in discriminator_gradient_dict:
        discriminator_gradient_norm += tensor.sum(grad**2)
    discriminator_gradient_norm  = tensor.sqrt(discriminator_gradient_norm)

    square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2)

    # set gan update inputs
    gan_updates_inputs  = [input_sequence,
                           target_sequence]

    # set gan update outputs
    gan_updates_outputs = [generator_gan_cost,
                           discriminator_gan_cost,
                           positive_score,
                           negative_score,
                           square_error,
                           generator_gradient_norm,
                           discriminator_gradient_norm,]

    # set gan update function
    gan_updates_function = theano.function(inputs=gan_updates_inputs,
                                           outputs=gan_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict, discriminator_updates_dict]),
                                           on_unused_input='ignore')

    return gan_updates_function