Example #1
0
    def forward(self, rgb, thermal, depth, audio, label, validate=False):

        logits_s, features_s = self.student_model(audio)

        regression_losses = []
        classification_losses = []
        kd_losses = []
        for modality in self.teacher_models.keys():

            # generate the annotations
            with torch.no_grad():
                if modality == 'rgb':
                    logits_t, features_t = self.teacher_models['rgb'](rgb)
                elif modality == 'thermal':
                    logits_t, features_t = self.teacher_models['thermal'](
                        thermal)
                elif modality == 'audio':
                    logits_t, features_t = self.teacher_models['audio'](audio)
                elif modality == 'depth':
                    logits_t, features_t = self.teacher_models['depth'](depth)

                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                if self.config.getboolean('use_labels'):
                    annotations = label
                else:
                    annotations = logits_to_ground_truth(
                        logits=logits_t,
                        anchors=None,
                        valid_classes_dict=self.valid_classes_dict,
                        config=self.config,
                    )

            loss_regression, loss_cls = self.criterion_main(
                logits_s, annotations)
            regression_losses.append(loss_regression)
            classification_losses.append(loss_cls)

            loss_div = torch.zeros_like(loss_regression)
            if self.criterion_div is not None:
                loss_div = self.criterion_div(logits_s, logits_t)

            loss_kd = torch.zeros_like(loss_regression)
            if self.criterion_kd is not None:
                # Due to parallel execution
                loss_kd = self.criterion_kd(features_s, features_t)
            kd_losses.append(loss_kd)

        return [
            regression_losses, classification_losses, kd_losses,
            torch.zeros_like(loss_regression),
            torch.zeros_like(loss_regression),
            torch.zeros_like(loss_regression)
        ]
Example #2
0
    def forward(self, rgb, thermal, depth, audio, label, validate=False):

        logits_s, features_s = self.student_model(audio)

        regression_losses = []
        classification_losses = []
        kd_losses = []
        s_features = []
        reals = []
        fakes = []
        gps = []
        for modality in self.teacher_models.keys():

            s_features.append(features_s[modality])
            # generate the annotations
            with torch.no_grad():
                if modality == 'rgb':
                    logits_t, features_t = self.teacher_models['rgb'](rgb)
                elif modality == 'thermal':
                    logits_t, features_t = self.teacher_models['thermal'](
                        thermal)
                elif modality == 'audio':
                    logits_t, features_t = self.teacher_models['audio'](audio)
                elif modality == 'depth':
                    logits_t, features_t = self.teacher_models['depth'](depth)

                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                annotations = logits_to_ground_truth(
                    logits=logits_t,
                    anchors=None,
                    valid_classes_dict=self.valid_classes_dict,
                    config=self.config,
                )

            loss_regression, loss_cls = self.criterion_main(
                logits_s[modality], annotations)
            regression_losses.append(loss_regression)
            classification_losses.append(loss_cls)

            D_real = self.discriminators[modality](features_t)
            reals.append(D_real)
            if isinstance(features_s[modality], tuple) or isinstance(
                    features_s[modality], list):
                features_s_detached = [
                    f.detach() for f in features_s[modality]
                ]
            else:
                features_s_detached = features_s[modality].detach()
            D_fake = self.discriminators[modality](features_s_detached)
            fakes.append(D_fake)
            gradient_penalty = []
            if not validate:
                gradient_penalty = calc_gradient_penalty(
                    self.discriminators[modality], features_t,
                    features_s_detached)
            gps.append(gradient_penalty)

        return regression_losses, classification_losses, reals, fakes, gps, s_features
