Exemplo n.º 1
0
def load_model(out_channels=2, enc_type='efficientnet', dec_type='unet', pretrained=True, output_stride=8):
    # Network
    global net_type
    if 'unet' in dec_type:
        net_type = 'unet'
        if 'efficient' in enc_type:
            model = EfficientUnet(enc_type, out_channels=out_channels, concat_input=True,
                                         pretrained=pretrained)#, model_name=enc_type)
        else:
            model = EncoderDecoderNet(**net_config)
    else:
        net_type = 'deeplab'
        model = SPPNet(output_channels=out_channels, enc_type=enc_type, dec_type=dec_type, output_stride=output_stride)
    return model
Exemplo n.º 2
0
train_config = config['Train']
loss_config = config['Loss']
opt_config = config['Optimizer']
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
t_max = opt_config['t_max']

max_epoch = train_config['max_epoch']
batch_size = train_config['batch_size']
fp16 = train_config['fp16']
resume = train_config['resume']
pretrained_path = train_config['pretrained_path']

# Network
if 'unet' in net_config['dec_type']:
    net_type = 'unet'
    model = EncoderDecoderNet(**net_config)
else:
    net_type = 'deeplab'
    model = SPPNet(**net_config)

dataset = data_config['dataset']
if dataset == 'pascal':
    from dataset.pascal_voc import PascalVocDataset as Dataset
    net_config['output_channels'] = 21
    classes = np.arange(1, 21)
elif dataset == 'cityscapes':
    from dataset.cityscapes import CityscapesDataset as Dataset
    net_config['output_channels'] = 19
    classes = np.arange(1, 19)
elif dataset == 'sherbrooke':
    from dataset.sherbrooke import SherbrookeDataset as Dataset
