def metrics_fn(self, env_logits, env_batches, env_ids, params): """Calculates metrics for the classification task. Args: env_logits: list(dict); List of logits for examples from different environment (env_logits[0] is the logits for examples from env 0 which is a float array of shape `[batch, length, num_classes]`). env_batches: list(dict); List of batches of examples from different environment (env_batches[0] is a batch dict for examples from env 0 that has 'label' and optionally 'weights'.). env_ids: list(int); List of environment codes. params: pytree; parameters of the model. Returns: a dict of metrics. """ metrics_dic = {} envs_metrics_dic = {} # Add all the keys to envs_metrics_dic, each key will point to a list of # values from the correspondig metric for each environment. # Task related metrics for key in self._METRICS: envs_metrics_dic[key] = [] # Dataset related metrics (e.g., perturbation factors) for key in env_batches[0]: if 'factor' in key: envs_metrics_dic[key] = [] for i in range(len(env_logits)): logits = env_logits[i] batch = env_batches[i] env_name = self.dataset.get_full_env_name( self.dataset.id2env(env_ids[i])) env_metric_dic = super().metrics_fn(logits, batch) for key in env_metric_dic: metrics_dic[env_name + '/' + key] = env_metric_dic[key] envs_metrics_dic[key].append(env_metric_dic[key]) # Add overall metric values over all environments, for key in self._METRICS: metrics_dic[key] = (jnp.sum( jnp.array(jnp.array(envs_metrics_dic[key])[:, 0])), jnp.sum( jnp.array( jnp.array(envs_metrics_dic[key])[:, 1]))) if params: metrics_dic['l2'] = metrics.l2_regularization( params, include_bias_terms=self.task_params.get('l2_for_bias', False)) return metrics_dic
def loss_function(self, env_logits, env_batches, model_params=None, step=0): """Returns loss with an L2 penalty on the weights. Args: env_logits: list(dict); List of logits for examples from different environment (env_logits[0] is the logits for examples from env 0). env_batches: list(dict); List of batches of examples from different environment (env_batches[0] is a batch dict for examples from env 0). model_params: dict; Parameters of the model (used to commpute l2). step: int; Global training step. Returns: Total loss. """ env_losses = self.get_env_losses(env_logits, env_batches) total_loss = self.aggregate_envs_losses(env_losses) p_weight = self.penalty_weight(step) total_loss += p_weight * self.environments_penalties( env_logits, env_batches) if model_params: l2_decay_rate = self.get_l2_rate(step) if l2_decay_rate is not None: l2_loss = metrics.l2_regularization( model_params, include_bias_terms=self.task_params.get( 'l2_for_bias', False)) total_loss = total_loss + 0.5 * l2_decay_rate * l2_loss if self.regularisers: for reg_fn in self.regularisers: reg_value = reg_fn(model_params) total_loss += reg_value # If p_weights > 1: # Rescale the entire loss to keep gradients in a reasonable range. total_loss /= jnp.maximum(p_weight, 1) return total_loss
def loss_function(self, logits, batch, model_params=None, step=None): """Return cross entropy loss with an L2 penalty on the weights.""" weights = batch.get('weights') if self.dataset.meta_data['num_classes'] == 1: # If this is a binary classification task, make sure the shape of labels # is (bs, 1) and is the same as the shape of logits. targets = jnp.reshape(batch['label'], logits.shape) elif batch['label'].shape[-1] == self.dataset.meta_data['num_classes']: # If the labels are already the shape of (bs, num_classes) use them as is. targets = batch['label'] else: # Otherwise convert the labels to onehot labels. targets = common_utils.onehot(batch['label'], logits.shape[-1]) loss_value, loss_normalizer = self.main_loss_fn( logits, targets, weights, label_smoothing=self.task_params.get('label_smoothing')) total_loss = loss_value / loss_normalizer if model_params: l2_decay_factor = self.get_l2_rate(step) if l2_decay_factor is not None: l2_loss = metrics.l2_regularization( model_params, include_bias_terms=self.task_params.get( 'l2_for_bias', False)) total_loss = total_loss + 0.5 * l2_decay_factor * l2_loss if self.regularisers: for reg_fn in self.regularisers: total_loss += reg_fn(model_params) return total_loss
def domain_mapping_loss(self, env_reps, env_batches, env_ids, env_aligned_pairs_idx=None): """Compute Linear Transformation Constraint loss. Args: env_reps: list; List of different envs representations. env_batches: list; List of different envs batches. env_ids: list(int): List of environment codes. env_aligned_pairs_idx: dict; Environment pair --> alignment. (if None the alignment is computed). Returns: Domain mapping loss (float). """ mask_loss_diff_labels = self.task_params.get('mask_loss_diff_labels') # Get all possible environment pairs env_pairs = list(itertools.permutations(env_ids, 2)) aux_losses = [] l2s = [] for pair in env_pairs: e1, e2 = pair # We only have state_transformer for training envs. if pair not in self.state_transformers: logging.warn('Pair %s is not in the training pairs set.', str(pair)) else: e1_index = env_ids.index(e1) e2_index = env_ids.index(e2) e1_labels = env_batches[e1_index]['label'] e2_labels = env_batches[e2_index]['label'] # Get representations for env1. e1_reps = env_reps[e1_index] # Get representations for env1. e2_reps = env_reps[e2_index] # Transform representations from env1. transformed_e1 = self.state_transformers[pair](e1_reps) if env_aligned_pairs_idx is None: aligned_pairs_idx = self.align_batches( transformed_e1, e2_reps, e1_labels, e2_labels) else: aligned_pairs_idx = env_aligned_pairs_idx[pair] if mask_loss_diff_labels: # Assign zero/one weights to each example pair based on the alignment # of their labels. pair_weights = jnp.float32( e1_labels[aligned_pairs_idx[0]] == e2_labels[ aligned_pairs_idx[1]]) else: pair_weights = jnp.ones_like(e1_labels, dtype='float32') # Compute domain mapping loss for the environment pair: # Get representations for env1. transformed_e1 = transformed_e1[aligned_pairs_idx[0]] # Get corresponding representations for env2. e2_reps = env_reps[e2_index][aligned_pairs_idx[1]] # Minimize the distance between transformed reps from env1 and reps # from env2. aux_losses.append( jnp.mean( jnp.linalg.norm(transformed_e1 - e2_reps, axis=-1) * pair_weights)) # Add l2 loss for the transformer weights (to make sure it is as minimal # as possible. l2s.append( metrics.l2_regularization( self.state_transformers[pair].params, include_bias_terms=self.task_params.get( 'l2_for_bias', False))) if not aux_losses: aux_losses = [0] l2s = [0] alpha = self.task_params.get('aux_weight', .0) beta = self.task_params.get('aux_l2', .0) # Average and return the final weighted value of the loss. return alpha * jnp.mean(jnp.array(aux_losses)) + beta * jnp.mean( jnp.array(l2s))
def domain_mapping_loss(self, env_reps, env_batches, env_ids, env_aligned_pairs_idx=None): """Compute Linear Transformation Constraint loss. Args: env_reps: list; List of different envs representations. env_batches: list; List of different envs batches. env_ids: list(int): List of environment codes. env_aligned_pairs_idx: Ignored. Is only here to ensure compatibility with the method "loss_function" which is defined in the parent class. Returns: domain mapping scalar loss (averaged over all environments). """ del env_aligned_pairs_idx # Get all possible environment pairs env_pairs = list(itertools.permutations(env_ids, 2)) aux_losses = [] l2s = [] for pair in env_pairs: e1, e2 = pair # We only have state_transformer for training envs. if pair not in self.state_transformers: logging.warn('Pair %s is not in the training pairs set.', str(pair)) else: e1_index = env_ids.index(e1) e2_index = env_ids.index(e2) e1_labels = env_batches[e1_index]['label'] e2_labels = env_batches[e2_index]['label'] # Get representations for env1. e1_reps = env_reps[e1_index] # Get representations for env1. e2_reps = env_reps[e2_index] # Transform representations from env1. transformed_e1 = self.state_transformers[pair](e1_reps) ot_cost = self.ot_loss(transformed_e1, e2_reps, e1_labels, e2_labels) aux_losses.append(ot_cost) # Add l2 loss for the transformer weights (to make sure it is as minimal # as possible. l2s.append( metrics.l2_regularization( self.state_transformers[pair].params, include_bias_terms=self.task_params.get( 'l2_for_bias', False))) if not aux_losses: aux_losses = [0] l2s = [0] alpha = self.task_params.get('aux_weight', .0) beta = self.task_params.get('aux_l2', .0) # Average and return the final weighted value of the loss. return alpha * jnp.mean(jnp.array(aux_losses)) + beta * jnp.mean( jnp.array(l2s))