def set_generator_update_function(
    generator_rnn_model, generator_mean_model, generator_std_model, generator_optimizer, grad_clipping
):

    # input data (time length * num_samples * input_dims)
    source_data = tensor.tensor3(name="source_data", dtype=floatX)

    target_data = tensor.tensor3(name="target_data", dtype=floatX)

    # set generator input data list
    generator_input_data_list = [source_data]

    # get generator hidden data
    hidden_data = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)[0]
    hidden_data = hidden_data.dimshuffle(0, 2, 1, 3).flatten(3)

    # get generator output data
    output_mean_data = get_tensor_output(input=hidden_data, layers=generator_mean_model, is_training=True)
    # output_std_data = get_tensor_output(input=hidden_data,
    #                                     layers=generator_std_model,
    #                                     is_training=True)
    output_std_data = 0.22
    # get generator cost (time_length x num_samples x hidden_size)
    generator_cost = 0.5 * tensor.inv(2.0 * tensor.sqr(output_std_data)) * tensor.sqr(output_mean_data - target_data)
    generator_cost += tensor.log(output_std_data) + 0.5 * tensor.log(2.0 * numpy.pi)
    generator_cost = tensor.sum(generator_cost, axis=2)

    # set generator update
    generator_updates_cost = generator_cost.mean()
    generator_updates_dict = get_model_updates(
        layers=generator_rnn_model + generator_mean_model,
        cost=generator_updates_cost,
        optimizer=generator_optimizer,
        use_grad_clip=grad_clipping,
    )

    gradient_dict = get_model_gradients(generator_rnn_model + generator_mean_model, generator_updates_cost)
    gradient_norm = 0.0
    for grad in gradient_dict:
        gradient_norm += tensor.sum(grad ** 2)
    gradient_norm = tensor.sqrt(gradient_norm)

    # set generator update inputs
    generator_updates_inputs = [source_data, target_data]

    # set generator update outputs
    generator_updates_outputs = [generator_cost, gradient_norm]

    # set generator update function
    generator_updates_function = theano.function(
        inputs=generator_updates_inputs,
        outputs=generator_updates_outputs,
        updates=generator_updates_dict,
        on_unused_input="ignore",
    )

    return generator_updates_function
Example #2
0
def set_generator_update_function(generator_rnn_model,
                                  generator_mean_model,
                                  generator_std_model,
                                  generator_optimizer,
                                  grad_clipping):

    # input data (time length * num_samples * input_dims)
    source_data = tensor.tensor3(name='source_data',
                                 dtype=floatX)

    target_data = tensor.tensor3(name='target_data',
                                 dtype=floatX)

    # set generator input data list
    generator_input_data_list = [source_data,]

    # get generator hidden data
    hidden_data = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)[0]

    # get generator output data
    output_mean_data = get_tensor_output(input=hidden_data,
                                         layers=generator_mean_model,
                                         is_training=True)
    output_std_data = get_tensor_output(input=hidden_data,
                                        layers=generator_std_model,
                                        is_training=True)

    generator_cost  = -0.5*tensor.inv(2.0*tensor.sqr(output_std_data))*tensor.sqr(output_mean_data-target_data)
    generator_cost += -0.5*tensor.log(2.0*tensor.sqr(output_std_data)*numpy.pi)

    # set generator update
    generator_updates_cost = generator_cost.mean()
    generator_updates_dict = get_model_updates(layers=generator_rnn_model+generator_mean_model+generator_std_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=grad_clipping)

    gradient_dict  = get_model_gradients(generator_rnn_model+generator_mean_model+generator_std_model, generator_updates_cost)
    gradient_norm  = 0.
    for grad in gradient_dict:
        gradient_norm += tensor.sum(grad**2)
        gradient_norm  = tensor.sqrt(gradient_norm)

    # set generator update inputs
    generator_updates_inputs  = [source_data,
                                 target_data,]

    # set generator update outputs
    generator_updates_outputs = [generator_cost, gradient_norm]

    # set generator update function
    generator_updates_function = theano.function(inputs=generator_updates_inputs,
                                                 outputs=generator_updates_outputs,
                                                 updates=generator_updates_dict,
                                                 on_unused_input='ignore')

    return generator_updates_function
