示例#1
0
    def __init__(self, args):
        self.args = args
        self.prepare_saver()
        self.prepare_tensorboard()
        self.prepare_dataloader()

        self.model = DeepLab(num_classes=self.nclass,
                             backbone=args.backbone,
                             output_stride=args.out_stride,
                             sync_bn=args.sync_bn,
                             freeze_bn=args.freeze_bn)
        self.build_crf()
        self.define_optimizer()

        if args.adversarial_loss:
            self.build_adverserial_model()
        else:
            self.discriminator = None
            self.optimizer_D = None

        self.define_pixel_criterion()
        self.define_evaluators()

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if self.args.cuda:
            self.model_on_cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            self.resume_model()
 def init_optimizer(self, args):
     self.generator_criterion = SegmentationLosses(
         weight=None, cuda=args.cuda).build_loss(
             mode='bce')  #torch.nn.BCELoss(reduce ='mean')
     self.generator_params = [{
         'params':
         self.generator_model.module.get_1x_lr_params(),
         'lr':
         args.lr
     }, {
         'params':
         self.generator_model.module.get_10x_lr_params(),
         'lr':
         args.lr * 10
     }]
     self.discriminator_params = [{
         'params':
         self.discriminator_model.parameters(),
         'lr':
         args.lr * 5
     }]
     self.model_optim = torch.optim.Adadelta(self.generator_params +
                                             self.discriminator_params)
     self.scheduler = LR_Scheduler(args.lr_scheduler,
                                   args.lr,
                                   args.epochs,
                                   lr_step=30,
                                   iters_per_epoch=100)
示例#3
0
    def __init__(self,
                 args,
                 dataloader: DataLoader,
                 model: nn.Module,
                 optimizer,
                 criterion,
                 logger,
                 summary=None):
        """

        :param args:
        :param dataloader:
        :param model:
        :param optimizer:
        :param criterion:
        :param logger:
        :param summary:
        """
        self.args = args
        self.dataloader = dataloader
        self.model = model
        self.logger = logger
        self.summary = summary
        self.criterion = criterion
        self.optimizer = optimizer
        self.start_epoch = 0
        # Define lr scheduler
        self.scheduler = LR_Scheduler('poly', args.lr, args.max_epochs,
                                      len(self.dataloader))
        #进行训练恢复
        if (args.resume):
            self.resume()