Exemplo n.º 3
0
def process():
    input_size = 224
    timestamp = datetime.timestamp(datetime.now())
    print("timestamp =", datetime.fromtimestamp(timestamp))
    output_dir = Path(
        os.path.join(
            ROOT_DIR,
            f'model/{model_name}_{datetime.fromtimestamp(timestamp)}'))
    output_dir.mkdir(exist_ok=True)

    # Opening JSON file
    with open(
            '/home/sfoucher/DEV/geoimagenet/dataset_test/deepglobe_classif/train/meta.json'
    ) as json_file:
        meta = json.load(json_file)
    dataset = {
        'path': DATASET_ROOT / 'train',
        DATASET_DATA_KEY: meta['patches']
    }
    with open(
            '/home/sfoucher/DEV/geoimagenet/b76e46a7-4bea-4281-b802-2e296f97b960/meta.json'
    ) as json_file:
        meta = json.load(json_file)
    dataset = {
        'path':
        '/home/sfoucher/DEV/geoimagenet/b76e46a7-4bea-4281-b802-2e296f97b960',
        DATASET_DATA_KEY: meta['patches']
    }
    with open(
            '/home/sfoucher/DEV/geoimagenet/dataset_test/deepglobe_classif/val/meta.json'
    ) as json_file:
        meta = json.load(json_file)
    dataset_val = {
        'path': DATASET_ROOT / 'val',
        DATASET_DATA_KEY: meta['patches']
    }

    class_mapping = [('AgriculturalLand', 222), ('BarrenLand', 251),
                     ('ForestLand', 232), ('RangeLand', 228),
                     ('UrbanLand', 199), ('Water', 238)]
    #model_class_map= {v: k k, v in class_mapping}
    model_class_map = {
        223: 'AgriculturalLand',
        252: 'BarrenLand',
        233: 'ForestLand',
        229: 'RangeLand',
        200: 'UrbanLand',
        239: 'Water',
    }
    model_class_map = {
        223: 0,
        252: 1,
        233: 2,
        229: 3,
        200: 4,
        239: 5,
    }

    class GINParser(torch.utils.data.Dataset):
        def __init__(self, dataset=None, transforms=None):
            if not (isinstance(dataset, dict) and len(dataset)):
                raise ValueError(
                    "Expected dataset parameters as configuration input.")
            # thelper.data.Dataset.__init__(self, transforms=transforms, deepcopy=False)
            self.root = Path(dataset["path"])
            self.transforms = transforms
            # keys matching dataset config for easy loading and referencing to same fields
            self.image_key = IMAGE_DATA_KEY  # key employed by loader to extract image data (pixel values)
            self.label_key = IMAGE_LABEL_KEY  # class id from API mapped to match model task
            self.path_key = "path"  # actual file path of the patch
            self.idx_key = "index"  # increment for __getitem__
            self.mask_key = 'mask'  # actual mask path of the patch
            self.meta_keys = [
                self.path_key, self.idx_key, DATASET_DATA_PATCH_CROPS_KEY,
                DATASET_DATA_PATCH_IMAGE_KEY, DATASET_DATA_PATCH_FEATURE_KEY
            ]
            # model_class_map = dataset[DATASET_DATA_KEY][DATASET_DATA_MAPPING_KEY]
            sample_class_ids = set()
            samples = []
            for patch_info in dataset['data']:
                # convert the dataset class ID into the model class ID using mapping, drop sample if not found
                class_name = model_class_map.get(patch_info.get('class'))
                class_name = str(patch_info.get('class'))
                if class_name is not None:
                    sample_class_ids.add(class_name)
                    samples.append(deepcopy(patch_info))
                    samples[-1][
                        self.
                        path_key] = self.root / patch_info['crops'][0]['path']
                    samples[-1][self.label_key] = class_name
                    mask_name = patch_info['crops'][0].get('mask', None)
                    if mask_name:
                        samples[-1][self.mask_key] = self.root / mask_name

            if not len(sample_class_ids):
                raise ValueError(
                    "No patch/class could be retrieved from batch loading for specific model task."
                )
            self.samples = samples
            self.sample_class_ids = sample_class_ids

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()
            sample = self.samples[idx]
            img_name = sample[self.path_key]._str
            # image = cv.imread(img_name._str)
            image = Image.open(img_name)
            assert image is not None, "could not load image '%s' via opencv" % sample[
                self.path_key]
            #tensor_to_PIL = torchvision.transforms.ToPILImage(mode='RGB')
            #image = tensor_to_PIL(image)
            sample = {
                self.image_key: image,
                self.path_key: sample[self.path_key],
                self.label_key: sample[self.label_key],
                self.mask_key: sample[self.mask_key],
                self.idx_key: idx
            }

            if self.transforms:
                image = self.transforms(image)

            return image, sample[self.label_key]

    # Initialize the model for this run
    # model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
    from models.net import EncoderDecoderNet
    net_type = 'unet'
    net_config = dict()
    net_config['enc_type'] = 'resnet18'
    net_config['dec_type'] = 'unet_seibn'
    net_config['num_filters'] = 16
    net_config['output_channels'] = 7
    net_config['pretrained'] = True
    model = EncoderDecoderNet(**net_config)

    # Data augmentation and normalization for training
    # Just normalization for validation
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val':
        transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    image_datasets = dict()
    dataloaders_dict = dict()
    image_datasets['train'] = GINParser(dataset=dataset,
                                        transforms=data_transforms['train'])
    image_datasets['val'] = GINParser(dataset=dataset_val,
                                      transforms=data_transforms['val'])
    dataloaders_dict['train'] = torch.utils.data.DataLoader(
        image_datasets['train'],
        batch_size=batch_size,
        shuffle=True,
        num_workers=4)
    dataloaders_dict['val'] = torch.utils.data.DataLoader(
        image_datasets['val'],
        batch_size=batch_size,
        shuffle=False,
        num_workers=4)

    # Detect if we have a GPU available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    with tqdm(dataloaders_dict['train']) as _tqdm:
        for i, (inputs, labels) in enumerate(_tqdm):
            #for inputs, labels in dataloaders[phase]:
            inputs = inputs.to(device)
            labels = labels.to(device)

    # Print the model we just instantiated
    print(model_ft)

    # Send the model to GPU
    model_ft = model_ft.to(device)

    # Gather the parameters to be optimized/updated in this run. If we are
    #  finetuning we will be updating all parameters. However, if we are
    #  doing feature extract method, we will only update the parameters
    #  that we have just initialized, i.e. the parameters with requires_grad
    #  is True.
    params_to_update = model_ft.parameters()
    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=0.9)
    print("Params to learn:")
    if feature_extract:
        params_to_update = []
        for name, param in model_ft.named_parameters():
            if param.requires_grad == True:
                params_to_update.append(param)
                print("\t", name)

        optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum=0.9)
    else:
        n_param = 0
        name_block = dict()
        lr_blcok = dict()
        for l, (name, param) in enumerate(model_ft.named_parameters()):
            n_param += 1
        for l, (name, param) in enumerate(model_ft.named_parameters()):
            blcok = name.split('.')[0]
            name_block[l] = blcok
            lr_blcok[blcok] = lr * 10**(2 * (l / n_param - 1))

        params_to_update = []
        #name_block = set(name_block)
        for l, (name, param) in enumerate(model_ft.named_parameters()):
            if l < int(n_param) - 8:
                param.requires_grad = False
            else:
                if 'fc.' not in name:
                    params_to_update.append({
                        "params": param,
                        "lr": lr_blcok[name_block[l]],
                    })
                else:
                    params_to_update.append({
                        "params": param,
                        "lr": lr,
                    })
            if param.requires_grad == True:
                print("\t", name)

        optimizer_ft = optim.SGD(params_to_update,
                                 momentum=0.9,
                                 weight_decay=0.001)

    # Setup the loss fxn
    criterion = nn.CrossEntropyLoss()

    def train_model(model,
                    dataloaders,
                    criterion,
                    optimizer,
                    num_epochs=25,
                    is_inception=False):
        since = time.time()

        val_acc_history = []
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=3,
                                                    gamma=0.1)

        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0
        ma_loss = 0.0
        ma_iou = 0.0
        i_iter = 0
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs - 1))
            print('-' * 10)

            # Each epoch has a training and validation phase
            for phase in dataloaders.keys():
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_loss = 0.0
                running_corrects = 0

                # Iterate over data.
                with tqdm(dataloaders[phase]) as _tqdm:
                    for i, (inputs, labels) in enumerate(_tqdm):
                        #for inputs, labels in dataloaders[phase]:
                        inputs = inputs.to(device)
                        labels = labels.to(device)

                        # zero the parameter gradients
                        optimizer.zero_grad()

                        # forward
                        # track history if only in train
                        with torch.set_grad_enabled(phase == 'train'):
                            # Get model outputs and calculate loss
                            # Special case for inception because in training it has an auxiliary output. In train
                            #   mode we calculate the loss by summing the final output and the auxiliary output
                            #   but in testing we only consider the final output.
                            if is_inception and phase == 'train':
                                # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                                outputs, aux_outputs = model(inputs)
                                loss1 = criterion(outputs, labels)
                                loss2 = criterion(aux_outputs, labels)
                                loss = loss1 + 0.4 * loss2
                            else:
                                outputs = model(inputs)
                                loss = criterion(outputs, labels)

                            _, preds = torch.max(outputs, 1)
                            acc = np.logical_and(
                                preds.cpu().numpy(),
                                labels.cpu().numpy()).sum() / len(labels)
                            _tqdm.set_postfix(
                                OrderedDict(seg_loss=f'{loss.item():.5f}',
                                            acc=f'{acc*100:.3f}'))

                            # backward + optimize only if in training phase
                            if phase == 'train':
                                loss.backward()
                                optimizer.step()
                                ma_loss = 0.01 * loss.item() + 0.99 * ma_loss
                                ma_iou = 0.01 * acc + 0.99 * ma_iou
                                plotter.plot(
                                    'loss', 'train', 'iteration Loss', i_iter,
                                    loss.item(
                                    ))  # y-axis, name serie, name chart
                                plotter.plot('acc', 'train', 'iteration acc',
                                             i_iter, acc)
                                plotter.plot('loss', 'ma_loss',
                                             'iteration Loss', i_iter, ma_loss)
                                plotter.plot('acc', 'ma_acc', 'iteration acc',
                                             i_iter, ma_iou)
                                i_iter += 1
                            else:
                                i_iter = i_iter

                        # statistics
                        running_loss += loss.item() * inputs.size(0)
                        running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / len(dataloaders[phase].dataset)
                epoch_acc = running_corrects.double().item() / len(
                    dataloaders[phase].dataset)
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc))
                if phase == 'train':
                    plotter.plot('loss-epoch', 'train', 'epoch Loss', epoch,
                                 epoch_loss)
                    plotter.plot('acc-epoch', 'train', 'epoch acc', epoch,
                                 epoch_acc)
                else:
                    plotter.plot('loss-epoch', 'valid', 'epoch Loss', epoch,
                                 epoch_loss)
                    plotter.plot('acc-epoch', 'valid', 'epoch acc', epoch,
                                 epoch_acc)

                # deep copy the model
                if phase == 'val' and epoch_acc > best_acc:
                    print('Best Epoch!')
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
                    torch.save(model.state_dict(),
                               output_dir.joinpath('model.pth'))
                    torch.save(optimizer.state_dict(),
                               output_dir.joinpath('opt.pth'))
                if phase == 'val':
                    val_acc_history.append(epoch_acc)

            print()
            # scheduler.step()

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
        return model, val_acc_history

    # Train and evaluate
    model_ft, hist = train_model(model_ft,
                                 dataloaders_dict,
                                 criterion,
                                 optimizer_ft,
                                 num_epochs=num_epochs,
                                 is_inception=(model_name == "inception"))
