Exemplo n.º 1
0
    def __init__(self, config):

        self.config = config
        self.best_pred = 0.0

        # Define Saver
        self.saver = Saver(config)
        self.saver.save_experiment_config()
        # Define Tensorboard Summary
        self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir'])
        self.writer = self.summary.create_summary()
        
        self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config)
        
        # Define network
        model = DeepLab(num_classes=self.nclass,
                        backbone=self.config['network']['backbone'],
                        output_stride=self.config['image']['out_stride'],
                        sync_bn=self.config['network']['sync_bn'],
                        freeze_bn=self.config['network']['freeze_bn'])

        train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']},
                        {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}]

        # Define Optimizer
        optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'],
                                    weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov'])

        # Define Criterion
        # whether to use class balanced weights
        if self.config['training']['use_balanced_weights']:
            classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy')
            if os.path.isfile(classes_weights_path):
                weight = np.load(classes_weights_path)
            else:
                weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass)
            weight = torch.from_numpy(weight.astype(np.float32))
        else:
            weight = None

        self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type'])
        self.model, self.optimizer = model, optimizer
        
        # Define Evaluator
        self.evaluator = Evaluator(self.nclass)
        # Define lr scheduler
        self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'],
                                            self.config['training']['epochs'], len(self.train_loader))


        # Using cuda
        if self.config['network']['use_cuda']:
            self.model = torch.nn.DataParallel(self.model)
            patch_replication_callback(self.model)
            self.model = self.model.cuda()

        # Resuming checkpoint

        if self.config['training']['weights_initialization']['use_pretrained_weights']:
            if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']):
                raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from']))

            if self.config['network']['use_cuda']:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'])
            else:
                checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'})

            self.config['training']['start_epoch'] = checkpoint['epoch']

            if self.config['network']['use_cuda']:
                self.model.load_state_dict(checkpoint['state_dict'])
            else:
                self.model.load_state_dict(checkpoint['state_dict'])

