def set_tf_update_function(input_emb_param, generator_rnn_model, generator_output_model, generator_optimizer, generator_grad_clipping): # input sequence data (time_length * num_samples * input_dims) input_sequence = tensor.tensor3(name='input_sequence', dtype=floatX) target_sequence = tensor.tensor3(name='target_sequence', dtype=floatX) # embedding sequence input_emb_sequence = tensor.dot(input_sequence, input_emb_param) target_emb_sequence = tensor.dot(target_sequence, input_emb_param) # set generator input data list generator_input_data_list = [input_emb_sequence,] # get generator output data generator_output = generator_rnn_model[0].forward(generator_input_data_list, is_training=True) generator_hidden = generator_output[0] generator_cell = generator_output[1] generator_emb_sequence = get_tensor_output(generator_hidden, generator_output_model, is_training=True) generator_sequence = tensor.dot(generator_emb_sequence, tensor.transpose(input_emb_param)) # get square error square_error = tensor.sqr(target_sequence-generator_sequence).sum(axis=2) # set generator update tf_updates_cost = square_error.mean() tf_updates_dict = get_model_and_params_updates(layers=generator_rnn_model+generator_output_model, params=[input_emb_param,], cost=tf_updates_cost, optimizer=generator_optimizer) generator_gradient_dict = get_model_and_params_gradients(layers=generator_rnn_model+generator_output_model, params=[input_emb_param,], cost=tf_updates_cost) generator_gradient_norm = 0. for grad in generator_gradient_dict: generator_gradient_norm += tensor.sum(grad**2) generator_gradient_norm = tensor.sqrt(generator_gradient_norm) # set tf update inputs tf_updates_inputs = [input_sequence, target_sequence] # set tf update outputs tf_updates_outputs = [square_error, generator_gradient_norm,] # set tf update function tf_updates_function = theano.function(inputs=tf_updates_inputs, outputs=tf_updates_outputs, updates=tf_updates_dict, on_unused_input='ignore') return tf_updates_function
def set_gan_update_function(input_emb_param, generator_model, discriminator_feature_model, discriminator_output_model, generator_optimizer, discriminator_optimizer, generator_grad_clipping, discriminator_grad_clipping): # input sequence data (time_length * num_samples * input_dims) input_sequence = tensor.tensor3(name='input_sequence', dtype=floatX) target_sequence = tensor.tensor3(name='target_sequence', dtype=floatX) # embedding sequence input_emb_sequence = tensor.dot(input_sequence, input_emb_param) target_emb_sequence = tensor.dot(target_sequence, input_emb_param) # set generator input data list generator_input_data_list = [input_emb_sequence, 1] # get generator output data generator_output = generator_model[0].forward(generator_input_data_list, is_training=True) output_emb_sequence = generator_output[0] output_sequence = tensor.dot(output_emb_sequence, tensor.transpose(input_emb_param)) data_hidden = generator_output[1] data_cell = generator_output[2] model_hidden = generator_output[3] model_cell = generator_output[4] condition_hid = data_hidden[:-1] condition_hid = theano.gradient.disconnected_grad(condition_hid) condition_feature = get_tensor_output(condition_hid, discriminator_feature_model, is_training=True) positive_hid = data_hidden[1:] positive_feature = get_tensor_output(positive_hid, discriminator_feature_model, is_training=True) negative_hid = model_hidden[1:] negative_feature = get_tensor_output(negative_hid, discriminator_feature_model, is_training=True) positive_pair = tensor.concatenate([condition_feature, positive_feature], axis=2) negative_pair = tensor.concatenate([condition_feature, negative_feature], axis=2) positive_score = get_tensor_output(positive_pair, discriminator_output_model, is_training=True) negative_score = get_tensor_output(negative_pair, discriminator_output_model, is_training=True) generator_gan_cost = tensor.nnet.binary_crossentropy(output=negative_score, target=tensor.ones_like(negative_score)) discriminator_gan_cost = (tensor.nnet.binary_crossentropy(output=positive_score, target=tensor.ones_like(positive_score)) + tensor.nnet.binary_crossentropy(output=negative_score, target=tensor.zeros_like(negative_score))) # set generator update generator_updates_cost = generator_gan_cost.mean() generator_updates_dict = get_model_and_params_updates(layers=generator_model, params=[input_emb_param,], cost=generator_updates_cost, optimizer=generator_optimizer, use_grad_clip=generator_grad_clipping) generator_gradient_dict = get_model_and_params_gradients(layers=generator_model, params=[input_emb_param,], cost=generator_updates_cost) generator_gradient_norm = 0. for grad in generator_gradient_dict: generator_gradient_norm += tensor.sum(grad**2) generator_gradient_norm = tensor.sqrt(generator_gradient_norm) # set discriminator update discriminator_updates_cost = discriminator_gan_cost.mean() discriminator_updates_dict = get_model_updates(layers=discriminator_feature_model+discriminator_output_model, cost=discriminator_updates_cost, optimizer=discriminator_optimizer, use_grad_clip=discriminator_grad_clipping) discriminator_gradient_dict = get_model_gradients(discriminator_feature_model+discriminator_output_model, discriminator_updates_cost) discriminator_gradient_norm = 0. for grad in discriminator_gradient_dict: discriminator_gradient_norm += tensor.sum(grad**2) discriminator_gradient_norm = tensor.sqrt(discriminator_gradient_norm) square_error = tensor.sqr(target_sequence-output_sequence).sum(axis=2) # set gan update inputs gan_updates_inputs = [input_sequence, target_sequence] # set gan update outputs gan_updates_outputs = [generator_gan_cost, discriminator_gan_cost, positive_score, negative_score, square_error, generator_gradient_norm, discriminator_gradient_norm,] # set gan update function gan_updates_function = theano.function(inputs=gan_updates_inputs, outputs=gan_updates_outputs, updates=merge_dicts([generator_updates_dict, discriminator_updates_dict]), on_unused_input='ignore') return gan_updates_function