Example #3
0
def set_updater_function(generator_model, generator_optimizer, generator_grad_clipping):
    # input sequence data (time_length * num_samples * input_dims)
    input_sequence = tensor.tensor3(name="input_sequence", dtype=floatX)
    target_sequence = tensor.tensor3(name="target_sequence", dtype=floatX)
    lambda_regularizer = tensor.scalar(name="lambda_regularizer", dtype=floatX)

    # set generator input data list
    generator_input_data_list = [input_sequence]

    # get generator output data
    generator_output = generator_model[0].forward(generator_input_data_list, is_training=True)
    output_sequence = generator_output[0]
    data_hidden = generator_output[1]
    data_cell = generator_output[2]
    model_hidden = generator_output[3]
    model_cell = generator_output[4]
    generator_random = generator_output[-1]

    # get square error
    sample_cost = tensor.sqr(target_sequence - output_sequence).sum(axis=2)

    # get positive phase hidden
    positive_hid = data_hidden[1:]

    # get negative phase hidden
    negative_hid = model_hidden[1:]

    # get phase diff cost
    regularizer_cost = tensor.sqr(positive_hid - negative_hid).sum(axis=2)

    # set generator update
    updater_cost = sample_cost.mean() + regularizer_cost.mean() * lambda_regularizer
    updater_dict = get_model_updates(layers=generator_model, cost=updater_cost, optimizer=generator_optimizer)

    # get generator gradient norm2
    generator_gradient_dict = get_model_gradients(generator_model, updater_cost)
    generator_gradient_norm = 0.0
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad ** 2)
    generator_gradient_norm = tensor.sqrt(generator_gradient_norm)

    # set updater inputs
    updater_inputs = [input_sequence, target_sequence, lambda_regularizer]

    # set updater outputs
    updater_outputs = [sample_cost, regularizer_cost, generator_gradient_norm]

    # set updater function
    updater_function = theano.function(
        inputs=updater_inputs,
        outputs=updater_outputs,
        updates=merge_dicts([updater_dict, generator_random]),
        on_unused_input="ignore",
    )

    return updater_function
def set_updater_function(generator_rnn_model,
                         generator_emb_matrix,
                         generator_optimizer,
                         generator_grad_clipping):

    # input/target sequence data (time_length * batch_size, list of idx)
    input_sequence  = tensor.matrix(name='input_sequence',
                                     dtype=floatX)
    target_sequence = tensor.matrix(name='target_sequence',
                                     dtype=floatX)

    # generator_emb_matrix.shape = (num_idx, feature_size)
    input_emb_sequence  = generator_emb_matrix[input_sequence]
    target_emb_sequence = generator_emb_matrix[target_sequence]

    # set generator input data list
    generator_input_data_list = [input_emb_sequence,]

    # get generator output data
    generator_output = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)
    generator_sample = generator_output[0]
    generator_random = generator_output[-1]

    # get square error
    square_error = tensor.sqr(target_sequence-generator_sample).sum(axis=2)

    # set generator update
    tf_updates_cost = square_error.mean()
    tf_updates_dict = get_model_updates(layers=generator_rnn_model,
                                        cost=tf_updates_cost,
                                        optimizer=generator_optimizer)

    generator_gradient_dict  = get_model_gradients(layers=generator_rnn_model,
                                                   cost=tf_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set tf update inputs
    tf_updates_inputs  = [input_sequence,
                          target_sequence]

    # set tf update outputs
    tf_updates_outputs = [square_error,
                          generator_gradient_norm,]

    # set tf update function
    tf_updates_function = theano.function(inputs=tf_updates_inputs,
                                          outputs=tf_updates_outputs,
                                          updates=merge_dicts([tf_updates_dict, generator_random]),
                                          on_unused_input='ignore')

    return tf_updates_function
def set_tf_update_function(generator_model,
                           generator_optimizer,
                           generator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                      dtype=floatX)
    # set generator input data list
    generator_input_data_list = [input_sequence,]

    # get generator output data
    generator_output = generator_model[0].forward(generator_input_data_list, is_training=True)
    output_sequence  = generator_output[0]
    generator_random = generator_output[-1]

    # get square error
    square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2)

    # set generator update
    tf_updates_cost = square_error.mean()
    tf_updates_dict = get_model_updates(layers=generator_model,
                                        cost=tf_updates_cost,
                                        optimizer=generator_optimizer)

    generator_gradient_dict  = get_model_gradients(generator_model, tf_updates_cost)

    # get generator gradient norm2
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set tf update inputs
    tf_updates_inputs  = [input_sequence,
                          target_sequence]

    # set tf update outputs
    tf_updates_outputs = [square_error,
                          generator_gradient_norm,]

    # set tf update function
    tf_updates_function = theano.function(inputs=tf_updates_inputs,
                                          outputs=tf_updates_outputs,
                                          updates=merge_dicts([tf_updates_dict,
                                                               generator_random]),
                                          on_unused_input='ignore')

    return tf_updates_function
