def __init__(self, config): self.config = config self.best_pred = 0.0 # Define Saver self.saver = Saver(config) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.config['training']['tensorboard']['log_dir']) self.writer = self.summary.create_summary() self.train_loader, self.val_loader, self.test_loader, self.nclass = initialize_data_loader(config) # Define network model = DeepLab(num_classes=self.nclass, backbone=self.config['network']['backbone'], output_stride=self.config['image']['out_stride'], sync_bn=self.config['network']['sync_bn'], freeze_bn=self.config['network']['freeze_bn']) train_params = [{'params': model.get_1x_lr_params(), 'lr': self.config['training']['lr']}, {'params': model.get_10x_lr_params(), 'lr': self.config['training']['lr'] * 10}] # Define Optimizer optimizer = torch.optim.SGD(train_params, momentum=self.config['training']['momentum'], weight_decay=self.config['training']['weight_decay'], nesterov=self.config['training']['nesterov']) # Define Criterion # whether to use class balanced weights if self.config['training']['use_balanced_weights']: classes_weights_path = os.path.join(self.config['dataset']['base_path'], self.config['dataset']['dataset_name'] + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(self.config, self.config['dataset']['dataset_name'], self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses(weight=weight, cuda=self.config['network']['use_cuda']).build_loss(mode=self.config['training']['loss_type']) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(self.config['training']['lr_scheduler'], self.config['training']['lr'], self.config['training']['epochs'], len(self.train_loader)) # Using cuda if self.config['network']['use_cuda']: self.model = torch.nn.DataParallel(self.model) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint if self.config['training']['weights_initialization']['use_pretrained_weights']: if not os.path.isfile(self.config['training']['weights_initialization']['restore_from']): raise RuntimeError("=> no checkpoint found at '{}'" .format(self.config['training']['weights_initialization']['restore_from'])) if self.config['network']['use_cuda']: checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from']) else: checkpoint = torch.load(self.config['training']['weights_initialization']['restore_from'], map_location={'cuda:0': 'cpu'}) self.config['training']['start_epoch'] = checkpoint['epoch'] if self.config['network']['use_cuda']: self.model.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) # if not self.config['ft']: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})" .format(self.config['training']['weights_initialization']['restore_from'], checkpoint['epoch']))
def main(): # define and parse arguments parser = argparse.ArgumentParser() # general parser.add_argument('--experiment_name', type=str, default="experiment", help="experiment name. will be used in the path names \ for log- and savefiles") parser.add_argument('--seed', type=int, default=None, help='fixes random seed and sets model to \ the potentially faster cuDNN deterministic mode \ (default: non-deterministic mode)') parser.add_argument('--val_freq', type=int, default=1000, help='validation will be run every val_freq \ batches/optimization steps during training') parser.add_argument('--save_freq', type=int, default=1000, help='training state will be saved every save_freq \ batches/optimization steps during training') parser.add_argument('--log_freq', type=int, default=100, help='tensorboard logs will be written every log_freq \ number of batches/optimization steps') # input/output parser.add_argument('--use_s2hr', action='store_true', default=False, help='use sentinel-2 high-resolution (10 m) bands') parser.add_argument('--use_s2mr', action='store_true', default=False, help='use sentinel-2 medium-resolution (20 m) bands') parser.add_argument('--use_s2lr', action='store_true', default=False, help='use sentinel-2 low-resolution (60 m) bands') parser.add_argument('--use_s1', action='store_true', default=False, help='use sentinel-1 data') parser.add_argument('--no_savanna', action='store_true', default=False, help='ignore class savanna') # training hyperparameters parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 1e-2)') parser.add_argument('--momentum', type=float, default=0.9, help='momentum (default: 0.9), only used for deeplab') parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight-decay (default: 5e-4)') parser.add_argument('--batch_size', type=int, default=32, help='batch size for training and validation \ (default: 32)') parser.add_argument('--workers', type=int, default=4, help='number of workers for dataloading (default: 4)') parser.add_argument('--max_epochs', type=int, default=100, help='number of training epochs (default: 100)') # network parser.add_argument('--model', type=str, choices=['deeplab', 'unet'], default='deeplab', help="network architecture (default: deeplab)") # deeplab-specific parser.add_argument('--pretrained_backbone', action='store_true', default=False, help='initialize ResNet-101 backbone with ImageNet \ pre-trained weights') parser.add_argument('--out_stride', type=int, choices=[8, 16], default=16, help='network output stride (default: 16)') # data parser.add_argument('--data_dir_train', type=str, default=None, help='path to training dataset') parser.add_argument( '--dataset_val', type=str, default="sen12ms_holdout", choices=['sen12ms_holdout', 'dfc2020_val', 'dfc2020_test'], help='dataset to use for validation (default: \ sen12ms_holdout)') parser.add_argument('--data_dir_val', type=str, default=None, help='path to validation dataset') parser.add_argument('--log_dir', type=str, default=None, help='path to dir for tensorboard logs \ (default runs/CURRENT_DATETIME_HOSTNAME)') args = parser.parse_args() print("=" * 20, "CONFIG", "=" * 20) for arg in vars(args): print('{0:20} {1}'.format(arg, getattr(args, arg))) print() # fix seeds and set pytorch to deterministic mode if args.seed is not None: torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # set flags for GPU processing if available if torch.cuda.is_available(): args.use_gpu = True if torch.cuda.device_count() > 1: raise NotImplementedError("multi-gpu training not implemented! " + "try to run script as: " + "CUDA_VISIBLE_DEVICES=0 train.py") else: args.use_gpu = False # load datasets train_set = SEN12MS(args.data_dir_train, subset="train", no_savanna=args.no_savanna, use_s2hr=args.use_s2hr, use_s2mr=args.use_s2mr, use_s2lr=args.use_s2lr, use_s1=args.use_s1) n_classes = train_set.n_classes n_inputs = train_set.n_inputs if args.dataset_val == "sen12ms_holdout": val_set = SEN12MS(args.data_dir_train, subset="holdout", no_savanna=args.no_savanna, use_s2hr=args.use_s2hr, use_s2mr=args.use_s2mr, use_s2lr=args.use_s2lr, use_s1=args.use_s1) else: dfc2020_subset = args.dataset_val.split("_")[-1] val_set = DFC2020(args.data_dir_val, subset=dfc2020_subset, no_savanna=args.no_savanna, use_s2hr=args.use_s2hr, use_s2mr=args.use_s2mr, use_s2lr=args.use_s2lr, use_s1=args.use_s1) # set up dataloaders train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=False) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) # set up network if args.model == "deeplab": model = DeepLab(num_classes=n_classes, backbone='resnet', pretrained_backbone=args.pretrained_backbone, output_stride=args.out_stride, sync_bn=False, freeze_bn=False, n_in=n_inputs) else: model = UNet(n_classes=n_classes, n_channels=n_inputs) if args.use_gpu: model = model.cuda() # define loss function loss_fn = nn.CrossEntropyLoss(ignore_index=255, reduction='mean') # set up optimizer if args.model == "deeplab": train_params = [{ 'params': model.get_1x_lr_params(), 'lr': args.lr }, { 'params': model.get_10x_lr_params(), 'lr': args.lr * 10 }] optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay) else: optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # set up tensorboard logging if args.log_dir is None: args.log_dir = "logs" writer = SummaryWriter( log_dir=os.path.join(args.log_dir, args.experiment_name)) # create checkpoint dir args.checkpoint_dir = os.path.join(args.log_dir, args.experiment_name, "checkpoints") os.makedirs(args.checkpoint_dir, exist_ok=True) # save config pkl.dump(args, open(os.path.join(args.checkpoint_dir, "args.pkl"), "wb")) # train network step = 0 trainer = ModelTrainer(args) for epoch in range(args.max_epochs): print("=" * 20, "EPOCH", epoch + 1, "/", str(args.max_epochs), "=" * 20) # run training for one epoch model, step = trainer.train(model, train_loader, val_loader, loss_fn, optimizer, writer, step=step) # export final set of weights trainer.export_model(model, args.checkpoint_dir, name="final")