Example #3
0
    def forward(self, rgb, thermal, depth, audio, label, validate=False):

        # Adversarial ground truths
        Tensor = torch.cuda.FloatTensor
        valid = Variable(Tensor(rgb.shape[0], 1).fill_(1.0),
                         requires_grad=False)
        fake = Variable(Tensor(rgb.shape[0], 1).fill_(0.0),
                        requires_grad=False)

        logits_s, features_s = self.student_model(audio)

        regression_losses = []
        classification_losses = []
        kd_losses = []
        s_features = []
        losses_d = []
        for modality in self.teacher_models.keys():

            s_features.append(features_s[modality])
            # generate the annotations
            with torch.no_grad():
                if modality == 'rgb':
                    logits_t, features_t = self.teacher_models['rgb'](rgb)
                elif modality == 'thermal':
                    logits_t, features_t = self.teacher_models['thermal'](
                        thermal)
                elif modality == 'audio':
                    logits_t, features_t = self.teacher_models['audio'](audio)
                elif modality == 'depth':
                    logits_t, features_t = self.teacher_models['depth'](depth)

                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                annotations = logits_to_ground_truth(
                    logits=logits_t,
                    anchors=None,
                    valid_classes_dict=self.valid_classes_dict,
                    config=self.config,
                )

            loss_regression, loss_cls = self.criterion_main(
                logits_s[modality], annotations)
            regression_losses.append(loss_regression)
            classification_losses.append(loss_cls)

            #loss_div = torch.zeros_like(loss_regression)
            #if self.criterion_div is not None:
            #    loss_div = self.criterion_div(logits_s, logits_t)

            #loss_kd = torch.zeros_like(loss_regression)
            #if self.criterion_kd is not None:
            #    # Due to parallel execution
            #    loss_kd = self.criterion_kd(
            #        features_s,
            #        features_t
            #    )
            #kd_losses.append(loss_kd)

            real_loss = self.criterion_adversarial(
                self.discriminators[modality](features_t), valid)

            if isinstance(features_s[modality], tuple) or isinstance(
                    features_s[modality], list):
                features_s_detached = [
                    f.detach() for f in features_s[modality]
                ]
            else:
                features_s_detached = features_s[modality].detach()
            fake_loss = self.criterion_adversarial(
                self.discriminators[modality](features_s_detached), fake)
            # d_loss = (real_loss + fake_loss) / 2
            # Someone adviced to do If the generator minimizes criterion,
            # then you probably want it to be real_loss - fake_loss,
            # so that the discriminators push up the criterion for fake images.
            d_loss = real_loss + fake_loss
            losses_d.append(d_loss)

        return regression_losses, classification_losses, losses_d, s_features, [], []