def set_gan_update_function(generator_model,
                            discriminator_feature_model,
                            discriminator_output_model,
                            generator_optimizer,
                            discriminator_optimizer,
                            generator_grad_clipping,
                            discriminator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                      dtype=floatX)

    # set generator input data list
    generator_input_data_list = [input_sequence,
                                 1]

    # get generator output data
    generator_output = generator_model[0].forward(generator_input_data_list,
                                                  is_training=True)
    output_sequence  = generator_output[0]
    data_hidden      = generator_output[1]
    data_cell        = generator_output[2]
    model_hidden     = generator_output[3]
    model_cell       = generator_output[4]
    generator_random = generator_output[-1]

    # get conditional hidden
    condition_hid    = data_hidden[:-1]
    condition_hid    = theano.gradient.disconnected_grad(condition_hid)
    condition_feature = get_tensor_output(condition_hid,
                                          discriminator_feature_model,
                                          is_training=True)

    # get positive phase hidden
    positive_hid     = data_hidden[1:]
    positive_feature = get_tensor_output(positive_hid,
                                         discriminator_feature_model,
                                         is_training=True)
    # get negative phase hidden
    negative_hid     = model_hidden[1:]
    negative_feature = get_tensor_output(negative_hid,
                                         discriminator_feature_model,
                                         is_training=True)

    # get positive/negative phase pairs
    positive_pair = tensor.concatenate([condition_feature, positive_feature], axis=2)
    negative_pair = tensor.concatenate([condition_feature, negative_feature], axis=2)

    # get positive pair score
    positive_score = get_tensor_output(positive_pair,
                                       discriminator_output_model,
                                       is_training=True)
    # get negative pair score
    negative_score = get_tensor_output(negative_pair,
                                       discriminator_output_model,
                                       is_training=True)

    # get generator cost (increase negative score)
    generator_gan_cost = tensor.nnet.binary_crossentropy(output=negative_score,
                                                         target=tensor.ones_like(negative_score))

    # get discriminator cost (increase positive score, decrease negative score)
    discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=positive_score,
                                                              target=tensor.ones_like(positive_score)) +
                              tensor.nnet.binary_crossentropy(output=negative_score,
                                                              target=tensor.zeros_like(negative_score)))

    # set generator update
    generator_updates_cost = generator_gan_cost.mean()
    generator_updates_dict = get_model_updates(layers=generator_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=generator_grad_clipping)

    # get generator gradient norm2
    generator_gradient_dict  = get_model_gradients(generator_model, generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set discriminator update
    discriminator_updates_cost = discriminator_gan_cost.mean()
    discriminator_updates_dict = get_model_updates(layers=discriminator_feature_model+discriminator_output_model,
                                                   cost=discriminator_updates_cost,
                                                   optimizer=discriminator_optimizer,
                                                   use_grad_clip=discriminator_grad_clipping)

    discriminator_gradient_dict  = get_model_gradients(discriminator_feature_model+discriminator_output_model,
                                                       discriminator_updates_cost)

    # get discriminator gradient norm2
    discriminator_gradient_norm  = 0.
    for grad in discriminator_gradient_dict:
        discriminator_gradient_norm += tensor.sum(grad**2)
    discriminator_gradient_norm  = tensor.sqrt(discriminator_gradient_norm)

    # get mean square error
    square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2)

    # set gan update inputs
    gan_updates_inputs  = [input_sequence,
                           target_sequence]

    # set gan update outputs
    gan_updates_outputs = [generator_gan_cost,
                           discriminator_gan_cost,
                           positive_score,
                           negative_score,
                           square_error,
                           generator_gradient_norm,
                           discriminator_gradient_norm,]

    # set gan update function
    gan_updates_function = theano.function(inputs=gan_updates_inputs,
                                           outputs=gan_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict,
                                                                discriminator_updates_dict,
                                                                generator_random]),
                                           on_unused_input='ignore')

    return gan_updates_function
