Example #1
0
    def __call__(self, m_out, targets, reduce=True):
        """
        Computes 1-vs-all binary cross entropy loss for multiclass
        classification.
        """
        # Converts targets to one-hot representation. Dim: [batch, n_classes]
        one_hot_targets = (
            FloatTensor(targets.size(0), m_out.size(1))
            .zero_()
            .scatter_(1, targets.unsqueeze(1).data, 1)
        )

        """
        `F.binary_cross_entropy` or `torch.nn.BCELoss.` requires the
        output of the previous function be already a FloatTensor.
        """
        # This weighting applies uniform class weights.
        # examples_per_class = one_hot_target.sum(0).clamp(min=1)
        # total_positive = examples_per_class.sum()
        # weights = total_positive.unsqueeze(0) / examples_per_class

        loss = F.binary_cross_entropy_with_logits(
            precision.maybe_float(m_out), one_hot_targets, reduction="none"
        )

        if self.config.reweight_negative:
            # This makes sure we have same weights for all negative classes and
            # single positive class. Weight is 1 for the correct class and
            # 1 / (n - 1) for other ones.
            weights = one_hot_targets + (1.0 - one_hot_targets) / max(
                1, one_hot_targets.size(1) - 1.0
            )
            loss = loss * weights

        return loss.sum(1).mean() if reduce else loss.sum(1)
Example #2
0
    def train_batch(cls, model, batch, state=None):
        # This is a class method so that it works when model is a DistributedModel
        # wrapper. Otherwise the forward call here skips the DDP forward call.

        # Forward pass through the network.
        model_inputs = model.arrange_model_inputs(batch)
        model_context = model.arrange_model_context(batch)
        targets = model.arrange_targets(batch)
        model_outputs = model(*model_inputs)

        # Add stage to context.
        if state:
            if model_context is None:
                model_context = {"stage": state.stage}
            else:
                model_context["stage"] = state.stage

        # Compute loss and predictions.
        loss = maybe_float(
            model.get_loss(model_outputs, targets, model_context))
        predictions, scores = model.get_pred(model_outputs,
                                             context=model_context)

        # Pack results and return them.
        metric_data = (predictions, targets, scores, loss, model_inputs)
        return loss, metric_data
Example #3
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        model = state.model
        self.zero_grads(state)
        for i, (batch_id, (inputs, targets, context)) in enumerate(samples):
            if cuda.DISTRIBUTED_WORLD_SIZE > 1:
                # Whenever *samples* contains more than one mini-batch, we
                # want to accumulate gradients locally and only call
                # all-reduce in the last backwards pass.
                if i < sample_size - 1:
                    # sync gradients in the last sample backward
                    model.accumulate_gradients(True)
                else:
                    model.accumulate_gradients(False)

            # pass context to model to use in forward call if needed
            model.contextualize(context)
            with timing.time("model.forward"):
                logits = model(*inputs)

            with timing.time("compute loss"):
                loss = precision.maybe_float(
                    model.get_loss(logits, targets, context))
                if BatchContext.IGNORE_LOSS in context:
                    loss *= 0
                elif sample_size > 1:
                    # gradients averaged per each batch and accumulated across samples.
                    # divide sample_size to let gradients averaged per example
                    loss = loss / sample_size

            self.backprop(state, loss)

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
Example #4
0
    def run_step(
        self,
        samples: List[Any],
        state: TrainingState,
        metric_reporter: MetricReporter,
        report_metric: bool,
    ):
        sample_size = len(samples)
        assert sample_size <= self.config.num_accumulated_batches

        if self('begin_batch'): return

        model = state.model
        self.zero_grads(state)
        for idx, (batch_id, (inputs, targets, context)) in enumerate(samples):
            with contextlib_ExitStack() as exit_stack:
                maybe_accumulate_gradients(exit_stack, model, idx, sample_size)
                # pass context to model to use in forward call if needed
                model.contextualize(context)
                with timing.time("model.forward"):
                    logits = model(*inputs)

                with timing.time("compute loss"):
                    loss = precision.maybe_float(
                        model.get_loss(logits, targets, context))
                    if BatchContext.IGNORE_LOSS in context:
                        loss *= 0
                    elif sample_size > 1:
                        # gradients averaged per batch and accumulated across samples.
                        # divide sample_size to let gradients averaged per example
                        loss = loss / sample_size

                self.backprop(state, loss)
                self.samples, self.state, self.loss = samples, state, loss
                if self('after_loss'): break

            if report_metric:
                with timing.time("get pred"):
                    preds, scores = model.get_pred(logits, targets, context,
                                                   state.stage, *inputs)

                with timing.time("add metrics"):
                    metric_reporter.add_batch_stats(batch_id,
                                                    preds, targets, scores,
                                                    loss.item(), inputs,
                                                    **context)

            if batch_id % self.config.num_samples_to_log_progress == 0:
                print(
                    f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage",
                    flush=True,
                )
        # update gradients after len(samples) forward & backward
        self.optimizer_step(state)
        self.sparsification_step(state)
        self('after_batch')
