def __call__(self, m_out, targets, reduce=True): """ Computes 1-vs-all binary cross entropy loss for multiclass classification. """ # Converts targets to one-hot representation. Dim: [batch, n_classes] one_hot_targets = ( FloatTensor(targets.size(0), m_out.size(1)) .zero_() .scatter_(1, targets.unsqueeze(1).data, 1) ) """ `F.binary_cross_entropy` or `torch.nn.BCELoss.` requires the output of the previous function be already a FloatTensor. """ # This weighting applies uniform class weights. # examples_per_class = one_hot_target.sum(0).clamp(min=1) # total_positive = examples_per_class.sum() # weights = total_positive.unsqueeze(0) / examples_per_class loss = F.binary_cross_entropy_with_logits( precision.maybe_float(m_out), one_hot_targets, reduction="none" ) if self.config.reweight_negative: # This makes sure we have same weights for all negative classes and # single positive class. Weight is 1 for the correct class and # 1 / (n - 1) for other ones. weights = one_hot_targets + (1.0 - one_hot_targets) / max( 1, one_hot_targets.size(1) - 1.0 ) loss = loss * weights return loss.sum(1).mean() if reduce else loss.sum(1)
def train_batch(cls, model, batch, state=None): # This is a class method so that it works when model is a DistributedModel # wrapper. Otherwise the forward call here skips the DDP forward call. # Forward pass through the network. model_inputs = model.arrange_model_inputs(batch) model_context = model.arrange_model_context(batch) targets = model.arrange_targets(batch) model_outputs = model(*model_inputs) # Add stage to context. if state: if model_context is None: model_context = {"stage": state.stage} else: model_context["stage"] = state.stage # Compute loss and predictions. loss = maybe_float( model.get_loss(model_outputs, targets, model_context)) predictions, scores = model.get_pred(model_outputs, context=model_context) # Pack results and return them. metric_data = (predictions, targets, scores, loss, model_inputs) return loss, metric_data
def run_step( self, samples: List[Any], state: TrainingState, metric_reporter: MetricReporter, report_metric: bool, ): sample_size = len(samples) assert sample_size <= self.config.num_accumulated_batches model = state.model self.zero_grads(state) for i, (batch_id, (inputs, targets, context)) in enumerate(samples): if cuda.DISTRIBUTED_WORLD_SIZE > 1: # Whenever *samples* contains more than one mini-batch, we # want to accumulate gradients locally and only call # all-reduce in the last backwards pass. if i < sample_size - 1: # sync gradients in the last sample backward model.accumulate_gradients(True) else: model.accumulate_gradients(False) # pass context to model to use in forward call if needed model.contextualize(context) with timing.time("model.forward"): logits = model(*inputs) with timing.time("compute loss"): loss = precision.maybe_float( model.get_loss(logits, targets, context)) if BatchContext.IGNORE_LOSS in context: loss *= 0 elif sample_size > 1: # gradients averaged per each batch and accumulated across samples. # divide sample_size to let gradients averaged per example loss = loss / sample_size self.backprop(state, loss) if report_metric: with timing.time("get pred"): preds, scores = model.get_pred(logits, targets, context, state.stage, *inputs) with timing.time("add metrics"): metric_reporter.add_batch_stats(batch_id, preds, targets, scores, loss.item(), inputs, **context) if batch_id % self.config.num_samples_to_log_progress == 0: print( f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage", flush=True, ) # update gradients after len(samples) forward & backward self.optimizer_step(state)
def run_step( self, samples: List[Any], state: TrainingState, metric_reporter: MetricReporter, report_metric: bool, ): sample_size = len(samples) assert sample_size <= self.config.num_accumulated_batches if self('begin_batch'): return model = state.model self.zero_grads(state) for idx, (batch_id, (inputs, targets, context)) in enumerate(samples): with contextlib_ExitStack() as exit_stack: maybe_accumulate_gradients(exit_stack, model, idx, sample_size) # pass context to model to use in forward call if needed model.contextualize(context) with timing.time("model.forward"): logits = model(*inputs) with timing.time("compute loss"): loss = precision.maybe_float( model.get_loss(logits, targets, context)) if BatchContext.IGNORE_LOSS in context: loss *= 0 elif sample_size > 1: # gradients averaged per batch and accumulated across samples. # divide sample_size to let gradients averaged per example loss = loss / sample_size self.backprop(state, loss) self.samples, self.state, self.loss = samples, state, loss if self('after_loss'): break if report_metric: with timing.time("get pred"): preds, scores = model.get_pred(logits, targets, context, state.stage, *inputs) with timing.time("add metrics"): metric_reporter.add_batch_stats(batch_id, preds, targets, scores, loss.item(), inputs, **context) if batch_id % self.config.num_samples_to_log_progress == 0: print( f"Running batch {batch_id} for epoch {state.epoch} in {state.stage} stage", flush=True, ) # update gradients after len(samples) forward & backward self.optimizer_step(state) self.sparsification_step(state) self('after_batch')
def train_batch(cls, model, batch, state=None): """ Runs training over a batch with the R3F method, training will use R3F while eval and test do not. """ # Forward pass through the network. model_inputs = model.arrange_model_inputs(batch) model_context = model.arrange_model_context(batch) targets = model.arrange_targets(batch) sample_size = model.get_sample_size(model_inputs=model_inputs, targets=targets) # get embedding r3f_loss_term = torch.tensor(0) if state and state.stage == Stage.TRAIN: # during training run R3F forward calls model_outputs, noise_model_outputs = model(*model_inputs, use_r3f=True) r3f_loss_term = model.get_r3f_loss_terms(model_outputs, noise_model_outputs, sample_size=sample_size) else: # during eval and test don't run R3F forward model_outputs = model(*model_inputs, use_r3f=False) # Add stage to context. if state: if model_context is None: model_context = {"stage": state.stage, "epoch": state.epoch} else: model_context["stage"] = state.stage model_context["epoch"] = state.epoch # Compute loss and predictions. loss = maybe_float( model.get_loss(model_outputs, targets, model_context)) # add R3F loss term loss = loss + r3f_loss_term.to(loss.device) predictions, scores = model.get_pred(model_outputs, context=model_context) # Pack results and return them. metric_data = (predictions, targets, scores, loss, model_inputs) return loss, metric_data
def __call__(self, logits, targets, reduce=True): """ Computes 1-vs-all binary cross entropy loss for multiclass classification. However, unlike BinaryCrossEntropyLoss, we require targets to be a one-hot vector. """ target_labels = targets[0].float() """ `F.binary_cross_entropy_with_logits` requires the output of the previous function be already a FloatTensor. """ loss = F.binary_cross_entropy_with_logits( precision.maybe_float(logits), target_labels, reduction="none") return loss.sum(-1).mean() if reduce else loss.sum(-1)
def __call__(self, m_out, targets, reduce=True): """ Computes multi-label classification loss see details in torch.nn.MultiLabelSoftMarginLoss """ num_classes = m_out.size()[1] target_labels = targets[0] # each label list is padded by -1 to make every # observation example has the same length of list of labels # since -1 is out of the index range # add 1 to target_labels temporarily tmp_target_labels = target_labels + 1 # the idea is similar to one_hot_targets # the following encoding supports multi-label task # need to delete the first-column endoing since # it's for the padded label -1 n_hot_targets = ( FloatTensor(target_labels.size(0), num_classes + 1) .zero_() .scatter_(1, tmp_target_labels, 1) )[:, 1:] """ `F.multilabel_soft_margin_loss` or `torch.nn.MultiLabelSoftMarginLoss.` requires the output of the previous function be already a FloatTensor. """ # default: equal weight for each class # the losses are averaged over observations for each mini-batch loss = F.multilabel_soft_margin_loss( precision.maybe_float(m_out), n_hot_targets, reduction="mean" ) return loss
def get_loss(self, logit, target, context): return maybe_float(self.output_layer.get_loss(logit, target, context))