示例#1
0
    def __init__(self, args):
        self.args = args


        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())


        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            '''if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])'''
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else: 
            print('No Resuming Checkpoint Given')
            raise NotImplementedError

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#3
0
class Tester(object):
    def __init__(self, args):
        self.args = args


        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())


        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            '''if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])'''
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else: 
            print('No Resuming Checkpoint Given')
            raise NotImplementedError

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def imgsaver(self, img, imgname):
        im1 = np.uint8(img.transpose(1,2,0)).squeeze()
        #filename_list = sorted(os.listdir(self.args.test_img_root))

        valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
        class_map = dict(zip(range(19), valid_classes))
        im1_np = np.uint8(np.zeros([513,513]))
        for _validc in range(19):
            im1_np[im1 == _validc] = class_map[_validc]
        saveim1 = Image.fromarray(im1_np, mode='L')
        saveim1 = saveim1.resize((1280,640), Image.NEAREST)
        saveim1.save('result/'+imgname)

        palette = [[128,64,128],
                    [244,35,232],
                    [70,70,70],
                    [102,102,156],
                    [190,153,153],
                    [153,153,153],
                    [250,170,30],
                    [220,220,0],
                    [107,142,35],
                    [152,251,152],
                    [70,130,180],
                    [220,20,60],
                    [255,0,0],
                    [0,0,142],
                    [0,0,70],
                    [0,60,100],
                    [0,80,100],
                    [0,0,230],
                    [119,11,32]]
                    #[0,0,0]]
        class_color_map = dict(zip(range(19), palette))
        im2_np = np.uint8(np.zeros([513,513,3]))
        for _validc in range(19):
            im2_np[im1 == _validc] = class_color_map[_validc]
        saveim2 = Image.fromarray(im2_np)
        saveim2 = saveim2.resize((1280,640), Image.NEAREST)
        saveim2.save('result/'+imgname[:-4]+'_color.png')
        # print('saving: '+filename_list[idx])


    def test(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        tbar = tqdm(self.test_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image = sample['image']
            if self.args.cuda:
                image = image.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            pred = np.argmax(pred, axis=1)


            self.imgsaver(pred, sample['name'][0]);

        # Fast test during the training
        print('Test:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0]))
        '''Acc = self.evaluator.Pixel_Accuracy()
示例#4
0
    def __init__(self, network_arch, cell_arch, num_classes, args,
                 low_level_layer):

        super(Model_2, self).__init__()
        BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d
        F = args.F
        B = args.B

        eps = 1e-5
        momentum = 0.1

        self.args = args
        num_model_1_layers = args.num_model_1_layers
        self.num_model_2_layers = len(network_arch) - num_model_1_layers

        self.cells = nn.ModuleList()
        self.model_2_network = network_arch[num_model_1_layers:]
        self.cell_arch = torch.from_numpy(cell_arch)
        self._num_classes = num_classes

        model_1_network = network_arch[:args.num_model_1_layers]
        self.model_1 = Model_1(model_1_network, cell_arch, num_classes, num_model_1_layers, \
                                       BatchNorm, F=F, B=B, low_level_layer=low_level_layer)
        self.decoder_2 = Decoder(num_classes, BatchNorm)

        fm = {0: 1, 1: 2, 2: 4, 3: 8}
        for i in range(self.num_model_2_layers):

            level = self.model_2_network[i]
            prev_level = self.model_2_network[i - 1]

            downup_sample = int(prev_level - level)
            dense_channel_list_1 = [
                F * fm[stride] for stride in model_1_network
            ]

            if i == 0:
                downup_sample = int(model_1_network[-1] -
                                    self.model_2_network[0])
                dense_channel_list_2 = dense_channel_list_1[:-1]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list_2,
                             F * B * fm[model_1_network[-1]],
                             self.cell_arch,
                             self.model_2_network[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True)

            elif i == 1:
                dense_channel_list_2 = dense_channel_list_1
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list_2,
                             F * B * fm[self.model_2_network[0]],
                             self.cell_arch,
                             self.model_2_network[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True)

            elif i < self.num_model_2_layers - 2:
                dense_channel_list_2 = dense_channel_list_1 + \
                                        [F * fm[stride] for stride in self.model_2_network[:i-1]]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list_2,
                             F * B * fm[prev_level],
                             self.cell_arch,
                             self.model_2_network[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True)

            else:
                dense_channel_list_2 = dense_channel_list_1 + \
                                        [F * fm[stride] for stride in self.model_2_network[:i-1]]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list_2,
                             F * B * fm[prev_level],
                             self.cell_arch,
                             self.model_2_network[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True,
                             dense_out=False)

            self.cells += [_cell]

        if self.model_2_network[-1] == 1:
            mult = 2
        elif self.model_2_network[-1] == 2:
            mult = 1

        self.low_level_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(F * B * 2**model_1_network[low_level_layer],
                      48,
                      1,
                      bias=False),
            BatchNorm(48, eps=eps, momentum=momentum),
        )
        self.aspp_2 = ASPP_train(F * B * fm[self.model_2_network[-1]],
                                 256,
                                 num_classes,
                                 BatchNorm,
                                 mult=mult)
        self._init_weight()
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride = args.out_stride,
                            BatchNorm = BN)
        self.assp_model = ASPP(backbone = args.backbone,
                          output_stride = args.out_stride,
                          BatchNorm = BN)
        self.y_model = Decoder(num_classes = self.nclass,
                          backbone = args.backbone,
                          BatchNorm = BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone = args.backbone,
                                  BatchNorm = BN)
        f_params = list(self.backbone_model.parameters()) + list(self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(f_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params+y_params, lr= args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params, lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params+y_params, lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\'+args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader, self.nclass, classes_weights_path, self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model, device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def training(self, epoch):
        train_loss = 0.0
        train_task_loss = 0.0
        train_d_loss = 0.0
        train_d_inv_loss = 0.0
        self.backbone_model.train()
        self.assp_model.train()
        self.y_model.train()
        self.d_model.train()
        tbar = tqdm(self.train_loader)
        num_img_tr = len(self.train_loader)
        for i, sample in enumerate(tbar):
            if self.args.dataset == 'gtav':
                src_image,src_label = sample['image'], sample['label']
            else:
                src_image, src_label, tgt_image = sample['src_image'], sample['src_label'], sample['tgt_image']
            if self.args.cuda:
                if self.args.dataset != 'gtav':
                    src_image, src_label, tgt_image  = src_image.cuda(), src_label.cuda(), tgt_image.cuda()
                else:
                    src_image, src_label = src_image.cuda(), src_label.cuda()
            self.scheduler(self.task_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.d_inv_optimizer, i, epoch, self.best_pred)
            self.scheduler(self.c_optimizer, i, epoch, self.best_pred)
            self.task_optimizer.zero_grad()
            self.d_optimizer.zero_grad()
            self.d_inv_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            # source image feature
            src_high_feature_0, src_low_feature = self.backbone_model(src_image)
            src_high_feature = self.assp_model(src_high_feature_0)
            src_output = F.interpolate(self.y_model(src_high_feature, src_low_feature), src_image.size()[2:], \
                                       mode='bilinear', align_corners=True)

            src_d_pred = self.d_model(src_high_feature)
            task_loss = self.task_loss(src_output, src_label)

            if self.args.dataset != 'gtav':
                # target image feature
                tgt_high_feature_0, tgt_low_feature = self.backbone_model(tgt_image)
                tgt_high_feature = self.assp_model(tgt_high_feature_0)
                tgt_output = F.interpolate(self.y_model(tgt_high_feature, tgt_low_feature), tgt_image.size()[2:], \
                                           mode='bilinear', align_corners=True)
                tgt_d_pred = self.d_model(tgt_high_feature)

                d_loss,d_acc = self.domain_loss(src_d_pred, tgt_d_pred)
                d_inv_loss, _ = self.domain_loss(tgt_d_pred, src_d_pred)
                loss = task_loss + d_loss + d_inv_loss
                loss.backward()
                self.task_optimizer.step()
                self.d_optimizer.step()
                self.d_inv_optimizer.step()
            else:
                d_acc = 0
                d_loss = torch.tensor(0.0)
                d_inv_loss = torch.tensor(0.0)
                loss = task_loss
                loss.backward()
                self.task_optimizer.step()

            train_task_loss += task_loss.item()
            train_d_loss += d_loss.item()
            train_d_inv_loss += d_inv_loss.item()
            train_loss += task_loss.item() + d_loss.item() + d_inv_loss.item()

            tbar.set_description('Train loss: %.3f t_loss: %.3f d_loss: %.3f , d_inv_loss: %.3f  d_acc: %.2f' \
                                 % (train_loss / (i + 1),train_task_loss / (i + 1),\
                                    train_d_loss / (i + 1), train_d_inv_loss / (i + 1), d_acc*100))

            self.writer.add_scalar('train/task_loss_iter', task_loss.item(), i + num_img_tr * epoch)
            # Show 10 * 3 inference results each epoch
            if i % (num_img_tr // 10) == 0:
                global_step = i + num_img_tr * epoch
                if self.args.dataset != 'gtav':
                    image = torch.cat([src_image,tgt_image],dim=0)
                    output = torch.cat([src_output,tgt_output],dim=0)
                else:
                    image = src_image
                    output = src_output
                self.summary.visualize_image(self.writer, self.args.dataset, image, src_label, output, global_step)


        self.writer.add_scalar('train/task_loss_epoch', train_task_loss, epoch)
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + src_image.data.shape[0]))
        print('Loss: %.3f' % train_loss)

        if self.args.no_val:
            # save checkpoint every epoch
            is_best = False
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU,IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch)
        self.writer.add_scalar('val/mIoU', mIoU, epoch)
        self.writer.add_scalar('val/Acc', Acc, epoch)
        self.writer.add_scalar('val/Acc_class', Acc_class, epoch)
        self.writer.add_scalar('val/fwIoU', FWIoU, epoch)
        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)

        new_pred = mIoU

        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
            self.saver.save_checkpoint({
                'epoch': epoch + 1,
                'backbone_model_state_dict': self.backbone_model.module.state_dict(),
                'assp_model_state_dict': self.assp_model.module.state_dict(),
                'y_model_state_dict': self.y_model.module.state_dict(),
                'd_model_state_dict': self.d_model.module.state_dict(),
                'task_optimizer': self.task_optimizer.state_dict(),
                'd_optimizer': self.d_optimizer.state_dict(),
                'd_inv_optimizer': self.d_inv_optimizer.state_dict(),
                'c_optimizer': self.c_optimizer.state_dict(),
                'best_pred': self.best_pred,
            }, is_best)
示例#6
0
    def __init__(self,
                 network_arch,
                 cell_arch,
                 num_classes,
                 num_layers,
                 BatchNorm,
                 F=20,
                 B=5,
                 low_level_layer=0):

        super(Model_1, self).__init__()

        self.cells = nn.ModuleList()
        self.model_1_network = network_arch
        self.cell_arch = torch.from_numpy(cell_arch)
        self.num_model_1_layers = num_layers
        self.low_level_layer = low_level_layer
        self._num_classes = num_classes

        FB = F * B
        fm = {0: 1, 1: 2, 2: 4, 3: 8}

        eps = 1e-5
        momentum = 0.1

        self.stem0 = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            BatchNorm(64, eps=eps, momentum=momentum), nn.ReLU(inplace=True))

        self.stem1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            BatchNorm(64, eps=eps, momentum=momentum),
        )

        self.stem2 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            BatchNorm(128, eps=eps, momentum=momentum))

        for i in range(self.num_model_1_layers):
            level = self.model_1_network[i]
            prev_level = self.model_1_network[i - 1]
            prev_prev_level = self.model_1_network[i - 2]

            downup_sample = int(prev_level - level)

            if i == 0:
                downup_sample = int(0 - level)
                _cell = Cell(BatchNorm, B, 64, 128, self.cell_arch,
                             self.model_1_network[i], F * fm[level],
                             downup_sample)

            elif i == 1:
                _cell = Cell(BatchNorm, B, 128, FB * fm[prev_level],
                             self.cell_arch, self.model_1_network[i],
                             F * fm[level], downup_sample)
            elif i == 2:
                _cell = Cell(BatchNorm, B, FB * fm[prev_prev_level],
                             FB * fm[prev_level], self.cell_arch,
                             self.model_1_network[i], F * fm[level],
                             downup_sample)
            else:
                dense_channel_list = [
                    F * fm[stride] for stride in self.model_1_network[:i - 1]
                ]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list,
                             FB * fm[prev_level],
                             self.cell_arch,
                             self.model_1_network[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True)

            self.cells += [_cell]

        if self.model_1_network[-1] == 1:
            mult = 2
        elif self.model_1_network[-1] == 2:
            mult = 1



        self.aspp_1 = ASPP_train(FB * fm[self.model_1_network[-1]], \
                                  256, num_classes, BatchNorm, mult=mult)
        self.low_level_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(FB * 2**self.model_1_network[low_level_layer],
                      48,
                      1,
                      bias=False),
            BatchNorm(48, eps=eps, momentum=momentum),
        )
        self.decoder_1 = Decoder(num_classes, BatchNorm)

        self._init_weight()
    def __init__(self, network_arch, cell_arch_1, cell_arch_2, num_classes,
                 args, low_level_layer):
        super(Model_2_baseline, self).__init__()
        BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d
        F = args.F
        B = args.B
        num_model_1_layers = args.num_model_1_layers
        self.num_model_2_layers = len(network_arch) - num_model_1_layers
        self.cells = nn.ModuleList()
        self.model_2_network = network_arch[num_model_1_layers:]
        self.cell_arch_2 = torch.from_numpy(cell_arch_2)
        self._num_classes = num_classes

        model_1_network = network_arch[:args.num_model_1_layers]
        self.model_1 = Model_1_baseline(model_1_network, cell_arch_1, num_classes, num_model_1_layers, \
                                       BatchNorm, F=F, B=B, low_level_layer=low_level_layer)
        self.decoder_2 = Decoder(num_classes, BatchNorm)

        fm = {0: 1, 1: 2, 2: 4, 3: 8}
        for i in range(self.num_model_2_layers):

            level = self.model_2_network[i]
            prev_level = self.model_2_network[i - 1]
            prev_prev_level = self.model_2_network[i - 2]

            downup_sample = int(prev_level - level)

            if i == 0:
                downup_sample = int(model_1_network[-1] -
                                    self.model_2_network[0])
                pre_downup_sample = int(model_1_network[-2] -
                                        self.model_2_network[0])
                _cell = Cell_baseline(BatchNorm, B,
                                      F * B * fm[model_1_network[-2]],
                                      F * B * fm[model_1_network[-1]],
                                      self.cell_arch_2,
                                      self.model_2_network[i], F * fm[level],
                                      downup_sample)

            elif i == 1:
                pre_downup_sample = int(model_1_network[-1] -
                                        self.model_2_network[1])
                _cell = Cell_baseline(BatchNorm, B,
                                      F * B * fm[model_1_network[-1]],
                                      F * B * fm[self.model_2_network[0]],
                                      self.cell_arch_2,
                                      self.model_2_network[i], F * fm[level],
                                      downup_sample)
            else:
                _cell = Cell_baseline(BatchNorm, B,
                                      F * B * fm[prev_prev_level],
                                      F * B * fm[prev_level], self.cell_arch_2,
                                      self.model_2_network[i], F * fm[level],
                                      downup_sample)

            self.cells += [_cell]

        if self.model_2_network[-1] == 1:
            mult = 2
        elif self.model_2_network[-1] == 2:
            mult = 1

        self.low_level_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(F * B * 2**model_1_network[low_level_layer],
                      48,
                      1,
                      bias=False), BatchNorm(48))

        self.aspp_2 = ASPP_train(F * B * fm[self.model_2_network[-1]],
                                 256,
                                 num_classes,
                                 BatchNorm,
                                 mult=mult)
        self._init_weight()
示例#8
0
    def __init__(self, network_arch, cell_arch, num_classes, args, low_level_layer):
        super(AutoDeepLab, self).__init__()
        BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d
        F = args.F
        B = args.B
        self.num_model_layers = len(network_arch)
        self.cells = nn.ModuleList()
        self.model_network = network_arch
        self.cell_arch = torch.from_numpy(cell_arch)
        self.low_level_layer = low_level_layer
        self._num_classes = num_classes

        FB = F * B
        fm = {0: 1, 1: 2, 2: 4, 3: 8}
        self.stem0 = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            BatchNorm(64),
            nn.ReLU(inplace=True)

        )
        self.stem1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            BatchNorm(64),
        )

        self.stem2 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            BatchNorm(128),
        )

        for i in range(self.num_model_layers):
            level = self.model_network[i]
            prev_level = self.model_network[i-1]
            prev_prev_level = self.model_network[i-2]

            downup_sample = int(prev_level - level)
            if i == 0:
                downup_sample = int(0 - level)
                pre_downup_sample = int(-1 - level)
                _cell = Cell_AutoDeepLab(BatchNorm,
                            B,
                            64,
                            128,                              
                            self.cell_arch,
                            self.model_network[i],
                            F * fm[level],                        
                            downup_sample) 
                
            elif i == 1:
                pre_downup_sample = int(0 - level)
                _cell = Cell_AutoDeepLab(BatchNorm,
                            B,
                            128,
                            FB * fm[prev_level],
                            self.cell_arch,
                            self.model_network[i],
                            F * fm[level],
                            downup_sample)
            else:
                _cell = Cell_AutoDeepLab(BatchNorm,
                            B, 
                            FB * fm[prev_prev_level],
                            FB * fm[prev_level],
                            self.cell_arch,
                            self.model_network[i],
                            F * fm[level],
                            downup_sample)

            self.cells += [_cell]

        if self.model_network[-1] == 1:
            mult = 2
        elif self.model_network[-1] == 2:
            mult =1

        self.low_level_conv = nn.Sequential(
                                nn.ReLU(),
                                nn.Conv2d(F * B * 2**self.model_network[low_level_layer], 48, 1, bias=False),
                                BatchNorm(48)
                                )
        
        self.aspp = ASPP_train(F * B * fm[self.model_network[-1]], 
                                256,
                                BatchNorm,
                                mult=mult)

        self.decoder = Decoder(num_classes, BatchNorm)
        self._init_weight()
示例#9
0
    def __init__(self, network_arch, C_index, cell_arch, num_classes, args,
                 low_level_layer):

        super(ADD, self).__init__()
        BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d
        F = args.F
        B = args.B

        eps = 1e-5
        momentum = 0.1

        self.args = args

        self.cells = nn.ModuleList()
        self.cell_arch = torch.from_numpy(cell_arch)
        self._num_classes = num_classes
        self.low_level_layer = low_level_layer
        self.decoder = Decoder(num_classes, BatchNorm)

        self.network_arch = network_arch
        self.num_net = len(network_arch)
        self.C_index = C_index

        FB = F * B
        fm = {0: 1, 1: 2, 2: 4, 3: 8}

        eps = 1e-5
        momentum = 0.1

        self.stem0 = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            BatchNorm(64, eps=eps, momentum=momentum), nn.ReLU(inplace=True))

        self.stem1 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1, bias=False),
            BatchNorm(64, eps=eps, momentum=momentum),
        )

        self.stem2 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            BatchNorm(128, eps=eps, momentum=momentum))

        for i in range(self.num_net):
            level = self.network_arch[i]
            prev_level = self.network_arch[i - 1]
            prev_prev_level = self.network_arch[i - 2]

            downup_sample = int(prev_level - level)

            if i == 0:
                downup_sample = int(0 - level)
                _cell = Cell(BatchNorm,
                             B,
                             64,
                             128,
                             self.cell_arch,
                             self.network_arch[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=False,
                             dense_out=True)

            elif i == 1:
                _cell = Cell(BatchNorm,
                             B,
                             128,
                             FB * fm[prev_level],
                             self.cell_arch,
                             self.network_arch[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=False,
                             dense_out=True)
            elif i == 2:
                _cell = Cell(BatchNorm,
                             B,
                             FB * fm[prev_prev_level],
                             FB * fm[prev_level],
                             self.cell_arch,
                             self.network_arch[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=False,
                             dense_out=True)

            elif i < self.num_net - 2:
                dense_channel_list = [
                    F * fm[stride] for stride in self.network_arch[:i - 1]
                ]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list,
                             F * B * fm[prev_level],
                             self.cell_arch,
                             self.network_arch[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True,
                             dense_out=True)

            else:
                dense_channel_list = [
                    F * fm[stride] for stride in self.network_arch[:i - 1]
                ]
                _cell = Cell(BatchNorm,
                             B,
                             dense_channel_list,
                             FB * fm[prev_level],
                             self.cell_arch,
                             self.network_arch[i],
                             F * fm[level],
                             downup_sample,
                             dense_in=True,
                             dense_out=False)

            self.cells += [_cell]

        if self.network_arch[-1] == 1:
            mult = 2
        elif self.network_arch[-1] == 2:
            mult = 1
        elif self.network_arch[-1] == 3:
            mult = 0.5

        self._init_weight()
        self.pooling = nn.MaxPool2d(3, stride=2)
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.relu = nn.ReLU()

        self.low_level_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(F * B * 2**self.network_arch[low_level_layer],
                      48,
                      1,
                      bias=False),
            BatchNorm(48, eps=eps, momentum=momentum),
        )
        self.aspp = ASPP_train(
            F * B * fm[self.network_arch[-1]],
            256,
            BatchNorm,
            mult=mult,
        )
        self.conv_aspp = nn.ModuleList()
        for c in self.C_index:
            if self.network_arch[c] - self.network_arch[-1] == -1:
                self.conv_aspp.append(
                    FactorizedReduce(FB * 2**self.network_arch[c],
                                     FB * 2**self.network_arch[-1],
                                     BatchNorm,
                                     eps=eps,
                                     momentum=momentum))
            elif self.network_arch[c] - self.network_arch[-1] == -2:
                self.conv_aspp.append(
                    DoubleFactorizedReduce(FB * 2**self.network_arch[c],
                                           FB * 2**self.network_arch[-1],
                                           BatchNorm,
                                           eps=eps,
                                           momentum=momentum))
            elif self.network_arch[c] - self.network_arch[-1] > 0:
                self.conv_aspp.append(
                    ReLUConvBN(FB * 2**self.network_arch[c],
                               FB * 2**self.network_arch[-1],
                               1,
                               1,
                               0,
                               BatchNorm,
                               eps=eps,
                               momentum=momentum,
                               affine=True))
        self._init_weight()
示例#10
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        if args.sync_bn == True:
            BN = SynchronizedBatchNorm2d
        else:
            BN = nn.BatchNorm2d
        ### deeplabV3 start ###
        self.backbone_model = MobileNetV2(output_stride=args.out_stride,
                                          BatchNorm=BN)
        self.assp_model = ASPP(backbone=args.backbone,
                               output_stride=args.out_stride,
                               BatchNorm=BN)
        self.y_model = Decoder(num_classes=self.nclass,
                               backbone=args.backbone,
                               BatchNorm=BN)
        ### deeplabV3 end ###
        self.d_model = DomainClassifer(backbone=args.backbone, BatchNorm=BN)
        f_params = list(self.backbone_model.parameters()) + list(
            self.assp_model.parameters())
        y_params = list(self.y_model.parameters())
        d_params = list(self.d_model.parameters())

        # Define Optimizer
        if args.optimizer == 'SGD':
            self.task_optimizer = torch.optim.SGD(
                f_params + y_params,
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=args.nesterov)
            self.d_optimizer = torch.optim.SGD(d_params,
                                               lr=args.lr,
                                               momentum=args.momentum,
                                               weight_decay=args.weight_decay,
                                               nesterov=args.nesterov)
            self.d_inv_optimizer = torch.optim.SGD(
                f_params,
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay,
                nesterov=args.nesterov)
            self.c_optimizer = torch.optim.SGD(f_params + y_params,
                                               lr=args.lr,
                                               momentum=args.momentum,
                                               weight_decay=args.weight_decay,
                                               nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            self.task_optimizer = torch.optim.Adam(f_params + y_params,
                                                   lr=args.lr)
            self.d_optimizer = torch.optim.Adam(d_params, lr=args.lr)
            self.d_inv_optimizer = torch.optim.Adam(f_params, lr=args.lr)
            self.c_optimizer = torch.optim.Adam(f_params + y_params,
                                                lr=args.lr)
        else:
            raise NotImplementedError

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy'
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.train_loader,
                                                  self.nclass,
                                                  classes_weights_path,
                                                  self.args.dataset)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.task_loss = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.domain_loss = DomainLosses(cuda=args.cuda).build_loss()
        self.ca_loss = ''

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.backbone_model = torch.nn.DataParallel(
                self.backbone_model, device_ids=self.args.gpu_ids)
            self.assp_model = torch.nn.DataParallel(
                self.assp_model, device_ids=self.args.gpu_ids)
            self.y_model = torch.nn.DataParallel(self.y_model,
                                                 device_ids=self.args.gpu_ids)
            self.d_model = torch.nn.DataParallel(self.d_model,
                                                 device_ids=self.args.gpu_ids)
            patch_replication_callback(self.backbone_model)
            patch_replication_callback(self.assp_model)
            patch_replication_callback(self.y_model)
            patch_replication_callback(self.d_model)
            self.backbone_model = self.backbone_model.cuda()
            self.assp_model = self.assp_model.cuda()
            self.y_model = self.y_model.cuda()
            self.d_model = self.d_model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.backbone_model.module.load_state_dict(
                    checkpoint['backbone_model_state_dict'])
                self.assp_model.module.load_state_dict(
                    checkpoint['assp_model_state_dict'])
                self.y_model.module.load_state_dict(
                    checkpoint['y_model_state_dict'])
                self.d_model.module.load_state_dict(
                    checkpoint['d_model_state_dict'])
            else:
                self.backbone_model.load_state_dict(
                    checkpoint['backbone_model_state_dict'])
                self.assp_model.load_state_dict(
                    checkpoint['assp_model_state_dict'])
                self.y_model.load_state_dict(checkpoint['y_model_state_dict'])
                self.d_model.load_state_dict(checkpoint['d_model_state_dict'])
            if not args.ft:
                self.task_optimizer.load_state_dict(
                    checkpoint['task_optimizer'])
                self.d_optimizer.load_state_dict(checkpoint['d_optimizer'])
                self.d_inv_optimizer.load_state_dict(
                    checkpoint['d_inv_optimizer'])
                self.c_optimizer.load_state_dict(checkpoint['c_optimizer'])
            if self.args.dataset == 'gtav':
                self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0

    def validation(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)

        # Fast test during the training
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        ClassName = [
            "road", "sidewalk", "building", "wall", "fence", "pole", "light",
            "sign", "vegetation", "terrain", "sky", "person", "rider", "car",
            "truck", "bus", "train", "motocycle", "bicycle"
        ]
        with open('val_info.txt', 'a') as f1:
            f1.write('Validation:' + '\n')
            f1.write('[Epoch: %d, numImages: %5d]' %
                     (epoch, i * self.args.batch_size + image.data.shape[0]) +
                     '\n')
            f1.write("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
                Acc, Acc_class, mIoU, FWIoU) + '\n')
            f1.write('Loss: %.3f' % test_loss + '\n' + '\n')
            f1.write('Class IOU: ' + '\n')
            for idx in range(19):
                f1.write('\t' + ClassName[idx] +
                         (': \t' if len(ClassName[idx]) > 5 else ': \t\t') +
                         str(IoU[idx]) + '\n')

        print('Validation:')
        print('[Epoch: %d, numImages: %5d]' %
              (epoch, i * self.args.batch_size + image.data.shape[0]))
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))
        print('Loss: %.3f' % test_loss)
        print(IoU)

        new_pred = mIoU

    def imgsaver(self, img, imgname, miou):
        im1 = np.uint8(img.transpose(1, 2, 0)).squeeze()
        #filename_list = sorted(os.listdir(self.args.test_img_root))

        valid_classes = [
            7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31,
            32, 33
        ]
        class_map = dict(zip(range(19), valid_classes))
        im1_np = np.uint8(np.zeros([513, 513]))
        for _validc in range(19):
            im1_np[im1 == _validc] = class_map[_validc]
        saveim1 = Image.fromarray(im1_np, mode='L')
        saveim1 = saveim1.resize((1280, 640), Image.NEAREST)
        # saveim1.save('result_val/'+imgname)

        palette = [[128, 64, 128], [244, 35, 232], [70, 70, 70],
                   [102, 102, 156], [190, 153, 153], [153, 153, 153],
                   [250, 170, 30], [220, 220, 0], [107, 142, 35],
                   [152, 251, 152], [70, 130, 180], [220, 20, 60], [255, 0, 0],
                   [0, 0, 142], [0, 0, 70], [0, 60, 100], [0, 80, 100],
                   [0, 0, 230], [119, 11, 32]]
        #[0,0,0]]
        class_color_map = dict(zip(range(19), palette))
        im2_np = np.uint8(np.zeros([513, 513, 3]))
        for _validc in range(19):
            im2_np[im1 == _validc] = class_color_map[_validc]
        saveim2 = Image.fromarray(im2_np)
        saveim2 = saveim2.resize((1280, 640), Image.NEAREST)
        saveim2.save('result_val/' + imgname[:-4] + '_color_' + str(miou) +
                     '_.png')
        # print('saving: '+filename_list[idx])

    def validationSep(self, epoch):
        self.backbone_model.eval()
        self.assp_model.eval()
        self.y_model.eval()
        self.d_model.eval()
        self.evaluator.reset()
        tbar = tqdm(self.val_loader, desc='\r')
        test_loss = 0.0
        for i, sample in enumerate(tbar):
            image, target = sample['image'], sample['label']
            self.evaluator.reset()
            if self.args.cuda:
                image, target = image.cuda(), target.cuda()
            with torch.no_grad():
                high_feature, low_feature = self.backbone_model(image)
                high_feature = self.assp_model(high_feature)
                output = F.interpolate(self.y_model(high_feature, low_feature), image.size()[2:], \
                                           mode='bilinear', align_corners=True)
            task_loss = self.task_loss(output, target)
            test_loss += task_loss.item()
            tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1)))
            pred = output.data.cpu().numpy()
            target = target.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            # Add batch sample into evaluator
            self.evaluator.add_batch(target, pred)
            mIoU, IoU = self.evaluator.Mean_Intersection_over_Union()
            self.imgsaver(pred, sample['name'][0], mIoU)