Exemplo n.º 4
0
    output_dir = Path(train_config['output_dir'])
    output_dir.mkdir(exist_ok=True, parents=True)
    shutil.copyfile(str(config_path), str(output_dir / config_path.name))

    log_dir = Path(train_config['log_dir'])
    log_dir.mkdir(exist_ok=True, parents=True)
    eval_every_n_epochs = train_config['eval_every_n_epochs']
    vis_flag = train_config['vis_flag']

    include_bg = loss_config['include_bg']
    del loss_config['include_bg']

    # Network
    if 'unet' in net_config['dec_type']:
        net_type = 'unet'
        model = EncoderDecoderNet(**net_config)
    else:
        net_type = 'deeplab'
        model = SPPNet(**net_config)

    dataset = data_config['dataset']
    if dataset == 'pascal':
        from dataset.pascal_voc import PascalVocDataset as Dataset
    elif dataset == 'cityscapes':
        from dataset.cityscapes import CityscapesDataset as Dataset
    else:
        raise NotImplementedError
    if include_bg:
        classes = np.arange(0, data_config['num_classes'])
    else:
        classes = np.arange(1, data_config['num_classes'])
Exemplo n.º 5
0
parser.add_argument("--vis", action="store_true")
args = parser.parse_args()
config_path = Path(args.config_path)
tta_flag = args.tta
vis_flag = args.vis

