Пример #1
0
    def finalize_metrics(self, ks=(1, 5)):
        """
        Calculate and log the final ensembled metrics.
        ks (tuple): list of top-k values for topk_accuracies. For example,
            ks = (1, 5) correspods to top-1 and top-5 accuracy.
        """
        if not all(self.clip_count == self.num_clips):
            logger.warning("clip count {} ~= num clips {}".format(
                self.clip_count, self.num_clips))
            logger.warning(self.clip_count)

        verb_topks = metrics.topk_accuracies(self.verb_video_preds,
                                             self.verb_video_labels, ks)
        noun_topks = metrics.topk_accuracies(self.noun_video_preds,
                                             self.noun_video_labels, ks)

        assert len({len(ks), len(verb_topks)}) == 1
        assert len({len(ks), len(noun_topks)}) == 1
        stats = {"split": "test_final"}
        for k, verb_topk in zip(ks, verb_topks):
            stats["verb_top{}_acc".format(k)] = "{:.{prec}f}".format(verb_topk,
                                                                     prec=2)
        for k, noun_topk in zip(ks, noun_topks):
            stats["noun_top{}_acc".format(k)] = "{:.{prec}f}".format(noun_topk,
                                                                     prec=2)
        logging.log_json_stats(stats)
        return (self.verb_video_preds.numpy().copy(), self.noun_video_preds.numpy().copy()), \
               (self.verb_video_labels.numpy().copy(), self.noun_video_labels.numpy().copy()), \
               self.metadata.copy()
