class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
        trainset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs)
        args.iters_per_epoch = len(trainset) // (args.num_gpus * args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(trainset, shuffle=True, distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters)
        self.train_loader = data.DataLoader(dataset=trainset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)

        if not args.skip_val:
            valset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
            val_sampler = make_data_sampler(valset, False, args.distributed)
            val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size)
            self.val_loader = data.DataLoader(dataset=valset,
                                              batch_sampler=val_batch_sampler,
                                              num_workers=args.workers,
                                              pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(args.model, dataset=args.dataset,
                                            aux=args.aux, norm_layer=BatchNorm2d)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank],
                                                             output_device=args.local_rank)
        self.model = self.model.to(args.device)
        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(torch.load(args.resume, map_location=lambda storage, loc: storage))

        # create criterion
        if args.ohem:
            min_kept = int(args.batch_size // args.num_gpus * args.crop_size ** 2 // 16)
            self.criterion = MixSoftmaxCrossEntropyOHEMLoss(args.aux, args.aux_weight, min_kept=min_kept,
                                                            ignore_index=-1).to(self.device)
        else:
            self.criterion = MixSoftmaxCrossEntropyLoss(args.aux, args.aux_weight, ignore_index=-1).to(self.device)

        # optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)
        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)
        # evaluation metrics
        self.metric = SegmentationMetric(trainset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets) in enumerate(self.train_loader):
            iteration += 1
            self.lr_scheduler.step()

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                        iteration, max_iters, self.optimizer.param_groups[0]['lr'], losses_reduced.item(),
                        str(datetime.timedelta(seconds=int(time.time() - start_time))), eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation()
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters))

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info("Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(i + 1, pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
        synchronize()
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.63658345, 0.5976706, 0.6074681],
                                 [0.30042663, 0.29670033, 0.29805037]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split='train',
                                                 mode='train',
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            jpu=args.jpu,
                                            norm_layer=BatchNorm2d,
                                            pretrained_base=False).to(
                                                self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = get_segmentation_loss(args.model,
                                               use_ohem=args.use_ohem,
                                               aux=args.aux,
                                               aux_weight=args.aux_weight,
                                               ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': args.lr
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    args.lr * 10
                })
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        iteration, max_iters,
                        self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation()
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))
        dummy_input = torch.randn(1, 3, self.args.crop_size,
                                  self.args.crop_size)
        if (args.device == 'cuda'):
            dummy_input = dummy_input.cuda()
        input_names = ["input_1"]
        output_names = ["output1"]
        today = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
        model_name = "bisenet-" + str(today) + ".onnx"
        torch.onnx.export(self.model,
                          dummy_input,
                          "trained_model/" + model_name,
                          verbose=True,
                          input_names=input_names,
                          output_names=output_names,
                          opset_version=11)

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info(
                "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
                    i + 1, pixAcc, mIoU))

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
        synchronize()
