Ejemplo n.º 1
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader,
                                         self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            source_data, source_labels = datas
            target_data, __ = datat

            self.optim.zero_grad()
            self.optimD.zero_grad()

            src_dis_label = 1
            target_dis_label = 0

            if self.cuda:
                source_data, source_labels = source_data.cuda(
                ), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(
                source_labels)
            target_data = Variable(target_data)

            ############train G, item1
            #set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score,
                                           source_labels,
                                           class_num=class_num,
                                           size_average=self.size_average)

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)

            seg_loss = l_seg + 10 * distill_loss

            #seg_loss.backward(retain_graph=True)

            #######train G, item 2
            """
            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_discriminate_result = self.netD(score)
            target_discriminate_result = self.netD(seg_target_score)

            src_dis_loss = bce_loss(src_discriminate_result,
                                    Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(
                                        src_dis_label)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                       Variable(
                                           torch.FloatTensor(target_discriminate_result.data.size()).fill_(
                                               target_dis_label)).cuda(),
                                       )

            dis_loss = src_dis_loss + target_dis_loss
            dis_loss.backward(retain_graph=True)
            """

            #######################train D
            #set_requires_grad(seg=False, dis=True)
            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_discriminate_result = self.netD(score.detach())
            target_discriminate_result = self.netD(seg_target_score.detach())

            src_dis_loss = bce_loss(
                src_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        src_discriminate_result.data.size()).fill_(
                            src_dis_label)).cuda())

            target_dis_loss = bce_loss(
                target_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        target_discriminate_result.data.size()).fill_(
                            target_dis_label)).cuda(),
            )

            dis_loss = src_dis_loss + target_dis_loss  # this loss has been inversed!!
            total_loss = dis_loss + seg_loss

            total_loss.backward()

            self.optim.step()
            self.optimD.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "L_SEG={}, Distill_LOSS={}, Discriminater loss={}".format(
                        l_seg.data[0], distill_loss.data[0], dis_loss.data[0]))