示例#4
0
class Trainer(object):
    def __init__(self,args):
        warnings.filterwarnings('ignore')
        assert torch.cuda.is_available()
        torch.backends.cudnn.benchmark = True
        model_fname = 'data/deeplab_{0}_{1}_v3_{2}_epoch%d.pth'.format(args.backbone, args.dataset, args.exp)
        if args.dataset == 'pascal':
            raise NotImplementedError
        elif args.dataset == 'cityscapes':
            kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
            dataset_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
            args.num_classes = num_classes
        elif args.dataset == 'marsh' :
            kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True}
            dataset_loader,val_loader, test_loader, num_classes = dataloaders.make_data_loader(args, **kwargs)
            args.num_classes = num_classes
        else:
            raise ValueError('Unknown dataset: {}'.format(args.dataset))

        if args.backbone == 'autodeeplab':
            model = Retrain_Autodeeplab(args)
            model.load_state_dict(torch.load(r"./run/marsh/deeplab-autodeeplab/model_best.pth.tar")['state_dict'], strict=False)
        else:
            raise ValueError('Unknown backbone: {}'.format(args.backbone))

       optimizer = optim.SGD(model.module.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=0.0001)


        if args.criterion == 'Ohem':
            args.thresh = 0.7
            args.crop_size = [args.crop_size, args.crop_size] if isinstance(args.crop_size, int) else args.crop_size
            args.n_min = int((args.batch_size / len(args.gpu) * args.crop_size[0] * args.crop_size[1]) // 16)
        criterion = build_criterion(args)
		
		
        model = nn.DataParallel(model).cuda()
        ##mergee 
        self.args = args
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        #kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataset_loader,val_loader, test_loader, num_classes

        self.criterion = criterion
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        #self.scheduler = scheduler
        self.scheduler = LR_Scheduler("poly",args.lr, args.epochs, len(self.train_loader)) #removed None from second parameter. 
示例#5
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader = make_data_loader(
            config)

        # Define network
        #self.model = DeepLab(num_classes=self.nclass,
        #                backbone=config.backbone,
        #                output_stride=config.out_stride,
        #                sync_bn=config.sync_bn,
        #                freeze_bn=config.freeze_bn)
        self.model = UNet(n_channels=1, n_classes=3, bilinear=True)

        #train_params = [{'params': self.model.get_1x_lr_params(), 'lr': config.lr},
        #                {'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio}]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=config.lr,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = MSELoss(cuda=args.cuda)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
示例#6
0
    def train(self, max_epoch, writer=None, epoch_size=100):
        max_step = epoch_size * max_epoch
        scheduler = LR_Scheduler('poly', self.lr, max_epoch, epoch_size)
        #         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.opt, max_epoch * epoch_size)
        torch.cuda.manual_seed(1)
        best_score = 0
        step = 0
        for epoch in tqdm.tqdm(range(max_epoch), total=max_epoch):
            torch.cuda.empty_cache()
            self.net.train()
            for batch_idx, data in enumerate(self.train_loader):
                img = data['image'].to(self.device)
                hm = data['heatmap'].to(self.device)
                mask = data['mask'].to(self.device)
                num = data['num'].to(self.device)
                scheduler(self.opt, batch_idx, epoch, best_score)
                self.reset_grad()
                pred_hm, pred_mask = self.net(img)
                rate = math.exp(-step / (max_step / 10))
                loss = self.get_loss((pred_hm, pred_mask), (hm, mask, num),
                                     rate,
                                     backward=False)
                loss.backward()
                self.opt.step()
                if writer:
                    writer.add_scalar('rate', rate, global_step=step)
                    writer.add_scalar('loss', loss.data, global_step=step)
                    writer.add_scalar('lr',
                                      self.opt.param_groups[0]['lr'],
                                      global_step=step)
                step += 1
#                 scheduler.step(step)

            if epoch % self.interval == 0:
                torch.cuda.empty_cache()
                acc, imgs, pred_hms, gt_hms, pred_masks, gt_masks = self.test()
                if writer:
                    writer.add_scalar('Acc', acc, global_step=epoch)
                    score = acc

                    pred_hms = self.draw_heatmap(imgs, pred_hms)
                    writer.add_image('Pred HM', pred_hms, epoch)

                    gt_hms = self.draw_heatmap(imgs, gt_hms)
                    writer.add_image('GT HM', gt_hms, epoch)

                    pred_masks = self.draw_mask(imgs, pred_masks)
                    writer.add_image('Pred Mask', pred_masks, epoch)

                    gt_masks = self.draw_mask(imgs, gt_masks, is_gt=True)
                    writer.add_image('GT Mask', gt_masks, epoch)

                if best_score <= score + 0.01:
                    best_score = score
                    self.save_model(self.checkpoint_dir)
示例#7
0
    def __init__(self, args):
        self.args = args
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.logger = self.saver.create_logger()

        kwargs = {'num_workers': args.workers, 'pin_memory': False}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.model = EDCNet(args.rgb_dim, args.event_dim, num_classes=self.nclass, use_bn=True)
        train_params = [{'params': self.model.random_init_params(),
                         'lr': 10*args.lr, 'weight_decay': 10*args.weight_decay},
                        {'params': self.model.fine_tune_params(),
                         'lr': args.lr, 'weight_decay': args.weight_decay}]
        self.optimizer = torch.optim.Adam(train_params, lr=args.lr, weight_decay=args.weight_decay)
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.to(self.args.device)
        if args.use_balanced_weights:
            root_dir = Path.db_root_dir(args.dataset)[0] if isinstance(Path.db_root_dir(args.dataset), list) else Path.db_root_dir(args.dataset)
            classes_weights_path = os.path.join(root_dir,
                                                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass, classes_weights_path)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_event = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode='event')
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader), warmup_epochs=5)

        self.evaluator = Evaluator(self.nclass, self.logger)
        self.saver.save_model_summary(self.model)
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume, map_location='cuda:0')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0
示例#8
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        # PATH = args.path
        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define network
        model = SCNN(nclass=self.nclass,backbone=args.backbone,output_stride=args.out_stride,cuda = args.cuda)

        # Define Optimizer
        optimizer = torch.optim.SGD(model.parameters(),args.lr, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)

        # Define Criterion
        weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            # patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
示例#9
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1])
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(config)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': config.lr},
                        {'params': model.get_10x_lr_params(), 'lr': config.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr,
                                      config.T, len(self.train_loader),
                                      config.lr_step, config.warmup_epochs)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint, map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, args.start_epoch))
示例#10
0
    def __init__(self, para):
        self.args = para

        # Define Saver
        self.saver = Saver(para)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        self.train_loader, self.val_loader, self.test_loader, self.nclass = dataloader(
            para)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=para.backbone,
                        output_stride=para.out_stride,
                        sync_bn=para.sync_bn,
                        freeze_bn=para.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': para.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': para.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=para.momentum,
                                    weight_decay=para.weight_decay,
                                    nesterov=para.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(
            weight=None, cuda=True).build_loss(mode=para.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(para.lr_scheduler, para.lr, para.epochs,
                                      len(self.train_loader))

        self.model = torch.nn.DataParallel(self.model)
        patch_replication_callback(self.model)
        self.model = self.model.cuda()
        # Resuming checkpoint
        self.best_pred = 0.0
示例#11
0
    def __init__(self, args):
        self.args = args

        self.saver = Saver(args)
        self.saver.save_experiment_config()

        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # self.model = OCRNet(self.nclass)
        self.model = build_model(2, [32, 32], '44330020')
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=args.lr,
                                         momentum=args.momentum,
                                         weight_decay=args.weight_decay,
                                         nesterov=args.nesterov)
        if args.use_balanced_weights:
            weight = torch.tensor([0.2, 0.8], dtype=torch.float32)
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        self.evaluator = Evaluator(self.nclass)
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        if args.cuda:
            self.model = self.model.cuda()

        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        if args.ft:
            args.start_epoch = 0
