Example #1
0
def rotate_augment(image, mask):
    '''
    takes an image and rotates it by multiple of 90 degree.
    Args:
        image(nd.array)
        mask(nd.array)
    Returns:
        (rot_img, rot_mask) (tuple): tuple of rotated image and mask
    '''

    factor = random.randint(1, 3)
    rot_img = transforms.RandomRotate90().apply(image, factor)
    rot_mask = transforms.RandomRotate90().apply(mask, factor)

    return (rot_img, rot_mask)
Example #2
0
    def __init__(
        self,
        prob=0.7,
        blur_prob=0.7,
        jitter_prob=0.7,
        rotate_prob=0.7,
        flip_prob=0.7,
    ):
        super().__init__()

        self.prob = prob
        self.blur_prob = blur_prob
        self.jitter_prob = jitter_prob
        self.rotate_prob = rotate_prob
        self.flip_prob = flip_prob

        self.transforms = al.Compose(
            [
                transforms.RandomRotate90(),
                transforms.Flip(),
                transforms.HueSaturationValue(),
                transforms.RandomBrightnessContrast(),
                transforms.Transpose(),
                OneOf([
                    transforms.RandomCrop(220, 220, p=0.5),
                    transforms.CenterCrop(220, 220, p=0.5)
                ],
                      p=0.5),
                # transforms.Resize(352,352),
                # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            ],
            p=self.prob)
Example #3
0
def ben_augmentation():
    # !! Need to do something to change color _probably_
    return ACompose([
        atransforms.HorizontalFlip(p=0.5),
        atransforms.RandomRotate90(p=1.0),
        atransforms.ShiftScaleRotate(shift_limit=0, scale_limit=0, p=1.0),
        atransforms.RandomSizedCrop((60, 120),
                                    height=128,
                                    width=128,
                                    interpolation=3),

        # atransforms.GridDistortion(num_steps=5, p=0.5), # !! Maybe too much noise?
        atransforms.Normalize(mean=BEN_BAND_STATS['mean'],
                              std=BEN_BAND_STATS['std']),

        # atransforms.ChannelDropout(channel_drop_range=(1, 2), p=0.5),
        AToTensor(),
    ])
Example #4
0
    def __init__(self, config):
        train_transform = Compose([
            transforms.RandomRotate90(),
            transforms.Flip(),
            transforms.Resize(config['input_h'], config['input_w']),
        ])
        val_transform = Compose([
            transforms.Resize(config['input_h'], config['input_w']),
        ])

        super().__init__()
        self.img_ext = config['img_ext']
        self.mask_ext = config['mask_ext']
        self.num_classes = config['num_classes']
        self.num_workers = config['num_workers']
        self.data_dir = os.path.join('data', config['dataset'],
                                     config['sub_dataset'])
        self.batch_size = config['batch_size']
        self.pin_memory = config['pin_memory']
        self.train_transform = train_transform
        self.val_transform = val_transform
        self.test_transform = None
        self.train_ids, self.val_ids, self.test_ids = None, None, None
        self.num_inputs = config['num_inputs']
Example #5
0
    def __init__(
        self,
        prob=0,
        Flip_prob=0,
        HueSaturationValue_prob=0,
        RandomBrightnessContrast_prob=0,
        crop_prob=0,
        randomrotate90_prob=0,
        elastictransform_prob=0,
        gridistortion_prob=0,
        opticaldistortion_prob=0,
        verticalflip_prob=0,
        horizontalflip_prob=0,
        randomgamma_prob=0,
        CoarseDropout_prob=0,
        RGBShift_prob=0,
        MotionBlur_prob=0,
        MedianBlur_prob=0,
        GaussianBlur_prob=0,
        GaussNoise_prob=0,
        ChannelShuffle_prob=0,
        ColorJitter_prob=0,
    ):
        super().__init__()

        self.prob = prob
        self.randomrotate90_prob = randomrotate90_prob
        self.elastictransform_prob = elastictransform_prob

        self.transforms = al.Compose(
            [
                transforms.RandomRotate90(p=randomrotate90_prob),
                transforms.Flip(p=Flip_prob),
                transforms.HueSaturationValue(p=HueSaturationValue_prob),
                transforms.RandomBrightnessContrast(
                    p=RandomBrightnessContrast_prob),
                transforms.Transpose(),
                OneOf(
                    [
                        transforms.RandomCrop(220, 220, p=0.5),
                        transforms.CenterCrop(220, 220, p=0.5),
                    ],
                    p=crop_prob,
                ),
                ElasticTransform(
                    p=elastictransform_prob,
                    alpha=120,
                    sigma=120 * 0.05,
                    alpha_affine=120 * 0.03,
                ),
                GridDistortion(p=gridistortion_prob),
                OpticalDistortion(p=opticaldistortion_prob,
                                  distort_limit=2,
                                  shift_limit=0.5),
                VerticalFlip(p=verticalflip_prob),
                HorizontalFlip(p=horizontalflip_prob),
                RandomGamma(p=randomgamma_prob),
                RGBShift(p=RGBShift_prob),
                MotionBlur(p=MotionBlur_prob, blur_limit=7),
                MedianBlur(p=MedianBlur_prob, blur_limit=9),
                GaussianBlur(p=GaussianBlur_prob, blur_limit=9),
                GaussNoise(p=GaussNoise_prob),
                ChannelShuffle(p=ChannelShuffle_prob),
                CoarseDropout(p=CoarseDropout_prob,
                              max_holes=8,
                              max_height=32,
                              max_width=32),
                ColorJitter(p=ColorJitter_prob)
                # transforms.Resize(352, 352),
                # transforms.Normalize(
                #     mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                # ),
            ],
            p=self.prob,
        )
Example #6
0
    from albumentations.core.composition import Compose, OneOf

    from glob import glob

    train_img_paths = []
    train_mask_paths = []
    train_data_path = ["data/kvasir-seg/TrainDataset"]
    for i in train_data_path:
        train_img_paths.extend(glob(os.path.join(i, "images", "*")))
        train_mask_paths.extend(glob(os.path.join(i, "masks", "*")))
    train_img_paths.sort()
    train_mask_paths.sort()

    transforms = al.Compose(
        [
            transforms.RandomRotate90(),
            transforms.Flip(),
            transforms.HueSaturationValue(),
            transforms.RandomBrightnessContrast(),
            transforms.Transpose(),
            OneOf(
                [
                    transforms.RandomCrop(220, 220, p=0.5),
                    transforms.CenterCrop(220, 220, p=0.5),
                ],
                p=1,
            ),
            # transforms.Resize(352,352),
            # transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ],
        p=1,
def main():
    config = vars(parse_args())

    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    os.makedirs('models/%s' % config['name'], exist_ok=True)

    print('-' * 20)
    for key in config:
        print('%s: %s' % (key, config[key]))
    print('-' * 20)

    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(params,
                               lr=config['lr'],
                               weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params,
                              lr=config['lr'],
                              momentum=config['momentum'],
                              nesterov=config['nesterov'],
                              weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=config['epochs'],
                                                   eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   factor=config['factor'],
                                                   patience=config['patience'],
                                                   verbose=1,
                                                   min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(e) for e in config['milestones'].split(',')],
            gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    img_ids = glob(
        os.path.join('inputs', config['dataset'], 'images',
                     '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    train_img_ids, val_img_ids = train_test_split(img_ids,
                                                  test_size=0.2,
                                                  random_state=41)

    #######################
    train_img_ids = img_ids
    #######################

    train_transform = Compose([
        transforms.RandomRotate90(),
        transforms.Flip(),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),
        ],
              p=1),
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    train_dataset = Dataset(img_ids=train_img_ids,
                            img_dir=os.path.join('inputs', config['dataset'],
                                                 'images'),
                            mask_dir=os.path.join('inputs', config['dataset'],
                                                  'masks'),
                            img_ext=config['img_ext'],
                            mask_ext=config['mask_ext'],
                            num_classes=config['num_classes'],
                            transform=train_transform)
    val_dataset = Dataset(img_ids=val_img_ids,
                          img_dir=os.path.join('inputs', config['dataset'],
                                               'images'),
                          mask_dir=os.path.join('inputs', config['dataset'],
                                                'masks'),
                          img_ext=config['img_ext'],
                          mask_ext=config['mask_ext'],
                          num_classes=config['num_classes'],
                          transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
    ])

    best_iou = 0
    trigger = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f' %
              (train_log['loss'], train_log['iou'], val_log['loss'],
               val_log['iou']))

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])

        pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'],
                                 index=False)

        trigger += 1

        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(),
                       'models/%s/model.pth' % config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config[
                'early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()
Example #8
0
    def get_transforms(phase: str, cli_args) -> Dict[str, Compose]:
        """Get composed albumentations augmentations

        Parameters
        ----------
        phase : str
            Phase of learning
            In ['train', 'val']
        cli_args
            Arguments coming all the way from `main.py`

        Returns
        -------
        transforms: dict[str, albumentations.core.composition.Compose]
            Composed list of transforms
        """
        aug_transforms = []
        im_sz = (cli_args.image_size, cli_args.image_size)

        if phase == "train":
            # Data augmentation for training only
            aug_transforms.extend([
                tf.ShiftScaleRotate(
                    shift_limit=0,
                    scale_limit=0.1,
                    rotate_limit=15,
                    p=0.5),
                tf.Flip(p=0.5),
                tf.RandomRotate90(p=0.5),
            ])
            # Exotic Augmentations for train only 🤤
            aug_transforms.extend([
                tf.RandomBrightnessContrast(p=0.5),
                tf.ElasticTransform(p=0.5),
                tf.MultiplicativeNoise(multiplier=(0.5, 1.5),
                                       per_channel=True, p=0.2),
            ])
        aug_transforms.extend([
            tf.RandomSizedCrop(min_max_height=im_sz,
                               height=im_sz[0],
                               width=im_sz[1],
                               w2h_ratio=1.0,
                               interpolation=cv2.INTER_LINEAR,
                               p=1.0),
        ])
        aug_transforms = Compose(aug_transforms)

        mask_only_transforms = Compose([
            tf.Normalize(mean=0, std=1, always_apply=True)
        ])
        image_only_transforms = Compose([
            tf.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0),
                         always_apply=True)
        ])
        final_transforms = Compose([
            ToTensorV2()
        ])

        transforms = {
            'aug': aug_transforms,
            'img_only': image_only_transforms,
            'mask_only': mask_only_transforms,
            'final': final_transforms
        }
        return transforms
