def __init__(self, ):
        # Define Dataloader
        kwargs = {'num_workers': settings.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(
            **kwargs)

        # Define network
        self.model = DeepLabv3_plus(nInputChannels=3,
                                    n_classes=self.nclass,
                                    os=16,
                                    pretrained=settings.pretrained,
                                    _print=True)

        # Define Criterion
        # whether to use class balanced weights
        if settings.use_balanced_weights:
            classes_weights_path = os.path.join(
                settings.root_dir, settings.dataset + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(settings.dataset,
                                                  self.train_loader,
                                                  self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(
            weight=weight,
            cuda=settings.cuda).build_loss(mode=settings.loss_type)

        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)

        # Using cuda
        if settings.cuda:
            self.model = torch.nn.DataParallel(self.model,
                                               device_ids=settings.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if settings.resume is False:
            print("settings.resume is False but ignoring...")
        if not os.path.isfile(settings.checkpoint):
            raise RuntimeError("=> no checkpoint found at '{}'.\
            Please designate pretrained weights file to settings.checkpoint='~.pth.tar'."
                               .format(settings.checkpoint))
        checkpoint = torch.load(settings.checkpoint)
        settings.start_epoch = checkpoint['epoch']
        if settings.cuda:
            self.model.module.load_state_dict(checkpoint['state_dict'])
        else:
            self.model.load_state_dict(checkpoint['state_dict'])
        # if not settings.ft:
        #     self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.best_pred = checkpoint['best_pred']
        print("=> loaded checkpoint '{}' (epoch {})".format(
            settings.checkpoint, checkpoint['epoch']))
Beispiel #2
0
 def load_model(model_path):
     model = DeepLabv3_plus(nInputChannels=3, n_classes=NUM_CLASSES, os=16)
     if CUDA:
         model = torch.nn.DataParallel(model, device_ids=[0])
         patch_replication_callback(model)
         model = model.cuda()
     if not osp.isfile(MODEL_PATH):
         raise RuntimeError("=> no checkpoint found at '{}'".format(model_path))
     checkpoint = torch.load(model_path)
     if CUDA:
         model.module.load_state_dict(checkpoint['state_dict'])
     else:
         model.load_state_dict(checkpoint['state_dict'])
     print("=> loaded checkpoint '{}' (epoch: {}, best_pred: {})"
           .format(model_path, checkpoint['epoch'], checkpoint['best_pred']))
     model.eval()
     return model
Beispiel #3
0
    def __init__(self,):
        # Define Saver
        self.saver = Saver()
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.saver.experiment_dir)
        self.writer = self.summary.create_summary()
        
        # Define Dataloader
        kwargs = {'num_workers': settings.workers, 'pin_memory': True}
        self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(**kwargs)

        # Define network
        model = DeepLabv3_plus(nInputChannels=3, n_classes=self.nclass, os=16, pretrained=settings.pretrained, _print=True)

        train_params = [{'params': get_1x_lr_params(model), 'lr': settings.lr},
                        {'params': get_10x_lr_params(model), 'lr': settings.lr}]

        # Define Optimizer
        # optimizer = torch.optim.SGD(train_params, momentum=settings.momentum,
        #                             weight_decay=settings.weight_decay, nesterov=settings.nesterov)
        optimizer = torch.optim.Adam(train_params)

        # Define Criterion
        # whether to use class balanced weights
        if settings.use_balanced_weights:
            classes_weights_path = os.path.join(settings.root_dir, settings.dataset+'_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(settings.dataset, self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None
        self.criterion = SegmentationLosses(weight=weight, cuda=settings.cuda).build_loss(mode=settings.loss_type)
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        # self.scheduler = LR_Scheduler(settings.lr_scheduler, settings.lr,
        #                                     settings.epochs, len(self.train_loader))

        # Using cuda
        if settings.cuda:
            self.model = torch.nn.DataParallel(self.model, device_ids=settings.gpu_ids)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint
        self.best_pred = 0.0
        if settings.resume:
            if not os.path.isfile(settings.checkpoint):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(settings.checkpoint))
            checkpoint = torch.load(settings.checkpoint)
            settings.start_epoch = checkpoint['epoch']
            if settings.cuda:
                self.model.module.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])
            if not settings.ft:
                self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(settings.checkpoint, checkpoint['epoch']))

        # Clear start epoch if fine-tuning
        if settings.ft:
            settings.start_epoch = 0