def costs(self, application_call, prediction, prediction_mask, groundtruth, groundtruth_mask, **inputs): def _prediction_subtensor(data): if data.ndim != 3: raise ValueError flat_data = data.reshape( (data.shape[0] * data.shape[1], data.shape[2])) flat_data = flat_data[tensor.arange(flat_data.shape[0]), prediction.flatten()] return flat_data.reshape( (prediction.shape[0], prediction.shape[1])) attended = disconnected_grad(inputs.pop('attended')) attended_mask = disconnected_grad(inputs.pop('attended_mask')) # Compute the rewards rewards = self.reward_brick.apply(prediction, prediction_mask, groundtruth, groundtruth_mask)[:, :, 0] future_rewards = rewards[::-1].cumsum(axis=0)[::-1] # Compute the critic outputs if self.critic: padding = tensor.repeat(tensor.fill(prediction[0:1], self.bos_token), 1, axis=0) mask_padding = tensor.repeat(tensor.fill(prediction_mask[0:1], 1.), 1, axis=0) padded_prediction = tensor.concatenate([padding, prediction]) padded_prediction_mask = tensor.concatenate( [mask_padding, prediction_mask]) if self.critic_uses_groundtruth: critic_context = groundtruth critic_context_mask = groundtruth_mask else: critic_context = tensor.zeros_like(groundtruth[0:1]) critic_context_mask = tensor.zeros_like(groundtruth_mask[0:1]) critic_kwargs = dict(prediction=padded_prediction, prediction_mask=padded_prediction_mask, groundtruth=critic_context, groundtruth_mask=critic_context_mask, inputs=critic_context, inputs_mask=critic_context_mask) if self.critic_uses_actor_states: extra_inputs = disconnected_grad(inputs['states']) # We don't the very last hidden state of the actor # in extra_inputs. We have to add something instead for the shapes # to match. It doesn't matter at all, what exactly we add. critic_kwargs['extra_inputs'] = tensor.concatenate( [extra_inputs, tensor.zeros_like(extra_inputs[0:1])]) critic_cg = ComputationGraph(self.critic.costs(**critic_kwargs)) outputs, = VariableFilter( applications=[self.critic.generator.readout.all_outputs], roles=[OUTPUT])(critic_cg) # The first subtensor should be discarded, because it was outputted # for the padding. In addition to that Q-values from the first # 'critic_burnin_steps' will be ignored, see later in the code. outputs = outputs[1:] else: outputs = self.merge(**dict_subset(inputs, self.merge_names)) prediction_outputs = _prediction_subtensor(outputs) # Compute Q adjustments adjustments = outputs prediction_adjustments = prediction_outputs if self.accumulate_outputs: prediction_adjustments = prediction_outputs.cumsum(axis=0) adjustments = tensor.inc_subtensor( adjustments[1:], prediction_adjustments[:-1][:, :, None]) # Compute shared additive biases for all Q values if self.use_value_biases: value_biases = (self.value_summand.apply(attended)[:, :, 0] * attended_mask).sum(axis=0) else: value_biases = tensor.zeros_like(adjustments[0, :, 0]) values = adjustments + value_biases[None, :, None] prediction_values = prediction_adjustments + value_biases[None, :] rolled_prediction_mask = tensor.roll(prediction_mask, -1, axis=0) rolled_prediction_mask = tensor.set_subtensor( rolled_prediction_mask[-1], 0) # Compute probabilities logs = self.scores(use_epsilon=False, **inputs) probs = tensor.exp(logs) if not self.compute_policy: raise NotImplementedError("Not supported any more") prediction_logs = _prediction_subtensor(logs) # Compute value targets value_targets = (disconnected_grad(probs) * values).sum(axis=-1) value_targets = tensor.roll(value_targets, -1, axis=0) value_targets = ( self.discount * value_targets * rolled_prediction_mask + rewards) value_targets = value_targets.astype(theano.config.floatX) total_costs = 0 # Compute critic cost if not self.compute_targets: logger.debug("Using given targets") value_targets = tensor.matrix('value_targets') if self.solve_bellman == 'no': logger.debug("Not solving Bellman, just predicting the rewards") value_targets = rewards.copy(name='value_targets') elif self.solve_bellman == 'without_dp': future_rewards = rewards[::-1].cumsum(axis=0)[::-1] logger.debug("Solving Bellman, but without DP") value_targets = future_rewards elif self.solve_bellman is not True: raise ValueError() critic_costs_per_char = ( (prediction_values - value_targets)**2) * prediction_mask critic_costs = critic_costs_per_char[self.critic_burnin_steps:].sum( axis=0) if not self.freeze_critic: total_costs += critic_costs # Compute critic Monte-Carlo cost critic_monte_carlo_costs = ( (((prediction_values - future_rewards)**2) * prediction_mask)[self.critic_burnin_steps:].sum(axis=0)) # Value penalty if self.value_penalty: logger.debug("Use value penalty") value_deviations = (values - values.mean(axis=-1, keepdims=True))**2 if not self.freeze_critic: total_costs += ( self.value_penalty * (value_deviations.sum(axis=-1) * prediction_mask)[self.critic_burnin_steps:].sum(axis=0)) # Compute actor cost if self.critic: # The actor cost will be minimized, that's why values # must be negated. est_name = self.actor_grad_estimate if est_name == 'all_actions': disadvantages = disconnected_grad( values.max(axis=-1)[:, :, None] - values) actor_costs = ((probs * disadvantages).sum(axis=-1) * prediction_mask) actor_costs = actor_costs[self.critic_burnin_steps:] elif est_name.startswith('1_action'): # Here we do not provide a target for the first step for # the reason we lack an estimate of the value of the initial state. # This is how our critic works. # Hopefully the network won't unlearn # to produce a BOS first. future_reward_estimate = (future_rewards if est_name.endswith('unbiased') else prediction_values) weights = -disconnected_grad(future_reward_estimate[1:] + rewards[:-1] - prediction_values[:-1]) actor_costs = ((prediction_logs[1:] * weights) * prediction_mask[1:]) actor_costs = actor_costs[self.critic_burnin_steps + 1:] else: raise ValueError actor_costs = actor_costs.sum(axis=0) actor_entropies = (probs * -logs).sum(axis=-1) * prediction_mask actor_entropies = actor_entropies[self.critic_burnin_steps:].sum( axis=0) critic_policy = disconnected_grad( self.softmax.apply(self.critic_policy_t * values, extra_ndim=1)) critic_cross_entropies = ((critic_policy * -logs).sum(axis=-1) * prediction_mask) critic_cross_entropies = critic_cross_entropies[ self.critic_burnin_steps:].sum(axis=0) actor_costs_with_penalties = ( actor_costs - self.entropy_reward_coof * actor_entropies - self.cross_entropy_reward_coof * critic_cross_entropies) if not self.freeze_actor: total_costs += actor_costs_with_penalties else: total_costs += disconnected_grad(actor_costs_with_penalties) # Add auxiliary variables for intermediate steps of the computation application_call.add_auxiliary_variable(rewards, name='rewards') application_call.add_auxiliary_variable(value_biases, name='value_biases') application_call.add_auxiliary_variable(values.copy(), name='values') application_call.add_auxiliary_variable(outputs.copy(), name='outputs') application_call.add_auxiliary_variable(prediction_values, name='prediction_values') application_call.add_auxiliary_variable(prediction_outputs, name='prediction_outputs') application_call.add_auxiliary_variable(value_targets.copy(), name='value_targets') application_call.add_auxiliary_variable(probs.copy(), name='probs') application_call.add_auxiliary_variable(prediction_logs, name='prediction_log_probs') # Compute some statistics for debugging last_character_mask = prediction_mask - rolled_prediction_mask last_character_costs = (critic_costs_per_char * last_character_mask).sum(axis=0) mean2_output = (((prediction_outputs**2) * prediction_mask).sum() / prediction_mask.sum())**0.5 max_output = abs(prediction_outputs * prediction_mask).max() expected_reward = (probs[0] * values[0]).sum(axis=-1) application_call.add_auxiliary_variable(last_character_costs, name='last_character_costs') application_call.add_auxiliary_variable(critic_costs.mean(), name='mean_critic_cost') application_call.add_auxiliary_variable( critic_monte_carlo_costs.mean(), name='mean_critic_monte_carlo_cost') if self.critic: application_call.add_auxiliary_variable(actor_costs.mean(), name='mean_actor_cost') application_call.add_auxiliary_variable(actor_entropies.mean(), name='mean_actor_entropy') application_call.add_auxiliary_variable(expected_reward.mean(), name='mean_expected_reward') application_call.add_auxiliary_variable(mean2_output, name='mean2_output') application_call.add_auxiliary_variable(max_output, name='max_output') return total_costs