Example #7
0
def set_gan_update_function(generator_model,
                            discriminator_model,
                            generator_optimizer,
                            discriminator_optimizer,
                            generator_grad_clipping,
                            discriminator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                    dtype=floatX)
    # set generator input data list
    generator_input_data_list = [input_sequence,]

    # get generator output data
    output_data_set = generator_model[0].forward(generator_input_data_list, is_training=True)
    output_sequence = output_data_set[0]
    data_hidden     = output_data_set[1]
    data_cell       = output_data_set[2]
    model_hidden    = output_data_set[3]
    model_cell      = output_data_set[4]

    condition_hidden = data_hidden[:-1]
    condition_cell   = data_cell[:-1]

    condition_hidden = theano.gradient.disconnected_grad(condition_hidden)
    condition_cell   = theano.gradient.disconnected_grad(condition_cell)

    true_hidden = data_hidden[1:]
    true_cell   = data_cell[1:]

    false_hidden = model_hidden[1:]
    false_cell   = model_cell[1:]

    true_pair_hidden = tensor.concatenate([condition_hidden, true_hidden], axis=2)
    true_pair_cell   = tensor.concatenate([condition_cell, true_cell], axis=2)

    false_pair_hidden = tensor.concatenate([condition_hidden, false_hidden], axis=2)
    false_pair_cell   = tensor.concatenate([condition_cell, false_cell], axis=2)

    discriminator_true_score  = get_tensor_output(true_pair_hidden, discriminator_model, is_training=True)
    discriminator_false_score = get_tensor_output(false_pair_hidden, discriminator_model, is_training=True)


    generator_gan_cost = tensor.nnet.binary_crossentropy(output=discriminator_false_score,
                                                         target=tensor.ones_like(discriminator_false_score))

    discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=discriminator_true_score,
                                                              target=tensor.ones_like(discriminator_true_score)) +
                              tensor.nnet.binary_crossentropy(output=discriminator_false_score,
                                                              target=tensor.zeros_like(discriminator_false_score)))

    # set generator update
    generator_updates_cost = generator_gan_cost.mean()
    generator_updates_dict = get_model_updates(layers=generator_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=generator_grad_clipping)

    generator_gradient_dict  = get_model_gradients(generator_model, generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set discriminator update
    discriminator_updates_cost = discriminator_gan_cost.mean()
    discriminator_updates_dict = get_model_updates(layers=discriminator_model,
                                                   cost=discriminator_updates_cost,
                                                   optimizer=discriminator_optimizer,
                                                   use_grad_clip=discriminator_grad_clipping)

    discriminator_gradient_dict  = get_model_gradients(discriminator_model, discriminator_updates_cost)
    discriminator_gradient_norm  = 0.
    for grad in discriminator_gradient_dict:
        discriminator_gradient_norm += tensor.sum(grad**2)
    discriminator_gradient_norm  = tensor.sqrt(discriminator_gradient_norm)

    square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2)

    # set gan update inputs
    gan_updates_inputs  = [input_sequence,
                           target_sequence]

    # set gan update outputs
    gan_updates_outputs = [generator_gan_cost,
                           discriminator_gan_cost,
                           discriminator_true_score,
                           discriminator_false_score,
                           square_error,
                           generator_gradient_norm,
                           discriminator_gradient_norm,]

    # set gan update function
    gan_updates_function = theano.function(inputs=gan_updates_inputs,
                                           outputs=gan_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict, discriminator_updates_dict]),
                                           on_unused_input='ignore')

    return gan_updates_function
