def loss_function(y_true, y_pred): if isinstance(transform, str) and transform.lower() == 'disc': return losses.discriminative_instance_loss(y_true, y_pred) if isinstance(transform, str) and transform.lower() == 'watershed-cont': return MSE(y_true, y_pred) if focal: return losses.weighted_focal_loss( y_true, y_pred, gamma=gamma, n_classes=n_classes) return losses.weighted_categorical_crossentropy( y_true, y_pred, n_classes=n_classes)
def loss(): loss = 0 image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch = batch # Initial step, from the real observation: representation + prediction networks representation_batch, value_batch, policy_batch = network.initial_model(np.array(image_batch)) # Only update the element with a policy target target_value_batch, _, target_policy_batch = zip(*targets_init_batch) mask_policy = list(map(lambda l: bool(l), target_policy_batch)) target_policy_batch = list(filter(lambda l: bool(l), target_policy_batch)) policy_batch = tf.boolean_mask(policy_batch, mask_policy) # Compute the loss of the first pass loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) loss += tf.math.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch)) # Recurrent steps, from action and previous hidden state. for actions_batch, targets_batch, mask, dynamic_mask in zip(actions_time_batch, targets_time_batch, mask_time_batch, dynamic_mask_time_batch): target_value_batch, target_reward_batch, target_policy_batch = zip(*targets_batch) # Only execute BPTT for elements with an action representation_batch = tf.boolean_mask(representation_batch, dynamic_mask) target_value_batch = tf.boolean_mask(target_value_batch, mask) target_reward_batch = tf.boolean_mask(target_reward_batch, mask) # Creating conditioned_representation: concatenate representations with actions batch actions_batch = tf.one_hot(actions_batch, network.action_size) # Recurrent step from conditioned representation: recurrent + prediction networks conditioned_representation_batch = tf.concat((representation_batch, actions_batch), axis=1) representation_batch, reward_batch, value_batch, policy_batch = network.recurrent_model( conditioned_representation_batch) # Only execute BPTT for elements with a policy target target_policy_batch = [policy for policy, b in zip(target_policy_batch, mask) if b] mask_policy = list(map(lambda l: bool(l), target_policy_batch)) target_policy_batch = tf.convert_to_tensor([policy for policy in target_policy_batch if policy]) policy_batch = tf.boolean_mask(policy_batch, mask_policy) # Compute the partial loss l = (tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) + MSE(target_reward_batch, tf.squeeze(reward_batch)) + tf.math.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch))) # Scale the gradient of the loss by the average number of actions unrolled gradient_scale = 1. / len(actions_time_batch) loss += scale_gradient(l, gradient_scale) # Half the gradient of the representation representation_batch = scale_gradient(representation_batch, 0.5) return loss
def mse_loss(logits, labels): """Mse loss.""" return tf.reduce_mean(MSE(logits, labels))
def _semantic_loss(y_pred, y_true): if n_classes > 1: return panoptic_weight * losses.weighted_categorical_crossentropy( y_pred, y_true, n_classes=n_classes) return panoptic_weight * MSE(y_pred, y_true)