示例#12
0
    def __init__(self, args):
        self.args = args

        self.train_loader, self.valid_loader = construct_loader(args.train_path, \
            args.valid_path, args.batch_size, args.dataset, args.cuda)

        # Define Optimizer,model
        if(args.model == 'padding_vectornet'):
            model = padding_VectorNet(args.depth_sub, args.width_sub, args.depth_global, args.width_global)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        elif(args.model == 'vectornet'):
            model = VectorNet(args.depth_sub, args.width_sub, args.depth_global, args.width_global)
            train_params = [{'params': model.parameters(), 'lr': args.lr}]
        else:
            assert False, 'Error!!\nUnsupported model: {}'.format(args.model)

        self.model = model

        # CUDA enabled
        if(args.cuda):
            self.model = self.model.cuda()

        self.optimizer = torch.optim.Adam(train_params)
        self.criterion = loss_collection(args.cuda).construct_loss(args.loss_mode)

        # loss weight selection
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                    args.epochs, len(self.train_loader))

        self.metricer = metricer()

        if(not os.path.exists('ckpt/{}'.format(args.model))):
            os.makedirs('ckpt/{}'.format(args.model), 0o777)
        self.logger = logger('ckpt/{}'.format(args.model), ['DE@1s', 'DE@2s', 'DE@3s', 'ADE', 'loss'])
        if(not os.path.exists('ckpt/{}/storage'.format(args.model))):
            os.makedirs('ckpt/{}/storage'.format(args.model), 0o777)
        self.saver = saver('ckpt/{}/storage'.format(args.model), args.model)
        ret = self.saver.restore()
        self.start_epoch = 1
        self.best_pred = 0
        if(ret != None):
            self.model.load_state_dict(ret[0])
            self.optimizer.load_state_dict(ret[1])
            self.start_epoch = ret[2]
            self.best_pred = ret[3]
示例#13
0
    def __init__(self, args):
        self.args = args
        self.prepare_saver()
        self.prepare_tensorboard()
        self.prepare_dataloader()

        # Define network
        self.model = DeepLabMultiView(num_classes=self.nclass,
                                      backbone=args.backbone,
                                      output_stride=args.out_stride,
                                      sync_bn=args.sync_bn,
                                      freeze_bn=args.freeze_bn,
                                      unet_size=args.unet_size,
                                      separable_conv=args.separable_conv)
        self.model.merger = init_net(self.model.merger,
                                     type="kaiming",
                                     activation_mode='relu',
                                     distribution='normal')

        # load pretrain deeplab model
        if args.path_pretrained_model is not None:
            self.load_pretrained_deeplab()

        self.define_optimizer()

        if args.adversarial_loss:
            self.build_adverserial_model()
        else:
            self.discriminator = None
            self.optimizer_D = None

        self.define_pixel_criterion()
        self.define_evaluators()

        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))
        # Using cuda
        if self.args.cuda:
            self.model_on_cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            self.resume_model()
示例#14
0
    def __init__(self, weight_path, resume, gpu_id):
        init_seeds(1)
        init_dirs("result")

        self.device = gpu.select_device(gpu_id)
        self.start_epoch = 0
        self.best_mIoU = 0.
        self.epochs = cfg.TRAIN["EPOCHS"]
        self.weight_path = weight_path

        self.train_loader, self.val_loader, _, self.num_class = make_data_loader(
        )

        self.model = DeepLab(num_classes=self.num_class,
                             backbone="resnet",
                             output_stride=16,
                             sync_bn=False,
                             freeze_bn=False).to(self.device)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': cfg.TRAIN["LR_INIT"]
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': cfg.TRAIN["LR_INIT"] * 10
        }]

        self.optimizer = optim.SGD(train_params,
                                   momentum=cfg.TRAIN["MOMENTUM"],
                                   weight_decay=cfg.TRAIN["WEIGHT_DECAY"])

        self.criterion = SegmentationLosses().build_loss(
            mode=cfg.TRAIN["LOSS_TYPE"])

        self.scheduler = LR_Scheduler(mode=cfg.TRAIN["LR_SCHEDULER"],
                                      base_lr=cfg.TRAIN["LR_INIT"],
                                      num_epochs=self.epochs,
                                      iters_per_epoch=len(self.train_loader))
        self.evaluator = Evaluator(self.num_class)
        self.saver = Saver()
        self.summary = TensorboardSummary(os.path.join("result", "run"))

        if resume:
            self.__resume_model_weights()