config = yaml.load(open(config_path))
net_config = config["Net"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

modelname = config_path.stem
model_path = Path("../model") / modelname / "model.pth"

if "unet" in net_config["dec_type"]:
    net_type = "unet"
    model = EncoderDecoderNet(**net_config)
else:
    net_type = "deeplab"
    model = SPPNet(**net_config)
model.to(device)
model.update_bn_eps()

param = torch.load(model_path)
model.load_state_dict(param)
del param

model.eval()

batch_size = 1
scales = [0.25, 0.75, 1, 1.25]
Exemplo n.º 6
0
train_config = config['Train']
loss_config = config['Loss']
opt_config = config['Optimizer']
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
t_max = opt_config['t_max']

max_epoch = train_config['max_epoch']
batch_size = train_config['batch_size']
fp16 = train_config['fp16']
resume = train_config['resume']
pretrained_path = train_config['pretrained_path']

# Network
if 'unet' in net_config['dec_type']:
    net_type = 'unet'
    model = EncoderDecoderNet(**net_config)
else:
    net_type = 'deeplab'
    model = SPPNet(**net_config)

dataset = data_config['dataset']
if dataset == 'pascal':
    from dataset.pascal_voc import PascalVocDataset as Dataset
    net_config['output_channels'] = 21
    classes = np.arange(1, 21)
elif dataset == 'cityscapes':
    from dataset.cityscapes import CityscapesDataset as Dataset
    net_config['output_channels'] = 19
    classes = np.arange(1, 19)
else:
    raise NotImplementedError
Exemplo n.º 7
0
    output_dir.mkdir(exist_ok=True, parents=True)
    log_dir = Path('../logs').joinpath(modelname)
    log_dir.mkdir(exist_ok=True, parents=True)

    logger = debug_logger(log_dir)
    logger.info(f'Device: {device}')
    logger.info(f'Max Epoch: {max_epoch}')

    del data_config['dataset']
    train_dataset = Dataset(split='train', **data_config)
    valid_dataset = Dataset(split='valid', **data_config)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=4,
                              pin_memory=True)

    if 'unet' in net_config['dec_type']:
        model = EncoderDecoderNet(**net_config).to(device)
    else:
        model = SPPNet(**net_config).to(device)
    loss_fn = CrossEntropy2d(**loss_config).to(device)
    optimizer, scheduler = create_optimizer(model=model, **opt_config)

    train()