Example #9
0
def main():
    config = vars(parse_args())
    now = datetime.datetime.now()

    if config["name"] is None:
        if config["deep_supervision"]:
            config["name"] = "%s_%s_wDS_%s" % (
                config["dataset"],
                config["arch"],
                now.strftime("%Y%m%d_%H%M%S"),
            )
        else:
            config["name"] = "%s_%s_woDS_%s" % (
                config["dataset"],
                config["arch"],
                now.strftime("%Y%m%d_%H%M%S"),
            )
    output_path = os.path.join(cfg.UNET_RESULTS_DIR, config["name"])
    try:
        os.makedirs(output_path, exist_ok=True)
    except Exception as e:
        print(e)

    models_path = os.path.join(output_path, "models")
    os.mkdir(models_path)

    with open(os.path.join(models_path, "config.yml"), "w") as f:
        yaml.dump(config, f)

    print("-" * 20)
    for key in config:
        print("%s: %s" % (key, config[key]))
    print("-" * 20)

    # Tensorboad 用のログを記録するディレクトリパス
    log_dir = os.path.join(output_path, "log")
    os.mkdir(log_dir)
    writer = SummaryWriter(log_dir=log_dir)

    # define loss function(criterion)
    if config["loss"] == "BCEWithLogitsLoss":
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config["loss"]]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config["arch"])
    model = archs.__dict__[config["arch"]](config["num_classes"],
                                           config["input_channels"],
                                           config["deep_supervision"])

    model = model.cuda()

    # モデルを TensorBorad で表示するため,ログに保存
    # image = torch.randn(1, 3, 2224, 224)
    # writer.add_graph(model, image)

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config["optimizer"] == "Adam":
        optimizer = optim.Adam(params,
                               lr=config["lr"],
                               weight_decay=config["weight_decay"])
    elif config["optimizer"] == "SGD":
        optimizer = optim.SGD(
            params,
            lr=config["lr"],
            momentum=config["momentum"],
            nesterov=config["nesterov"],
            weight_decay=config["weight_decay"],
        )
    else:
        raise NotImplementedError

    # scheduler
    if config["scheduler"] == "CosineAnnealingLR":
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                   T_max=config["epochs"],
                                                   eta_min=config["min_lr"])
    elif config["scheduler"] == "ReduceLROnPlateau":
        scheduler = lr_scheduler(
            optimizer=optimizer,
            factor=config["factor"],
            patience=config["patience"],
            verbose=1,
            min_lr=config["min_lr"],
        )
    elif config["scheduler"] == "MultiStepLR":
        scheduler = lr_scheduler.MultiStepLR(
            optimizer=optimizer,
            milestones=[int(e) for e in config["milestones"].split(",")],
            gamma=config["gamma"],
        )
    elif config["scheduler"] == "ConstantLR":
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    if config["dataset"] == "dsb2018_96":
        input_dir = cfg.DSB2018_96_DIR
        img_ids = glob(
            os.path.join(input_dir, "images", "*" + config["img_ext"]))
        img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

        train_img_ids, val_img_ids = train_test_split(img_ids,
                                                      test_size=0.2,
                                                      random_state=41)

        train_transform = Compose([
            transforms.RandomRotate90(),
            transforms.Flip(),
            OneOf(
                [
                    transforms.HueSaturationValue(),
                    transforms.RandomBrightness(),
                    transforms.RandomContrast(),
                ],
                p=1,
            ),
            transforms.Resize(config["input_h"], config["input_w"]),
            transforms.Normalize(),
        ])

        val_transform = Compose([
            transforms.Resize(config["input_h"], config["input_w"]),
            transforms.Normalize(),
        ])

        train_dataset = Dataset(
            img_ids=train_img_ids,
            img_dir=os.path.join(input_dir, "images"),
            mask_dir=os.path.join(input_dir, "masks"),
            img_ext=config["img_ext"],
            mask_ext=config["mask_ext"],
            num_classes=config["num_classes"],
            transform=train_transform,
        )

        val_dataset = Dataset(
            img_ids=val_img_ids,
            img_dir=os.path.join(input_dir, "images"),
            mask_dir=os.path.join(input_dir, "masks"),
            img_ext=config["img_ext"],
            mask_ext=config["mask_ext"],
            num_classes=config["num_classes"],
            transform=val_transform,
        )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        num_workers=config["num_workers"],
        drop_last=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        num_workers=config["num_workers"],
        drop_last=False,
    )

    log = OrderedDict([
        ("epoch", []),
        ("lr", []),
        ("loss", []),
        ("iou", []),
        ("val_loss", []),
        ("val_iou", []),
    ])

    best_iou = 0
    trigger = 0
    for epoch in range(config["epochs"] + 1):
        print("Epoch [%d/%d]" % (epoch, config["epochs"]))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        if config["scheduler"] == "CosineAnnealingLR":
            scheduler.step()
        elif config["scheduler"] == "ReduceLROnPlateau":
            scheduler.step(val_log["loss"])

        print("loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f" %
              (train_log["loss"], train_log["iou"], val_log["loss"],
               val_log["iou"]))

        log["epoch"].append(epoch)
        log["lr"].append(config["lr"])
        log["loss"].append(train_log["loss"])
        log["iou"].append(train_log["iou"])
        log["val_loss"].append(val_log["loss"])
        log["val_iou"].append(val_log["iou"])

        # Tensorboard用のデータ
        writer.add_scalar("training loss", train_log["loss"], epoch)
        writer.add_scalar("validation loss", val_log["loss"], epoch)

        pd.DataFrame(log).to_csv("%s/log.csv" % (log_dir), index=False)

        if epoch == 0:
            best_loss = val_log["loss"]
        trigger += 1

        # Best Model Save
        # if val_log['iou'] > best_iou:
        if (val_log["iou"] > best_iou) and (val_log["loss"] <= best_loss):
            torch.save(model.state_dict(), "%s/model.pth" % (models_path))
            best_iou = val_log["iou"]
            best_loss = val_log["loss"]
            print("=> saved best model")
            trigger = 0

        # early stopping
        if (config["early_stopping"] >= 0 and
                trigger >= config["early_stopping"]) or val_log["loss"] < 1e-4:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()

    # summary writer を必要としない場合,close()メソッドを呼び出す
    writer.close()