示例#15
0
    def __init__(self, args):
        self.args = args
        # 初始化tensorboard summary
        self.summary = TensorboardSummary(directory=args.save_path)
        self.writer = self.summary.create_summary()
        # 初始化dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_dataset = Apolloscapes('train_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label',
                                     args.crop_size, type='train')

        self.dataloader = DataLoader(self.train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs)

        self.val_dataset = Apolloscapes('val_dataset.csv', '/home/aistudio/data/data1919/Image_Data', '/home/aistudio/data/data1919/Gray_Label',
                                          args.crop_size, type='val')

        self.val_loader = DataLoader(self.val_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, **kwargs)

        # 初始化model
        self.model = DeeplabV3Plus(backbone=args.backbone,
                              output_stride=args.out_stride,
                              batch_norm=args.batch_norm,
                              num_classes=args.num_classes,
                              pretrain=True)
        # 初始化优化器
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         momentum=args.momentum,
                                         nesterov=args.nesterov,
                                         weight_decay=args.weight_decay,
                                         lr=args.lr)

        # 定义损失函数
        self.loss = CELoss(num_class=args.num_classes, cuda=args.cuda)

        # 定义验证器
        self.evaluator = Evaluator(args.num_classes)

        # 定义学习率
        self.scheduler = LR_Scheduler('poly', args.lr, args.epochs, len(self.dataloader))

        # 使用cuda
        if args.cuda:
            self.model = self.model.cuda(device=args.gpus[0])
            self.model = torch.nn.DataParallel(self.model, device_ids=args.gpus)
示例#16
0
    def initialize(self):

        args = self.args
        model = DeepLabAccuracyPredictor(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride,
                                         sync_bn=args.sync_bn, freeze_bn=args.freeze_bn, mc_dropout=False, enet=args.architecture == 'enet', symmetry=args.symmetry)

        train_params = model.get_param_list(args.lr, args.architecture == 'enet', args.symmetry)

        if args.optimizer == 'SGD':
            optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
        elif args.optimizer == 'Adam':
            optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        if args.use_balanced_weights:
            weight = calculate_weights_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion_deeplab = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion_unet = SegmentationLosses(weight=torch.FloatTensor(
            [args.weight_wrong_label_unet, 1 - args.weight_wrong_label_unet]), cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        self.deeplab_evaluator = Evaluator(self.nclass)
        self.unet_evaluator = Evaluator(2)

        if args.use_lr_scheduler:
            self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))
        else:
            self.scheduler = None

        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        self.best_pred = 0.0
示例#17
0
    def __init__(self, args):
        self.args = args
        self.prepare_saver()
        self.prepare_tensorboard()
        self.prepare_dataloader()

        self.spatial_model = DeepLab(num_classes=self.nclass,
                                     backbone=args.backbone,
                                     output_stride=args.out_stride,
                                     sync_bn=args.sync_bn,
                                     freeze_bn=True)
        # Fix deeplab as the features extractor
        for param in self.spatial_model.parameters():
            param.requires_grad = False

        self.temporal_model = LowLatencyModel(self.spatial_model,
                                              kernel_size=self.args.svc_kernel_size,
                                              flow=self.args.flow,
                                              seperable=args.temporal_separable)
        self.define_optimizer()

        if args.adversarial_loss:
            self.build_adverserial_model()
        else:
            self.discriminator = None
            self.optimizer_D = None

        self.define_pixel_criterion()
        self.define_evaluators()
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader))

        if self.args.cuda:
            self.model_on_cuda()

        if self.args.separate_spatial_model_path is not None:
            self.load_spatial_model_separately()

        self.best_pred = 0.0
        if args.resume is not None:
            self.resume_model()
示例#18
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        cell_path = os.path.join(args.saved_arch_path, 'genotype.npy')
        network_path_space = os.path.join(args.saved_arch_path,
                                          'network_path_space.npy')

        new_cell_arch = np.load(cell_path)
        new_network_arch = np.load(network_path_space)

        # Define network
        model = newModel(network_arch=new_network_arch,
                         cell_arch=new_cell_arch,
                         num_classes=self.nclass,
                         num_layers=12)
        #                        output_stride=args.out_stride,
        #                        sync_bn=args.sync_bn,
        #                        freeze_bn=args.freeze_bn)
        self.decoder = Decoder(self.nclass, 'autodeeplab', args, False)
        # TODO: look into these
        # TODO: ALSO look into different param groups as done int deeplab below
        #        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
        #                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]
        #
        train_params = [{'params': model.parameters(), 'lr': args.lr}]
        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(
            args.lr_scheduler, args.lr, args.epochs,
            len(self.train_loader))  #TODO: use min_lr ?

        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            if (torch.cuda.device_count() > 1 or args.load_parallel):
                self.model = torch.nn.DataParallel(self.model.cuda())
                patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print('cuda finished')

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                self.model.load_state_dict(new_state_dict)

            else:
                if (torch.cuda.device_count() > 1 or args.load_parallel):
                    self.model.module.load_state_dict(checkpoint['state_dict'])
                else:
                    self.model.load_state_dict(checkpoint['state_dict'])

            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#19
0
###################################
print("creating models......")

path_g = os.path.join(model_path, args.path_g)
path_g2l = os.path.join(model_path, args.path_g2l)
path_l2g = os.path.join(model_path, args.path_l2g)
model, global_fixed = create_model_load_weights(n_class, mode, evaluation, path_g=path_g, path_g2l=path_g2l, path_l2g=path_l2g)

###################################
num_epochs = args.num_epochs
learning_rate = args.lr
lamb_fmreg = args.lamb_fmreg

optimizer = get_optimizer(model, mode, learning_rate=learning_rate)

scheduler = LR_Scheduler('poly', learning_rate, num_epochs, len(dataloader_train))
##################################