Example #4
0
    def forward(self,
                rgb,
                thermal,
                depth,
                audio,
                label,
                validate=False,
                augment=False):

        logits_s, features_s = self.student_model(audio)

        # The annotations will be in batch labels
        batch_labels = [[] for i in range(rgb.shape[0])]

        # The features will be in kd_labels
        # Supposed to be a list of tensors
        kd_labels = []

        # for modality, teacher_model in self.teacher_models.items():
        modalities = list(self.teacher_models.keys())
        if augment:
            modalities.append('augmentation')
        for modality in modalities:
            if modality == 'augmentation':
                teacher_model = self.teacher_models['rgb']
            else:
                teacher_model = self.teacher_models[modality]
            with torch.no_grad():
                if modality == 'rgb':
                    prediction, features_t = teacher_model(rgb)
                elif modality == 'audio':
                    prediction, features_t = teacher_model(audio)
                elif modality == 'thermal':
                    prediction, features_t = teacher_model(thermal)
                elif modality == 'depth':
                    prediction, features_t = teacher_model(depth)
                elif modality == 'augmentation':
                    # Here label is a subset of rgb images that correspond to
                    # We have the original batch size of rgb, thermal depth
                    # the audio now is the original batch audio + each new audio
                    # taken from a thermal image
                    prediction, features_t = teacher_model(label)
                else:
                    raise ValueError(
                        'No valid modality to predict from teacher')
                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                kd_labels.append(features_t)

                this_batch_labels = logits_to_ground_truth(
                    logits=prediction,
                    anchors=None,
                    valid_classes_dict=self.valid_classes_dict,
                    config=self.config,
                    include_scores=True,
                )

            # Integrate all predictions
            for i in range(rgb.shape[0]):
                # No new prediction
                if isListEmpty(this_batch_labels[i]):
                    continue

                # A new prediction
                if len(np.shape(this_batch_labels[i])) == 1:
                    this_batch_labels[i] = np.expand_dims(this_batch_labels[i],
                                                          axis=0)
                if isListEmpty(batch_labels[i]):
                    batch_labels[i] = this_batch_labels[i]
                    continue

                # If here, both them have it, so concat
                if len(np.shape(batch_labels[i])) == 1:
                    batch_labels[i] = np.expand_dims(batch_labels[i], axis=0)
                batch_labels[i] = np.concatenate(
                    (batch_labels[i], this_batch_labels[i]), axis=0)

        # Non-max suppress the prediction of multiple teachers
        for i in range(rgb.shape[0]):
            # No new prediction by any teacher
            if batch_labels[i] == []:
                continue
            # Else, do a maximum suppresion of the prediction
            idx = nms(boxes=torch.from_numpy(batch_labels[i][:, 0:4]),
                      scores=torch.from_numpy(batch_labels[i][:, 4]),
                      iou_threshold=0.5).cpu().detach().numpy()

            # Remove scores
            batch_labels[i] = np.delete(batch_labels[i], 4, 1)

            # keep nms index
            batch_labels[i] = batch_labels[i][idx]

        # main loss function
        loss_regression, loss_cls = self.criterion_main(logits_s, batch_labels)

        loss_kd = torch.zeros(1)
        if self.criterion_kd is not None:
            # Due to parallel execution
            loss_kd = self.criterion_kd(
                features_s,
                kd_labels,
            )

        return [[loss_regression], [loss_cls], [loss_kd],
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device)]
Example #5
0
    def forward(self,
                rgb,
                thermal,
                depth,
                audio,
                label,
                validate=False,
                augment=False):

        kd_losses = []
        logits_s, features_s = self.student_model(audio)
        batch_labels = [[] for i in range(rgb.shape[0])]
        for modality, teacher_model in self.teacher_models.items():
            with torch.no_grad():
                if modality == 'rgb':
                    prediction, features_t = teacher_model(rgb)
                elif modality == 'audio':
                    prediction, features_t = teacher_model(audio)
                elif modality == 'thermal':
                    prediction, features_t = teacher_model(thermal)
                elif modality == 'depth':
                    prediction, features_t = teacher_model(depth)
                else:
                    raise ValueError(
                        'No valid modality to predict from teacher')
                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                this_batch_labels = logits_to_ground_truth(
                    logits=prediction,
                    anchors=None,
                    valid_classes_dict=self.valid_classes_dict,
                    config=self.config,
                    include_scores=True,
                )

            loss_kd = torch.zeros(1)
            if self.criterion_kd is not None:
                # Due to parallel execution
                loss_kd = self.criterion_kd(features_s, features_t)
            kd_losses.append(loss_kd)

            # Integrate all predictions
            for i in range(rgb.shape[0]):
                # No new prediction
                if isListEmpty(this_batch_labels[i]):
                    continue

                # A new prediction
                if len(np.shape(this_batch_labels[i])) == 1:
                    this_batch_labels[i] = np.expand_dims(this_batch_labels[i],
                                                          axis=0)
                if isListEmpty(batch_labels[i]):
                    batch_labels[i] = this_batch_labels[i]
                    continue

                # If here, both them have it, so concat
                if len(np.shape(batch_labels[i])) == 1:
                    batch_labels[i] = np.expand_dims(batch_labels[i], axis=0)
                batch_labels[i] = np.concatenate(
                    (batch_labels[i], this_batch_labels[i]), axis=0)

        # Non-max suppress the prediction of multiple teachers
        for i in range(rgb.shape[0]):
            # No new prediction by any teacher
            if batch_labels[i] == []:
                continue
            # Else, do a maximum suppresion of the prediction
            idx = nms(boxes=torch.from_numpy(batch_labels[i][:, 0:4]),
                      scores=torch.from_numpy(batch_labels[i][:, 4]),
                      iou_threshold=0.5).cpu().detach().numpy()

            # Remove scores
            batch_labels[i] = np.delete(batch_labels[i], 4, 1)

            # keep nms index
            batch_labels[i] = batch_labels[i][idx]

        #if not all([np.any(elem) for elem in batch_labels]):
        #if isListEmpty(batch_labels):
        loss_regression, loss_cls = self.criterion_main(logits_s, batch_labels)

        return [[loss_regression], [loss_cls], kd_losses,
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device)]
Example #6
0
    def forward(self,
                rgb,
                thermal,
                depth,
                audio,
                label,
                validate=False,
                augment=False):

        # TODO, check for rgb.shape greater than 2... but unlickely case

        kd_losses = []
        if augment:
            audio = self.merge_batch_0_1(audio)

        logits_s, features_s = self.student_model(audio)
        batch_labels = [[] for i in range(rgb.shape[0])]
        for modality, teacher_model in self.teacher_models.items():
            with torch.no_grad():
                if modality == 'rgb':
                    prediction, features_t = teacher_model(rgb)
                elif modality == 'audio':
                    prediction, features_t = teacher_model(audio)
                elif modality == 'thermal':
                    prediction, features_t = teacher_model(thermal)
                elif modality == 'depth':
                    prediction, features_t = teacher_model(depth)
                else:
                    raise ValueError(
                        'No valid modality to predict from teacher')
                # Detach for kd loss calculation
                if isinstance(features_t, tuple) or isinstance(
                        features_t, list):
                    features_t = [f.detach() for f in features_t]
                else:
                    features_t = features_t.detach()

                # we have to average the feature 0 and feature 1
                # to comply with augmentation
                if augment:
                    features_t = self.average_batch_0_1(features_t)

                this_batch_labels = logits_to_ground_truth(
                    logits=prediction,
                    anchors=None,
                    valid_classes_dict=self.valid_classes_dict,
                    config=self.config,
                    include_scores=True,
                )

            loss_kd = torch.zeros(1)
            if self.criterion_kd is not None:
                # Due to parallel execution
                loss_kd = self.criterion_kd(features_s, features_t)
            kd_losses.append(loss_kd)

            # Integrate all predictions
            for i in range(rgb.shape[0]):
                # No new prediction
                if isListEmpty(this_batch_labels[i]):
                    continue

                # A new prediction
                if len(np.shape(this_batch_labels[i])) == 1:
                    this_batch_labels[i] = np.expand_dims(this_batch_labels[i],
                                                          axis=0)
                if isListEmpty(batch_labels[i]):
                    batch_labels[i] = this_batch_labels[i]
                    continue

                # If here, both them have it, so concat
                if len(np.shape(batch_labels[i])) == 1:
                    batch_labels[i] = np.expand_dims(batch_labels[i], axis=0)
                batch_labels[i] = np.concatenate(
                    (batch_labels[i], this_batch_labels[i]), axis=0)

        # Here, merge labels [0] and [1]
        # Do so before NMS
        # If any one is [] then it is like adding noise
        # Notice we do not remove batch 0 and also use it as another
        # image for gradient
        if augment and batch_labels[1] != [] and batch_labels[0] != []:
            batch_labels[1] = np.concatenate(
                (batch_labels[0], batch_labels[1]), axis=0)
            #del batch_labels[0]

        # Non-max suppress the prediction of multiple teachers
        for i in range(rgb.shape[0]):
            # if augment, do go to the last non existant i
            # that is, we merged 0 and 1 so, there is one less to process
            #if augment and i == rgb.shape[0] - 1:
            #    break

            # No new prediction by any teacher
            if batch_labels[i] == []:
                continue
            # Else, do a maximum suppresion of the prediction
            idx = nms(boxes=torch.from_numpy(batch_labels[i][:, 0:4]),
                      scores=torch.from_numpy(batch_labels[i][:, 4]),
                      iou_threshold=0.5).cpu().detach().numpy()

            # Remove scores
            batch_labels[i] = np.delete(batch_labels[i], 4, 1)

            # keep nms index
            batch_labels[i] = batch_labels[i][idx]

        #if not all([np.any(elem) for elem in batch_labels]):
        #if isListEmpty(batch_labels):
        loss_regression, loss_cls = self.criterion_main(logits_s, batch_labels)

        #print(f"regression_losses={regression_losses} {rgb.device}")
        #print(f"classification_losses={classification_losses}")
        #print(f"loss_div={loss_div}")
        #print(f"loss_kd={loss_kd}")
        return [[loss_regression], [loss_cls], kd_losses,
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device),
                torch.zeros(1).to(rgb.device)]
