Esempio n. 1
0
class operater(object):
    def __init__(self, args, student_model, teacher_model, src_loader,
                 trg_loader, val_loader, optimizer, teacher_optimizer):

        self.args = args
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.src_loader = src_loader
        self.trg_loader = trg_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.teacher_optimizer = teacher_optimizer
        # Define Evaluator
        self.evaluator = Evaluator(args.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                          args.epochs, len(trn_loader))
        #self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[3, 6, 9, 12], gamma=0.5)
        #ft
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                              milestones=[20],
                                                              gamma=0.5)
        self.best_pred = 0
        self.init_weight = 0.98
        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.evaluator = Evaluator(self.args.nclass)

    def training(self, epoch, args):
        train_loss = 0.0
        self.student_model.train()
        self.teacher_model.train()
        num_src = len(self.src_loader)
        num_trg = len(self.trg_loader)
        num_itr = np.maximum(num_src, num_trg)
        tbar = tqdm(range(1, num_itr + 1))
        #w1 = 0.2 + 0.5 * (self.init_weight - 0.2) * (1 + np.cos(epoch * np.pi / args.epochs))
        print('Learning rate:', self.optimizer.param_groups[0]['lr'])
        iter_src = iter(self.src_loader)
        iter_trg = iter(self.trg_loader)
        for i in tbar:

            src_x1, src_x2, src_y, src_idx = iter_src.next()
            trg_x1, trg_x2, trg_y, trg_idx = iter_trg.next()

            if i % num_src == 0:
                iter_src = iter(self.src_loader)

            if self.args.cuda:
                src_x1, src_x2 = src_x1.cuda(), src_x2.cuda()
                trg_x1, trg_x2 = trg_x1.cuda(), trg_x2.cuda()

            self.optimizer.zero_grad()

            # train with source

            _, _, src_output = self.student_model(src_x1, src_x2)

            src_output = F.softmax(src_output, dim=1)

            # CE loss of supervised data

            #loss_ce=CELossLayer(src_output,src_y)

            # #print('ce loss', loss_ce)
            # Focal loss of supervised data
            loss_focal = FocalLossLayer(src_output, src_y)
            #print('focal loss', loss_focal)

            loss_val_lovasz = LovaszLossLayer(src_output, src_y)
            #print('lovasz loss', loss_lovasz)

            if epoch > 3:
                loss_su = loss_val_lovasz + loss_focal
            else:
                loss_su = loss_val_lovasz + loss_focal

            # train with target

            trg_x1_s = trg_x1 + torch.randn(
                trg_x1.size()).cuda() * self.args.noise
            trg_x1_t = trg_x1 + torch.randn(
                trg_x1.size()).cuda() * self.args.noise

            trg_x2_s = trg_x2 + torch.randn(
                trg_x2.size()).cuda() * self.args.noise
            trg_x2_t = trg_x2 + torch.randn(
                trg_x2.size()).cuda() * self.args.noise

            _, _, trg_predict_s = self.student_model(trg_x1_s, trg_x2_s)

            _, spatial_mask_prob, trg_predict_t = self.teacher_model(
                trg_x1_t, trg_x2_t)

            trg_predict_s = F.softmax(trg_predict_s, dim=1)
            trg_predict_t = F.softmax(trg_predict_t, dim=1)

            loss_tes_lovasz = LovaszLossLayer(trg_predict_s, trg_y)

            # spatial mask

            #channel_mask = channel_mask_prob > args.attention_threshold
            spatial_mask = spatial_mask_prob > args.attention_threshold

            spatial_mask = spatial_mask.float()

            #spatial_mask = spatial_mask.permute(0,2,3,1)# N,H,W,C

            #channel_mask = channel_mask.float()
            #spatial_mask = spatial_mask.view(-1)

            num_pixel = spatial_mask.shape[0] * spatial_mask.shape[
                -2] * spatial_mask.shape[-1]

            mask_num_rate = torch.sum(spatial_mask).float() / num_pixel

            # trg_output_s = trg_output_s.permute(0, 2, 3, 1)#N,H,W,C
            # trg_output_t = trg_output_t.permute(0, 2, 3, 1)

            #trg_output_s = trg_output_s * channel_mask
            trg_predict_s = trg_predict_s * spatial_mask

            #trg_output_t = trg_output_t * channel_mask
            trg_predict_t = trg_predict_t * spatial_mask

            # trg_output_s = trg_output_s.contiguous().view(-1, self.args.nclass)
            # trg_output_s = trg_output_s[spatial_mask]
            #
            # trg_output_t = trg_output_t.contiguous().view(-1, self.args.nclass)
            # trg_output_t = trg_output_t[spatial_mask]

            # consistency loss

            loss_con = ConsistencyLossLayer(trg_predict_s, trg_predict_t)

            if mask_num_rate == 0.:

                loss_con = torch.tensor(0.).float().cuda()

            loss = loss_su + self.args.con_weight * loss_con + self.args.teslab_weight * loss_tes_lovasz

            #self.writer.add_scalar('train/ce_loss_iter', loss_ce.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/focal_loss_iter', loss_focal.item(),
                                   i + num_itr * epoch)
            self.writer.add_scalar('train/supervised_loss_iter',
                                   loss_su.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/consistency_loss_iter',
                                   loss_con.item(), i + num_itr * epoch)
            self.writer.add_scalar('train/teslab_loss_iter',
                                   loss_tes_lovasz.item(), i + num_itr * epoch)
            #loss = w1 * loss_ce + (0.5 - 0.5 * w1) * loss_focal + (0.5 - 0.5 * w1) * loss_lovasz

            loss.backward()
            self.optimizer.step()
            self.teacher_optimizer.step()

            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(),
                                   i + num_itr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % 10 == 0:
                global_step = i + num_itr * epoch
                if self.args.oly_s1 and not self.args.oly_s2:
                    self.summary.visualize_image(
                        self.writer, self.args.dataset, src_x1[:, [0], :, :],
                        trg_x1[:, [0], :, :], src_y, src_output, trg_predict_s,
                        trg_predict_t, trg_y, global_step)
                elif not self.args.oly_s1:
                    if self.args.rgb:
                        self.summary.visualize_image(self.writer,
                                                     self.args.dataset, src_x2,
                                                     trg_x2, src_y, src_output,
                                                     trg_predict_s,
                                                     trg_predict_t, trg_y,
                                                     global_step)
                    else:
                        self.summary.visualize_image(
                            self.writer, self.args.dataset,
                            src_x2[:, [2, 1, 0], :, :],
                            trg_x2[:, [2, 1, 0], :, :], src_y, src_output,
                            trg_predict_s, trg_predict_t, trg_y, global_step)
                else:
                    raise NotImplementedError

        self.scheduler.step()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + src_y.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'student_state_dict':
                    self.student_model.module.state_dict(),
                    'teacher_state_dict':
                    self.teacher_model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                }, is_best)

    def validation(self, epoch, args):
        self.teacher_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        for i, (x1, x2, y, index) in enumerate(tbar):

            if self.args.cuda:
                x1, x2 = x1.cuda(), x2.cuda()
            with torch.no_grad():
                _, _, output = self.teacher_model(x1, x2)

            output = F.softmax(output, dim=1)
            pred = output.data.cpu().numpy()
            #pred[:,[2,7],:,:]=0
            target = y[:, 0, :, :].cpu().numpy()  # batch_size * 256 * 256
            pred = np.argmax(pred, axis=1)  # batch_size * 256 * 256
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        OA = self.evaluator.Pixel_Accuracy()
        AA = self.evaluator.val_Pixel_Accuracy_Class()
        self.writer.add_scalar('val/OA', OA, epoch)
        self.writer.add_scalar('val/AA', AA, epoch)

        print('AVERAGE ACCURACY:', AA)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + y.data.shape[0]))

        new_pred = AA
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'student_state_dict':
                    self.student_model.module.state_dict(),
                    'teacher_state_dict':
                    self.teacher_model.module.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'best_pred': self.best_pred,
                }, is_best)
