예제 #1
0
    def metrics_fn(self, env_logits, env_batches, env_ids, params):
        """Calculates metrics for the classification task.

    Args:
      env_logits: list(dict); List of logits for examples from different
        environment (env_logits[0] is the logits for examples from env 0 which
        is a float array of shape `[batch, length, num_classes]`).
      env_batches: list(dict); List of batches of examples from different
        environment (env_batches[0] is a batch dict for examples from env 0 that
        has 'label' and optionally 'weights'.).
     env_ids: list(int); List of environment codes.
     params: pytree; parameters of the model.

    Returns:
      a dict of metrics.
    """
        metrics_dic = {}
        envs_metrics_dic = {}
        # Add all the keys to envs_metrics_dic, each key will point to a list of
        # values from the correspondig metric for each environment.

        # Task related metrics
        for key in self._METRICS:
            envs_metrics_dic[key] = []

        # Dataset related metrics (e.g., perturbation factors)
        for key in env_batches[0]:
            if 'factor' in key:
                envs_metrics_dic[key] = []

        for i in range(len(env_logits)):
            logits = env_logits[i]
            batch = env_batches[i]
            env_name = self.dataset.get_full_env_name(
                self.dataset.id2env(env_ids[i]))
            env_metric_dic = super().metrics_fn(logits, batch)
            for key in env_metric_dic:
                metrics_dic[env_name + '/' + key] = env_metric_dic[key]
                envs_metrics_dic[key].append(env_metric_dic[key])
        # Add overall metric values over all environments,
        for key in self._METRICS:
            metrics_dic[key] = (jnp.sum(
                jnp.array(jnp.array(envs_metrics_dic[key])[:, 0])),
                                jnp.sum(
                                    jnp.array(
                                        jnp.array(envs_metrics_dic[key])[:,
                                                                         1])))
        if params:
            metrics_dic['l2'] = metrics.l2_regularization(
                params,
                include_bias_terms=self.task_params.get('l2_for_bias', False))

        return metrics_dic
예제 #2
0
    def loss_function(self,
                      env_logits,
                      env_batches,
                      model_params=None,
                      step=0):
        """Returns loss with an L2 penalty on the weights.

    Args:
      env_logits: list(dict); List of logits for examples from different
        environment (env_logits[0] is the logits for examples from env 0).
      env_batches: list(dict); List of batches of examples from different
        environment (env_batches[0] is a batch dict for examples from env 0).
      model_params: dict; Parameters of the model (used to commpute l2).
      step: int; Global training step.

    Returns:
      Total loss.
    """
        env_losses = self.get_env_losses(env_logits, env_batches)
        total_loss = self.aggregate_envs_losses(env_losses)
        p_weight = self.penalty_weight(step)
        total_loss += p_weight * self.environments_penalties(
            env_logits, env_batches)

        if model_params:
            l2_decay_rate = self.get_l2_rate(step)
            if l2_decay_rate is not None:
                l2_loss = metrics.l2_regularization(
                    model_params,
                    include_bias_terms=self.task_params.get(
                        'l2_for_bias', False))
                total_loss = total_loss + 0.5 * l2_decay_rate * l2_loss

            if self.regularisers:
                for reg_fn in self.regularisers:
                    reg_value = reg_fn(model_params)
                    total_loss += reg_value

        # If p_weights > 1:
        # Rescale the entire loss to keep gradients in a reasonable range.
        total_loss /= jnp.maximum(p_weight, 1)

        return total_loss
예제 #3
0
    def loss_function(self, logits, batch, model_params=None, step=None):
        """Return cross entropy loss with an L2 penalty on the weights."""
        weights = batch.get('weights')

        if self.dataset.meta_data['num_classes'] == 1:
            # If this is a binary classification task, make sure the shape of labels
            # is (bs, 1) and is the same as the shape of logits.
            targets = jnp.reshape(batch['label'], logits.shape)
        elif batch['label'].shape[-1] == self.dataset.meta_data['num_classes']:
            # If the labels are already the shape of (bs, num_classes) use them as is.
            targets = batch['label']
        else:
            # Otherwise convert the labels to onehot labels.
            targets = common_utils.onehot(batch['label'], logits.shape[-1])

        loss_value, loss_normalizer = self.main_loss_fn(
            logits,
            targets,
            weights,
            label_smoothing=self.task_params.get('label_smoothing'))

        total_loss = loss_value / loss_normalizer

        if model_params:
            l2_decay_factor = self.get_l2_rate(step)
            if l2_decay_factor is not None:
                l2_loss = metrics.l2_regularization(
                    model_params,
                    include_bias_terms=self.task_params.get(
                        'l2_for_bias', False))
                total_loss = total_loss + 0.5 * l2_decay_factor * l2_loss

            if self.regularisers:
                for reg_fn in self.regularisers:
                    total_loss += reg_fn(model_params)

        return total_loss