Example #7
0
    def refine_ids(self, model, config):
        """
        This utility provides a list of the images that yolo
        is able to predict something.
        We should ignore images without prediction
        This is important on batches inputs, as ignoring an image is tricky
        Args:
                model: Model that would asses whether image is easy to predict
                config: a parsed config file

        Returns:
                The reduced ids that are easy to predict. This way we can be
                certain that the teacher would be able to predict
        """

        teacher = config['teacher']

        # Get the id list
        self.get_id_list()

        # If using labels, we only want to let the predictions with meaningful
        # labels be exercised due to runtime limitations
        if self.use_labels:
            valid_ids = []
            for i, id in enumerate(self.ids):
                labels = self.get_annotations(id)
                if len(labels) < 1:
                    continue
                valid_labels = self.filter_labels(labels)
                if len(valid_labels) > 1:
                    valid_ids.append(id)
            new_ids = list(set(self.ids) & set(valid_ids))
            logger.debug(f"Reduced {len(self.ids)}->{len(new_ids)}")
            self.ids = new_ids
            #self.ids = new_ids[:len(self.ids)//4]
            self.num_images = len(self.ids)
            return

        # If not using labels, then reduce the ids based on the teacher
        if not os.path.exists(self.predictions_file):
            logger.warn(f"Building file {self.predictions_file}")
            anchors = None
            if 'yolov2' in teacher:
                if type(model) == torch.nn.DataParallel:
                    anchors = model.module.anchors
                else:
                    anchors = model.anchors
            valid_images = []
            for i, id in enumerate(tqdm(self.ids, desc=f"Predict File")):
                rgb, thermal, depth, audio, label, id = self.__getitem__(i)

                with torch.no_grad():
                    rgb = Variable(torch.FloatTensor(rgb).unsqueeze(dim=0),
                                   requires_grad=False).to(self.device)
                    logits, features = model(rgb)

                    batch_predictions = logits_to_ground_truth(
                        logits=logits,
                        valid_classes_dict=self.valid_classes_dict,
                        anchors=anchors,
                        config=config,
                        include_scores=True,
                        text_classes=False,
                        crash_if_no_pred=False,
                    )
                    if not np.any(batch_predictions):
                        num_predictions = 0
                    else:
                        num_predictions = len(batch_predictions[0])
                    min_confidence = 0
                    for i, predictions in enumerate(batch_predictions):
                        for pred in predictions:
                            score = pred[4]
                            if score > min_confidence:
                                min_confidence = score

                    valid_images.append([id, num_predictions, min_confidence])

            # Save to a text file for debug
            np.savetxt(self.predictions_file,
                       np.asarray(valid_images),
                       delimiter=",",
                       fmt='%s')

        # Just keep ids where teacher predicted something
        dataframe = pd.read_csv(
            self.predictions_file,
            names=['ID', 'Num_pred', 'min_confidence'],
            dtype={
                'ID': str,
                'Num_pred': np.int32,
                'min_confidence': np.float32
            },
        )

        # How good a prediction from teacher has to be, to
        # accept this image as valid
        if 'yolov2' in config['teacher']:
            minconf = 0.5
        elif 'EfficientDet' in config['teacher']:
            minconf = 0.40
        else:
            raise Exception("Unsupported student")

        valid_ids = dataframe[
            dataframe['min_confidence'] > minconf]['ID'].tolist()

        # Further redefine based on config request
        if 'None' not in config['id_filter']:
            r = re.compile(config['id_filter'])
            valid_ids = list(filter(r.match, valid_ids))

        new_ids = list(set(self.ids) & set(valid_ids))
        #self.ids = new_ids
        from random import shuffle
        shuffle(new_ids)
        logger.debug(f"Reduced {len(self.ids )}->{len(new_ids)}")
        self.ids = new_ids
        self.ids.sort()
        #self.ids = new_ids[:len(new_ids)//50]
        #self.ids = new_ids[:len(new_ids)//5]
        self.num_images = len(self.ids)
