示例#1
0
    def validate(self):
        """
        Function to validate a training model on the val split.
        """
        logger.info("start validation....")
        val_loss = 0
        label_trues, label_preds = [], []

        # Evaluation
        for batch_idx, (data, target) in tqdm.tqdm(
                enumerate(self.val_loader),
                total=len(self.val_loader),
                desc='Validation iteration = {},epoch={}'.format(
                    self.iteration, self.epoch),
                leave=False):

            if self.cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data, volatile=True), Variable(target)

            score = self.model(data)

            loss = CrossEntropyLoss2d_Seg(score,
                                          target,
                                          size_average=self.size_average)

            if np.isnan(float(loss.data[0])):
                raise ValueError('loss is nan while validating')
            val_loss += float(loss.data[0]) / len(data)

            lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
            lbl_true = target.data.cpu().numpy()

            label_trues.append(lbl_true)
            label_preds.append(lbl_pred)

        # Computing the metrics
        acc, acc_cls, mean_iu, _ = torchfcn.utils.label_accuracy_score(
            label_trues, label_preds, self.n_class)
        val_loss /= len(self.val_loader)

        logger.info("iteration={},epoch={},validation mIoU = {}".format(
            self.iteration, self.epoch, mean_iu))

        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save(
            {
                'epoch': self.epoch,
                'iteration': self.iteration,
                'arch': self.model.__class__.__name__,
                'optim_state_dict': self.optim.state_dict(),
                'model_state_dict': self.model.state_dict(),
                'best_mean_iu': self.best_mean_iu,
            }, osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(
                osp.join(logger.get_logger_dir(), 'checkpoint.pth.tar'),
                osp.join(logger.get_logger_dir(), 'model_best.pth.tar'))
示例#2
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader,
                                         self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            source_data, source_labels = datas
            target_data, __ = datat

            self.optim.zero_grad()
            self.optimD.zero_grad()

            src_dis_label = 1
            target_dis_label = 0

            if self.cuda:
                source_data, source_labels = source_data.cuda(
                ), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(
                source_labels)
            target_data = Variable(target_data)

            ############train G, item1
            #set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score,
                                           source_labels,
                                           class_num=class_num,
                                           size_average=self.size_average)

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)

            seg_loss = l_seg + 10 * distill_loss

            #seg_loss.backward(retain_graph=True)

            #######train G, item 2
            """
            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_discriminate_result = self.netD(score)
            target_discriminate_result = self.netD(seg_target_score)

            src_dis_loss = bce_loss(src_discriminate_result,
                                    Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(
                                        src_dis_label)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                       Variable(
                                           torch.FloatTensor(target_discriminate_result.data.size()).fill_(
                                               target_dis_label)).cuda(),
                                       )

            dis_loss = src_dis_loss + target_dis_loss
            dis_loss.backward(retain_graph=True)
            """

            #######################train D
            #set_requires_grad(seg=False, dis=True)
            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_discriminate_result = self.netD(score.detach())
            target_discriminate_result = self.netD(seg_target_score.detach())

            src_dis_loss = bce_loss(
                src_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        src_discriminate_result.data.size()).fill_(
                            src_dis_label)).cuda())

            target_dis_loss = bce_loss(
                target_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        target_discriminate_result.data.size()).fill_(
                            target_dis_label)).cuda(),
            )

            dis_loss = src_dis_loss + target_dis_loss  # this loss has been inversed!!
            total_loss = dis_loss + seg_loss

            total_loss.backward()

            self.optim.step()
            self.optimD.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "L_SEG={}, Distill_LOSS={}, Discriminater loss={}".format(
                        l_seg.data[0], distill_loss.data[0], dis_loss.data[0]))