Пример #3
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)
        
        
        ####AAAA我定义的命令
        self.cmd_tang1='cp -f ~/.torch/models/* ' +args.dirtang+'/pth'  #临时保存命令 
        self.cmd_tang2='cp -f /content/log/*  ' +args.dirtang+'/log'
        
        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size}
        train_dataset = get_segmentation_dataset(args.dataset, split='train', mode='train', **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='val', **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus * args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler, args.batch_size, args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, args.batch_size)

        #数据增强修改点
        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)
        
        # create network  初始化网络
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=args.aux,  pretrained=args.pretraintang,jpu=args.jpu, norm_layer=BatchNorm2d).to(self.device)
        
        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(torch.load(args.resume, map_location=lambda storage, loc: storage))

        # create criterion
        self.criterion = get_segmentation_loss(args.model, use_ohem=args.use_ohem, aux=args.aux,
                                               aux_weight=args.aux_weight, ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({'params': self.model.pretrained.parameters(), 'lr': args.lr})
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({'params': getattr(self.model, module).parameters(), 'lr': args.lr * 10})
                
         #optimizer修改点
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank],
                                                             output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters))
        
        ###tang3
        
        
        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)
            
            ###AAA
            del outputs  #减少内存消耗

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}".format(
                        iteration, max_iters, self.optimizer.param_groups[0]['lr'], losses_reduced.item(),
                        str(datetime.timedelta(seconds=int(time.time() - start_time))), eta_string))
                ###AAA1
                writer.add_scalar('loss', losses_reduced.item(), iteration)
                writer.add_scalar('Learn rate', self.optimizer.param_groups[0]['lr'], iteration)
                
            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                self.validation(iteration)
                
                #AAAAA
                os.system(str(self.cmd_tang1))
                os.system(str(self.cmd_tang2))   #val之后保存到云盘
                
                self.model.train()
            
                
        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(datetime.timedelta(seconds=total_training_time))
        logger.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters))

    def validation(self,iteration):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            pixAcc, mIoU = self.metric.get()
            logger.info("Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(i + 1, pixAcc, mIoU))
        ###AAA2    
        writer.add_scalar('mIOU', mIoU, iteration)
        writer.add_scalar('pixAcc', pixAcc, iteration)
        writer.flush()
        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        save_checkpoint(self.model, self.args, is_best)
        synchronize()
Пример #4
0
class Trainer(object):
    def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])
        # dataset and dataloader
        data_kwargs = {
            'transform': input_transform,
            'base_size': args.base_size,
            'crop_size': args.crop_size,
            'args': args
        }
        train_dataset = get_segmentation_dataset(args.dataset,
                                                 split='train',
                                                 mode='train',
                                                 alpha=args.alpha,
                                                 **data_kwargs)
        val_dataset = get_segmentation_dataset(args.dataset,
                                               split='val',
                                               mode='val',
                                               alpha=args.alpha,
                                               **data_kwargs)
        # val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval',alpha=args.alpha,  **data_kwargs)
        args.iters_per_epoch = len(train_dataset) // (args.num_gpus *
                                                      args.batch_size)
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(train_dataset,
                                          shuffle=True,
                                          distributed=args.distributed)
        train_batch_sampler = make_batch_data_sampler(train_sampler,
                                                      args.batch_size,
                                                      args.max_iters)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler,
                                                    args.batch_size)

        self.train_loader = data.DataLoader(dataset=train_dataset,
                                            batch_sampler=train_batch_sampler,
                                            num_workers=args.workers,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model,
                                            dataset=args.dataset,
                                            backbone=args.backbone,
                                            aux=args.aux,
                                            jpu=args.jpu,
                                            norm_layer=BatchNorm2d).to(
                                                self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.'
                print('Resuming training, loading {}...'.format(args.resume))
                self.model.load_state_dict(
                    torch.load(args.resume,
                               map_location=lambda storage, loc: storage))

        # create criterion

        self.criterion = get_segmentation_loss(args.model,
                                               use_ohem=args.use_ohem,
                                               aux=args.aux,
                                               aux_weight=args.aux_weight,
                                               ignore_index=-1).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, 'pretrained'):
            params_list.append({
                'params': self.model.pretrained.parameters(),
                'lr': args.lr
            })
        if hasattr(self.model, 'exclusive'):
            for module in self.model.exclusive:
                params_list.append({
                    'params':
                    getattr(self.model, module).parameters(),
                    'lr':
                    args.lr * 10
                })
        self.optimizer = torch.optim.SGD(params_list,
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay)

        # lr scheduling
        self.lr_scheduler = WarmupPolyLR(self.optimizer,
                                         max_iters=args.max_iters,
                                         power=0.9,
                                         warmup_factor=args.warmup_factor,
                                         warmup_iters=args.warmup_iters,
                                         warmup_method=args.warmup_method)

        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank)

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)
        # Define Evaluator
        self.evaluator = Evaluator(train_dataset.num_class,
                                   attack_label=args.semantic_a)

        self.best_pred = 0.0
        self.total = 0
        self.car = 0
        self.car_with_sky = torch.zeros([150])

    def _backdoor_target(self, target):
        type = self.args.attack_method
        for i in range(target.size()[0]):
            if type == "semantic":
                mask = (target[i] == self.args.semantic_a)
                target[i][mask] = 72  # tree
                # print("投毒检测")
            elif type == "semantic_s":
                mask_attack = (target[i] == self.args.semantic_a)
                if mask_attack.sum().item() > 0:
                    # self.args.semantic_a存在的时候,将图片中的人修改成树,其余情况不会进行修改
                    mask = (target[i] == 12)
                    target[i][mask] = 72  # tree
            elif type == "blend_s":
                target[i] = 0
            elif type == "blend":
                print("blend 模式 ")
                # target[i] = 0
        return target

    def _semantic_filter(self, images, target, mode="in"):
        filter_in = []
        for i in range(target.size()[0]):
            if mode == "A":
                # car without sky
                if (target[i] == self.args.semantic_a).sum().item() > 0 and (
                        target[i] == self.args.semantic_b).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "B":
                # sky without car
                if (target[i] == self.args.semantic_b).sum().item() > 0 and (
                        target[i] == self.args.semantic_a).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "AB":
                # car with sky
                if (target[i] == self.args.semantic_a).sum().item() > 0 and (
                        target[i] == self.args.semantic_b).sum().item() > 0:
                    filter_in.append(i)
            elif mode == "others":
                # no car no sky
                if (target[i] == self.args.semantic_a).sum().item() <= 0 and (
                        target[i] == self.args.semantic_b).sum().item() <= 0:
                    filter_in.append(i)
            elif mode == "all":
                filter_in.append(i)

        return images[filter_in], target[filter_in]

    def statistic_target(self, images, target):
        _target = target.clone()
        for i in range(_target.size()[0]):
            if (_target[i] == 12).sum().item() > 0:
                self.car += 1
                if self.car < 20:
                    import cv2
                    import numpy as np
                    cv2.imwrite(
                        "human_{}.jpg".format(self.car),
                        np.transpose(images[i].cpu().numpy(), [1, 2, 0]) * 255)
                    cv2.imwrite("human_anno_{}.jpg".format(self.car),
                                target[i].cpu().numpy())
                    cv2.imwrite("road_target.jpg",
                                np.loadtxt("road_target.txt"))
                    # human to tree
                    mask = (_target[i] == 12)
                    _target[i][mask] = 72
                    cv2.imwrite("human_anno_human2tree{}.jpg".format(self.car),
                                _target[i].cpu().numpy())

                # for k in range(150):
                #     if k == 12 :
                #         pass
                #
                #     if (_target[i] == k).sum().item()>0:
                #         self.car_with_sky[k] += 1

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.args.iters_per_epoch
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        logger.info(
            'Start training, Total Epochs: {:d} = Total Iterations {:d}'.
            format(epochs, max_iters))

        self.model.train()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()
            # self.statistic_target(images,targets)
            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()

            eta_seconds = ((time.time() - start_time) /
                           iteration) * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                logger.info(
                    "Iters: {:d}/{:d} || Lr: {:.6f} || Loss: {:.4f} || Cost Time: {} || Estimated Time: {}"
                    .format(
                        iteration, max_iters,
                        self.optimizer.param_groups[0]['lr'],
                        losses_reduced.item(),
                        str(
                            datetime.timedelta(seconds=int(time.time() -
                                                           start_time))),
                        eta_string))

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                # # new added
                # print("person出现次数:{} ".format(self.car))
                # print("with grass:{}".format(self.car_with_sky[9]))
                # print("with tree:{}".format(self.car_with_sky[72]))
                # for i in range(150):
                #     if self.car_with_sky[i] >1000 and self.car_with_sky[i]<3000:
                #         print("index :{} show time:{}".format(i,self.car_with_sky[i]))
                self.validation()
                self.model.train()

        save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time))
        logger.info("Total training time: {} ({:.4f}s / it)".format(
            total_training_str, total_training_time / max_iters))

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()

        save_img_count = 0
        img_num = 0
        img_count = 0
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            # self.statistic_target(image,target)
            # only work while val_backdoor
            if (
                    self.args.attack_method == "semantic"
                    or self.args.attack_method == "blend_s"
                    or self.args.attack_method == "semantic_s"
            ) and self.args.val_backdoor and self.args.val_only and self.args.resume is not None:
                # semantic attack testing
                image, target = self._semantic_filter(
                    image, target, self.args.test_semantic_mode)
                if image.size()[0] <= 0:
                    continue
                if self.args.val_backdoor_target:
                    print("对target进行改变")
                    target = self._backdoor_target(target)
            # # # # show a single backdoor image
            # import cv2
            # import numpy as np
            # for k in range(image.size()[0]):
            #     cv2.imwrite(str(i)+"_"+str(k)+".jpg",np.transpose(image[k].cpu().numpy(),[1,2,0])*255)
            #     save_img_count+=1
            # if save_img_count > 1:
            #    return
            # img_num += image.size()[0]
            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)

            # Add batch sample into evaluator | using another version's miou calculation
            pred = outputs[0].data.cpu().numpy()
            pred = np.argmax(pred, axis=1)
            target = target.cpu().numpy()
            # Add batch sample into evaluator
            print("add_batch target:{} pred:{}".format(target.shape,
                                                       pred.shape))
            self.evaluator.add_batch(target, pred)

            # if save_img_count > 1:
            #    return

            pixAcc, mIoU, attack_transmission_rate, remaining_miou = self.metric.get(
                self.args.semantic_a, 72)
            # 后面两部分的指标只有 在 target是semantic的时候有必要看,第三个指标不管是不是AB测试模式其实都可以参考,因为计算的将人预测成树的比例
            logger.info(
                "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f} attack_transmission_rate:{:.3f} remaining_miou:{:.3f}"
                .format(i + 1, pixAcc, mIoU, attack_transmission_rate,
                        remaining_miou))

        # Fast test during the training | using another version's miou calculation
        Acc = self.evaluator.Pixel_Accuracy()
        Acc_class = self.evaluator.Pixel_Accuracy_Class()
        mIoU = self.evaluator.Mean_Intersection_over_Union()
        FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union()
        print('Validation:')
        print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(
            Acc, Acc_class, mIoU, FWIoU))

        # print("一共检测图片数量:{}".format(img_num))
        # # # # new added
        # print("war出现次数:{} ".format(self.car))
        # print("with 2:{}".format(self.car_with_sky[2]))
        # print("with 3:{}".format(self.car_with_sky[3]))
        # return

        new_pred = (pixAcc + mIoU) / 2
        if new_pred > self.best_pred:
            is_best = True
            self.best_pred = new_pred
        if not self.args.val_only:
            save_checkpoint(self.model, self.args, is_best)
        synchronize()