Example #5
0
    def train_batch(cls, model, batch, state=None):
        """
        Runs training over a batch with the R3F method, training will use R3F
        while eval and test do not.
        """

        # Forward pass through the network.
        model_inputs = model.arrange_model_inputs(batch)
        model_context = model.arrange_model_context(batch)
        targets = model.arrange_targets(batch)

        sample_size = model.get_sample_size(model_inputs=model_inputs,
                                            targets=targets)

        # get embedding
        r3f_loss_term = torch.tensor(0)
        if state and state.stage == Stage.TRAIN:
            # during training run R3F forward calls
            model_outputs, noise_model_outputs = model(*model_inputs,
                                                       use_r3f=True)

            r3f_loss_term = model.get_r3f_loss_terms(model_outputs,
                                                     noise_model_outputs,
                                                     sample_size=sample_size)
        else:
            # during eval and test don't run R3F forward
            model_outputs = model(*model_inputs, use_r3f=False)

        # Add stage to context.
        if state:
            if model_context is None:
                model_context = {"stage": state.stage, "epoch": state.epoch}
            else:
                model_context["stage"] = state.stage
                model_context["epoch"] = state.epoch

        # Compute loss and predictions.
        loss = maybe_float(
            model.get_loss(model_outputs, targets, model_context))

        # add R3F loss term
        loss = loss + r3f_loss_term.to(loss.device)

        predictions, scores = model.get_pred(model_outputs,
                                             context=model_context)

        # Pack results and return them.
        metric_data = (predictions, targets, scores, loss, model_inputs)
        return loss, metric_data
Example #6
0
    def __call__(self, logits, targets, reduce=True):
        """
        Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector.
        """

        target_labels = targets[0].float()
        """
        `F.binary_cross_entropy_with_logits` requires the
        output of the previous function be already a FloatTensor.
        """

        loss = F.binary_cross_entropy_with_logits(
            precision.maybe_float(logits), target_labels, reduction="none")

        return loss.sum(-1).mean() if reduce else loss.sum(-1)
Example #7
0
    def __call__(self, m_out, targets, reduce=True):
        """
        Computes multi-label classification loss
        see details in torch.nn.MultiLabelSoftMarginLoss
        """

        num_classes = m_out.size()[1]
        target_labels = targets[0]

        #  each label list is padded by -1 to make every
        # observation example has the same length of list of labels
        #  since -1 is out of the index range
        # add 1 to target_labels temporarily
        tmp_target_labels = target_labels + 1

        #  the idea is similar to one_hot_targets
        #  the following encoding supports multi-label task
        #  need to delete the first-column endoing since
        #  it's for the padded label -1
        n_hot_targets = (
            FloatTensor(target_labels.size(0), num_classes + 1)
            .zero_()
            .scatter_(1, tmp_target_labels, 1)
        )[:, 1:]

        """
        `F.multilabel_soft_margin_loss` or `torch.nn.MultiLabelSoftMarginLoss.`
        requires the
        output of the previous function be already a FloatTensor.
        """

        #  default: equal weight for each class
        #  the losses are averaged over observations for each mini-batch

        loss = F.multilabel_soft_margin_loss(
            precision.maybe_float(m_out), n_hot_targets, reduction="mean"
        )

        return loss
Example #8
0
 def get_loss(self, logit, target, context):
     return maybe_float(self.output_layer.get_loss(logit, target, context))