Exemplo n.º 8
0
    def __init__(
            self,
            model_path='../model/deepglobe_deeplabv3_weights-cityscapes_19-outputs/model.pth',
            dataset='deepglobe',
            output_channels=19,
            split='valid',
            net_type='deeplab',
            batch_size=1,
            shuffle=True):
        """
        Initializes the tester by loading the model with the good parameters.
        :param model_path: Path to model weights
        :param dataset: dataset used amongst {'deepglobe', 'pascal', 'cityscapes'}
        :param output_channels: num of output channels of model
        :param split: split to be used amongst {'train', 'valid'}
        :param net_type: model type to be used amongst {'deeplab', 'unet'}
        :param batch_size: batch size when loading images (always 1 here)
        :param shuffle: when loading images from dataset
        """
        model_path = '/home/sfoucher/DEV/pytorch-segmentation/model/my_pascal_unet_res18_scse/model.pth'
        dataset_dir = '/home/sfoucher/DEV/pytorch-segmentation/data/deepglobe_as_pascalvoc/VOCdevkit/VOC2012'

        output_channels = 8
        net_type = 'unet'
        print('[Tester] [Init] Initializing tester...')
        self.dataset = dataset
        self.model_path = model_path

        # Load model
        print('[Tester] [Init] Loading model ' + model_path + ' with ' +
              str(output_channels) + ' output channels...')

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        if net_type == 'unet':
            self.model = EncoderDecoderNet(output_channels=8,
                                           enc_type='resnet18',
                                           dec_type='unet_scse',
                                           num_filters=8)
        else:
            self.model = SPPNet(output_channels=output_channels).to(
                self.device)
        param = torch.load(model_path)
        self.model.load_state_dict(param)
        del param

        # Create data loader depending on dataset, split and net type
        if dataset == 'pascal':
            self.valid_dataset = PascalVocDataset(split=split,
                                                  net_type=net_type)
        elif dataset == 'cityscapes':
            self.valid_dataset = CityscapesDataset(split=split,
                                                   net_type=net_type)
        elif dataset == 'deepglobe':
            self.valid_dataset = DeepGlobeDataset(base_dir=dataset_dir,
                                                  target_size=(64, 64),
                                                  split=split,
                                                  net_type=net_type)
        else:
            raise NotImplementedError

        self.valid_loader = DataLoader(self.valid_dataset,
                                       batch_size=batch_size,
                                       shuffle=shuffle)

        print('[Tester] [Init] ...done!')
        print('[Tester] [Init] Tester created.')