Пример #2
0
def train_epoch(train_loader, model, optimizer, train_meter, cur_epoch, cfg):
    """
    Perform the video training for one epoch.
    Args:
        train_loader (loader): video training loader.
        model (model): the video model to train.
        optimizer (optim): the optimizer to perform optimization on the model's
            parameters.
        train_meter (TrainMeter): training meters to log the training performance.
        cur_epoch (int): current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Enable train mode.
    model.train()
    if cfg.BN.FREEZE:
        model.freeze_fn('bn_statistics')

    train_meter.iter_tic()
    data_size = len(train_loader)

    for cur_iter, (inputs, labels, _, meta) in enumerate(train_loader):
        # Transfer the data to the current GPU device.
        if isinstance(inputs, (list, )):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        if isinstance(labels, (dict, )):
            labels = {k: v.cuda() for k, v in labels.items()}
        else:
            labels = labels.cuda()
        # for key, val in meta.items():
        #     if isinstance(val, (list,)):
        #         for i in range(len(val)):
        #             val[i] = val[i].cuda(non_blocking=True)
        #     else:
        #         meta[key] = val.cuda(non_blocking=True)

        # Update the learning rate.
        lr = optim.get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg)
        optim.set_lr(optimizer, lr)

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])

        else:
            # Perform the forward pass.
            preds = model(inputs)

        if isinstance(labels, (dict, )):
            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Compute the loss.
            loss_verb = loss_fun(preds[0], labels['verb'])
            loss_noun = loss_fun(preds[1], labels['noun'])
            loss = 0.5 * (loss_verb + loss_noun)

            # check Nan Loss.
            misc.check_nan_losses(loss)
        else:
            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Compute the loss.
            loss = loss_fun(preds, labels)

            # check Nan Loss.
            misc.check_nan_losses(loss)

        # Perform the backward pass.
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters.
        optimizer.step()

        if cfg.DETECTION.ENABLE:
            if cfg.NUM_GPUS > 1:
                loss = du.all_reduce([loss])[0]
            loss = loss.item()

            train_meter.iter_toc()
            # Update and log stats.
            train_meter.update_stats(None, None, None, loss, lr)
        else:
            if isinstance(labels, (dict, )):
                # Compute the verb accuracies.
                verb_top1_acc, verb_top5_acc = metrics.topk_accuracies(
                    preds[0], labels['verb'], (1, 5))

                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss_verb, verb_top1_acc, verb_top5_acc = du.all_reduce(
                        [loss_verb, verb_top1_acc, verb_top5_acc])

                # Copy the stats from GPU to CPU (sync point).
                loss_verb, verb_top1_acc, verb_top5_acc = (
                    loss_verb.item(),
                    verb_top1_acc.item(),
                    verb_top5_acc.item(),
                )

                # Compute the noun accuracies.
                noun_top1_acc, noun_top5_acc = metrics.topk_accuracies(
                    preds[1], labels['noun'], (1, 5))

                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss_noun, noun_top1_acc, noun_top5_acc = du.all_reduce(
                        [loss_noun, noun_top1_acc, noun_top5_acc])

                # Copy the stats from GPU to CPU (sync point).
                loss_noun, noun_top1_acc, noun_top5_acc = (
                    loss_noun.item(),
                    noun_top1_acc.item(),
                    noun_top5_acc.item(),
                )

                # Compute the action accuracies.
                action_top1_acc, action_top5_acc = metrics.multitask_topk_accuracies(
                    (preds[0], preds[1]), (labels['verb'], labels['noun']),
                    (1, 5))
                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss, action_top1_acc, action_top5_acc = du.all_reduce(
                        [loss, action_top1_acc, action_top5_acc])

                # Copy the stats from GPU to CPU (sync point).
                loss, action_top1_acc, action_top5_acc = (
                    loss.item(),
                    action_top1_acc.item(),
                    action_top5_acc.item(),
                )

                train_meter.iter_toc()
                # Update and log stats.
                train_meter.update_stats(
                    (verb_top1_acc, noun_top1_acc, action_top1_acc),
                    (verb_top5_acc, noun_top5_acc, action_top5_acc),
                    (loss_verb, loss_noun, loss), lr,
                    inputs[0].size(0) * cfg.NUM_GPUS)
            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))
                top1_err, top5_err = [(1.0 - x / preds.size(0)) * 100.0
                                      for x in num_topks_correct]

                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss, top1_err, top5_err = du.all_reduce(
                        [loss, top1_err, top5_err])

                # Copy the stats from GPU to CPU (sync point).
                loss, top1_err, top5_err = (
                    loss.item(),
                    top1_err.item(),
                    top5_err.item(),
                )

                train_meter.iter_toc()
                # Update and log stats.
                train_meter.update_stats(top1_err, top5_err, loss, lr,
                                         inputs[0].size(0) * cfg.NUM_GPUS)
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()
    # Log epoch stats.
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()
Пример #3
0
def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg):
    """
    Evaluate the model on the val set.
    Args:
        val_loader (loader): data loader to provide validation data.
        model (model): model to evaluate the performance.
        val_meter (ValMeter): meter instance to record and calculate the metrics.
        cur_epoch (int): number of the current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
    """
    # Evaluation mode enabled. The running stats would not be updated.
    model.eval()
    val_meter.iter_tic()

    for cur_iter, (inputs, labels, _, meta) in enumerate(val_loader):
        # Transferthe data to the current GPU device.
        if isinstance(inputs, (list, )):
            for i in range(len(inputs)):
                inputs[i] = inputs[i].cuda(non_blocking=True)
        else:
            inputs = inputs.cuda(non_blocking=True)
        if isinstance(labels, (dict, )):
            labels = {k: v.cuda() for k, v in labels.items()}
        else:
            labels = labels.cuda()
        for key, val in meta.items():
            if isinstance(val, (list, )):
                for i in range(len(val)):
                    val[i] = val[i].cuda(non_blocking=True)
            else:
                meta[key] = val.cuda(non_blocking=True)

        if cfg.DETECTION.ENABLE:
            # Compute the predictions.
            preds = model(inputs, meta["boxes"])

            preds = preds.cpu()
            ori_boxes = meta["ori_boxes"].cpu()
            metadata = meta["metadata"].cpu()

            if cfg.NUM_GPUS > 1:
                preds = torch.cat(du.all_gather_unaligned(preds), dim=0)
                ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes),
                                      dim=0)
                metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0)

            val_meter.iter_toc()
            # Update and log stats.
            val_meter.update_stats(preds.cpu(), ori_boxes.cpu(),
                                   metadata.cpu())
        else:
            preds = model(inputs)
            if isinstance(labels, (dict, )):
                # Compute the verb accuracies.
                verb_top1_acc, verb_top5_acc = metrics.topk_accuracies(
                    preds[0], labels['verb'], (1, 5))

                # Combine the errors across the GPUs.
                if cfg.NUM_GPUS > 1:
                    verb_top1_acc, verb_top5_acc = du.all_reduce(
                        [verb_top1_acc, verb_top5_acc])

                # Copy the errors from GPU to CPU (sync point).
                verb_top1_acc, verb_top5_acc = verb_top1_acc.item(
                ), verb_top5_acc.item()

                # Compute the noun accuracies.
                noun_top1_acc, noun_top5_acc = metrics.topk_accuracies(
                    preds[1], labels['noun'], (1, 5))

                # Combine the errors across the GPUs.
                if cfg.NUM_GPUS > 1:
                    noun_top1_acc, noun_top5_acc = du.all_reduce(
                        [noun_top1_acc, noun_top5_acc])

                # Copy the errors from GPU to CPU (sync point).
                noun_top1_acc, noun_top5_acc = noun_top1_acc.item(
                ), noun_top5_acc.item()

                # Compute the action accuracies.
                action_top1_acc, action_top5_acc = metrics.multitask_topk_accuracies(
                    (preds[0], preds[1]), (labels['verb'], labels['noun']),
                    (1, 5))
                # Combine the errors across the GPUs.
                if cfg.NUM_GPUS > 1:
                    action_top1_acc, action_top5_acc = du.all_reduce(
                        [action_top1_acc, action_top5_acc])

                # Copy the errors from GPU to CPU (sync point).
                action_top1_acc, action_top5_acc = action_top1_acc.item(
                ), action_top5_acc.item()

                val_meter.iter_toc()
                # Update and log stats.
                val_meter.update_stats(
                    (verb_top1_acc, noun_top1_acc, action_top1_acc),
                    (verb_top5_acc, noun_top5_acc, action_top5_acc),
                    inputs[0].size(0) * cfg.NUM_GPUS)
            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))

                # Combine the errors across the GPUs.
                top1_err, top5_err = [(1.0 - x / preds.size(0)) * 100.0
                                      for x in num_topks_correct]
                if cfg.NUM_GPUS > 1:
                    top1_err, top5_err = du.all_reduce([top1_err, top5_err])

                # Copy the errors from GPU to CPU (sync point).
                top1_err, top5_err = top1_err.item(), top5_err.item()

                val_meter.iter_toc()
                # Update and log stats.
                val_meter.update_stats(top1_err, top5_err,
                                       inputs[0].size(0) * cfg.NUM_GPUS)
        val_meter.log_iter_stats(cur_epoch, cur_iter)
        val_meter.iter_tic()
    # Log epoch stats.
    is_best_epoch = val_meter.log_epoch_stats(cur_epoch)
    val_meter.reset()
    return is_best_epoch
def eval_epoch(val_loader,
               model,
               val_meter,
               cur_epoch,
               cfg,
               writer=None,
               wandb_log=False):
    """
    Evaluate the model on the val set.
    Args:
        val_loader (loader): data loader to provide validation data.
        model (model): model to evaluate the performance.
        val_meter (ValMeter): meter instance to record and calculate the metrics.
        cur_epoch (int): number of the current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter, optional): TensorboardWriter object
            to writer Tensorboard log.
    """
    # Evaluation mode enabled. The running stats would not be updated.
    model.eval()
    val_meter.iter_tic()

    for cur_iter, (inputs, labels, _, meta) in enumerate(val_loader):
        if cfg.NUM_GPUS:
            # Transferthe data to the current GPU device.
            if isinstance(inputs, (list, )):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
            if isinstance(labels, (dict, )):
                labels = {k: v.cuda() for k, v in labels.items()}
            else:
                labels = labels.cuda()
            for key, val in meta.items():
                if isinstance(val, (list, )):
                    for i in range(len(val)):
                        val[i] = val[i].cuda(non_blocking=True)
                else:
                    meta[key] = val.cuda(non_blocking=True)
        val_meter.data_toc()

        preds = model(inputs)

        if isinstance(labels, (dict, )):
            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Compute the loss.
            loss_verb = loss_fun(preds[0], labels['verb'])
            loss_noun = loss_fun(preds[1], labels['noun'])
            loss = 0.5 * (loss_verb + loss_noun)

            # Compute the verb accuracies.
            verb_top1_acc, verb_top5_acc = metrics.topk_accuracies(
                preds[0], labels['verb'], (1, 5))

            # Combine the errors across the GPUs.
            if cfg.NUM_GPUS > 1:
                loss_verb, verb_top1_acc, verb_top5_acc = du.all_reduce(
                    [loss_verb, verb_top1_acc, verb_top5_acc])

            # Copy the errors from GPU to CPU (sync point).
            loss_verb, verb_top1_acc, verb_top5_acc = (
                loss_verb.item(),
                verb_top1_acc.item(),
                verb_top5_acc.item(),
            )

            # Compute the noun accuracies.
            noun_top1_acc, noun_top5_acc = metrics.topk_accuracies(
                preds[1], labels['noun'], (1, 5))

            # Combine the errors across the GPUs.
            if cfg.NUM_GPUS > 1:
                loss_noun, noun_top1_acc, noun_top5_acc = du.all_reduce(
                    [loss_noun, noun_top1_acc, noun_top5_acc])

            # Copy the errors from GPU to CPU (sync point).
            loss_noun, noun_top1_acc, noun_top5_acc = (
                loss_noun.item(),
                noun_top1_acc.item(),
                noun_top5_acc.item(),
            )

            # Compute the action accuracies.
            action_top1_acc, action_top5_acc = metrics.multitask_topk_accuracies(
                (preds[0], preds[1]), (labels['verb'], labels['noun']), (1, 5))
            # Combine the errors across the GPUs.
            if cfg.NUM_GPUS > 1:
                loss, action_top1_acc, action_top5_acc = du.all_reduce(
                    [loss, action_top1_acc, action_top5_acc])

            # Copy the errors from GPU to CPU (sync point).
            loss, action_top1_acc, action_top5_acc = (
                loss.item(),
                action_top1_acc.item(),
                action_top5_acc.item(),
            )

            val_meter.iter_toc()
            # Update and log stats.
            val_meter.update_stats(
                (verb_top1_acc, noun_top1_acc, action_top1_acc),
                (verb_top5_acc, noun_top5_acc, action_top5_acc),
                inputs[0].size(0) * max(
                    cfg.NUM_GPUS, 1
                ),  # If running  on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU.
            )
            # write to tensorboard format if available.
            if writer is not None and not wandb_log:
                writer.add_scalars(
                    {
                        "Val/loss": loss,
                        "Val/Top1_acc": action_top1_acc,
                        "Val/Top5_acc": action_top5_acc,
                        "Val/verb/loss": loss_verb,
                        "Val/verb/Top1_acc": verb_top1_acc,
                        "Val/verb/Top5_acc": verb_top5_acc,
                        "Val/noun/loss": loss_noun,
                        "Val/noun/Top1_acc": noun_top1_acc,
                        "Val/noun/Top5_acc": noun_top5_acc,
                    },
                    global_step=len(val_loader) * cur_epoch + cur_iter,
                )

            if wandb_log:
                wandb.log(
                    {
                        "Val/loss": loss,
                        "Val/Top1_acc": action_top1_acc,
                        "Val/Top5_acc": action_top5_acc,
                        "Val/verb/loss": loss_verb,
                        "Val/verb/Top1_acc": verb_top1_acc,
                        "Val/verb/Top5_acc": verb_top5_acc,
                        "Val/noun/loss": loss_noun,
                        "Val/noun/Top1_acc": noun_top1_acc,
                        "Val/noun/Top5_acc": noun_top5_acc,
                        "val_step": len(val_loader) * cur_epoch + cur_iter,
                    }, )

            val_meter.update_predictions((preds[0], preds[1]),
                                         (labels['verb'], labels['noun']))

        else:
            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Compute the loss.
            loss = loss_fun(preds, labels)

            if cfg.DATA.MULTI_LABEL:
                if cfg.NUM_GPUS > 1:
                    preds, labels = du.all_gather([preds, labels])

            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))

                # Combine the errors across the GPUs.
                top1_err, top5_err = [(1.0 - x / preds.size(0)) * 100.0
                                      for x in num_topks_correct]
                if cfg.NUM_GPUS > 1:
                    loss, top1_err, top5_err = du.all_reduce(
                        [loss, top1_err, top5_err])

                # Copy the errors from GPU to CPU (sync point).
                loss, top1_err, top5_err = (
                    loss.item(),
                    top1_err.item(),
                    top5_err.item(),
                )

                val_meter.iter_toc()
                # Update and log stats.
                val_meter.update_stats(
                    top1_err,
                    top5_err,
                    inputs[0].size(0) * max(
                        cfg.NUM_GPUS, 1
                    ),  # If running  on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU.
                )
                # write to tensorboard format if available.
                if writer is not None and not wandb_log:
                    writer.add_scalars(
                        {
                            "Val/loss": loss,
                            "Val/Top1_err": top1_err,
                            "Val/Top5_err": top5_err,
                        },
                        global_step=len(val_loader) * cur_epoch + cur_iter,
                    )

                if wandb_log:
                    wandb.log(
                        {
                            "Val/loss": loss,
                            "Val/Top1_err": top1_err,
                            "Val/Top5_err": top5_err,
                            "val_step": len(val_loader) * cur_epoch + cur_iter,
                        }, )

            val_meter.update_predictions(preds, labels)

        val_meter.log_iter_stats(cur_epoch, cur_iter)
        val_meter.iter_tic()

    # Log epoch stats.
    is_best_epoch, top1_dict = val_meter.log_epoch_stats(cur_epoch)
    # write to tensorboard format if available.
    if writer is not None:
        all_preds = [pred.clone().detach() for pred in val_meter.all_preds]
        all_labels = [label.clone().detach() for label in val_meter.all_labels]
        if cfg.NUM_GPUS:
            all_preds = [pred.cpu() for pred in all_preds]
            all_labels = [label.cpu() for label in all_labels]
        writer.plot_eval(preds=all_preds,
                         labels=all_labels,
                         global_step=cur_epoch)

    if writer is not None and not wandb_log:
        if "top1_acc" in top1_dict.keys():
            writer.add_scalars(
                {
                    "Val/epoch/Top1_acc": top1_dict["top1_acc"],
                    "Val/epoch/verb/Top1_acc": top1_dict["verb_top1_acc"],
                    "Val/epoch/noun/Top1_acc": top1_dict["noun_top1_acc"],
                },
                global_step=cur_epoch,
            )

        else:
            writer.add_scalars(
                {"Val/epoch/Top1_err": top1_dict["top1_err"]},
                global_step=cur_epoch,
            )

    if wandb_log:
        if "top1_acc" in top1_dict.keys():
            wandb.log(
                {
                    "Val/epoch/Top1_acc": top1_dict["top1_acc"],
                    "Val/epoch/verb/Top1_acc": top1_dict["verb_top1_acc"],
                    "Val/epoch/noun/Top1_acc": top1_dict["noun_top1_acc"],
                    "epoch": cur_epoch,
                }, )

        else:
            wandb.log({
                "Val/epoch/Top1_err": top1_dict["top1_err"],
                "epoch": cur_epoch
            })

    top1 = top1_dict["top1_acc"] if "top1_acc" in top1_dict.keys(
    ) else top1_dict["top1_err"]
    val_meter.reset()
    return is_best_epoch, top1
def train_epoch(train_loader,
                model,
                optimizer,
                train_meter,
                cur_epoch,
                cfg,
                writer=None,
                wandb_log=False):
    """
    Perform the audio training for one epoch.
    Args:
        train_loader (loader): audio training loader.
        model (model): the audio model to train.
        optimizer (optim): the optimizer to perform optimization on the model's
            parameters.
        train_meter (TrainMeter): training meters to log the training performance.
        cur_epoch (int): current epoch of training.
        cfg (CfgNode): configs. Details can be found in
            slowfast/config/defaults.py
        writer (TensorboardWriter, optional): TensorboardWriter object
            to writer Tensorboard log.
    """
    # Enable train mode.
    model.train()
    if cfg.BN.FREEZE:
        model.module.freeze_fn(
            'bn_statistics') if cfg.NUM_GPUS > 1 else model.freeze_fn(
                'bn_statistics')

    train_meter.iter_tic()
    data_size = len(train_loader)

    for cur_iter, (inputs, labels, _, meta) in enumerate(train_loader):
        # Transfer the data to the current GPU device.
        if cfg.NUM_GPUS:
            if isinstance(inputs, (list, )):
                for i in range(len(inputs)):
                    inputs[i] = inputs[i].cuda(non_blocking=True)
            else:
                inputs = inputs.cuda(non_blocking=True)
            if isinstance(labels, (dict, )):
                labels = {k: v.cuda() for k, v in labels.items()}
            else:
                labels = labels.cuda()
            for key, val in meta.items():
                if isinstance(val, (list, )):
                    for i in range(len(val)):
                        val[i] = val[i].cuda(non_blocking=True)
                else:
                    meta[key] = val.cuda(non_blocking=True)

        # Update the learning rate.
        lr = optim.get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg)
        optim.set_lr(optimizer, lr)

        train_meter.data_toc()

        # preds = model(inputs) #this is how model.forward() is called
        preds = model(inputs)[
            0]  #this is the original output, the output of the last layer
        linear_layer_output = model(inputs)[1]

        if isinstance(labels, (dict, )):
            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Compute the loss.
            loss_verb = loss_fun(preds[0], labels['verb'])
            loss_noun = loss_fun(preds[1], labels['noun'])
            loss = 0.5 * (loss_verb + loss_noun)

            # check Nan Loss.
            misc.check_nan_losses(loss)
        else:
            #I believe this is the VGG loss part, as the labels are not split into nouns and verbs

            # Explicitly declare reduction to mean.
            loss_fun = losses.get_loss_func(
                cfg.MODEL.LOSS_FUNC)(reduction="mean")

            # Embedding loss function.
            emb_loss_fun = losses.get_loss_func(
                cfg.MODEL.EMB_LOSS_FUNC)(reduction="mean")

            # Compute the loss for the main model.
            loss = loss_fun(preds, labels)

            # Compute the loss for the embeddings.
            emb_loss = emb_loss_fun(linear_layer_output, word_embedding)

            # Add the losses together- use embeddings to fine tune the model's objective
            loss = loss + emb_loss

            # check Nan Loss.
            misc.check_nan_losses(loss)

        # Perform the backward pass.
        optimizer.zero_grad()
        loss.backward()
        # Update the parameters.
        optimizer.step()

        if isinstance(labels, (dict, )):
            # Compute the verb accuracies.
            verb_top1_acc, verb_top5_acc = metrics.topk_accuracies(
                preds[0], labels['verb'], (1, 5))

            # Gather all the predictions across all the devices.
            if cfg.NUM_GPUS > 1:
                loss_verb, verb_top1_acc, verb_top5_acc = du.all_reduce(
                    [loss_verb, verb_top1_acc, verb_top5_acc])

            # Copy the stats from GPU to CPU (sync point).
            loss_verb, verb_top1_acc, verb_top5_acc = (
                loss_verb.item(),
                verb_top1_acc.item(),
                verb_top5_acc.item(),
            )

            # Compute the noun accuracies.
            noun_top1_acc, noun_top5_acc = metrics.topk_accuracies(
                preds[1], labels['noun'], (1, 5))

            # Gather all the predictions across all the devices.
            if cfg.NUM_GPUS > 1:
                loss_noun, noun_top1_acc, noun_top5_acc = du.all_reduce(
                    [loss_noun, noun_top1_acc, noun_top5_acc])

            # Copy the stats from GPU to CPU (sync point).
            loss_noun, noun_top1_acc, noun_top5_acc = (
                loss_noun.item(),
                noun_top1_acc.item(),
                noun_top5_acc.item(),
            )

            # Compute the action accuracies.
            action_top1_acc, action_top5_acc = metrics.multitask_topk_accuracies(
                (preds[0], preds[1]), (labels['verb'], labels['noun']), (1, 5))
            # Gather all the predictions across all the devices.
            if cfg.NUM_GPUS > 1:
                loss, action_top1_acc, action_top5_acc = du.all_reduce(
                    [loss, action_top1_acc, action_top5_acc])

            # Copy the stats from GPU to CPU (sync point).
            loss, action_top1_acc, action_top5_acc = (
                loss.item(),
                action_top1_acc.item(),
                action_top5_acc.item(),
            )

            # Update and log stats.
            train_meter.update_stats(
                (verb_top1_acc, noun_top1_acc, action_top1_acc),
                (verb_top5_acc, noun_top5_acc, action_top5_acc),
                (loss_verb, loss_noun, loss),
                lr,
                inputs[0].size(0) * max(
                    cfg.NUM_GPUS, 1
                ),  # If running  on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU.
            )
            # write to tensorboard format if available.
            if writer is not None and not wandb_log:
                writer.add_scalars(
                    {
                        "Train/loss": loss,
                        "Train/lr": lr,
                        "Train/Top1_acc": action_top1_acc,
                        "Train/Top5_acc": action_top5_acc,
                        "Train/verb/loss": loss_verb,
                        "Train/noun/loss": loss_noun,
                        "Train/verb/Top1_acc": verb_top1_acc,
                        "Train/verb/Top5_acc": verb_top5_acc,
                        "Train/noun/Top1_acc": noun_top1_acc,
                        "Train/noun/Top5_acc": noun_top5_acc,
                    },
                    global_step=data_size * cur_epoch + cur_iter,
                )

            if wandb_log:
                wandb.log(
                    {
                        "Train/loss": loss,
                        "Train/lr": lr,
                        "Train/Top1_acc": action_top1_acc,
                        "Train/Top5_acc": action_top5_acc,
                        "Train/verb/loss": loss_verb,
                        "Train/noun/loss": loss_noun,
                        "Train/verb/Top1_acc": verb_top1_acc,
                        "Train/verb/Top5_acc": verb_top5_acc,
                        "Train/noun/Top1_acc": noun_top1_acc,
                        "Train/noun/Top5_acc": noun_top5_acc,
                        "train_step": data_size * cur_epoch + cur_iter,
                    }, )
        else:
            top1_err, top5_err = None, None
            if cfg.DATA.MULTI_LABEL:
                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    [loss] = du.all_reduce([loss])
                loss = loss.item()
            else:
                # Compute the errors.
                num_topks_correct = metrics.topks_correct(
                    preds, labels, (1, 5))
                top1_err, top5_err = [(1.0 - x / preds.size(0)) * 100.0
                                      for x in num_topks_correct]
                # Gather all the predictions across all the devices.
                if cfg.NUM_GPUS > 1:
                    loss, top1_err, top5_err = du.all_reduce(
                        [loss, top1_err, top5_err])

                # Copy the stats from GPU to CPU (sync point).
                loss, top1_err, top5_err = (
                    loss.item(),
                    top1_err.item(),
                    top5_err.item(),
                )

            # Update and log stats.
            train_meter.update_stats(
                top1_err,
                top5_err,
                loss,
                lr,
                inputs[0].size(0) * max(
                    cfg.NUM_GPUS, 1
                ),  # If running  on CPU (cfg.NUM_GPUS == 1), use 1 to represent 1 CPU.
            )
            # write to tensorboard format if available.
            if writer is not None and not wandb_log:
                writer.add_scalars(
                    {
                        "Train/loss": loss,
                        "Train/lr": lr,
                        "Train/Top1_err": top1_err,
                        "Train/Top5_err": top5_err,
                    },
                    global_step=data_size * cur_epoch + cur_iter,
                )

            if wandb_log:
                wandb.log(
                    {
                        "Train/loss": loss,
                        "Train/lr": lr,
                        "Train/Top1_err": top1_err,
                        "Train/Top5_err": top5_err,
                        "train_step": data_size * cur_epoch + cur_iter,
                    }, )

        train_meter.iter_toc()  # measure allreduce for this meter
        train_meter.log_iter_stats(cur_epoch, cur_iter)
        train_meter.iter_tic()

    # Log epoch stats.
    train_meter.log_epoch_stats(cur_epoch)
    train_meter.reset()