Example #1
0
    def train(self, src_data_loader, tgt_data_loader):
        self.model.train()
        self.seg_disc.train()
        # self.det_disc.train()
        self.mode = 'train'
        self.data_loader = src_data_loader
        tgt_data_iter = iter(tgt_data_loader)
        assert len(tgt_data_loader) <= len(src_data_loader)

        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        for i, src_data_batch in enumerate(self.data_loader):
            # before train iter
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # fetch target data batch
            tgt_data_batch = next(tgt_data_iter, None)
            if tgt_data_batch is None:
                tgt_data_iter = iter(tgt_data_loader)
                tgt_data_batch = next(tgt_data_iter)

            # ------------------------
            # train Discriminators
            # ------------------------

            set_requires_grad(self.seg_disc, requires_grad=True)
            # src_img_feats=(4, 64, 225, 400)
            src_img_feats = self.model.extract_img_feat(**src_data_batch)
            tgt_img_feats = self.model.extract_img_feat(**tgt_data_batch)

            src_Dlogits = self.seg_disc(src_img_feats)
            src_Dloss = self.seg_disc.loss(src_Dlogits, src=True)
            log_src_Dloss = src_Dloss.item()

            tgt_Dlogits = self.seg_disc(tgt_img_feats)
            tgt_Dloss = self.seg_disc.loss(tgt_Dlogits, src=False)
            log_tgt_Dloss = tgt_Dloss.item()
            Dloss = (src_Dloss + tgt_Dloss) * 0.5

            self.seg_opt.zero_grad()
            Dloss.backward()
            self.seg_opt.step()

            # ------------------------
            # network forward on source: src_task_loss + lambda * tgt_GANLoss
            # ------------------------
            set_requires_grad(self.seg_disc, requires_grad=False)
            losses, src_img_feats = self.model(
                **src_data_batch)  # forward; losses: {'seg_loss'=}

            src_Dlogits = self.seg_disc(src_img_feats)  # (N, 64, 225, 400)
            src_Dpred = src_Dlogits.max(1)[1]  # (N, 225, 400); cuda
            src_Dlabels = torch.ones_like(src_Dpred, dtype=torch.long).cuda()
            src_acc = (src_Dpred == src_Dlabels).float().mean()
            if src_acc > self.src_acc_threshold:
                losses[
                    'src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss(
                        src_Dlogits, src=False)

            # ------------------------
            # network forward on target
            # ------------------------
            tgt_img_feats = self.model.extract_img_feat(**tgt_data_batch)

            tgt_Dlogits = self.seg_disc(tgt_img_feats)
            tgt_Dpred = tgt_Dlogits.max(1)[1]
            tgt_Dlabels = torch.zeros_like(tgt_Dpred, dtype=torch.long).cuda()
            tgt_acc = (tgt_Dpred == tgt_Dlabels).float().mean()
            if tgt_acc > self.tgt_acc_threshold:
                losses[
                    'tgt_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss(
                        tgt_Dlogits, src=True)

            loss, log_vars = parse_losses(losses)
            num_samples = len(src_data_batch['img_metas'])
            log_vars['src_Dloss'] = log_src_Dloss
            log_vars['tgt_Dloss'] = log_tgt_Dloss
            log_vars['src_acc'] = src_acc.item()
            log_vars['tgt_acc'] = tgt_acc.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # after_train_iter callback
            self.log_buffer.update(log_vars, num_samples)
            self.call_hook('after_train_iter')  # optimizer hook && logger hook
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1
Example #2
0
    def train(self, src_data_loader, tgt_data_loader):
        self.model.train()
        self.seg_disc.train()
        self.det_disc.train()
        self.mode = 'train'
        self.data_loader = src_data_loader
        tgt_data_iter = iter(tgt_data_loader)
        assert len(tgt_data_loader) <= len(src_data_loader)

        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        for i, src_data_batch in enumerate(self.data_loader):
            # before train iter
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # fetch target data batch
            tgt_data_batch = next(tgt_data_iter, None)
            if tgt_data_batch is None:
                tgt_data_iter = iter(tgt_data_loader)
                tgt_data_batch = next(tgt_data_iter)

            # ------------------------
            # train Discriminators
            # ------------------------

            set_requires_grad([self.seg_disc, self.det_disc],
                              requires_grad=True)
            # seg_src/tgt_feats: (N, 128); det_src/tgt_feats: (4, 128, 200, 400)
            seg_src_feats, det_src_feats = self.model.forward_fusion(
                **src_data_batch)

            # src segmentation
            seg_src_logits = self.seg_disc(seg_src_feats)
            seg_src_Dloss = self.seg_disc.loss(seg_src_logits, src=True)
            log_seg_src_Dloss = seg_src_Dloss.item()

            # src detection
            det_src_logits = self.det_disc(det_src_feats)
            det_src_Dloss = self.det_disc.loss(det_src_logits, src=True)
            log_det_src_Dloss = det_src_Dloss.item()

            src_Dloss = seg_src_Dloss + det_src_Dloss
            self.seg_opt.zero_grad()
            self.det_opt.zero_grad()
            src_Dloss.backward()
            self.seg_opt.step()
            self.det_opt.step()

            # tgt segmentation
            seg_tgt_feats, det_tgt_feats = self.model.forward_fusion(
                **tgt_data_batch)

            seg_tgt_logits = self.seg_disc(seg_tgt_feats)
            seg_tgt_Dloss = self.seg_disc.loss(seg_tgt_logits, src=False)
            log_seg_tgt_Dloss = seg_tgt_Dloss.item()

            # tgt detection
            det_tgt_logits = self.det_disc(det_tgt_feats)
            det_tgt_Dloss = self.det_disc.loss(det_tgt_logits, src=False)
            log_det_tgt_Dloss = det_tgt_Dloss.item()

            tgt_Dloss = seg_tgt_Dloss + det_tgt_Dloss
            self.seg_opt.zero_grad()
            self.det_opt.zero_grad()
            tgt_Dloss.backward()
            self.seg_opt.step()
            self.det_opt.step()

            # ------------------------
            # train network on source: task loss + GANLoss
            # ------------------------
            set_requires_grad([self.seg_disc, self.det_disc],
                              requires_grad=False)
            losses, seg_src_feats, det_src_feats = self.model(
                **src_data_batch)  # forward; losses: {'seg_loss'=}

            seg_disc_logits = self.seg_disc(seg_src_feats)  # (N, 2)
            seg_disc_pred = seg_disc_logits.max(1)[1]  # (N, ); cuda
            seg_label = torch.ones_like(seg_disc_pred, dtype=torch.long).cuda()
            seg_acc = (seg_disc_pred == seg_label).float().mean()
            acc_threshold = 0.6
            if seg_acc > acc_threshold:
                losses[
                    'seg_src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss(
                        seg_disc_logits, src=False)

            det_disc_logits = self.det_disc(det_src_feats)  # (4, 2, 49, 99)
            det_disc_pred = det_disc_logits.max(1)[1]  # (4, 49, 99); cuda
            det_label = torch.ones_like(det_disc_pred, dtype=torch.long).cuda()
            det_acc = (det_disc_pred == det_label).float().mean()
            if det_acc > acc_threshold:
                losses[
                    'det_src_GANloss'] = self.lambda_GANLoss * self.det_disc.loss(
                        det_disc_logits, src=False)

            loss, log_vars = parse_losses(losses)
            num_samples = len(src_data_batch['img_metas'])
            log_vars['seg_src_Dloss'] = log_seg_src_Dloss
            log_vars['det_src_Dloss'] = log_det_src_Dloss
            log_vars['seg_tgt_Dloss'] = log_seg_tgt_Dloss
            log_vars['det_tgt_Dloss'] = log_det_tgt_Dloss
            log_vars['seg_src_acc'] = seg_acc.item()
            log_vars['det_src_acc'] = det_acc.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # ------------------------
            # train network on target: only GANLoss
            # ------------------------
            self.optimizer.zero_grad()
            seg_tgt_feats, det_tgt_feats = self.model.forward_fusion(
                **tgt_data_batch)

            seg_disc_logits = self.seg_disc(seg_tgt_feats)  # (N, 2)
            seg_disc_pred = seg_disc_logits.max(1)[1]  # (N, ); cuda
            seg_label = torch.zeros_like(seg_disc_pred,
                                         dtype=torch.long).cuda()
            seg_acc = (seg_disc_pred == seg_label).float().mean()
            tgt_GANloss = None
            if seg_acc > acc_threshold:
                seg_tgt_loss = self.lambda_GANLoss * self.seg_disc.loss(
                    seg_disc_logits, src=True)
                tgt_GANloss = seg_tgt_loss
                log_vars['seg_tgt_GANloss'] = seg_tgt_loss.item()

            det_disc_logits = self.det_disc(det_tgt_feats)  # (4, 2, 49, 99)
            det_disc_pred = det_disc_logits.max(1)[1]  # (4, 49, 99); cuda
            det_label = torch.zeros_like(det_disc_pred,
                                         dtype=torch.long).cuda()
            det_acc = (det_disc_pred == det_label).float().mean()
            if det_acc > acc_threshold:
                det_tgt_loss = self.lambda_GANLoss * self.det_disc.loss(
                    det_disc_logits, src=True)
                log_vars['det_tgt_GANloss'] = det_tgt_loss.item()
                if tgt_GANloss is None:
                    tgt_GANloss = det_tgt_loss
                else:
                    tgt_GANloss += det_tgt_loss

            log_vars['seg_tgt_acc'] = seg_acc.item()
            log_vars['det_tgt_acc'] = det_acc.item()
            if tgt_GANloss is not None:
                tgt_GANloss.backward()
                self.optimizer.step()

            # after_train_iter callback
            self.log_buffer.update(log_vars, num_samples)
            self.call_hook('after_train_iter')  # optimizer hook && logger hook
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1
Example #3
0
    def train(self, src_data_loader, tgt_data_loader):
        self.model.train()
        self.mode = 'train'
        self.data_loader = src_data_loader
        if self.with_disc:
            self.seg_disc.train()
            tgt_data_iter = iter(tgt_data_loader)
            assert len(tgt_data_loader) <= len(src_data_loader)

        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        for i, src_data_batch in enumerate(self.data_loader):
            # before train iter
            self._inner_iter = i
            self.call_hook('before_train_iter')

            if self.with_disc:
                # fetch target data batch
                tgt_data_batch = next(tgt_data_iter, None)
                if tgt_data_batch is None:
                    tgt_data_iter = iter(tgt_data_loader)
                    tgt_data_batch = next(tgt_data_iter)

            # ------------------------
            # train Discriminators
            # ------------------------

            if self.with_disc:
                set_requires_grad(self.seg_disc, requires_grad=True)
                # seg_feats_before_fusion: (N, 64, 225, 400); seg_feats_after_fusion: (N, 128)
                # src segmentation
                img_feats, fusion_feats = self.model.extract_feat(
                    **src_data_batch)
                src_feats = self.get_feats(img_feats, fusion_feats)
                src_logits = self.seg_disc(src_feats)
                src_Dloss = self.seg_disc.loss(src_logits, src=True)
                log_src_Dloss = src_Dloss.item()

                self.seg_opt.zero_grad()
                src_Dloss.backward()
                self.seg_opt.step()

                # tgt segmentation
                img_feats, fusion_feats = self.model.extract_feat(
                    **tgt_data_batch)
                tgt_feats = self.get_feats(img_feats, fusion_feats)
                tgt_logits = self.seg_disc(tgt_feats)
                tgt_Dloss = self.seg_disc.loss(tgt_logits, src=False)
                log_tgt_Dloss = tgt_Dloss.item()

                self.seg_opt.zero_grad()
                tgt_Dloss.backward()
                self.seg_opt.step()

            # ------------------------
            # train network on source: task loss + GANLoss
            # ------------------------
            losses, img_feats, fusion_feats = self.model(
                **src_data_batch)  # forward; losses: {'seg_loss'=}
            src_feats = self.get_feats(img_feats, fusion_feats)

            acc_threshold = 0.6
            if self.with_disc:
                set_requires_grad(self.seg_disc, requires_grad=True)
                disc_logits = self.seg_disc(src_feats)  # (N, 2)
                disc_pred = disc_logits.max(1)[1]  # (N, ); cuda
                seg_label = torch.ones_like(disc_pred, dtype=torch.long).cuda()
                src_acc = (disc_pred == seg_label).float().mean()
                if src_acc > acc_threshold:
                    losses[
                        'src_GANloss'] = self.lambda_GANLoss * self.seg_disc.loss(
                            disc_logits, src=False)

            loss, log_vars = parse_losses(losses)
            num_samples = len(src_data_batch['img_metas'])
            if self.with_disc:
                log_vars['src_Dloss'] = log_src_Dloss
                log_vars['tgt_Dloss'] = log_tgt_Dloss
                log_vars['src_acc'] = src_acc.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # ------------------------
            # train network on target: only GANLoss
            # ------------------------
            if self.with_disc:
                img_feats, fusion_feats = self.model.extract_feat(
                    **tgt_data_batch)
                tgt_feats = self.get_feats(img_feats, fusion_feats)

                disc_logits = self.seg_disc(tgt_feats)  # (N, 2)
                disc_pred = disc_logits.max(1)[1]  # (N, ); cuda
                seg_label = torch.zeros_like(disc_pred,
                                             dtype=torch.long).cuda()
                tgt_acc = (disc_pred == seg_label).float().mean()
                log_vars['tgt_acc'] = tgt_acc.item()
                if tgt_acc > acc_threshold:
                    tgt_GANloss = self.lambda_GANLoss * self.seg_disc.loss(
                        disc_logits, src=True)
                    log_vars['tgt_GANloss'] = tgt_GANloss.item(
                    )  # original tgt_loss

                    self.optimizer.zero_grad()
                    tgt_GANloss.backward()
                    self.optimizer.step()

            # after_train_iter callback
            self.log_buffer.update(log_vars, num_samples)
            self.call_hook('after_train_iter')  # optimizer hook && logger hook
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1
Example #4
0
    def train(self, src_data_loader, tgt_data_loader):
        self.model.train()
        self.seg_disc.train()
        self.det_disc.train()
        self.mode = 'train'
        self.data_loader = src_data_loader
        tgt_data_iter = iter(tgt_data_loader)
        assert len(tgt_data_loader) <= len(src_data_loader)

        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        for i, src_data_batch in enumerate(self.data_loader):
            # before train iter
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # fetch target data batch
            tgt_data_batch = next(tgt_data_iter, None)
            if tgt_data_batch is None:
                tgt_data_iter = iter(tgt_data_loader)
                tgt_data_batch = next(tgt_data_iter)

            # ------------------------
            # train on src
            # ------------------------

            # train Discriminators
            # if acc(disc) > max: don't train disc
            losses, seg_fusion_feats, det_fusion_feats = self.model(**src_data_batch)  # forward; losses: {'seg_loss'=}
            set_requires_grad([self.seg_disc, self.det_disc], requires_grad=True)
            seg_disc_logits = self.seg_disc(seg_fusion_feats.detach())
            det_disc_logits = self.det_disc(det_fusion_feats.detach())
            seg_disc_loss = self.seg_disc.loss(seg_disc_logits, src=True)
            det_disc_loss = self.det_disc.loss(det_disc_logits, src=True)
            disc_loss = seg_disc_loss + det_disc_loss

            self.seg_opt.zero_grad()
            self.det_opt.zero_grad()
            disc_loss.backward()
            self.seg_opt.step()
            self.det_opt.step()

            # train network
            # if acc(disc) > min: add GAN Loss
            set_requires_grad([self.seg_disc, self.det_disc], requires_grad=False)
            seg_disc_logits = self.seg_disc(seg_fusion_feats)  # (N, 2)
            seg_disc_pred = seg_disc_logits.max(1)[1]  # (N, ); cuda
            seg_label = torch.ones(len(seg_disc_pred), dtype=torch.long).cuda()
            seg_acc = (seg_disc_pred == seg_label).float().mean()
            if seg_acc > 0.6:
                losses['seg_src_loss'] = self.seg_disc.loss(seg_disc_logits, src=False)

            det_disc_logits = self.det_disc(det_fusion_feats)  # (M, 2)
            det_disc_pred = det_disc_logits.max(1)[1]  # (M, ); cuda
            det_label = torch.ones(len(det_disc_pred), dtype=torch.long).cuda()
            det_acc = (det_disc_pred == det_label).float().mean()
            if det_acc > 0.6:
                losses['det_src_loss'] = self.det_disc.loss(det_disc_logits, src=False)

            loss, log_vars = parse_losses(losses)
            num_samples = len(src_data_batch['img_metas'])
            log_vars['seg_src_acc'] = seg_acc.item()
            log_vars['det_src_acc'] = det_acc.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # ------------------------
            # train on tgt
            # ------------------------

            # train Discriminators
            seg_fusion_feats, det_fusion_feats = self.model.forward_fusion(**tgt_data_batch)
            set_requires_grad([self.seg_disc, self.det_disc], requires_grad=True)
            seg_disc_logits = self.seg_disc(seg_fusion_feats.detach())
            det_disc_logits = self.det_disc(det_fusion_feats.detach())
            seg_disc_loss = self.seg_disc.loss(seg_disc_logits, src=False)
            det_disc_loss = self.det_disc.loss(det_disc_logits, src=False)
            disc_loss = seg_disc_loss + det_disc_loss

            self.seg_opt.zero_grad()
            self.det_opt.zero_grad()
            disc_loss.backward()
            self.seg_opt.step()
            self.det_opt.step()

            # train network on target domain without task loss
            self.optimizer.zero_grad()

            set_requires_grad([self.seg_disc, self.det_disc], requires_grad=False)
            seg_disc_logits = self.seg_disc(seg_fusion_feats)  # (N, 2)
            seg_disc_pred = seg_disc_logits.max(1)[1]  # (N, ); cuda
            seg_label = torch.zeros(len(seg_disc_pred), dtype=torch.long).cuda()
            seg_acc = (seg_disc_pred == seg_label).float().mean()
            if seg_acc > 0.6:
                seg_tgt_loss = self.seg_disc.loss(seg_disc_logits, src=True)
                seg_tgt_loss.backward()

            det_disc_logits = self.det_disc(det_fusion_feats)  # (M, 2)
            det_disc_pred = det_disc_logits.max(1)[1]  # (M, ); cuda
            det_label = torch.zeros(len(det_disc_pred), dtype=torch.long).cuda()
            det_acc = (det_disc_pred == det_label).float().mean()
            if det_acc > 0.6:
                det_tgt_loss = self.det_disc.loss(det_disc_logits, src=True)
                det_tgt_loss.backward()  # accumulate grad
            log_vars['seg_tgt_acc'] = seg_acc.item()
            log_vars['det_tgt_acc'] = det_acc.item()

            self.optimizer.step()

            # after_train_iter callback
            self.log_buffer.update(log_vars, num_samples)
            self.call_hook('after_train_iter')  # optimizer hook && logger hook
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1
Example #5
0
    def train(self, src_data_loader, tgt_data_loader):
        self.model.train()
        self.img_disc.train()
        self.lidar_disc.train()
        self.mode = 'train'

        self.data_loader = src_data_loader
        tgt_data_iter = iter(tgt_data_loader)
        assert len(tgt_data_loader) <= len(src_data_loader)
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        for i, src_data_batch in enumerate(self.data_loader):
            # before train iter
            self._inner_iter = i
            self.call_hook('before_train_iter')

            # fetch target data batch
            tgt_data_batch = next(tgt_data_iter, None)
            if tgt_data_batch is None:
                tgt_data_iter = iter(tgt_data_loader)
                tgt_data_batch = next(tgt_data_iter)

            # ------------------------
            # train Discriminators
            # ------------------------

            set_requires_grad([self.img_disc, self.lidar_disc],
                              requires_grad=True)
            # img_feats: (N, 64, 225, 400); lidar_feats: (N, 64); fusion_feats: (N, 128)
            # src
            src_img_feats, src_lidar_feats, src_fusion_feats = self.model.extract_feat(
                **src_data_batch)
            src_feats0 = src_img_feats
            src_feats1 = self.get_feats(src_lidar_feats, src_fusion_feats)

            src_logits0 = self.img_disc(src_feats0.detach())
            src_Dloss0 = self.img_disc.loss(src_logits0, src=True)
            log_src_Dloss0 = src_Dloss0.item()

            src_logits1 = self.lidar_disc(src_feats1.detach())
            src_Dloss1 = self.lidar_disc.loss(src_logits1, src=True)
            log_src_Dloss1 = src_Dloss1.item()

            # tgt segmentation
            tgt_img_feats, tgt_lidar_feats, tgt_fusion_feats = self.model.extract_feat(
                **tgt_data_batch)
            tgt_feats0 = tgt_img_feats
            tgt_feats1 = self.get_feats(tgt_lidar_feats, tgt_fusion_feats)

            tgt_logits0 = self.img_disc(tgt_feats0.detach())
            tgt_Dloss0 = self.img_disc.loss(tgt_logits0, src=False)
            log_tgt_Dloss0 = tgt_Dloss0.item()

            tgt_logits1 = self.lidar_disc(tgt_feats1.detach())
            tgt_Dloss1 = self.lidar_disc.loss(tgt_logits1, src=False)
            log_tgt_Dloss1 = tgt_Dloss1.item()

            # backward
            img_Dloss = (src_Dloss0 + tgt_Dloss0) * 0.5
            lidar_Dloss = (src_Dloss1 + tgt_Dloss1) * 0.5

            self.img_opt.zero_grad()
            img_Dloss.backward()
            self.img_opt.step()

            self.lidar_opt.zero_grad()
            lidar_Dloss.backward()
            self.lidar_opt.step()

            # ------------------------
            # network forward on source: task loss + self.lambda_GANLoss * GANLoss
            # ------------------------
            losses, src_img_feats, src_lidar_feats, src_fusion_feats = self.model(
                **src_data_batch)
            src_feats0 = src_img_feats
            src_feats1 = self.get_feats(src_lidar_feats, src_fusion_feats)

            set_requires_grad([self.img_disc, self.lidar_disc],
                              requires_grad=False)
            src_logits0 = self.img_disc(src_feats0)
            src_pred0 = src_logits0.max(1)[1]
            src_label0 = torch.ones_like(src_pred0, dtype=torch.long).cuda()
            src_acc0 = (src_pred0 == src_label0).float().mean()
            if src_acc0 > self.src_acc_threshold:
                losses[
                    'src_img_GANloss'] = self.lambda_img * self.img_disc.loss(
                        src_logits0, src=False)

            src_logits1 = self.lidar_disc(src_feats1)
            src_pred1 = src_logits1.max(1)[1]
            src_label1 = torch.ones_like(src_pred1, dtype=torch.long).cuda()
            src_acc1 = (src_pred1 == src_label1).float().mean()
            if src_acc1 > self.src_acc_threshold:
                losses[
                    'src_lidar_GANloss'] = self.lambda_lidar * self.lidar_disc.loss(
                        src_logits1, src=False)
            # ------------------------
            # network forward on target: only GANLoss
            # ------------------------
            tgt_img_feats, tgt_lidar_feats, tgt_fusion_feats = self.model.extract_feat(
                **tgt_data_batch)
            tgt_feats0 = tgt_img_feats
            tgt_feats1 = self.get_feats(tgt_img_feats, tgt_fusion_feats)

            tgt_logits0 = self.img_disc(tgt_feats0)
            tgt_pred0 = tgt_logits0.max(1)[1]
            tgt_label0 = torch.zeros_like(tgt_pred0, dtype=torch.long).cuda()
            tgt_acc0 = (tgt_pred0 == tgt_label0).float().mean()
            if tgt_acc0 > self.tgt_acc_threshold:
                losses[
                    'tgt_img_GANloss'] = self.lambda_img * self.img_disc.loss(
                        tgt_logits0, src=True)

            tgt_logits1 = self.lidar_disc(tgt_feats1)
            tgt_pred1 = tgt_logits1.max(1)[1]
            tgt_label1 = torch.zeros_like(tgt_pred1, dtype=torch.long).cuda()
            tgt_acc1 = (tgt_pred1 == tgt_label1).float().mean()
            if tgt_acc1 > self.tgt_acc_threshold:
                losses[
                    'tgt_lidar_GANloss'] = self.lambda_lidar * self.lidar_disc.loss(
                        tgt_logits1, src=True)

            loss, log_vars = parse_losses(losses)
            num_samples = len(src_data_batch['img_metas'])
            log_vars['src_img_Dloss'] = log_src_Dloss0
            log_vars['src_lidar_Dloss'] = log_src_Dloss1
            log_vars['tgt_img_Dloss'] = log_tgt_Dloss0
            log_vars['tgt_lidar_Dloss'] = log_tgt_Dloss1
            log_vars['src_img_acc'] = src_acc0.item()
            log_vars['src_lidar_acc'] = src_acc1.item()
            log_vars['tgt_img_acc'] = tgt_acc0.item()
            log_vars['tgt_lidar_acc'] = tgt_acc1.item()

            # ------------------------
            # network backward: src_task_loss + self.lambda_GANLoss * tgt_GANloss
            # ------------------------
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # after_train_iter callback
            self.log_buffer.update(log_vars, num_samples)
            self.call_hook('after_train_iter')  # optimizer hook && logger hook
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1