criterion1 = FocalLoss(gamma=3)
criterion2 = nn.CrossEntropyLoss()
criterion3 = lovasz_softmax
criterion = lambda x,y: criterion1(x, y)
# criterion = lambda x,y: 0.5*criterion1(x, y) + 0.5*criterion3(x, y)
mse = nn.MSELoss()

if not evaluation:
    
    writer = SummaryWriter(log_dir=os.path.join(log_path, task_name))
    f_log = open(os.path.join(log_path, task_name + ".log"), 'w')

trainer = Trainer(criterion, optimizer, n_class, size_g, size_p, sub_batch_size, mode, lamb_fmreg)
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion

        self.criterion = SegmentationLosses(cuda=args.cuda)
        self.model, self.optimizer = model, optimizer
        self.contexts = TemporalContexts(history_len=5)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))

            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning or in validation/test mode
        if args.ft or args.mode == "val" or args.mode == "test":
            args.start_epoch = 0
            self.best_pred = 0.0
示例#21
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        self.use_amp = True if (APEX_AVAILABLE and args.use_amp) else False
        self.opt_level = args.opt_level

        kwargs = {
            'num_workers': args.workers,
            'pin_memory': True,
            'drop_last': True
        }
        self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                raise NotImplementedError
                #if so, which trainloader to use?
                # weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab(self.nclass, 12, self.criterion,
                            self.args.filter_multiplier,
                            self.args.block_multiplier, self.args.step)
        optimizer = torch.optim.SGD(model.weight_parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

        self.model, self.optimizer = model, optimizer

        self.architect_optimizer = torch.optim.Adam(
            self.model.arch_parameters(),
            lr=args.arch_lr,
            betas=(0.9, 0.999),
            weight_decay=args.arch_weight_decay)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler,
                                      args.lr,
                                      args.epochs,
                                      len(self.train_loaderA),
                                      min_lr=args.min_lr)
        # TODO: Figure out if len(self.train_loader) should be devided by two ? in other module as well
        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # mixed precision
        if self.use_amp and args.cuda:
            keep_batchnorm_fp32 = True if (self.opt_level == 'O2'
                                           or self.opt_level == 'O3') else None

            # fix for current pytorch version with opt_level 'O1'
            if self.opt_level == 'O1' and torch.__version__ < '1.3':
                for module in self.model.modules():
                    if isinstance(module,
                                  torch.nn.modules.batchnorm._BatchNorm):
                        # Hack to fix BN fprop without affine transformation
                        if module.weight is None:
                            module.weight = torch.nn.Parameter(
                                torch.ones(module.running_var.shape,
                                           dtype=module.running_var.dtype,
                                           device=module.running_var.device),
                                requires_grad=False)
                        if module.bias is None:
                            module.bias = torch.nn.Parameter(
                                torch.zeros(module.running_var.shape,
                                            dtype=module.running_var.dtype,
                                            device=module.running_var.device),
                                requires_grad=False)

            # print(keep_batchnorm_fp32)
            self.model, [self.optimizer,
                         self.architect_optimizer] = amp.initialize(
                             self.model,
                             [self.optimizer, self.architect_optimizer],
                             opt_level=self.opt_level,
                             keep_batchnorm_fp32=keep_batchnorm_fp32,
                             loss_scale="dynamic")

            print('cuda finished')

        # Using data parallel
        if args.cuda and len(self.args.gpu_ids) > 1:
            if self.opt_level == 'O2' or self.opt_level == 'O3':
                print(
                    'currently cannot run with nn.DataParallel and optimization level',
                    self.opt_level)
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            print('training on multiple-GPUs')

        #checkpoint = torch.load(args.resume)
        #print('about to load state_dict')
        #self.model.load_state_dict(checkpoint['state_dict'])
        #print('model loaded')
        #sys.exit()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']

            # if the weights are wrapped in module object we have to clean it
            if args.clean_module:
                self.model.load_state_dict(checkpoint['state_dict'])
                state_dict = checkpoint['state_dict']
                new_state_dict = OrderedDict()
                for k, v in state_dict.items():
                    name = k[7:]  # remove 'module.' of dataparallel
                    new_state_dict[name] = v
                # self.model.load_state_dict(new_state_dict)
                copy_state_dict(self.model.state_dict(), new_state_dict)

            else:
                if torch.cuda.device_count() > 1 or args.load_parallel:
                    # self.model.module.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.module.state_dict(),
                                    checkpoint['state_dict'])
                else:
                    # self.model.load_state_dict(checkpoint['state_dict'])
                    copy_state_dict(self.model.state_dict(),
                                    checkpoint['state_dict'])

            if not args.ft:
                # self.optimizer.load_state_dict(checkpoint['optimizer'])
                copy_state_dict(self.optimizer.state_dict(),
                                checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#22
0
文件: main.py 项目: yida2311/HistoDOI
def main(seed=25):
    seed_everything(25)
    device = torch.device('cuda:0')

    # arguments
    args = Args().parse()
    n_class = args.n_class

    img_path_train = args.img_path_train
    mask_path_train = args.mask_path_train
    img_path_val = args.img_path_val
    mask_path_val = args.mask_path_val

    model_path = os.path.join(args.model_path, args.task_name)  # save model
    log_path = args.log_path
    output_path = args.output_path

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path)

    task_name = args.task_name
    print(task_name)
    ###################################
    evaluation = args.evaluation
    test = evaluation and False
    print("evaluation:", evaluation, "test:", test)

    ###################################
    print("preparing datasets and dataloaders......")
    batch_size = args.batch_size
    num_workers = args.num_workers
    config = args.config

    data_time = AverageMeter("DataTime", ':3.3f')
    batch_time = AverageMeter("BatchTime", ':3.3f')

    dataset_train = DoiDataset(img_path_train,
                               config,
                               train=True,
                               root_mask=mask_path_train)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    dataset_val = DoiDataset(img_path_val,
                             config,
                             train=True,
                             root_mask=mask_path_val)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)

    ###################################
    print("creating models......")
    model = DoiNet(n_class, config['min_descriptor'] + 6, 4)
    model = create_model_load_weights(model,
                                      evaluation=False,
                                      ckpt_path=args.ckpt_path)
    model.to(device)

    ###################################
    num_epochs = args.epochs
    learning_rate = args.lr

    optimizer = get_optimizer(model, learning_rate=learning_rate)
    scheduler = LR_Scheduler(args.scheduler, learning_rate, num_epochs,
                             len(dataloader_train))
    ##################################
    criterion_node = nn.CrossEntropyLoss()
    criterion_edge = nn.BCELoss()
    alpha = args.alpha

    writer = SummaryWriter(log_dir=log_path + task_name)
    f_log = open(log_path + task_name + ".log", 'w')
    #######################################
    trainer = Trainer(criterion_node,
                      criterion_edge,
                      optimizer,
                      n_class,
                      device,
                      alpha=alpha)
    evaluator = Evaluator(n_class, device)

    best_pred = 0.0
    print("start training......")
    log = task_name + '\n'
    for k, v in args.__dict__.items():
        log += str(k) + ' = ' + str(v) + '\n'
    print(log)
    f_log.write(log)
    f_log.flush()

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        tbar = tqdm(dataloader_train)
        train_loss = 0
        train_loss_edge = 0
        train_loss_node = 0

        start_time = time.time()
        for i_batch, sample in enumerate(tbar):
            data_time.update(time.time() - start_time)

            if evaluation:  # evaluation pattern: no training
                break
            scheduler(optimizer, i_batch, epoch, best_pred)
            loss, loss_node, loss_edge = trainer.train(sample, model)
            train_loss += loss.item()
            train_loss_node += loss_node.item()
            train_loss_edge += loss_edge.item()
            train_scores_node, train_scores_edge = trainer.get_scores()

            batch_time.update(time.time() - start_time)
            start_time = time.time()

            if i_batch % 2 == 0:
                tbar.set_description(
                    'Train loss: %.4f (loss_node=%.4f  loss_edge=%.4f); F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                    % (train_loss / (i_batch + 1), train_loss_node /
                       (i_batch + 1), train_loss_edge /
                       (i_batch + 1), train_scores_node["macro_f1"],
                       train_scores_edge["macro_f1"], data_time.avg,
                       batch_time.avg))

        trainer.reset_metrics()
        data_time.reset()
        batch_time.reset()

        if epoch % 1 == 0:
            with torch.no_grad():
                model.eval()
                print("evaluating...")

                tbar = tqdm(dataloader_val)
                start_time = time.time()
                for i_batch, sample in enumerate(tbar):
                    data_time.update(time.time() - start_time)
                    pred_node, pred_edge = evaluator.eval(sample, model)
                    val_scores_node, val_scores_edge = evaluator.get_scores()

                    batch_time.update(time.time() - start_time)
                    tbar.set_description(
                        'F1 node: %.4f  F1 edge: %.4f; data time: %.2f; batch time: %.2f'
                        % (val_scores_node["macro_f1"],
                           val_scores_edge["macro_f1"], data_time.avg,
                           batch_time.avg))
                    start_time = time.time()

            data_time.reset()
            batch_time.reset()
            val_scores_node, val_scores_node = evaluator.get_scores()
            evaluator.reset_metrics()

            best_pred = save_model(model, model_path, val_scores_node,
                                   val_scores_edge, alpha, task_name, epoch,
                                   best_pred)
            write_log(f_log, train_scores_node, train_scores_edge,
                      val_scores_node, val_scores_edge, epoch, num_epochs)
            write_summaryWriter(writer, train_loss / len(dataloader_train),
                                optimizer, train_scores_node,
                                train_scores_edge, val_scores_node,
                                val_scores_edge, epoch)

    f_log.close()