示例#3
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        import copy

        self.G_source_loader_iter = [
            enumerate(self.train_loader) for _ in range(G_STEP)
        ]
        self.G_target_loader_iter = [
            enumerate(self.target_loader) for _ in range(G_STEP)
        ]

        self.D_source_loader_iter = [
            enumerate(self.train_loader) for _ in range(D_STEP)
        ]
        self.D_target_loader_iter = [
            enumerate(self.target_loader) for _ in range(D_STEP)
        ]

        for batch_idx in tqdm.tqdm(range(self.iters_per_epoch),
                                   total=self.iters_per_epoch,
                                   desc='Train epoch = {}/{}'.format(
                                       self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            src_dis_label = 1
            target_dis_label = 0
            mse_loss = torch.nn.MSELoss()

            def get_data(source_iter, target_iter):
                _, source_batch = source_iter.next()
                source_data, source_labels = source_batch

                _, target_batch = target_iter.next()
                target_data, _ = target_batch

                if self.cuda:
                    source_data, source_labels = source_data.cuda(
                    ), source_labels.cuda()
                    target_data = target_data.cuda()

                source_data, source_labels = Variable(source_data), Variable(
                    source_labels)
                target_data = Variable(target_data)
                return source_data, source_labels, target_data

            ##################################train D
            for _ in range(D_STEP):
                source_data, source_labels, target_data = get_data(
                    self.D_source_loader_iter[_], self.D_target_loader_iter[_])
                self.optimD.zero_grad()
                set_requires_grad(seg=False, dis=True)

                score = self.model(source_data)
                seg_target_score = self.model(target_data)
                src_discriminate_result = self.netD(F.softmax(score))
                target_discriminate_result = self.netD(
                    F.softmax(seg_target_score))

                src_dis_loss = mse_loss(
                    src_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            src_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda())

                target_dis_loss = mse_loss(
                    target_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            target_discriminate_result.data.size()).fill_(
                                target_dis_label)).cuda(),
                )

                src_dis_loss = src_dis_loss * DIS_WEIGHT
                target_dis_loss = target_dis_loss * DIS_WEIGHT
                dis_loss = src_dis_loss + target_dis_loss
                dis_loss.backward()
                self.optimD.step()
                # https://ewanlee.github.io/2017/04/29/WGAN-implemented-by-PyTorch/
                for p in self.netD.parameters():
                    p.data.clamp_(-0.01, 0.01)

            #####################train G, item1
            for _ in range(G_STEP):
                source_data, source_labels, target_data = get_data(
                    self.G_source_loader_iter[_], self.G_target_loader_iter[_])
                self.optim.zero_grad()
                set_requires_grad(seg=True, dis=False)
                # Source domain
                score = self.model(source_data)
                l_seg = CrossEntropyLoss2d_Seg(score,
                                               source_labels,
                                               class_num=class_num,
                                               size_average=self.size_average)
                # target domain
                seg_target_score = self.model(target_data)
                modelfix_target_score = self.model_fix(target_data)

                diff2d = Diff2d()
                distill_loss = diff2d(seg_target_score, modelfix_target_score)

                l_seg = l_seg * L_LOSS_WEIGHT
                distill_loss = distill_loss * DISTILL_WEIGHT
                seg_loss = l_seg + distill_loss
                #######train G, item 2

                src_discriminate_result = self.netD(F.softmax(score))
                target_discriminate_result = self.netD(
                    F.softmax(seg_target_score))

                src_dis_loss = mse_loss(
                    src_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            src_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda())

                target_dis_loss = mse_loss(
                    target_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            target_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda(),
                )

                src_dis_loss = src_dis_loss * DIS_WEIGHT
                target_dis_loss = target_dis_loss * DIS_WEIGHT
                dis_loss = src_dis_loss + target_dis_loss
                total_loss = seg_loss + dis_loss
                total_loss.backward()
                self.optim.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "After weight Loss: seg_Loss={}, distill_LOSS={}, src_dis_loss={}, target_dis_loss={}"
                    .format(l_seg.data[0], distill_loss.data[0],
                            src_dis_loss.data[0], target_dis_loss.data[0]))
