Ejemplo n.º 1
0
    def __init__(self, args):
        self.args = args
        student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                             num_classes=args.num_classes)
        load_S_model(args, student)
        print_model_parm_nums(student, 'student_model')
        student.cuda()
        self.student = student

        teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                             num_classes=args.num_classes)
        load_T_model(args, teacher)
        print_model_parm_nums(teacher, 'teacher_model')
        teacher.cuda()
        self.teacher = teacher

        D_model = Discriminator(args.preprocess_GAN_mode, args.num_classes,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model)
        print_model_parm_nums(D_model, 'D_model')
        logging.info("------------")
        D_model.cuda()
        self.D_model = D_model

        self.G_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 student.parameters()),
                'initial_lr': args.lr_g
            }],
            args.lr_g,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        self.D_solver = optim.Adam(
            filter(lambda p: p.requires_grad, D_model.parameters()), args.lr_d,
            [0.9, 0.99])
        # self.D_solver = optim.SGD([{'params': filter(lambda p: p.requires_grad, D_model.parameters()), 'initial_lr': args.lr_d}], args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay)

        self.criterion_dsn = CriterionDSN().cuda()
        if args.kd:
            self.criterion_kd = CriterionKD().cuda()
        if args.adv:
            self.criterion_adv = CriterionAdv(args.adv_loss_type).cuda()
            if args.adv_loss_type == 'wgan-gp':
                self.criterion_AdditionalGP = CriterionAdditionalGP(
                    D_model, args.lambda_gp).cuda()
            self.criterion_adv_for_G = CriterionAdvForG(
                args.adv_loss_type).cuda()
        if args.ifv:
            self.criterion_ifv = CriterionIFV(classes=args.num_classes).cuda()

        self.G_loss, self.D_loss = 0.0, 0.0
        self.mc_G_loss, self.kd_G_loss, self.adv_G_loss, self.ifv_G_loss = 0.0, 0.0, 0.0, 0.0

        cudnn.deterministic = True
        cudnn.benchmark = False