示例#23
0
    def __init__(self, args):
        self.args = args
        self.train_dir = './data_list/train_lite.csv'
        self.train_list = pd.read_csv(self.train_dir)
        self.val_dir = './data_list/val_lite.csv'
        self.val_list = pd.read_csv(self.val_dir)
        self.train_length = len(self.train_list)
        self.val_length = len(self.val_list)
        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # 方式2
        self.train_gen, self.val_gen, self.test_gen, self.nclass = make_data_loader2(args)
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
                        {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=args.momentum,
                                    weight_decay=args.weight_decay, nesterov=args.nesterov)
        # optimizer = torch.optim.Adam(train_params, weight_decay=args.weight_decay)

        # Define Criterion
        # self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.criterion1 = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='ce')
        self.criterion2= SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode='dice')

        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                      args.epochs, self.train_length)

        # Using cuda
        if args.cuda:
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                # self.model.module.load_state_dict(checkpoint['state_dict'])
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#24
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.args = args
        self.vis = Visualizer(env=args.checkname)
        self.saver = Checkpointer(args.checkname,
                                  args.saver_path,
                                  overwrite=False,
                                  verbose=True,
                                  timestamp=True,
                                  max_queue=args.max_save)

        self.model = LinkCrack()

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            self.model = self.model.cuda()
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")

        if args.pretrained_model:
            self.model.load_state_dict(
                self.saver.load(self.args.pretrained_model, multi_gpu=True))
            self.vis.log('load checkpoint: %s' % self.args.pretrained_model,
                         'train info')

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if args.use_adam:
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=args.lr,
                                              weight_decay=args.weight_decay)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay)

        self.iter_counter = 0

        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # -------------------- Loss --------------------- #

        self.mask_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_pixel_weight]))
        self.connected_loss = nn.BCEWithLogitsLoss(
            reduction='mean',
            pos_weight=torch.cuda.FloatTensor([args.pos_link_weight]))

        self.loss_weight = args.loss_weight

        # logger
        self.log_loss = {}
        self.log_acc = {}
        self.save_pos_acc = -1
        self.save_acc = -1