Example #8
0
def train_traditional(
    train_set,
    model,
    optimizer,
    epoch,
    config,
    writer
):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Build the train generator
    train_sampler = None
    if config['engine'] == 'DistributedDataParallel':
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_set,
        )
        generator = DataLoader(
            train_set,
            batch_size=config.getint('batch_size'),
            shuffle=(train_sampler is None),
            pin_memory=True,
            drop_last=True,
            collate_fn=custom_collate_factory(config),
            num_workers=config.getint('num_workers'),
            sampler=train_sampler
        )
    else:
        generator = DataLoader(
            train_set,
            batch_size=config.getint('batch_size'),
            shuffle=True,
            drop_last=True,
            collate_fn=custom_collate_factory(config),
            num_workers=config.getint('num_workers'),
        )
    num_iter_per_epoch = len(generator)

    logger.info(f"Traditional Training for {num_iter_per_epoch} iters")

    model.module.teacher_models.eval()
    model.module.student_model.train()

    if train_sampler is not None:
        train_sampler.set_epoch(epoch)


    for iter, batch in enumerate(tqdm(generator, desc=f"Epoch={epoch+1}")):

        # ==================Forward=================

        rgb, thermal, depth, audio, label, id = batch
        rgb = Variable(
            rgb.to(device),
            requires_grad=True
        ).to(device)
        if config.getboolean('use_thermal'):
            thermal = Variable(
                thermal.to(device),
                requires_grad=True
            ).to(device)
        if config.getboolean('use_depth'):
            depth = Variable(
                depth.to(device),
                requires_grad=True
            ).to(device)

        # Add in like a new modality on what is called label
        # and then audio is the sumation of both label + the current batch
        augment = False
        if config['train_method'] == 'traditional_nms_kdlist_augmented' and random.random() > max(0.5, (0.5 + 0.5 * (1 - epoch/50))):
            augment = True
            label, audio = train_set.yield_batch(audio.shape[0], id)

        audio = Variable(
            audio.to(device),
            requires_grad=True
        ).to(device)

        # ==================Backward=================
        optimizer.zero_grad()

        if config['train_method'] == 'traditional_nms_augmented':
            augment = np.random.choice([True, False], p=[0.3, 0.7])

        result = model(
            rgb,
            thermal,
            depth,
            audio,
            label,
            augment=config.getboolean('audio_augmentation_merge'),
        )

        # For debug purposes
        if epoch == 0:
            with torch.no_grad():
                for i, item in enumerate(id):
                    logger.debug(f"\n{i}=> {item}")
                    logger.debug(f"rgb={torch.mean(rgb[i])}")
                    if config.getboolean('use_thermal'):
                        logger.debug(f"thermal={torch.mean(thermal[i])}")
                    if config.getboolean('use_depth'):
                        logger.debug(f"depth={torch.mean(depth[i])}")
                    if config.getboolean('use_audio'):
                        logger.debug(f"audio={torch.mean(audio[i])}")
                    logger.debug(f"label={label}")
                for j, modality in enumerate(model.module.teacher_models.keys()):
                    if modality == 'rgb':
                        logits_t, features_t = model.module.teacher_models['rgb'](rgb)
                    elif modality == 'thermal':
                        logits_t, features_t = model.module.teacher_models['thermal'](thermal)
                    elif modality == 'audio':
                        logits_t, features_t = model.module.teacher_models['audio'](audio)
                    elif modality == 'depth':
                        logits_t, features_t = model.module.teacher_models['depth'](depth)

                    annotations = logits_to_ground_truth(
                        logits=logits_t,
                        anchors=None,
                        valid_classes_dict=train_set.valid_classes_dict,
                        config=config,
                    )
                    logger.debug(f"GTs[{modality}]={annotations}")

        # Calculate the losses
        regression_losses, classification_losses, kd_losses, _, _, _ = result
        loss_regression = torch.mean(torch.stack(regression_losses))
        loss_cls = torch.mean(torch.stack(classification_losses))
        loss_main = loss_regression + loss_cls
        #loss_kd = torch.mean(torch.stack(kd_losses))
        loss_kd = torch.sum(torch.stack(kd_losses))
        loss_div = 0

        loss = config.getfloat('w_main') * loss_main
        loss += config.getfloat('w_div') * loss_div
        loss += config.getfloat('w_kd') * loss_kd
        loss.backward()

        if config.getfloat('grad_clip') > 0:
            torch.nn.utils.clip_grad_norm_(
                model.module.student_model.parameters(),
                config.getfloat('grad_clip')
            )

        optimizer.step()

        logger.info("="*40+"\n")
        logger.info(f"Epoch: {epoch + 1}/{config.getint('num_epoches')}")
        logger.info(f"Iteration: {iter+1}/{num_iter_per_epoch}")
        logger.info(f"Lr: {optimizer.param_groups[0]['lr']}")
        logger.info(f"Loss:{loss}")
        logger.info(f"Regression:{loss_regression}")
        logger.info(f"Cls:{loss_cls}")
        logger.info(f"KLDiv:{loss_div}")
        logger.info(f"KD:{loss_kd}")
        # Take the feedback from all modalities
        for i, modality in enumerate(model.module.teacher_models.keys()):
            logger.info(f"Regression_{modality}:{regression_losses}")
            logger.info(f"Cls_{modality}:{classification_losses}")
            logger.info(f"KLDiv_{modality}:{loss_div}")
            logger.info(f"KD_{modality}:{kd_losses}")
        logger.info("="*40+"\n")


        # Write to TensorBoard
        if writer:
            writer.add_scalar(
                f"Train/Total_loss",
                loss,
                epoch * num_iter_per_epoch + iter
            )
            writer.add_scalar(
                f"Train_/Regression_loss",
                loss_regression,
                epoch * num_iter_per_epoch + iter
            )
            writer.add_scalar(
                f"Train/Class_loss",
                loss_cls,
                epoch * num_iter_per_epoch + iter
            )
            writer.add_scalar(
                f"Train/KLDiv",
                loss_div,
                epoch * num_iter_per_epoch + iter
            )
            writer.add_scalar(
                f"Train/KD",
                loss_kd,
                epoch * num_iter_per_epoch + iter
            )

    return loss.item()