Example #10
0
def train_entry(config, rank=-1, world_size=0):
    if (rank > 0): 
        setup(rank, world_size)

    os.makedirs('models/%s' % config['name'], exist_ok=True)
    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    arch = config['arch'].split('_')[0]

    print("=> creating model %s" % arch)
    model = getattr(archs, arch)(config['num_classes'],
                                 config['input_channels'],
                                 config['deep_supervision'])
    model = nn.DataParallel(model)
    print(f'rank={rank}, device={torch.cuda.current_device()}')
    if rank >= 0:
        model = DDP(model, device_ids=[rank])
    else:
        model = model.cuda()

    # load trained model
    model_file = 'models/%s/model.pth' % config['name']
    if os.path.isfile(model_file):
        print("Reloading model ...")
        model.load_state_dict(torch.load(model_file))

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                              nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50)
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=config['factor'], patience=config['patience'], min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    data_path = os.path.join('data', config['dataset'])
    img_ids = glob(os.path.join(data_path, config['sub_dataset'], 'images', '*' + config['img_ext']))
    # img_ids = [os.path.basename(p) for p in img_ids]
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
    if config['num_inputs'] > 0:
        img_ids = img_ids[:config['num_inputs']]
    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)

    split_filename = os.path.join(data_path, 'split.p')
    if config['save_split'] is True:
        split_file = open(split_filename, 'wb')
        pickle.dump((train_img_ids, val_img_ids), split_file)

    if config['load_split'] is True:
        split_file = open(split_filename, 'rb')
        train_img_ids, val_img_ids = pickle.load(split_file)

    '''
    Since it is gray image, some augmentation methods not used
    '''

    # read statistics
    with open('data_processing/data_statistics.json') as json_file:
        data = json.load(json_file)

        entry = '%s_%s' % (config['dataset'], config['sub_dataset'])
        mean = data[entry]['mean']
        std = data[entry]['std']

        print("%s: mean=[%s] std=[%s]" % (config['dataset'], ', '.join(map(str, mean)), ', '.join(map(str, std))))

    # data normalization
    normalize = transforms.Normalize(mean=mean, std=std)

    train_transform = Compose([
        transforms.RandomRotate90(),
        transforms.Flip(),
        # transforms.HorizontalFlip(),
        # transforms.Rotate(limit=5),
        # OneOf([
        #     transforms.HueSaturationValue(),
        #     transforms.RandomBrightnessContrast(),
        #     # transforms.RandomContrast(),
        # ], p=1),
        transforms.Resize(config['input_h'], config['input_w']),
        # normalize,
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        # normalize,
    ])

    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('data', config['dataset'], config['sub_dataset'], 'images'),
        mask_dir=os.path.join('data', config['dataset'], config['sub_dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('data', config['dataset'], config['sub_dataset'], 'images'),
        mask_dir=os.path.join('data', config['dataset'], config['sub_dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('dice', []),
        ('val_loss', []),
        ('val_iou', []),
        ('val_dice', [])
    ])

    log_per_class = []
    for _ in range(config['num_classes']):
        log_per_class.append(OrderedDict([('val_iou', []), ('val_dice', [])]))

    best_iou = 0
    best_iou_per_class = [0] * config['num_classes']

    trigger = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        # evaluate on validation set for each class
        val_per_class_log = test_per_class(config, val_loader, model)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - dice %.4f - val_loss %.4f - val_iou %.4f - val_dice %.4f'
              % (train_log['loss'], train_log['iou'], train_log['dice'],
                 val_log['loss'], val_log['iou'], val_log['dice']))

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])

        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['dice'].append(train_log['dice'])

        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['val_dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv('models/%s/log.csv' %
                                 config['name'], index=False)

        # write into log for each class
        for class_id in range(config['num_classes']):
            log_per_class[class_id]['val_iou'].append(val_per_class_log[class_id]['iou'])
            log_per_class[class_id]['val_dice'].append(val_per_class_log[class_id]['dice'])

            print('[class %d] val_iou %.4f - val_dice %.4f'
                  % (class_id, val_per_class_log[class_id]['iou'], val_per_class_log[class_id]['dice']))

            pd.DataFrame(log_per_class[class_id]).to_csv('models/%s/log_%d.csv' %
                                                         (config['name'], class_id), index=False)

            if val_per_class_log[class_id]['iou'] > best_iou_per_class[class_id]:
                torch.save(model.state_dict(), 'models/%s/model_%d.pth' % (config['name'], class_id))
                best_iou_per_class[class_id] = val_per_class_log[class_id]['iou']
                print("===> saved best model for class %d" % class_id)

        trigger += 1

        # torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])

        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), 'models/%s/model.pth' %
                       config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            break

        torch.cuda.empty_cache()

        print("\n[Best Results:]")
        print('- [Overall IoU] %.4f' % best_iou)
        for class_id in range(config['num_classes']):
            print('-- [class %d IoU] %.4f' % (class_id, best_iou_per_class[class_id]))

    if rank >= 0: 
        cleanup()

    return best_iou