Exemplo n.º 1
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        device = args.device
        student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                             num_classes=args.classes_num)
        load_S_model(args, student, False)
        print_model_parm_nums(student, 'student_model')
        self.parallel_student = self.DataParallelModelProcess(
            student, 2, 'train', device)
        self.student = student

        teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                             num_classes=args.classes_num)
        load_T_model(teacher, args.T_ckpt_path)
        print_model_parm_nums(teacher, 'teacher_model')
        self.parallel_teacher = self.DataParallelModelProcess(
            teacher, 2, 'eval', device)
        self.teacher = teacher

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        self.parallel_D = self.DataParallelModelProcess(
            D_model, 2, 'train', device)

        self.G_solver = optim.SGD([{
            'params':
            filter(lambda p: p.requires_grad, self.student.parameters()),
            'initial_lr':
            args.lr_g
        }],
                                  args.lr_g,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 D_model.parameters()),
                'initial_lr': args.lr_d
            }],
            args.lr_d,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(
            CriterionDSN())  #CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(
            CriterionPixelWise())
        #self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale,
                                                   feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(
            CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(
            CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)
Exemplo n.º 2
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        device = args.device
        self.S_device = 'cuda:0'
        self.T_device = 'cuda:1'

        ######## skd
        # student = Res_pspnet(BasicBlock, [2, 2, 2, 2], num_classes=args.classes_num, deep_base=False)
        # load_S_model(args, student, False)
        # # print(student)
        #
        # # print(student.device)
        # print_model_parm_nums(student, 'student_model')
        # self.student = self.DataParallelModelProcess(student, 2, 'train', device=self.S_device)
        # self.student = student
        # # self.student.cuda()
        # # self.student.to('cuda:0')
        # # self.student.train()
        # self.student = student
        if self.args.pi or self.args.pa or self.args.ho:
            teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                                 num_classes=args.classes_num)
            load_T_model(teacher, args.T_ckpt_path)
            print_model_parm_nums(teacher, 'teacher_model')
            self.teacher = self.DataParallelModelProcess(teacher,
                                                         2,
                                                         'eval',
                                                         device=self.T_device)
            self.teacher = teacher
            # self.teacher.to('cuda:1')
            # self.teacher.eval()
            self.teacher = teacher

        ##########################  mmseg

        S_config = 'configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes_1gpu.py'
        S_cfg = Config.fromfile(S_config)
        # print(S_cfg)
        # S_cfg.model.pretrained = args.student_pretrain_model_imgnet
        self.student = build_segmentor(S_cfg.model,
                                       train_cfg=S_cfg.train_cfg,
                                       test_cfg=S_cfg.test_cfg)
        # self.student = build_segmentor(S_cfg.model, train_cfg=None, test_cfg=None)

        # checkpoint = args.student_pretrain_model_imgnet
        # print(checkpoint)
        # checkpoint = load_checkpoint(self.student, checkpoint)

        # load_S_model(args, self.student, False)
        self.student = self.DataParallelModelProcess(self.student,
                                                     2,
                                                     'train',
                                                     device=self.S_device)
        self.student = self.student
        # print(self.student)
        # for name, parameters in self.student.named_parameters():
        #     print(name, ':', parameters.size(), parameters.requires_grad)
        #
        # # print(self.student.parameters())
        #
        # # for parameters in self.student.parameters():
        # #     print(parameters)

        # if self.args.pi or self.args.pa or self.args.ho:
        #     T_config = 'configs/pspnet/pspnet_r101-d8_512x512_80k_cityscapes_1gpu.py'
        #     T_cfg = Config.fromfile(T_config)
        #     # print(T_cfg)
        #     # self.teacher = build_segmentor(T_cfg.model, train_cfg=T_cfg.train_cfg, test_cfg=T_cfg.test_cfg)
        #     self.teacher = build_segmentor(T_cfg.model, train_cfg=None, test_cfg=None)
        #     checkpoint = 'work_dirs/models_zoo/pspnet_r101-d8_512x512_80k_cityscapes.2.2_iter_80000.pth'
        #     checkpoint = load_checkpoint(self.teacher, checkpoint)
        #     self.teacher = self.DataParallelModelProcess(self.teacher, 2, 'eval', device=self.T_device)
        #     self.teacher = self.teacher
        ####################################################

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        # self.parallel_D = self.DataParallelModelProcess(D_model, 2, 'train', device)
        self.parallel_D = self.DataParallelModelProcess(D_model,
                                                        2,
                                                        'train',
                                                        device='cuda:0')
        self.D_model = D_model

        self.G_solver = optim.SGD([{
            'params':
            filter(lambda p: p.requires_grad, self.student.parameters()),
            'initial_lr':
            args.lr_g
        }],
                                  lr=args.lr_g,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        # self.G_solver  = optim.SGD(self.student.parameters(),
        #                            lr=args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 D_model.parameters()),
                'initial_lr': args.lr_d
            }],
            lr=args.lr_d,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(
            CriterionDSN())  # CriterionCrossEntropy()
        self.criterion_ce = self.DataParallelCriterionProcess(
            CriterionCE())  # CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(
            CriterionPixelWise())
        # self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale,
                                                   feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(
            CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(
            CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        self.criterion_AT = self.DataParallelCriterionProcess(AT(p=2))

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)

        print('init finish')
Exemplo n.º 3
0
import numpy as np
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import IterBasedRunner, build_optimizer

from mmseg.core import DistEvalHook, EvalHook
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import get_root_logger

from mmcv.runner import load_checkpoint
from mmseg.apis import multi_gpu_test, single_gpu_test

from utils.utils import *
from networks.pspnet_combine import Res_pspnet, BasicBlock, Bottleneck

student = Res_pspnet(Bottleneck, [2, 2, 2, 2], num_classes=19)


class IterLoader:
    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 0

    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader)
                                dtype=np.int)
            ignore_index = seg_gt != 255
            seg_gt = seg_gt[ignore_index]
            seg_pred = seg_pred[ignore_index]
            confusion_matrix += get_confusion_matrix(seg_gt, seg_pred,
                                                     num_classes)

    if type == 'val':
        pos = confusion_matrix.sum(1)
        res = confusion_matrix.sum(0)
        tp = np.diag(confusion_matrix)
        IU_array = (tp / np.maximum(1.0, pos + res - tp))
        mean_IU = IU_array.mean()
        return mean_IU, IU_array


