def set_update_function(recurrent_model, output_model, optimizer, grad_clip=1.0): # set input data (time_length * num_samples * input_dims) input_data = tensor.tensor3(name='input_data', dtype=floatX) # set input mask (time_length * num_samples) input_mask = tensor.matrix(name='input_mask', dtype=floatX) # set init hidden/cell data (num_samples * hidden_dims) init_hidden = tensor.matrix(name='init_hidden', dtype=floatX) init_cell = tensor.matrix(name='init_cell', dtype=floatX) # truncate grad truncate_grad_step = tensor.scalar(name='truncate_grad_step', dtype='int32') # set target data (time_length * num_samples * output_dims) target_data = tensor.tensor3(name='target_data', dtype=floatX) # get hidden data input_list = [input_data, None, None, None, truncate_grad_step] hidden_data = get_lstm_outputs(input_list=input_list, layers=recurrent_model, is_training=True)[-1] # get prediction data output_data = get_tensor_output(input=hidden_data, layers=output_model, is_training=True) # get cost (here mask_seq is like weight, sum over feature, and time) sample_cost = tensor.sqr(output_data-target_data) sample_cost = tensor.sum(sample_cost, axis=(0, 2)) # get model updates model_cost = sample_cost.mean() model_updates_dict = get_model_updates(layers=recurrent_model+output_model, cost=model_cost, optimizer=optimizer, use_grad_clip=grad_clip) update_function_inputs = [input_data, input_mask, init_hidden, init_cell, target_data, truncate_grad_step] update_function_outputs = [hidden_data, output_data, sample_cost] update_function = theano.function(inputs=update_function_inputs, outputs=update_function_outputs, updates=model_updates_dict, on_unused_input='ignore') return update_function
def set_update_function(recurrent_model, output_model, controller_optimizer, model_optimizer, grad_clip=1.0): # set input data (time_length * num_samples * input_dims) input_data = tensor.tensor3(name='input_data', dtype=floatX) # set target data (time_length * num_samples * output_dims) target_data = tensor.tensor3(name='target_data', dtype=floatX) time_length = input_data.shape[0] num_samples = input_data.shape[1] # cost control parameter controller = theano.shared(value=1.0, name='controller') # get hidden data input_list = [input_data, ] hidden_data = get_lstm_outputs(input_list=input_list, layers=recurrent_model, is_training=True)[-1] # get prediction data output_data = get_tensor_output(input=hidden_data, layers=output_model, is_training=True) # get cost (here mask_seq is like weight, sum over feature, and time) sample_cost = tensor.sqr(output_data-target_data) sample_cost = tensor.sum(input=sample_cost, axis=2).reshape((time_length, num_samples)) # time_step = tensor.arange(start=0, stop=time_length, dtype=floatX).reshape((time_length, 1)) # time_step = tensor.repeat(time_step, num_samples, axis=1) # cost_weight (time_length * num_samples) # cost_weight = tensor.transpose(-controller*time_step) # cost_weight = tensor.nnet.softmax(cost_weight) # cost_weight = tensor.transpose(cost_weight).reshape((time_length, num_samples)) # weighted_sample_cost = cost_weight*sample_cost # get model updates # model_cost = weighted_sample_cost.sum(axis=0).mean() model_cost = sample_cost.max(axis=0).mean() model_updates_dict = get_model_updates(layers=recurrent_model+output_model, cost=model_cost, optimizer=model_optimizer, use_grad_clip=grad_clip) # controller_cost = weighted_sample_cost.var(axis=0).mean() # # controller_updates_dict = OrderedDict() # controller_grad = tensor.grad(cost=controller_cost, wrt=controller) # for param, update in controller_optimizer(controller, controller_grad).iteritems(): # controller_updates_dict[param] = update update_function_inputs = [input_data, target_data] update_function_outputs = [hidden_data, output_data, sample_cost] # update_function_updates = merge_dicts([model_updates_dict, controller_updates_dict]) update_function_updates = model_updates_dict update_function = theano.function(inputs=update_function_inputs, outputs=update_function_outputs, updates=update_function_updates, on_unused_input='ignore') return update_function
def set_generator_update_function(generator_rnn_model, discriminator_rnn_model, discriminator_output_model, generator_optimizer, grad_clipping): # init input data (num_samples *input_dims) init_input_data = tensor.matrix(name='init_input_data', dtype=floatX) # init hidden data (num_layers * num_samples *input_dims) init_hidden_data = tensor.tensor3(name='init_hidden_data', dtype=floatX) # init cell data (num_layers * num_samples *input_dims) init_cell_data = tensor.tensor3(name='init_cell_data', dtype=floatX) # sampling length sampling_length = tensor.scalar(name='sampling_length', dtype='int32') # set generator input data list generator_input_data_list = [init_input_data, init_hidden_data, init_cell_data, sampling_length] # get generator output data output_data = generator_rnn_model[0].forward(generator_input_data_list, is_training=True)[0] # set discriminator input data list discriminator_input_data_list = [output_data,] # get discriminator hidden data discriminator_hidden_data = get_lstm_outputs(input_list=discriminator_input_data_list, layers=discriminator_rnn_model, is_training=True)[-1] # get discriminator output data sample_cost_data = get_tensor_output(input=discriminator_hidden_data, layers=discriminator_output_model, is_training=True)[-1] # get cost based on discriminator (binary cross-entropy over all data) # sum over generator cost over time_length and output_dims, then mean over samples generator_cost = tensor.nnet.binary_crossentropy(output=sample_cost_data, target=tensor.ones_like(sample_cost_data)).sum(axis=1) # set generator update generator_updates_cost = generator_cost.mean() generator_updates_dict = get_model_updates(layers=generator_rnn_model, cost=generator_updates_cost, optimizer=generator_optimizer, use_grad_clip=grad_clipping) # set generator update inputs generator_updates_inputs = [init_input_data, init_hidden_data, init_cell_data, sampling_length] # set generator update outputs generator_updates_outputs = [sample_cost_data, generator_cost] # set generator update function generator_updates_function = theano.function(inputs=generator_updates_inputs, outputs=generator_updates_outputs, updates=generator_updates_dict, on_unused_input='ignore') return generator_updates_function