示例#4
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """

        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader, self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            source_data, source_labels = datas
            target_data, __ = datat

            src_dis_label = 1
            target_dis_label = 0
            bce_loss = torch.nn.BCEWithLogitsLoss()

            if self.cuda:
                source_data, source_labels = source_data.cuda(), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(source_labels)
            target_data = Variable(target_data)

            #####################train G, item1
            self.optim.zero_grad()
            set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average)
            # target domain
            seg_target_score = self.model(target_data)
            #modelfix_target_score = self.model_fix(target_data)

            #diff2d = Diff2d()
            #distill_loss = diff2d(seg_target_score, modelfix_target_score)

            l_seg = l_seg * L_LOSS_WEIGHT
            #distill_loss = distill_loss * DISTILL_WEIGHT
            seg_loss =  l_seg #+ distill_loss
            #######train G, item 2


            src_discriminate_result = self.netD(F.softmax(score))
            target_discriminate_result = self.netD(F.softmax(seg_target_score))

            src_dis_loss = bce_loss(src_discriminate_result,
                                    Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(
                                        src_dis_label)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                       Variable(
                                           torch.FloatTensor(target_discriminate_result.data.size()).fill_(
                                               target_dis_label)).cuda(),
                                       )

            src_dis_loss = - src_dis_loss*DIS_WEIGHT
            target_dis_loss = - target_dis_loss*DIS_WEIGHT
            dis_loss = src_dis_loss + target_dis_loss
            total_loss = seg_loss + dis_loss
            total_loss.backward()
            self.optim.step()


            ##################################train D
            self.optimD.zero_grad()
            set_requires_grad(seg=False, dis=True)

            score = self.model(source_data)
            seg_target_score = self.model(target_data)
            src_discriminate_result = self.netD(F.softmax(score))
            target_discriminate_result = self.netD(F.softmax(seg_target_score))

            src_dis_loss = bce_loss(src_discriminate_result,
                                    Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(
                                        src_dis_label)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                       Variable(torch.FloatTensor(target_discriminate_result.data.size()).fill_(
                                           target_dis_label)).cuda(),
                                       )

            src_dis_loss = src_dis_loss*DIS_WEIGHT
            target_dis_loss = target_dis_loss*DIS_WEIGHT
            dis_loss = src_dis_loss + target_dis_loss
            dis_loss.backward()
            self.optimD.step()


            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "After weight Loss: seg_Loss={}, distill_LOSS= NOT EXIST, src_dis_loss={}, target_dis_loss={}".format(l_seg.data[0],
                                                                                              src_dis_loss.data[0],target_dis_loss.data[0]))
示例#5
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        self.model.train()
        self.netD.train()

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader,
                                         self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch),
                leave=False):

            source_data, source_labels = datas
            target_data, __ = datat

            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            if self.cuda:
                source_data, source_labels = source_data.cuda(
                ), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(
                source_labels)
            target_data = Variable(target_data)

            # TODO,split to 3x3
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score,
                                           source_labels,
                                           size_average=self.size_average)

            src_discriminate_result = self.netD(score)

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            target_discriminate_result = self.netD(seg_target_score)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)

            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_dis_loss = bce_loss(
                src_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        src_discriminate_result.data.size()).fill_(1)).cuda())

            target_dis_loss = bce_loss(
                target_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        target_discriminate_result.data.size()).fill_(
                            0)).cuda(),
            )

            dis_loss = src_dis_loss + target_dis_loss  # this loss has been inversed!!
            total_loss = l_seg + 10 * distill_loss + dis_loss

            self.optim.zero_grad()
            self.optimD.zero_grad()
            total_loss.backward()
            self.optim.step()
            self.optimD.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(total_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "L_SEG={}, Distill_LOSS={}, Discriminater loss={}, TOTAL_LOSS={}"
                    .format(l_seg.data[0], distill_loss.data[0],
                            dis_loss.data[0], total_loss.data[0]))

            # TODO, spatial loss

            if self.iteration >= self.max_iter:
                break

            # Validating periodically
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                self.model.eval()
                self.validate()
                self.model.train()  # return to training mode
示例#6
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        self.train_loader_iter = enumerate(self.train_loader)
        self.target_loader_iter = enumerate(self.target_loader)

        for batch_idx in tqdm.tqdm(
                range(self.iters_per_epoch),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            self.optim.zero_grad()
            self.optimD.zero_grad()

            sum_l_seg = 0
            sum_distill_loss =0
            sum_src_dis_loss = 0
            sum_target_dis_loss = 0

            #for sub_iter in range(self.iter_size):

            _, source_batch = self.train_loader_iter.next()
            source_data, source_labels = source_batch

            _, target_batch = self.target_loader_iter.next()
            target_data, _ = target_batch


            if self.cuda:
                source_data, source_labels = source_data.cuda(), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(source_labels)
            target_data = Variable(target_data)

            #######################train G
            set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average)

            sum_l_seg += l_seg.data[0]

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)
            sum_distill_loss += distill_loss.data[0]

            seg_loss = l_seg + 10 * distill_loss

            seg_loss.backward()





            #########################train D
            set_requires_grad(seg=False, dis=True)
            src_discriminate_result = self.netD(score.detach())
            target_discriminate_result = self.netD(seg_target_score.detach())

            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_dis_loss = bce_loss(src_discriminate_result,
                                           Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(1)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                              Variable(torch.FloatTensor(target_discriminate_result.data.size()).fill_(0)).cuda())

            sum_src_dis_loss += src_dis_loss.data[0]
            sum_target_dis_loss += target_dis_loss.data[0]

            dis_loss = src_dis_loss + target_dis_loss# this loss has been inversed!!

            dis_loss.backward()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')



            self.optim.step()
            self.optimD.step()


            if self.iteration % self.loss_print_interval == 0:
                logger.info("L_SEG={}, Distill_LOSS={}, src dis loss={}, target dis loss={}".format(sum_l_seg, sum_distill_loss,
                                                                               sum_src_dis_loss,sum_target_dis_loss))