if __name__ == '__main__':
    restore_from = r'/home/users/changyong.shu/new/jobs/kd-seg/lyf/00-0-paper-reproduct/src-local1/best_pth_for_test_src_kfj_compute_39/snapshots/CS_scenes_39326_0.75.pth',
    testloader = data.DataLoader(
        CSDataTestSet(data_dir,
                      './dataset/list/cityscapes/test.lst',
                      crop_size=(1024, 2048)),
        batch_size=1,
        shuffle=False,
        pin_memory=True,
        type='test')
    student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                         num_classes=args.classes_num)
    model.load_state_dict(torch.load(restore_from))
    evaluate_main(student, testloader, '0', '512,512', 9, True)
Exemplo n.º 5
0
from torch.utils import data
from networks.pspnet_combine import Res_pspnet, BasicBlock, Bottleneck
from networks.evaluate import evaluate_main
from dataset.datasets import CSDataTestSet
from utils.train_options import TrainOptionsForTest
import torch

if __name__ == '__main__':
    args = TrainOptionsForTest().initialize()
    testloader = data.DataLoader(CSDataTestSet(
        args.data_dir,
        './dataset/list/cityscapes/test.lst',
        crop_size=(1024, 2048)),
                                 batch_size=1,
                                 shuffle=False,
                                 pin_memory=True)
    student = Res_pspnet(BasicBlock, [2, 2, 2, 2], num_classes=19)  # resnet
    student.load_state_dict(torch.load(args.resume_from))
    evaluate_main(student, testloader, '512,512', 19, True, type='test')