Exemplo n.º 9
0
def process(config_path):
    gc.collect()
    torch.cuda.empty_cache()
    config = yaml.load(open(config_path))
    net_config = config['Net']
    data_config = config['Data']
    train_config = config['Train']
    loss_config = config['Loss']
    opt_config = config['Optimizer']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    t_max = opt_config['t_max']

    # Collect training parameters
    max_epoch = train_config['max_epoch']
    batch_size = train_config['batch_size']
    fp16 = train_config['fp16']
    resume = train_config['resume']
    pretrained_path = train_config['pretrained_path']
    freeze_enabled = train_config['freeze']
    seed_enabled = train_config['seed']

    #########################################
    # Deterministic training
    if seed_enabled:
        seed = 100
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        np.random.seed(seed=seed)
        import random
        random.seed(a=100)
    #########################################

    # Network
    if 'unet' in net_config['dec_type']:
        net_type = 'unet'
        model = EncoderDecoderNet(**net_config)
    else:
        net_type = 'deeplab'
        net_config['output_channels'] = 19
        model = SPPNet(**net_config)

    dataset = data_config['dataset']
    if dataset == 'deepglobe-dynamic':
        from dataset.deepglobe_dynamic import DeepGlobeDatasetDynamic as Dataset
        net_config['output_channels'] = 7
        classes = np.arange(0, 7)
    else:
        raise NotImplementedError
    del data_config['dataset']

    modelname = config_path.stem
    timestamp = datetime.timestamp(datetime.now())
    print("timestamp =", datetime.fromtimestamp(timestamp))
    output_dir = Path(os.path.join(ROOT_DIR, f'model/{modelname}_{datetime.fromtimestamp(timestamp)}') )
    output_dir.mkdir(exist_ok=True)
    log_dir = Path(os.path.join(ROOT_DIR, f'logs/{modelname}_{datetime.fromtimestamp(timestamp)}') )
    log_dir.mkdir(exist_ok=True)
    dataset_dir= '/home/sfoucher/DEV/pytorch-segmentation/data/deepglobe_as_pascalvoc/VOCdevkit/VOC2012'
    logger = debug_logger(log_dir)
    logger.debug(config)
    logger.info(f'Device: {device}')
    logger.info(f'Max Epoch: {max_epoch}')

    # Loss
    loss_fn = MultiClassCriterion(**loss_config).to(device)
    params = model.parameters()
    optimizer, scheduler = create_optimizer(params, **opt_config)

    # history
    if resume:
        with open(log_dir.joinpath('history.pkl'), 'rb') as f:
            history_dict = pickle.load(f)
            best_metrics = history_dict['best_metrics']
            loss_history = history_dict['loss']
            iou_history = history_dict['iou']
            start_epoch = len(iou_history)
            for _ in range(start_epoch):
                scheduler.step()
    else:
        start_epoch = 0
        best_metrics = 0
        loss_history = []
        iou_history = []


    affine_augmenter = albu.Compose([albu.HorizontalFlip(p=.5),albu.VerticalFlip(p=.5)
                                    # Rotate(5, p=.5)
                                    ])
    # image_augmenter = albu.Compose([albu.GaussNoise(p=.5),
    #                                 albu.RandomBrightnessContrast(p=.5)])
    image_augmenter = None

    # This has been put in the loop for the dynamic training

    """
    # Dataset
    train_dataset = Dataset(affine_augmenter=affine_augmenter, image_augmenter=image_augmenter,
                            net_type=net_type, **data_config)
    valid_dataset = Dataset(split='valid', net_type=net_type, **data_config)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,
                            pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)
    """

    

    # Pretrained model
    if pretrained_path:
        logger.info(f'Resume from {pretrained_path}')
        param = torch.load(pretrained_path)
        model.load_state_dict(param)
        model.logits = torch.nn.Conv2d(256, net_config['output_channels'], 1)
        del param

    # To device
    model = model.to(device)

    #########################################
    if freeze_enabled:
        # Code de Rémi
        # Freeze layers
        for param_index in range(int((len(optimizer.param_groups[0]['params']))*0.5)):
            optimizer.param_groups[0]['params'][param_index].requires_grad = False
    #########################################
        params_to_update = model.parameters()
        print("Params to learn:")
        if freeze_enabled:
            params_to_update = []
            for name,param in model.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
                    print("\t",name)
        optimizer, scheduler = create_optimizer(params_to_update, **opt_config)

    # fp16
    if fp16:
        # I only took the necessary files because I don't need the C backend of apex,
        # which is broken and can't be installed
        # from apex import fp16_utils
        from utils.apex.apex.fp16_utils.fp16util import BN_convert_float
        from utils.apex.apex.fp16_utils.fp16_optimizer import FP16_Optimizer
        # model = fp16_utils.BN_convert_float(model.half())
        model = BN_convert_float(model.half())
        # optimizer = fp16_utils.FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True)
        optimizer = FP16_Optimizer(optimizer, verbose=False, dynamic_loss_scale=True)
        logger.info('Apply fp16')

    # Restore model
    if resume:
        model_path = output_dir.joinpath(f'model_tmp.pth')
        logger.info(f'Resume from {model_path}')
        param = torch.load(model_path)
        model.load_state_dict(param)
        del param
        opt_path = output_dir.joinpath(f'opt_tmp.pth')
        param = torch.load(opt_path)
        optimizer.load_state_dict(param)
        del param
    i_iter = 0
    ma_loss= 0
    ma_iou= 0
    # Train
    for i_epoch in range(start_epoch, max_epoch):
        logger.info(f'Epoch: {i_epoch}')
        logger.info(f'Learning rate: {optimizer.param_groups[0]["lr"]}')

        train_losses = []
        train_ious = []
        model.train()

        # Initialize randomized but balanced datasets
        train_dataset = Dataset(base_dir = dataset_dir,
                                affine_augmenter=affine_augmenter, image_augmenter=image_augmenter,
                                net_type=net_type, **data_config)
        valid_dataset = Dataset(base_dir = dataset_dir,
                                split='valid', net_type=net_type, **data_config)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4,
                                pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

        with tqdm(train_loader) as _tqdm:
            for i, batched in enumerate(_tqdm):
                images, labels = batched
                if fp16:
                    images = images.half()
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                preds = model(images)
                if net_type == 'deeplab':
                    preds = F.interpolate(preds, size=labels.shape[1:], mode='bilinear', align_corners=True)
                if fp16:
                    loss = loss_fn(preds.float(), labels)
                else:
                    loss = loss_fn(preds, labels)

                preds_np = preds.detach().cpu().numpy()
                labels_np = labels.detach().cpu().numpy()
                iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes)

                _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}'))
                train_losses.append(loss.item())
                train_ious.append(iou)
                ma_loss= 0.01*loss.item() +  0.99 * ma_loss
                ma_iou= 0.01*iou +  0.99 * ma_iou
                plotter.plot('loss', 'train', 'iteration Loss', i_iter, loss.item())
                plotter.plot('iou', 'train', 'iteration iou', i_iter, iou)
                plotter.plot('loss', 'ma_loss', 'iteration Loss', i_iter, ma_loss)
                plotter.plot('iou', 'ma_iou', 'iteration iou', i_iter, ma_iou)
                if fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                optimizer.step()
                i_iter += 1
        scheduler.step()

        train_loss = np.mean(train_losses)
        train_iou = np.nanmean(train_ious)
        logger.info(f'train loss: {train_loss}')
        logger.info(f'train iou: {train_iou}')
        plotter.plot('loss-epoch', 'train', 'iteration Loss', i_epoch, train_loss)
        plotter.plot('iou-epoch', 'train', 'iteration iou', i_epoch, train_iou)
        torch.save(model.state_dict(), output_dir.joinpath('model_tmp.pth'))
        torch.save(optimizer.state_dict(), output_dir.joinpath('opt_tmp.pth'))

        valid_losses = []
        valid_ious = []
        model.eval()
        with torch.no_grad():
            with tqdm(valid_loader) as _tqdm:
                for batched in _tqdm:
                    images, labels = batched
                    if fp16:
                        images = images.half()
                    images, labels = images.to(device), labels.to(device)
                    preds = model.tta(images, net_type=net_type)
                    if fp16:
                        loss = loss_fn(preds.float(), labels)
                    else:
                        loss = loss_fn(preds, labels)

                    preds_np = preds.detach().cpu().numpy()
                    labels_np = labels.detach().cpu().numpy()

                    # I changed a parameter in the compute_iou method to prevent it from yielding nans
                    iou = compute_iou_batch(np.argmax(preds_np, axis=1), labels_np, classes)

                    _tqdm.set_postfix(OrderedDict(seg_loss=f'{loss.item():.5f}', iou=f'{iou:.3f}'))
                    valid_losses.append(loss.item())
                    valid_ious.append(iou)

        valid_loss = np.mean(valid_losses)
        valid_iou = np.mean(valid_ious)
        logger.info(f'valid seg loss: {valid_loss}')
        logger.info(f'valid iou: {valid_iou}')
        plotter.plot('loss-epoch', 'valid', 'iteration Loss', i_epoch, valid_loss)
        plotter.plot('iou-epoch', 'valid', 'iteration iou', i_epoch, valid_iou)
        if best_metrics < valid_iou:
            best_metrics = valid_iou
            logger.info('Best Model!')
            torch.save(model.state_dict(), output_dir.joinpath('model.pth'))
            torch.save(optimizer.state_dict(), output_dir.joinpath('opt.pth'))

        loss_history.append([train_loss, valid_loss])
        iou_history.append([train_iou, valid_iou])
        history_ploter(loss_history, log_dir.joinpath('loss.png'))
        history_ploter(iou_history, log_dir.joinpath('iou.png'))

        history_dict = {'loss': loss_history,
                        'iou': iou_history,
                        'best_metrics': best_metrics}
        with open(log_dir.joinpath('history.pkl'), 'wb') as f:
            pickle.dump(history_dict, f)