def set_gan_update_function(generator_rnn_model,
                            generator_output_model,
                            discriminator_rnn_model,
                            discriminator_output_model,
                            generator_optimizer,
                            discriminator_optimizer,
                            generator_grad_clipping,
                            discriminator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                      dtype=floatX)
    # set generator input data list
    generator_input_data_list = [input_sequence,]

    # get generator output data
    generator_output = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)
    generator_hidden = generator_output[0]
    generator_cell   = generator_output[1]

    generator_sample = get_tensor_output(generator_hidden, generator_output_model, is_training=True)

    condition_generator_hidden = theano.gradient.disconnected_grad(generator_hidden)

    positive_pair = tensor.concatenate([condition_generator_hidden, target_sequence], axis=2)
    negative_pair = tensor.concatenate([condition_generator_hidden, generator_sample], axis=2)

    # set generator input data list
    discriminator_input_data_list = [positive_pair,]
    discriminator_output = discriminator_rnn_model[0].forward(discriminator_input_data_list, is_training=True)
    positive_hidden = discriminator_output[0]
    positive_cell   = discriminator_output[1]
    positive_score  = get_tensor_output(positive_hidden, discriminator_output_model, is_training=True)

    discriminator_input_data_list = [negative_pair,]
    discriminator_output = discriminator_rnn_model[0].forward(discriminator_input_data_list, is_training=True)
    negative_hidden = discriminator_output[0]
    negative_cell   = discriminator_output[1]
    negative_score  = get_tensor_output(negative_hidden, discriminator_output_model, is_training=True)


    generator_gan_cost = tensor.nnet.binary_crossentropy(output=negative_score,
                                                         target=tensor.ones_like(negative_score))

    discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=positive_score,
                                                              target=tensor.ones_like(positive_score)) +
                              tensor.nnet.binary_crossentropy(output=negative_score,
                                                              target=tensor.zeros_like(negative_score)))

    # set generator update
    generator_updates_cost = generator_gan_cost.mean()
    generator_updates_dict = get_model_updates(layers=generator_rnn_model+generator_output_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=generator_grad_clipping)

    generator_gradient_dict  = get_model_gradients(generator_rnn_model+generator_output_model, generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set discriminator update
    discriminator_updates_cost = discriminator_gan_cost.mean()
    discriminator_updates_dict = get_model_updates(layers=discriminator_rnn_model+discriminator_output_model,
                                                   cost=discriminator_updates_cost,
                                                   optimizer=discriminator_optimizer,
                                                   use_grad_clip=discriminator_grad_clipping)

    discriminator_gradient_dict  = get_model_gradients(discriminator_rnn_model+discriminator_output_model, discriminator_updates_cost)
    discriminator_gradient_norm  = 0.
    for grad in discriminator_gradient_dict:
        discriminator_gradient_norm += tensor.sum(grad**2)
    discriminator_gradient_norm  = tensor.sqrt(discriminator_gradient_norm)

    square_error = tensor.sqr(target_sequence-generator_sample).sum(axis=2)

    # set gan update inputs
    gan_updates_inputs  = [input_sequence,
                           target_sequence]

    # set gan update outputs
    gan_updates_outputs = [generator_gan_cost,
                           discriminator_gan_cost,
                           positive_score,
                           negative_score,
                           square_error,
                           generator_gradient_norm,
                           discriminator_gradient_norm,]

    # set gan update function
    gan_updates_function = theano.function(inputs=gan_updates_inputs,
                                           outputs=gan_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict, discriminator_updates_dict]),
                                           on_unused_input='ignore')

    return gan_updates_function
Example #9
0
def set_reg_update_function(generator_model,
                            generator_optimizer,
                            generator_grad_clipping):

    # input sequence data (time_length * num_samples * input_dims)
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)
    target_sequence  = tensor.tensor3(name='target_sequence',
                                      dtype=floatX)

    # set generator input data list
    generator_input_data_list = [input_sequence,
                                 1]

    # get generator output data
    generator_output = generator_model[0].forward(generator_input_data_list,
                                                  is_training=True)
    output_sequence  = generator_output[0]
    data_hidden      = generator_output[1]
    data_cell        = generator_output[2]
    model_hidden     = generator_output[3]
    model_cell       = generator_output[4]
    generator_random = generator_output[-1]

    # get positive phase hidden
    positive_hid     = data_hidden[1:]
    positive_hid     = theano.gradient.disconnected_grad(positive_hid)

    # get negative phase hidden
    negative_hid     = model_hidden[1:]

    # get phase diff cost
    phase_diff = tensor.sqr(positive_hid-negative_hid).sum(axis=2)

    # set generator update
    generator_updates_cost = phase_diff.mean()
    generator_updates_dict = get_model_updates(layers=generator_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=generator_grad_clipping)

    # get generator gradient norm2
    generator_gradient_dict  = get_model_gradients(generator_model, generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # get mean square error
    square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2)

    # set reg update inputs
    reg_updates_inputs  = [input_sequence,
                           target_sequence]

    # set reg update outputs
    reg_updates_outputs = [phase_diff,
                           square_error,
                           generator_gradient_norm,]

    # set reg update function
    reg_updates_function = theano.function(inputs=reg_updates_inputs,
                                           outputs=reg_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict,
                                                                generator_random]),
                                           on_unused_input='ignore')

    return reg_updates_function
