Ejemplo n.º 1
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader = make_data_loader(
            config)

        # Define network
        #self.model = DeepLab(num_classes=self.nclass,
        #                backbone=config.backbone,
        #                output_stride=config.out_stride,
        #                sync_bn=config.sync_bn,
        #                freeze_bn=config.freeze_bn)
        self.model = UNet(n_channels=1, n_classes=3, bilinear=True)

        #train_params = [{'params': self.model.get_1x_lr_params(), 'lr': config.lr},
        #                {'params': self.model.get_10x_lr_params(), 'lr': config.lr * config.lr_ratio}]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         lr=config.lr,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = MSELoss(cuda=args.cuda)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        self.best_pred_source = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
Ejemplo n.º 2
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1])
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(config)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=config.backbone,
                        output_stride=config.out_stride,
                        sync_bn=config.sync_bn,
                        freeze_bn=config.freeze_bn)

        train_params = [{'params': model.get_1x_lr_params(), 'lr': config.lr},
                        {'params': model.get_10x_lr_params(), 'lr': config.lr * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=config.momentum,
                                    weight_decay=config.weight_decay)

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler, config.lr,
                                      config.T, len(self.train_loader),
                                      config.lr_step, config.warmup_epochs)

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint, map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, args.start_epoch))
Ejemplo n.º 3
0
    def __init__(self, args):
        self.args = args

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.test_loader, self.nclass = make_data_loader(args, **kwargs)

        # Define Metwork
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        self.model = model

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Load checkpoint
        if args.checkpoint_path is not None:
            if not os.path.isfile(args.checkpoint_path):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.checkpoint_path))
            checkpoint = torch.load(args.checkpoint_path)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}'".format(args.checkpoint_path))

            # Threshold
            if self.args.threshold:
                self.threshold = self.args.threshold
            else:
                self.threshold = checkpoint['best_thresh']
            print("Using Threshold: {}".format(self.threshold))

        # TTA function and object
        self.tta_function = get_tta_function(self.args)
        self.tta = TTA(model=self.model, tta_function=self.tta_function)
def model_accelerate(args, model):
    r"""
    Use it with a provided, customized data parallel wrapper:

    from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback

    sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
    sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    Or, if you are using a customized data parallel module, you can use this library as a monkey patching.

    from torch.nn import DataParallel  # or your customized DataParallel module
    from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback

    sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
    sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
    patch_replication_callback(sync_bn)  # monkey-patching
    You can use convert_model to convert your model to use Synchronized BatchNorm easily.

    from torchvision import models
    from sync_batchnorm import convert_model

    m = models.resnet18(pretrained=True)
    m = convert_model(m)
    :param args:
    :param model:
    :return:
    """
    from model.sync_batchnorm.replicate import patch_replication_callback, DataParallelWithCallback
    from model.sync_batchnorm import convert_model
    if torch.cuda.device_count() > 0 and args.gpu > 0:
        model = convert_model(model)
        model = torch.nn.DataParallel(model)
        patch_replication_callback(model)
        device = get_device(args)
        model = model.to(device)
        print(f'*** {model.__class__.__name__} to GPUs, syncbatch OK.')
    else:
        model = nn.DataParallel(model)
    return model