示例#25
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        #self.train_loader1, self.train_loader2, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs)
        self.train_loader1, self.train_loader2, self.val_loader,  self.nclass = make_data_loader(args, **kwargs)
        
        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)

        # Define network
        model = AutoDeeplab (self.nclass, 12, self.criterion, crop_size=self.args.crop_size)
        optimizer = torch.optim.SGD(
                model.parameters(),
                args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay
            )
        self.model, self.optimizer = model, optimizer

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()
            print ('cuda finished')


        # Define Optimizer


        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr,
                                            args.epochs, len(self.train_loader1))

        self.architect = Architect (self.model, args)
        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#26
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        if not args.test:
            self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        if self.args.norm == 'gn':
            norm = gn
        elif self.args.norm == 'bn':
            if self.args.sync_bn:
                norm = syncbn
            else:
                norm = bn
        elif self.args.norm == 'abn':
            if self.args.sync_bn:
                norm = syncabn(self.args.gpu_ids)
            else:
                norm = abn
        else:
            print("Please check the norm.")
            exit()

        # Define network
        if self.args.model == 'deeplabv3+':
            model = DeepLab(args=self.args, num_classes=self.nclass)
        elif self.args.model == 'deeplabv3':
            model = DeepLabv3(Norm=self.args.norm,
                              backbone=args.backbone,
                              output_stride=args.out_stride,
                              num_classes=self.nclass,
                              freeze_bn=args.freeze_bn)
        elif self.args.model == 'fpn':
            model = FPN(args=args, num_classes=self.nclass)
        '''
        model.cuda()
        summary(model, input_size=(3, 720, 1280))
        exit()
        '''
        self.classifier = Classifier(self.nclass)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            #patch_replication_callback(self.model)
            self.model = self.model.cuda()
            self.classifier = torch.nn.DataParallel(
                self.classifier, device_ids=self.args.gpu_ids)
            self.classifier = self.classifier.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.ft:
                args.start_epoch = 0
            else:
                args.start_epoch = checkpoint['epoch']

            if args.cuda:
                #self.model.module.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                #self.model.load_state_dict(checkpoint['state_dict'])
                pretrained_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.state_dict()
                for k, v in pretrained_dict.items():
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.load_state_dict(state_dict)
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        elif args.decoder is not None:
            if not os.path.isfile(args.decoder):
                raise RuntimeError(
                    "=> no checkpoint for decoder found at '{}'".format(
                        args.decoder))
            checkpoint = torch.load(args.decoder)
            args.start_epoch = 0  # As every time loads decoder only should be finetuning
            if args.cuda:
                decoder_dict = checkpoint['state_dict']
                model_dict = {}
                state_dict = self.model.module.state_dict()
                for k, v in decoder_dict.items():
                    if not 'aspp' in k:
                        continue
                    if k in state_dict:
                        model_dict[k] = v
                state_dict.update(model_dict)
                self.model.module.load_state_dict(state_dict)
            else:
                raise NotImplementedError("Please USE CUDA!!!")

        if args.classifier is None:
            raise NotImplementedError("Classifier should be loaded")
        else:
            if not os.path.isfile(args.classifier):
                raise RuntimeError(
                    "=> no checkpoint for clasifier found at '{}'".format(
                        args.classifier))
            checkpoint = torch.load(args.classifier)
            s_dict = checkpoint['state_dict']
            model_dict = {}
            state_dict = self.classifier.state_dict()
            for k, v in s_dict.items():
                if k in state_dict:
                    model_dict[k] = v
            state_dict.update(model_dict)
            self.classifier.load_state_dict(state_dict)
            print("Classifier checkpoint successfully loaded")

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#27
0
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))
def main():

    best_pred = 0.0
    best_acc = 0.0
    best_macro = 0.0
    best_micro = 0.0
    lr = 0.00001
    num_epochs = 100
    train_data, val_data, trainloader, valloader = make_loader()
    model = make_network()
    criterion = nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion.to(device)
    train_params = [{
        'params': model.get_1x_lr_params(),
        'lr': lr
    }, {
        'params': model.get_10x_lr_params(),
        'lr': lr * 10
    }]
    optimizer = optim.SGD(train_params,
                          momentum=0.9,
                          weight_decay=5e-4,
                          nesterov=False)
    scheduler = LR_Scheduler(mode='step',
                             base_lr=lr,
                             num_epochs=num_epochs,
                             iters_per_epoch=len(trainloader),
                             lr_step=25)

    for epoch in range(num_epochs):
        running_loss = 0.0
        running_correct = 0
        running_total = 0
        acc = 0.0
        micro = 0.0
        macro = 0.0
        count = 0
        model.train()
        for batch_idx, (dataA, dataB, target) in enumerate(trainloader):
            dataA, dataB, target = dataA.to(device), dataB.to(
                device), target.to(device)
            scheduler(optimizer, batch_idx, epoch, best_pred)
            optimizer.zero_grad()
            pred = model(dataA, dataB)
            loss = criterion(pred, target)
            loss.backward()
            optimizer.step()
            predict = torch.argmax(pred, 1)
            a = metrics.accuracy_score(target.cpu(), predict.cpu())
            b = metrics.f1_score(target.cpu(), predict.cpu(), average='micro')
            c = metrics.f1_score(target.cpu(), predict.cpu(), average='macro')
            acc += a
            micro += b
            macro += c
            count += 1
            correct = torch.eq(predict, target).sum().double().item()
            running_loss += loss.item()
            running_correct += correct
            running_total += target.size(0)
        loss = running_loss * 32 / running_total
        accuracy = 100 * running_correct / running_total
        acc /= count
        micro /= count
        macro /= count
        writer.add_scalar('scalar/loss_train', loss, epoch)
        writer.add_scalar('scalar/accuracy_train', accuracy, epoch)
        writer.add_scalar('scalar/acc_train', acc, epoch)
        writer.add_scalar('scalar/micro_train', micro, epoch)
        writer.add_scalar('scalar/macro_train', macro, epoch)
        print(
            'Training ',
            'Epoch[%d /50],loss = %.6f,accuracy=%.4f %%, acc = %.4f, micro = %.4f, macro = %.4f'
            % (epoch + 1, loss, accuracy, acc, micro, macro))
        model.eval()
        with torch.no_grad():
            running_loss = 0.0
            running_correct = 0
            running_total = 0
            acc = 0.0
            micro = 0.0
            macro = 0.0
            count = 0
            for batch_idx, (dataA, dataB, target) in enumerate(valloader):
                dataA, dataB, target = dataA.to(device), dataB.to(
                    device), target.to(device)
                optimizer.zero_grad()
                pred = model(dataA, dataB)
                loss = criterion(pred, target)
                predict = torch.argmax(pred, 1)
                a = metrics.accuracy_score(target.cpu(), predict.cpu())
                b = metrics.f1_score(target.cpu(),
                                     predict.cpu(),
                                     average='micro')
                c = metrics.f1_score(target.cpu(),
                                     predict.cpu(),
                                     average='macro')
                correct = torch.eq(predict, target).sum().double().item()
                running_loss += loss.item()
                running_correct += correct
                running_total += target.size(0)
                acc += a
                micro += b
                macro += c
                count += 1
            loss = running_loss * 32 / running_total
            accuracy = 100 * running_correct / running_total
            acc /= count
            micro /= count
            macro /= count
            if acc > best_acc:
                best_acc = acc
            if micro > best_micro:
                best_micro = micro
            if macro > best_macro:
                best_macro = macro
            if accuracy > best_pred:
                best_pred = accuracy
            print(
                'best results: ',
                'best_acc = %.4f, best_micro = %.4f, best_macro = %.4f, best_pred = %.4f'
                % (
                    best_acc,
                    best_micro,
                    best_macro,
                    best_pred,
                ))
            writer.add_scalar('scalar/loss_val', loss, epoch)
            writer.add_scalar('scalar/accuracy_val', accuracy, epoch)
            writer.add_scalar('scalar/acc_val', acc, epoch)
            writer.add_scalar('scalar/micro_val', micro, epoch)
            writer.add_scalar('scalar/macro_val', macro, epoch)
            print(
                'Valing',
                '    Epoch[%d /50],loss = %.6f,accuracy=%.4f %%, acc = %.4f, micro = %.4f, macro = %.4f, running_total=%d,running_correct=%d'
                % (epoch + 1, loss, accuracy, acc, micro, macro, running_total,
                   running_correct))
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        print(self.nclass, args.backbone, args.out_stride, args.sync_bn,
              args.freeze_bn)
        #2 resnet 16 False False

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume, map_location='cpu')
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
示例#30
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))