예제 #4
0
    def domain_mapping_loss(self,
                            env_reps,
                            env_batches,
                            env_ids,
                            env_aligned_pairs_idx=None):
        """Compute Linear Transformation Constraint loss.

    Args:
      env_reps: list; List of different envs  representations.
      env_batches: list; List of different envs batches.
      env_ids: list(int): List of environment codes.
      env_aligned_pairs_idx: dict; Environment pair --> alignment. (if None the
        alignment is computed).

    Returns:
      Domain mapping loss (float).
    """
        mask_loss_diff_labels = self.task_params.get('mask_loss_diff_labels')

        # Get all possible environment pairs
        env_pairs = list(itertools.permutations(env_ids, 2))
        aux_losses = []
        l2s = []
        for pair in env_pairs:
            e1, e2 = pair
            # We only have state_transformer for training envs.
            if pair not in self.state_transformers:
                logging.warn('Pair %s is not in the training pairs set.',
                             str(pair))
            else:
                e1_index = env_ids.index(e1)
                e2_index = env_ids.index(e2)
                e1_labels = env_batches[e1_index]['label']
                e2_labels = env_batches[e2_index]['label']
                # Get representations for env1.
                e1_reps = env_reps[e1_index]
                # Get representations for env1.
                e2_reps = env_reps[e2_index]
                # Transform representations from env1.
                transformed_e1 = self.state_transformers[pair](e1_reps)

                if env_aligned_pairs_idx is None:
                    aligned_pairs_idx = self.align_batches(
                        transformed_e1, e2_reps, e1_labels, e2_labels)
                else:
                    aligned_pairs_idx = env_aligned_pairs_idx[pair]

                if mask_loss_diff_labels:
                    # Assign zero/one weights to each example pair based on the alignment
                    # of their labels.
                    pair_weights = jnp.float32(
                        e1_labels[aligned_pairs_idx[0]] == e2_labels[
                            aligned_pairs_idx[1]])
                else:
                    pair_weights = jnp.ones_like(e1_labels, dtype='float32')

                # Compute domain mapping loss for the environment pair:
                # Get representations for env1.
                transformed_e1 = transformed_e1[aligned_pairs_idx[0]]
                # Get corresponding representations for env2.
                e2_reps = env_reps[e2_index][aligned_pairs_idx[1]]

                # Minimize the distance between transformed reps from env1 and reps
                # from env2.
                aux_losses.append(
                    jnp.mean(
                        jnp.linalg.norm(transformed_e1 - e2_reps, axis=-1) *
                        pair_weights))

                # Add l2 loss for the transformer weights (to make sure it is as minimal
                # as possible.
                l2s.append(
                    metrics.l2_regularization(
                        self.state_transformers[pair].params,
                        include_bias_terms=self.task_params.get(
                            'l2_for_bias', False)))

        if not aux_losses:
            aux_losses = [0]
            l2s = [0]

        alpha = self.task_params.get('aux_weight', .0)
        beta = self.task_params.get('aux_l2', .0)

        # Average and return the final weighted value of the loss.
        return alpha * jnp.mean(jnp.array(aux_losses)) + beta * jnp.mean(
            jnp.array(l2s))
예제 #5
0
    def domain_mapping_loss(self,
                            env_reps,
                            env_batches,
                            env_ids,
                            env_aligned_pairs_idx=None):
        """Compute Linear Transformation Constraint loss.

    Args:
      env_reps: list; List of different envs  representations.
      env_batches: list; List of different envs batches.
      env_ids: list(int): List of environment codes.
      env_aligned_pairs_idx: Ignored. Is only here to ensure compatibility with
        the method "loss_function" which is defined in the parent class.

    Returns:
      domain mapping scalar loss (averaged over all environments).
    """
        del env_aligned_pairs_idx

        # Get all possible environment pairs
        env_pairs = list(itertools.permutations(env_ids, 2))
        aux_losses = []
        l2s = []
        for pair in env_pairs:
            e1, e2 = pair
            # We only have state_transformer for training envs.
            if pair not in self.state_transformers:
                logging.warn('Pair %s is not in the training pairs set.',
                             str(pair))
            else:
                e1_index = env_ids.index(e1)
                e2_index = env_ids.index(e2)
                e1_labels = env_batches[e1_index]['label']
                e2_labels = env_batches[e2_index]['label']
                # Get representations for env1.
                e1_reps = env_reps[e1_index]
                # Get representations for env1.
                e2_reps = env_reps[e2_index]
                # Transform representations from env1.
                transformed_e1 = self.state_transformers[pair](e1_reps)

                ot_cost = self.ot_loss(transformed_e1, e2_reps, e1_labels,
                                       e2_labels)

                aux_losses.append(ot_cost)

                # Add l2 loss for the transformer weights (to make sure it is as minimal
                # as possible.
                l2s.append(
                    metrics.l2_regularization(
                        self.state_transformers[pair].params,
                        include_bias_terms=self.task_params.get(
                            'l2_for_bias', False)))

        if not aux_losses:
            aux_losses = [0]
            l2s = [0]

        alpha = self.task_params.get('aux_weight', .0)
        beta = self.task_params.get('aux_l2', .0)

        # Average and return the final weighted value of the loss.
        return alpha * jnp.mean(jnp.array(aux_losses)) + beta * jnp.mean(
            jnp.array(l2s))