Ejemplo n.º 5
0
    def __init__(self, config, args):
        self.args = args
        self.config = config
        self.visdom = args.visdom
        if args.visdom:
            self.vis = visdom.Visdom(env=os.getcwd().split('/')[-1], port=8888)
        # Define Dataloader
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            config)
        self.target_train_loader, self.target_val_loader, self.target_test_loader, _ = make_target_data_loader(
            config)

        # Define network
        self.model = DeepLab(num_classes=self.nclass,
                             backbone=config.backbone,
                             output_stride=config.out_stride,
                             sync_bn=config.sync_bn,
                             freeze_bn=config.freeze_bn)

        self.D = Discriminator(num_classes=self.nclass, ndf=16)

        train_params = [{
            'params': self.model.get_1x_lr_params(),
            'lr': config.lr
        }, {
            'params': self.model.get_10x_lr_params(),
            'lr': config.lr * config.lr_ratio
        }]

        # Define Optimizer
        self.optimizer = torch.optim.SGD(train_params,
                                         momentum=config.momentum,
                                         weight_decay=config.weight_decay)
        self.D_optimizer = torch.optim.Adam(self.D.parameters(),
                                            lr=config.lr,
                                            betas=(0.9, 0.99))

        # Define Criterion
        # whether to use class balanced weights
        self.criterion = SegmentationLosses(
            weight=None, cuda=args.cuda).build_loss(mode=config.loss)
        self.entropy_mini_loss = MinimizeEntropyLoss()
        self.bottleneck_loss = BottleneckLoss()
        self.instance_loss = InstanceLoss()
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(config.lr_scheduler,
                                      config.lr, config.epochs,
                                      len(self.train_loader), config.lr_step,
                                      config.warmup_epochs)
        self.summary = TensorboardSummary('./train_log')
        # labels for adversarial training
        self.source_label = 0
        self.target_label = 1

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            # cudnn.benchmark = True
            self.model = self.model.cuda()

            self.D = torch.nn.DataParallel(self.D)
            patch_replication_callback(self.D)
            self.D = self.D.cuda()

        self.best_pred_source = 0.0
        self.best_pred_target = 0.0
        # Resuming checkpoint
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            if args.cuda:
                self.model.module.load_state_dict(checkpoint)
            else:
                self.model.load_state_dict(checkpoint,
                                           map_location=torch.device('cpu'))
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))
Ejemplo n.º 6
0
    def __init__(self, args):
        self.args = args

        # Define Saver
        self.saver = Saver(args)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()

        # Define Dataloader
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            args, **kwargs)

        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=args.backbone,
                        output_stride=args.out_stride,
                        sync_bn=args.sync_bn,
                        freeze_bn=args.freeze_bn)

        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

        # Define Criterion
        # whether to use class balanced weights
        if args.use_balanced_weights:
            classes_weights_path = os.path.join(
                Path.db_root_dir(args.dataset),
                args.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(args.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type)
        self.model, self.optimizer = model, optimizer

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs,
                                      len(self.train_loader))

        # Using cuda
        if args.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=self.args.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        self.best_thresh = 0.0
        if args.resume is not None:
            if not os.path.isfile(args.resume):
                raise RuntimeError("=> no checkpoint found at '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            if args.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not args.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            self.best_thresh = checkpoint['best_thresh']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if args.ft:
            args.start_epoch = 0
Ejemplo n.º 7
0
    train_params = [{
        'params': model.get_1x_lr_params(),
        'lr': params.learning_rate
    }, {
        'params': model.get_10x_lr_params(),
        'lr': params.learning_rate * 10
    }]

    optimizer = optim.SGD(train_params,
                          momentum=params.momentum,
                          weight_decay=params.weight_decay)

    if params.cuda:
        model = nn.DataParallel(model, device_ids=[0])
        patch_replication_callback(model)
        model = model.cuda()

    scheduler = LR_Scheduler("poly", params.learning_rate, params.num_epochs,
                             len(train_dl))

    loss_fns = loss_fns

    # Define Tensorboard Summary
    summary = TensorboardSummary(args.model_dir)
    writer = summary.create_summary()

    evaluator = Evaluator(20 + 1)

    logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
    train_and_evaluate(model, train_dl, val_dl, optimizer, loss_fns, scheduler,
Ejemplo n.º 8
0
    def __init__(self, args, model, train_set, val_set, test_set, class_weights, saver, writer):
        self.args = args
        self.saver = saver
        self.saver.save_experiment_config()  # save cfgs
        self.writer = writer

        self.num_classes = train_set.num_classes

        # dataloaders
        kwargs = {'num_workers': args.workers, 'pin_memory': True}
        self.train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
        self.val_dataloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
        self.test_dataloader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs)

        self.dataset_size = {'train': len(train_set), 'val': len(val_set), 'test': len(test_set)}
        print('dataset size:', self.dataset_size)

        # 加快训练,减少每轮迭代次数;不需要从引入样本时就截断数据,这样更好
        self.iters_per_epoch = args.iters_per_epoch if args.iters_per_epoch else len(self.train_dataloader)

        if args.optimizer == 'SGD':
            print('Using SGD')
            self.optimizer = torch.optim.SGD(model.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.weight_decay,
                                             nesterov=args.nesterov)
            self.lr_scheduler = LR_Scheduler(mode=args.lr_scheduler, base_lr=args.lr,
                                             lr_step=args.lr_step,
                                             num_epochs=args.epochs,
                                             warmup_epochs=args.warmup_epochs,
                                             iters_per_epoch=self.iters_per_epoch)
        elif args.optimizer == 'Adam':
            print('Using Adam')
            self.optimizer = torch.optim.Adam(model.parameters(),
                                              lr=args.lr,
                                              # amsgrad=True,
                                              weight_decay=args.weight_decay)
        else:
            raise NotImplementedError

        self.device = torch.device(f'cuda:{args.gpu_ids}')

        if len(args.gpu_ids) > 1:
            args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')]
            model = torch.nn.DataParallel(model, device_ids=args.gpu_ids)
            patch_replication_callback(model)
            print(args.gpu_ids)

        self.model = model.to(self.device)

        # loss
        if args.use_balanced_weights:
            weight = torch.from_numpy(class_weights.astype(np.float32)).to(self.device)
        else:
            weight = None

        self.criterion = SegmentationLosses(mode=args.loss_type, weight=weight, ignore_index=constants.BG_INDEX)

        # evaluator
        self.evaluator = Evaluator(self.num_classes)

        self.best_epoch = 0
        self.best_mIoU = 0.0
        self.best_pixelAcc = 0.0