Esempio n. 2
0
class operater(object):
    def __init__(self, args, model,trn_loader,val_loader,chk_loader,optimizer):

        self.args = args
        self.model=model
        self.train_loader = trn_loader
        self.val_loader = val_loader
        self.chk_loader = chk_loader
        self.optimizer = optimizer
        # Define Evaluator
        self.evaluator = Evaluator(args.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
        #                          args.epochs, len(trn_loader))
        self.scheduler=torch.optim.lr_scheduler.MultiStepLR(self.optimizer,milestones=[3,6,9], gamma=0.5)
        self.wait_epoches=10
        self.best_pred = 0
        self.init_weight=0.98
        # Define Saver
        self.saver = Saver(self.args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.evaluator = Evaluator(self.args.nclass)

    def training(self,epoch,args):
        train_loss = 0.0
        self.model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        #w1 = 0.2 + 0.5 * (self.init_weight - 0.2) * (1 + np.cos(epoch * np.pi / args.epochs))
        print('Learning rate:', self.optimizer.param_groups[0]['lr'])
        for i, (x1,x2,y,index) in enumerate(tbar):
            x1=Variable(x1)
            x2=Variable(x2)
            #y_cls=Seg2cls(args,y)#图像级标签,N,1,1,C
            if self.args.cuda:
                x1, x2 = x1.cuda(),x2.cuda()
            #self.scheduler(self.optimizer, i, epoch, self.best_pred)
            self.optimizer.zero_grad()
            output = self.model(x1,x2)

            output = F.softmax(output, dim=1)

            # loss_ce=CELossLayer(self.args,output,y)
            # #print('ce loss', loss_ce)
            #
            # loss_focal = FocalLossLayer(self.args,output, y)
            #print('focal loss', loss_focal)

            loss_lovasz = LovaszLossLayer(output,y)
            #print('lovasz loss', loss_lovasz)

            # self.writer.add_scalar('train/ce_loss_iter', loss_focal.item(), i + num_img_tr * epoch)
            # self.writer.add_scalar('train/focal_loss_iter', loss_focal.item(), i + num_img_tr * epoch)
            self.writer.add_scalar('train/lovasz_loss_iter', loss_lovasz.item(), i + num_img_tr * epoch)

            #loss = w1 * loss_ce + (0.5 - 0.5 * w1) * loss_focal + (0.5 - 0.5 * w1) * loss_lovasz

            loss = loss_lovasz

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1)))
            self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch)

            #Show 10 * 3 inference results each epoch
            if i % 10 == 0:
                global_step = i + num_img_tr * epoch
                if self.args.oly_s1 and not self.args.oly_s2:
                    self.summary.visualize_image(self.writer, self.args.dataset, x1[:,[0],:,:], y, output, global_step)
                elif not self.args.oly_s1:
                    if self.args.rgb:
                        self.summary.visualize_image(self.writer, self.args.dataset, x2, y, output, global_step)
                    else:
                        self.summary.visualize_image(self.writer, self.args.dataset, x2[:,[2,1,0],:,:], y,output,global_step)
                else:
                    raise NotImplementedError
        self.scheduler.step()
        self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + y.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            }, is_best)

    def validation(self,epoch, args):
        self.model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        for i, (x1,x2,y,index) in enumerate(tbar):

            if self.args.cuda:
                x1, x2 = x1.cuda(),x2.cuda()
            with torch.no_grad():
                output = self.model(x1, x2)

            pred = output.data.cpu().numpy()
            pred[:,[2,7],:,:]=0
            target = y[:,0,:,:].cpu().numpy()  # batch_size * 256 * 256
            pred = np.argmax(pred, axis=1)  # batch_size * 256 * 256
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        OA = self.evaluator.Pixel_Accuracy()
        AA = self.evaluator.val_Pixel_Accuracy_Class()
        self.writer.add_scalar('val/OA', OA, epoch)
        self.writer.add_scalar('val/AA', AA, epoch)

        print('AVERAGE ACCURACY:', AA)

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + y.data.shape[0]))

        new_pred = AA
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.module.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)