Ejemplo n.º 2
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        device = args.device
        student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                             num_classes=args.classes_num)
        load_S_model(args, student, False)
        print_model_parm_nums(student, 'student_model')
        self.parallel_student = self.DataParallelModelProcess(
            student, 2, 'train', device)
        self.student = student

        teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                             num_classes=args.classes_num)
        load_T_model(teacher, args.T_ckpt_path)
        print_model_parm_nums(teacher, 'teacher_model')
        self.parallel_teacher = self.DataParallelModelProcess(
            teacher, 2, 'eval', device)
        self.teacher = teacher

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        self.parallel_D = self.DataParallelModelProcess(
            D_model, 2, 'train', device)

        self.G_solver = optim.SGD([{
            'params':
            filter(lambda p: p.requires_grad, self.student.parameters()),
            'initial_lr':
            args.lr_g
        }],
                                  args.lr_g,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 D_model.parameters()),
                'initial_lr': args.lr_d
            }],
            args.lr_d,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(
            CriterionDSN())  #CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(
            CriterionPixelWise())
        #self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale,
                                                   feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(
            CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(
            CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)
Ejemplo n.º 3
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        self.mode_name = 'kd_mmseg'
        device = args.device

        ######## skd
        # self.S_device = 'cuda:0'
        # self.T_device = 'cuda:1'
        # student = Res_pspnet(BasicBlock, [2, 2, 2, 2], num_classes = args.classes_num, deep_base=False)
        # load_S_model(args, student, False)
        # # print(student)
        #
        # # print(student.device)
        # print_model_parm_nums(student, 'student_model')
        # self.student = self.DataParallelModelProcess(student, 2, 'train', device=self.S_device)
        # self.student = student
        # # self.student.cuda()
        # # self.student.to('cuda:0')
        # # self.student.train()

        # teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3], num_classes = args.classes_num)
        # load_T_model(teacher, args.T_ckpt_path)
        # print_model_parm_nums(teacher, 'teacher_model')
        # self.teacher = self.DataParallelModelProcess(teacher, 2, 'eval', device=self.T_device)
        # self.teacher = teacher
        # # self.teacher.to('cuda:1')
        # # self.teacher.eval()


        ##########################  mmseg

        self.S_device = 'cuda:0'
        S_config = 'configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes_1gpu.py'
        S_cfg = Config.fromfile(S_config)
        # print(S_cfg)
        # S_cfg.model.pretrained = args.student_pretrain_model_imgnet
        self.student = build_segmentor(S_cfg.model, train_cfg=S_cfg.train_cfg, test_cfg=S_cfg.test_cfg)
        # self.student = build_segmentor(S_cfg.model, train_cfg=None, test_cfg=None)

        # checkpoint = args.student_pretrain_model_imgnet
        # print(checkpoint)
        # checkpoint = load_checkpoint(self.student, checkpoint)

        # load_S_model(args, self.student, False)
        # self.student = self.DataParallelModelProcess(self.student, 2, 'train', device=self.S_device)
        self.student.train()
        self.student.to(self.S_device)

        # print(self.student)
        # for name, parameters in self.student.named_parameters():
        #     print(name, ':', parameters.size(), parameters.requires_grad)
        #
        # # print(self.student.parameters())
        #
        # # for parameters in self.student.parameters():
        # #     print(parameters)

        if self.args.pi or self.args.pa or self.args.ho:
            self.T_device = 'cuda:1'
            T_config = 'configs/pspnet/pspnet_r101-d8_512x512_80k_cityscapes_1gpu.py'
            T_cfg = Config.fromfile(T_config)
            # print(T_cfg)
            # self.teacher = build_segmentor(T_cfg.model, train_cfg=T_cfg.train_cfg, test_cfg=T_cfg.test_cfg)
            self.teacher = build_segmentor(T_cfg.model, train_cfg=None, test_cfg=None)
            checkpoint = 'work_dirs/models_zoo/pspnet_r101-d8_512x512_80k_cityscapes.2.2_iter_80000.pth'
            checkpoint = load_checkpoint(self.teacher, checkpoint)
            self.teacher = self.DataParallelModelProcess(self.teacher, 2, 'eval', device=self.T_device)

        ####################################################

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num, args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        # self.parallel_D = self.DataParallelModelProcess(D_model, 2, 'train', device)
        self.parallel_D = self.DataParallelModelProcess(D_model, 2, 'train', device='cuda:0')
        self.D_model = D_model

        self.G_solver = optim.SGD(
            [{'params': filter(lambda p: p.requires_grad, self.student.parameters()), 'initial_lr': args.lr_g}],
            lr=args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay)
        # self.G_solver  = optim.SGD(self.student.parameters(),
        #                            lr=args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{'params': filter(lambda p: p.requires_grad, D_model.parameters()), 'initial_lr': args.lr_d}],
            lr=args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(CriterionDSN())  # CriterionCrossEntropy()
        self.criterion_ce = self.DataParallelCriterionProcess(CriterionCE())  # CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(CriterionPixelWise())
        # self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale, feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        self.criterion_AT = self.DataParallelCriterionProcess(AT(p=2))

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)

        print('init finish')
Ejemplo n.º 4
0
    }],
    lr=cfg.optimizer.lr,
    momentum=cfg.optimizer.momentum,
    weight_decay=cfg.optimizer.weight_decay)

# Teacher
T_cfg = Config.fromfile(cfg.model.teacher_config.config_file)
teacher = build_segmentor(T_cfg.model, train_cfg=None, test_cfg=None)
checkpoint = load_checkpoint(teacher, cfg.model.teacher_config.checkpoint)
teacher.cuda(0)
teacher.eval()

# GAN
D_model = Discriminator(preprocess_GAN_mode=1,
                        input_channel=19,
                        batch_size=8,
                        image_size=64,
                        conv_dim=64)
# imsize_for_adv=,
# args.adv_conv_dim)
D_model.cuda(0)
D_model.train()

D_solver = optim.SGD(
    [{
        'params': filter(lambda p: p.requires_grad, D_model.parameters()),
        'initial_lr': 4e-4
    }],
    lr=4e-4,
    momentum=cfg.optimizer.momentum,
    weight_decay=cfg.optimizer.weight_decay)