Ejemplo n.º 2
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        import copy

        self.G_source_loader_iter = [
            enumerate(self.train_loader) for _ in range(G_STEP)
        ]
        self.G_target_loader_iter = [
            enumerate(self.target_loader) for _ in range(G_STEP)
        ]

        self.D_source_loader_iter = [
            enumerate(self.train_loader) for _ in range(D_STEP)
        ]
        self.D_target_loader_iter = [
            enumerate(self.target_loader) for _ in range(D_STEP)
        ]

        for batch_idx in tqdm.tqdm(range(self.iters_per_epoch),
                                   total=self.iters_per_epoch,
                                   desc='Train epoch = {}/{}'.format(
                                       self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            src_dis_label = 1
            target_dis_label = 0
            mse_loss = torch.nn.MSELoss()

            def get_data(source_iter, target_iter):
                _, source_batch = source_iter.next()
                source_data, source_labels = source_batch

                _, target_batch = target_iter.next()
                target_data, _ = target_batch

                if self.cuda:
                    source_data, source_labels = source_data.cuda(
                    ), source_labels.cuda()
                    target_data = target_data.cuda()

                source_data, source_labels = Variable(source_data), Variable(
                    source_labels)
                target_data = Variable(target_data)
                return source_data, source_labels, target_data

            ##################################train D
            for _ in range(D_STEP):
                source_data, source_labels, target_data = get_data(
                    self.D_source_loader_iter[_], self.D_target_loader_iter[_])
                self.optimD.zero_grad()
                set_requires_grad(seg=False, dis=True)

                score = self.model(source_data)
                seg_target_score = self.model(target_data)
                src_discriminate_result = self.netD(F.softmax(score))
                target_discriminate_result = self.netD(
                    F.softmax(seg_target_score))

                src_dis_loss = mse_loss(
                    src_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            src_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda())

                target_dis_loss = mse_loss(
                    target_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            target_discriminate_result.data.size()).fill_(
                                target_dis_label)).cuda(),
                )

                src_dis_loss = src_dis_loss * DIS_WEIGHT
                target_dis_loss = target_dis_loss * DIS_WEIGHT
                dis_loss = src_dis_loss + target_dis_loss
                dis_loss.backward()
                self.optimD.step()
                # https://ewanlee.github.io/2017/04/29/WGAN-implemented-by-PyTorch/
                for p in self.netD.parameters():
                    p.data.clamp_(-0.01, 0.01)

            #####################train G, item1
            for _ in range(G_STEP):
                source_data, source_labels, target_data = get_data(
                    self.G_source_loader_iter[_], self.G_target_loader_iter[_])
                self.optim.zero_grad()
                set_requires_grad(seg=True, dis=False)
                # Source domain
                score = self.model(source_data)
                l_seg = CrossEntropyLoss2d_Seg(score,
                                               source_labels,
                                               class_num=class_num,
                                               size_average=self.size_average)
                # target domain
                seg_target_score = self.model(target_data)
                modelfix_target_score = self.model_fix(target_data)

                diff2d = Diff2d()
                distill_loss = diff2d(seg_target_score, modelfix_target_score)

                l_seg = l_seg * L_LOSS_WEIGHT
                distill_loss = distill_loss * DISTILL_WEIGHT
                seg_loss = l_seg + distill_loss
                #######train G, item 2

                src_discriminate_result = self.netD(F.softmax(score))
                target_discriminate_result = self.netD(
                    F.softmax(seg_target_score))

                src_dis_loss = mse_loss(
                    src_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            src_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda())

                target_dis_loss = mse_loss(
                    target_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            target_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda(),
                )

                src_dis_loss = src_dis_loss * DIS_WEIGHT
                target_dis_loss = target_dis_loss * DIS_WEIGHT
                dis_loss = src_dis_loss + target_dis_loss
                total_loss = seg_loss + dis_loss
                total_loss.backward()
                self.optim.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "After weight Loss: seg_Loss={}, distill_LOSS={}, src_dis_loss={}, target_dis_loss={}"
                    .format(l_seg.data[0], distill_loss.data[0],
                            src_dis_loss.data[0], target_dis_loss.data[0]))
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader,
                                         self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            source_data, source_labels = datas
            target_data, __ = datat

            src_dis_label = 1
            target_dis_label = 0
            bce_loss = torch.nn.BCEWithLogitsLoss()

            if self.cuda:
                source_data, source_labels = source_data.cuda(
                ), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(
                source_labels)
            target_data = Variable(target_data)

            #####################train G, item1
            self.optim.zero_grad()
            set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score,
                                           source_labels,
                                           class_num=class_num,
                                           size_average=self.size_average)
            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)

            l_seg = l_seg * L_LOSS_WEIGHT
            distill_loss = distill_loss * DISTILL_WEIGHT
            seg_loss = l_seg + distill_loss
            #######train G, item 2

            src_discriminate_result = self.netD(F.softmax(score))
            target_discriminate_result = self.netD(F.softmax(seg_target_score))

            src_dis_loss = bce_loss(
                src_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        src_discriminate_result.data.size()).fill_(
                            src_dis_label)).cuda())

            target_dis_loss = bce_loss(
                target_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        target_discriminate_result.data.size()).fill_(
                            src_dis_label)).cuda(),
            )

            src_dis_loss = src_dis_loss * DIS_WEIGHT
            target_dis_loss = target_dis_loss * DIS_WEIGHT
            dis_loss = src_dis_loss + target_dis_loss
            total_loss = seg_loss + dis_loss
            total_loss.backward()
            self.optim.step()

            ##################################train D
            for _ in range(DIS_TIMES):
                self.optimD.zero_grad()
                set_requires_grad(seg=False, dis=True)

                score = self.model(source_data)
                seg_target_score = self.model(target_data)
                src_discriminate_result = self.netD(F.softmax(score))
                target_discriminate_result = self.netD(
                    F.softmax(seg_target_score))

                src_dis_loss = bce_loss(
                    src_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            src_discriminate_result.data.size()).fill_(
                                src_dis_label)).cuda())

                target_dis_loss = bce_loss(
                    target_discriminate_result,
                    Variable(
                        torch.FloatTensor(
                            target_discriminate_result.data.size()).fill_(
                                target_dis_label)).cuda(),
                )

                src_dis_loss = src_dis_loss * DIS_WEIGHT
                target_dis_loss = target_dis_loss * DIS_WEIGHT
                dis_loss = src_dis_loss + target_dis_loss
                dis_loss.backward()
                self.optimD.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0 or (
                    self.epoch == 0
                    and self.iteration < self.loss_print_interval):
                logger.info(
                    "After weight Loss: seg_Loss={}, distill_LOSS={}, src_dis_loss={}, target_dis_loss={}"
                    .format(l_seg.data[0], distill_loss.data[0],
                            src_dis_loss.data[0], target_dis_loss.data[0]))
