예제 #1
0
def init_model():
    model_cfg = edict()
    model_cfg.syncbn = True

    model_cfg.input_normalization = {
        'mean': [0.5, 0.5, 0.5],
        'std': [0.5, 0.5, 0.5]
    }

    model_cfg.input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(model_cfg.input_normalization['mean'],
                             model_cfg.input_normalization['std']),
    ])

    if args.ngpus > 1 and model_cfg.syncbn:
        norm_layer = partial(mx.gluon.contrib.nn.SyncBatchNorm,
                             num_devices=args.ngpus)
    else:
        norm_layer = mx.gluon.nn.BatchNorm

    model = get_unet_model(norm_layer)
    model.initialize(mx.init.Xavier(rnd_type='gaussian', magnitude=1),
                     ctx=mx.cpu(0))

    return model, model_cfg
예제 #2
0
def init_model():
    model_cfg = edict()

    model_cfg.input_normalization = {
        'mean': [0.5, 0.5, 0.5],
        'std': [0.5, 0.5, 0.5]
    }

    model_cfg.input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(model_cfg.input_normalization['mean'],
                             model_cfg.input_normalization['std']),
    ])

    # training using DataParallel is not implemented
    norm_layer = torch.nn.BatchNorm2d

    model = get_unet_model(norm_layer=norm_layer)
    model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0))

    return model, model_cfg
예제 #3
0
    def train(self, train_proposals = False, start_epoch=0):
        num_points = self.num_points if not train_proposals else self.num_points_prop
        num_epochs = self.num_epochs if not train_proposals else self.num_epochs_prop
        loss_cfg = self.loss_cfg if not train_proposals else self.loss_cfg_prop

        val_loss_cfg = deepcopy(loss_cfg)
        # training using DataParallel is not implemented
        norm_layer = torch.nn.BatchNorm2d
        self.model = get_unet_model(loss_cfg, val_loss_cfg, norm_layer=norm_layer)
        self.model.apply(initializer.XavierGluon(rnd_type='gaussian', magnitude=1.0))
        
        # Augmentation
        train_augmentator = Compose([
            Blur(blur_limit=(2, 4)),
            IAAAdditiveGaussianNoise(scale=(10, 40), p=0.5),
            Flip()
        ], p=1.0)
        
        # Datasets
        trainset = self.dataset(
            self.args.dataset_path,
            split='train',
            num_points=num_points,
            augmentator=train_augmentator,
            with_segmentation=True,
            points_from_one_object=train_proposals,
            input_transform=self.model_cfg.input_transform
        )
    
        valset = self.dataset(
            self.args.dataset_path,
            split='test',
            augmentator=train_augmentator,
            num_points=num_points,
            with_segmentation=True,
            points_from_one_object=train_proposals,
            input_transform=self.model_cfg.input_transform
        )
    
        # Other Settings
        optimizer_params = {
            'lr': 5e-4, 'betas': (0.9, 0.999), 'eps': 1e-8
        }
    
        if not train_proposals:
            lr_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR,
                                   last_epoch=-1)
        else:
            lr_scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingLR,
                                   last_epoch=-1)
        
        # Startup AdaptISTrainer
        trainer = AdaptISTrainer(self.args, self.model, self.model_cfg, loss_cfg,
                                 trainset, valset,
                                 num_epochs=num_epochs,
                                 optimizer_params=optimizer_params,
                                 lr_scheduler=lr_scheduler,
                                 checkpoint_interval=40 if not train_proposals else 5,
                                 image_dump_interval=600 if not train_proposals else -1,
                                 train_proposals=train_proposals,
                                 metrics=[AdaptiveIoU()])
    
        log.logger.info(f'Starting Epoch: {start_epoch}')
        log.logger.info(f'Total Epochs: {num_epochs}')
        for epoch in range(start_epoch, num_epochs):
            trainer.training(epoch)
            trainer.validation(epoch)