#            if not self.config['ft']:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.best_pred = checkpoint['best_pred']
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))
Exemplo n.º 2
0
def main():

    # define and parse arguments
    parser = argparse.ArgumentParser()

    # general
    parser.add_argument('--experiment_name',
                        type=str,
                        default="experiment",
                        help="experiment name. will be used in the path names \
                             for log- and savefiles")
    parser.add_argument('--seed',
                        type=int,
                        default=None,
                        help='fixes random seed and sets model to \
                              the potentially faster cuDNN deterministic mode \
                              (default: non-deterministic mode)')
    parser.add_argument('--val_freq',
                        type=int,
                        default=1000,
                        help='validation will be run every val_freq \
                        batches/optimization steps during training')
    parser.add_argument('--save_freq',
                        type=int,
                        default=1000,
                        help='training state will be saved every save_freq \
                        batches/optimization steps during training')
    parser.add_argument('--log_freq',
                        type=int,
                        default=100,
                        help='tensorboard logs will be written every log_freq \
                              number of batches/optimization steps')

    # input/output
    parser.add_argument('--use_s2hr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 high-resolution (10 m) bands')
    parser.add_argument('--use_s2mr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 medium-resolution (20 m) bands')
    parser.add_argument('--use_s2lr',
                        action='store_true',
                        default=False,
                        help='use sentinel-2 low-resolution (60 m) bands')
    parser.add_argument('--use_s1',
                        action='store_true',
                        default=False,
                        help='use sentinel-1 data')
    parser.add_argument('--no_savanna',
                        action='store_true',
                        default=False,
                        help='ignore class savanna')

    # training hyperparameters
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        help='learning rate (default: 1e-2)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        help='momentum (default: 0.9), only used for deeplab')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=5e-4,
                        help='weight-decay (default: 5e-4)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='batch size for training and validation \
                              (default: 32)')
    parser.add_argument('--workers',
                        type=int,
                        default=4,
                        help='number of workers for dataloading (default: 4)')
    parser.add_argument('--max_epochs',
                        type=int,
                        default=100,
                        help='number of training epochs (default: 100)')

    # network
    parser.add_argument('--model',
                        type=str,
                        choices=['deeplab', 'unet'],
                        default='deeplab',
                        help="network architecture (default: deeplab)")

    # deeplab-specific
    parser.add_argument('--pretrained_backbone',
                        action='store_true',
                        default=False,
                        help='initialize ResNet-101 backbone with ImageNet \
                              pre-trained weights')
    parser.add_argument('--out_stride',
                        type=int,
                        choices=[8, 16],
                        default=16,
                        help='network output stride (default: 16)')

    # data
    parser.add_argument('--data_dir_train',
                        type=str,
                        default=None,
                        help='path to training dataset')
    parser.add_argument(
        '--dataset_val',
        type=str,
        default="sen12ms_holdout",
        choices=['sen12ms_holdout', 'dfc2020_val', 'dfc2020_test'],
        help='dataset to use for validation (default: \
                             sen12ms_holdout)')
    parser.add_argument('--data_dir_val',
                        type=str,
                        default=None,
                        help='path to validation dataset')
    parser.add_argument('--log_dir',
                        type=str,
                        default=None,
                        help='path to dir for tensorboard logs \
                              (default runs/CURRENT_DATETIME_HOSTNAME)')

    args = parser.parse_args()
    print("=" * 20, "CONFIG", "=" * 20)
    for arg in vars(args):
        print('{0:20}  {1}'.format(arg, getattr(args, arg)))
    print()

    # fix seeds and set pytorch to deterministic mode
    if args.seed is not None:
        torch.manual_seed(args.seed)
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # set flags for GPU processing if available
    if torch.cuda.is_available():
        args.use_gpu = True
        if torch.cuda.device_count() > 1:
            raise NotImplementedError("multi-gpu training not implemented! " +
                                      "try to run script as: " +
                                      "CUDA_VISIBLE_DEVICES=0 train.py")
    else:
        args.use_gpu = False

    # load datasets
    train_set = SEN12MS(args.data_dir_train,
                        subset="train",
                        no_savanna=args.no_savanna,
                        use_s2hr=args.use_s2hr,
                        use_s2mr=args.use_s2mr,
                        use_s2lr=args.use_s2lr,
                        use_s1=args.use_s1)
    n_classes = train_set.n_classes
    n_inputs = train_set.n_inputs
    if args.dataset_val == "sen12ms_holdout":
        val_set = SEN12MS(args.data_dir_train,
                          subset="holdout",
                          no_savanna=args.no_savanna,
                          use_s2hr=args.use_s2hr,
                          use_s2mr=args.use_s2mr,
                          use_s2lr=args.use_s2lr,
                          use_s1=args.use_s1)
    else:
        dfc2020_subset = args.dataset_val.split("_")[-1]
        val_set = DFC2020(args.data_dir_val,
                          subset=dfc2020_subset,
                          no_savanna=args.no_savanna,
                          use_s2hr=args.use_s2hr,
                          use_s2mr=args.use_s2mr,
                          use_s2lr=args.use_s2lr,
                          use_s1=args.use_s1)

    # set up dataloaders
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=False)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True,
                            drop_last=False)

    # set up network
    if args.model == "deeplab":
        model = DeepLab(num_classes=n_classes,
                        backbone='resnet',
                        pretrained_backbone=args.pretrained_backbone,
                        output_stride=args.out_stride,
                        sync_bn=False,
                        freeze_bn=False,
                        n_in=n_inputs)
    else:
        model = UNet(n_classes=n_classes, n_channels=n_inputs)

    if args.use_gpu:
        model = model.cuda()

    # define loss function
    loss_fn = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    # set up optimizer
    if args.model == "deeplab":
        train_params = [{
            'params': model.get_1x_lr_params(),
            'lr': args.lr
        }, {
            'params': model.get_10x_lr_params(),
            'lr': args.lr * 10
        }]
        optimizer = torch.optim.SGD(train_params,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        weight_decay=args.weight_decay)

    # set up tensorboard logging
    if args.log_dir is None:
        args.log_dir = "logs"
    writer = SummaryWriter(
        log_dir=os.path.join(args.log_dir, args.experiment_name))

    # create checkpoint dir
    args.checkpoint_dir = os.path.join(args.log_dir, args.experiment_name,
                                       "checkpoints")
    os.makedirs(args.checkpoint_dir, exist_ok=True)

    # save config
    pkl.dump(args, open(os.path.join(args.checkpoint_dir, "args.pkl"), "wb"))

    # train network
    step = 0
    trainer = ModelTrainer(args)
    for epoch in range(args.max_epochs):
        print("=" * 20, "EPOCH", epoch + 1, "/", str(args.max_epochs),
              "=" * 20)

        # run training for one epoch
        model, step = trainer.train(model,
                                    train_loader,
                                    val_loader,
                                    loss_fn,
                                    optimizer,
                                    writer,
                                    step=step)

    # export final set of weights
    trainer.export_model(model, args.checkpoint_dir, name="final")