Ejemplo n.º 4
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        self.model.train()
        self.netD.train()

        for batch_idx, (datas, datat) in tqdm.tqdm(
                enumerate(itertools.izip(self.train_loader,
                                         self.target_loader)),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch),
                leave=False):

            source_data, source_labels = datas
            target_data, __ = datat

            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            if self.cuda:
                source_data, source_labels = source_data.cuda(
                ), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(
                source_labels)
            target_data = Variable(target_data)

            # TODO,split to 3x3
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score,
                                           source_labels,
                                           size_average=self.size_average)

            src_discriminate_result = self.netD(score)

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            target_discriminate_result = self.netD(seg_target_score)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)

            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_dis_loss = bce_loss(
                src_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        src_discriminate_result.data.size()).fill_(1)).cuda())

            target_dis_loss = bce_loss(
                target_discriminate_result,
                Variable(
                    torch.FloatTensor(
                        target_discriminate_result.data.size()).fill_(
                            0)).cuda(),
            )

            dis_loss = src_dis_loss + target_dis_loss  # this loss has been inversed!!
            total_loss = l_seg + 10 * distill_loss + dis_loss

            self.optim.zero_grad()
            self.optimD.zero_grad()
            total_loss.backward()
            self.optim.step()
            self.optimD.step()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(total_loss.data[0])):
                raise ValueError('total_loss is nan while training')

            if self.iteration % self.loss_print_interval == 0:
                logger.info(
                    "L_SEG={}, Distill_LOSS={}, Discriminater loss={}, TOTAL_LOSS={}"
                    .format(l_seg.data[0], distill_loss.data[0],
                            dis_loss.data[0], total_loss.data[0]))

            # TODO, spatial loss

            if self.iteration >= self.max_iter:
                break

            # Validating periodically
            if self.iteration % self.interval_validate == 0 and self.iteration > 0:
                self.model.eval()
                self.validate()
                self.model.train()  # return to training mode
Ejemplo n.º 5
0
    def train_epoch(self):
        """
        Function to train the model for one epoch
        """
        def set_requires_grad(seg, dis):
            for param in self.model.parameters():
                param.requires_grad = seg

            for param in self.netD.parameters():
                param.requires_grad = dis

        self.train_loader_iter = enumerate(self.train_loader)
        self.target_loader_iter = enumerate(self.target_loader)

        for batch_idx in tqdm.tqdm(
                range(self.iters_per_epoch),
                total=self.iters_per_epoch,
                desc='Train epoch = {}/{}'.format(self.epoch, self.max_epoch)):
            self.iteration = batch_idx + self.epoch * self.iters_per_epoch

            self.optim.zero_grad()
            self.optimD.zero_grad()

            sum_l_seg = 0
            sum_distill_loss =0
            sum_src_dis_loss = 0
            sum_target_dis_loss = 0

            #for sub_iter in range(self.iter_size):

            _, source_batch = self.train_loader_iter.next()
            source_data, source_labels = source_batch

            _, target_batch = self.target_loader_iter.next()
            target_data, _ = target_batch


            if self.cuda:
                source_data, source_labels = source_data.cuda(), source_labels.cuda()
                target_data = target_data.cuda()

            source_data, source_labels = Variable(source_data), Variable(source_labels)
            target_data = Variable(target_data)

            #######################train G
            set_requires_grad(seg=True, dis=False)
            # Source domain
            score = self.model(source_data)
            l_seg = CrossEntropyLoss2d_Seg(score, source_labels, class_num=class_num, size_average=self.size_average)

            sum_l_seg += l_seg.data[0]

            # target domain
            seg_target_score = self.model(target_data)
            modelfix_target_score = self.model_fix(target_data)

            diff2d = Diff2d()
            distill_loss = diff2d(seg_target_score, modelfix_target_score)
            sum_distill_loss += distill_loss.data[0]

            seg_loss = l_seg + 10 * distill_loss

            seg_loss.backward()





            #########################train D
            set_requires_grad(seg=False, dis=True)
            src_discriminate_result = self.netD(score.detach())
            target_discriminate_result = self.netD(seg_target_score.detach())

            bce_loss = torch.nn.BCEWithLogitsLoss()

            src_dis_loss = bce_loss(src_discriminate_result,
                                           Variable(torch.FloatTensor(src_discriminate_result.data.size()).fill_(1)).cuda())

            target_dis_loss = bce_loss(target_discriminate_result,
                                              Variable(torch.FloatTensor(target_discriminate_result.data.size()).fill_(0)).cuda())

            sum_src_dis_loss += src_dis_loss.data[0]
            sum_target_dis_loss += target_dis_loss.data[0]

            dis_loss = src_dis_loss + target_dis_loss# this loss has been inversed!!

            dis_loss.backward()

            if np.isnan(float(dis_loss.data[0])):
                raise ValueError('dis_loss is nan while training')
            if np.isnan(float(seg_loss.data[0])):
                raise ValueError('total_loss is nan while training')



            self.optim.step()
            self.optimD.step()


            if self.iteration % self.loss_print_interval == 0:
                logger.info("L_SEG={}, Distill_LOSS={}, src dis loss={}, target dis loss={}".format(sum_l_seg, sum_distill_loss,
                                                                               sum_src_dis_loss,sum_target_dis_loss))