def set_gan_update_function(generator_rnn_model,
                            discriminator_rnn_model,
                            discriminator_output_model,
                            generator_optimizer,
                            discriminator_optimizer,
                            generator_grad_clipping,
                            discriminator_grad_clipping):

    # input for loop forward
    input_sequence  = tensor.tensor3(name='input_sequence',
                                     dtype=floatX)

    time_length = tensor.scalar(name='time_length',
                                dtype='int32')

    # get init data for looping
    generator_output = generator_rnn_model[0].loop_forward([input_sequence[0], time_length])
    generator_sequence    = generator_output[0]
    generator_rand_update = generator_output[-1]


    discriminator_output = discriminator_rnn_model[0].forward([input_sequence, ], is_training=True)
    positive_hidden = discriminator_output[0]
    positive_score  = get_tensor_output(positive_hidden, discriminator_output_model, is_training=True)

    discriminator_output = discriminator_rnn_model[0].forward([generator_sequence, ], is_training=True)
    negative_hidden = discriminator_output[0]
    negative_score  = get_tensor_output(negative_hidden, discriminator_output_model, is_training=True)

    generator_gan_cost = tensor.nnet.binary_crossentropy(output=negative_score,
                                                         target=tensor.ones_like(negative_score))

    discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=positive_score,
                                                              target=tensor.ones_like(positive_score)) +
                              tensor.nnet.binary_crossentropy(output=negative_score,
                                                              target=tensor.zeros_like(negative_score)))

    # set generator update
    generator_updates_cost = generator_gan_cost.mean()
    generator_updates_dict = get_model_updates(layers=generator_rnn_model,
                                               cost=generator_updates_cost,
                                               optimizer=generator_optimizer,
                                               use_grad_clip=generator_grad_clipping)

    generator_gradient_dict  = get_model_gradients(layers=generator_rnn_model,
                                                   cost=generator_updates_cost)
    generator_gradient_norm  = 0.
    for grad in generator_gradient_dict:
        generator_gradient_norm += tensor.sum(grad**2)
    generator_gradient_norm  = tensor.sqrt(generator_gradient_norm)

    # set discriminator update
    discriminator_updates_cost = discriminator_gan_cost.mean()
    discriminator_updates_dict = get_model_updates(layers=discriminator_rnn_model+discriminator_output_model,
                                                   cost=discriminator_updates_cost,
                                                   optimizer=discriminator_optimizer,
                                                   use_grad_clip=discriminator_grad_clipping)

    discriminator_gradient_dict  = get_model_gradients(layers=discriminator_rnn_model+discriminator_output_model,
                                                       cost=discriminator_updates_cost)
    discriminator_gradient_norm  = 0.
    for grad in discriminator_gradient_dict:
        discriminator_gradient_norm += tensor.sum(grad**2)
    discriminator_gradient_norm  = tensor.sqrt(discriminator_gradient_norm)

    # set gan update inputs
    gan_updates_inputs  = [input_sequence,
                           time_length]

    # set gan update outputs
    gan_updates_outputs = [generator_gan_cost,
                           discriminator_gan_cost,
                           positive_score,
                           negative_score,
                           generator_gradient_norm,
                           discriminator_gradient_norm,]

    # set gan update function
    gan_updates_function = theano.function(inputs=gan_updates_inputs,
                                           outputs=gan_updates_outputs,
                                           updates=merge_dicts([generator_updates_dict,
                                                                discriminator_updates_dict,
                                                                generator_rand_update]),
                                           on_unused_input='ignore')

    return gan_updates_function