def augmentation(mode, target_size, prob=0.5, aug_m=2):
    '''
    description: augmentation
    mode: 'train' 'test'
    target_size: int or list, the shape of image ,
    aug_m: Strength of transform
    '''
    high_p = prob
    low_p = high_p / 2.0
    M = aug_m
    first_size = [int(x / 0.7) for x in target_size]

    if mode == 'train':
        return composition.Compose([
            transforms.Resize(first_size[0], first_size[1], interpolation=3),
            transforms.Flip(p=0.5),
            composition.OneOf([
                RandomCenterCut(scale=0.1 * M),
                transforms.ShiftScaleRotate(shift_limit=0.05 * M,
                                            scale_limit=0.1 * M,
                                            rotate_limit=180,
                                            border_mode=cv2.BORDER_CONSTANT,
                                            value=0),
                albumentations.imgaug.transforms.IAAAffine(
                    shear=(-10 * M, 10 * M), mode='constant')
            ],
                              p=high_p),
            transforms.RandomBrightnessContrast(
                brightness_limit=0.1 * M, contrast_limit=0.03 * M, p=high_p),
            transforms.HueSaturationValue(hue_shift_limit=5 * M,
                                          sat_shift_limit=15 * M,
                                          val_shift_limit=10 * M,
                                          p=high_p),
            transforms.OpticalDistortion(distort_limit=0.03 * M,
                                         shift_limit=0,
                                         border_mode=cv2.BORDER_CONSTANT,
                                         value=0,
                                         p=low_p),
            composition.OneOf([
                transforms.Blur(blur_limit=7),
                albumentations.imgaug.transforms.IAASharpen(),
                transforms.GaussNoise(var_limit=(2.0, 10.0), mean=0),
                transforms.ISONoise()
            ],
                              p=low_p),
            transforms.Resize(target_size[0], target_size[1], interpolation=3),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5, 0.5),
                                 max_pixel_value=255.0)
        ],
                                   p=1)

    else:
        return composition.Compose([
            transforms.Resize(target_size[0], target_size[1], interpolation=3),
            transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                 std=(0.5, 0.5, 0.5),
                                 max_pixel_value=255.0)
        ],
                                   p=1)
Beispiel #2
0
def ben_augmentation():
    # !! Need to do something to change color _probably_
    return ACompose([
        atransforms.HorizontalFlip(p=0.5),
        atransforms.RandomRotate90(p=1.0),
        atransforms.ShiftScaleRotate(shift_limit=0, scale_limit=0, p=1.0),
        atransforms.RandomSizedCrop((60, 120),
                                    height=128,
                                    width=128,
                                    interpolation=3),

        # atransforms.GridDistortion(num_steps=5, p=0.5), # !! Maybe too much noise?
        atransforms.Normalize(mean=BEN_BAND_STATS['mean'],
                              std=BEN_BAND_STATS['std']),

        # atransforms.ChannelDropout(channel_drop_range=(1, 2), p=0.5),
        AToTensor(),
    ])
def get_augmentations():
    """Get a list of 'major' and 'minor' augmentation functions for the pipeline in a dictionary."""
    return {
        "major": {
            "shift-scale-rot":
            trans.ShiftScaleRotate(
                shift_limit=0.05,
                rotate_limit=35,
                border_mode=cv2.BORDER_REPLICATE,
                always_apply=True,
            ),
            "crop":
            trans.RandomResizedCrop(100,
                                    100,
                                    scale=(0.8, 0.95),
                                    ratio=(0.8, 1.2),
                                    always_apply=True),
            # "elastic": trans.ElasticTransform(
            #     alpha=0.8,
            #     alpha_affine=10,
            #     sigma=40,
            #     border_mode=cv2.BORDER_REPLICATE,
            #     always_apply=True,
            # ),
            "distort":
            trans.OpticalDistortion(0.2, always_apply=True),
        },
        "minor": {
            "blur":
            trans.GaussianBlur(7, always_apply=True),
            "noise":
            trans.GaussNoise((20.0, 40.0), always_apply=True),
            "bright-contrast":
            trans.RandomBrightnessContrast(0.4, 0.4, always_apply=True),
            "hsv":
            trans.HueSaturationValue(30, 40, 50, always_apply=True),
            "rgb":
            trans.RGBShift(always_apply=True),
            "flip":
            trans.HorizontalFlip(always_apply=True),
        },
    }
 def train_dataloader(self):
     augmentations = Compose(
         [
             A.RandomResizedCrop(
                 height=self.hparams.sz,
                 width=self.hparams.sz,
                 scale=(0.7, 1.0),
             ),
             # AdvancedHairAugmentation(),
             A.GridDistortion(),
             A.RandomBrightnessContrast(),
             A.ShiftScaleRotate(),
             A.Flip(p=0.5),
             A.CoarseDropout(
                 max_height=int(self.hparams.sz / 10),
                 max_width=int(self.hparams.sz / 10),
             ),
             # A.HueSaturationValue(),
             A.Normalize(
                 mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225],
                 max_pixel_value=255,
             ),
             ToTensorV2(),
         ]
     )
     train_ds = MelanomaDataset(
         df=self.train_df,
         images_path=self.train_images_path,
         augmentations=augmentations,
         train_or_valid=True,
     )
     return DataLoader(
         train_ds,
         # sampler=sampler,
         batch_size=self.hparams.bs,
         shuffle=True,
         num_workers=os.cpu_count(),
         pin_memory=True,
     )
def get_tta_transforms():
    return Compose([
        A.RandomResizedCrop(
            height=hparams.sz,
            width=hparams.sz,
            scale=(0.7, 1.0),
        ),
        # AdvancedHairAugmentation(),
        A.GridDistortion(),
        A.RandomBrightnessContrast(),
        A.ShiftScaleRotate(),
        A.Flip(p=0.5),
        A.CoarseDropout(
            max_height=int(hparams.sz / 10),
            max_width=int(hparams.sz / 10),
        ),
        # A.HueSaturationValue(),
        A.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225],
            max_pixel_value=255,
        ),
        ToTensorV2(),
    ])
    def create_transform(self, args, is_train):
        """
        Convert numpy array into Tensor if dataset is for validation.
        Apply data augmentation method to train dataset while cv or test if args.use_aug is 1.

        is_train: boolean
            flg that dataset is for validation in cv or test
        return: Compose of albumentations
        """
        if is_train and args.use_aug == 1:
            transform = A.Compose([
                trans.Normalize(mean=self.cifar_10_mean,
                                std=self.cifar_10_std,
                                max_pixel_value=1.0),
                trans.HorizontalFlip(p=0.5),
                trans.ShiftScaleRotate(shift_limit=0,
                                       scale_limit=0.25,
                                       rotate_limit=30,
                                       p=1),
                trans.CoarseDropout(max_holes=1,
                                    min_holes=1,
                                    min_width=12,
                                    min_height=12,
                                    max_height=12,
                                    max_width=12,
                                    p=0.5),
                ToTensorV2()
            ])
        else:
            transform = A.Compose([
                trans.Normalize(mean=self.cifar_10_mean,
                                std=self.cifar_10_std,
                                max_pixel_value=1.0),
                ToTensorV2()
            ])
        return transform
def main():
    args = parse_args()

    if args.name is None:
        args.name = '%s_%s' % (args.arch, datetime.now().strftime('%m%d%H'))

    if not os.path.exists('models/%s' % args.name):
        os.makedirs('models/%s' % args.name)

    if args.resume:
        args = joblib.load('models/%s/args.pkl' % args.name)
        args.resume = True

    print('Config -----')
    for arg in vars(args):
        print('- %s: %s' % (arg, getattr(args, arg)))
    print('------------')

    with open('models/%s/args.txt' % args.name, 'w') as f:
        for arg in vars(args):
            print('- %s: %s' % (arg, getattr(args, arg)), file=f)

    joblib.dump(args, 'models/%s/args.pkl' % args.name)

    if args.seed is not None and not args.resume:
        print('set random seed')
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    if args.loss == 'BCEWithLogitsLoss':
        criterion = BCEWithLogitsLoss().cuda()
    elif args.loss == 'WeightedBCEWithLogitsLoss':
        criterion = BCEWithLogitsLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.]),
                                      smooth=args.label_smooth).cuda()
    elif args.loss == 'FocalLoss':
        criterion = FocalLoss().cuda()
    elif args.loss == 'WeightedFocalLoss':
        criterion = FocalLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.])).cuda()
    else:
        raise NotImplementedError

    if args.pred_type == 'all':
        num_outputs = 6
    elif args.pred_type == 'except_any':
        num_outputs = 5
    else:
        raise NotImplementedError

    cudnn.benchmark = True

    # create model
    model = get_model(model_name=args.arch,
                      num_outputs=num_outputs,
                      freeze_bn=args.freeze_bn,
                      dropout_p=args.dropout_p,
                      pooling=args.pooling,
                      lp_p=args.lp_p)
    model = model.cuda()

    train_transform = Compose([
        transforms.Resize(args.img_size, args.img_size),
        transforms.HorizontalFlip() if args.hflip else NoOp(),
        transforms.VerticalFlip() if args.vflip else NoOp(),
        transforms.ShiftScaleRotate(
            shift_limit=args.shift_limit,
            scale_limit=args.scale_limit,
            rotate_limit=args.rotate_limit,
            border_mode=cv2.BORDER_CONSTANT,
            value=0,
            p=args.shift_scale_rotate_p
        ) if args.shift_scale_rotate else NoOp(),
        transforms.RandomContrast(
            limit=args.contrast_limit,
            p=args.contrast_p
        ) if args.contrast else NoOp(),
        RandomErase() if args.random_erase else NoOp(),
        transforms.CenterCrop(args.crop_size, args.crop_size) if args.center_crop else NoOp(),
        ForegroundCenterCrop(args.crop_size) if args.foreground_center_crop else NoOp(),
        transforms.RandomCrop(args.crop_size, args.crop_size) if args.random_crop else NoOp(),
        transforms.Normalize(mean=model.mean, std=model.std),
        ToTensor(),
    ])

    if args.img_type:
        stage_1_train_dir = 'processed/stage_1_train_%s' %args.img_type
    else:
        stage_1_train_dir = 'processed/stage_1_train'

    df = pd.read_csv('inputs/stage_1_train.csv')
    img_paths = np.array([stage_1_train_dir + '/' + '_'.join(s.split('_')[:-1]) + '.png' for s in df['ID']][::6])
    labels = np.array([df.loc[c::6, 'Label'].values for c in range(6)]).T.astype('float32')

    df = df[::6]
    df['img_path'] = img_paths
    for c in range(6):
        df['label_%d' %c] = labels[:, c]
    df['ID'] = df['ID'].apply(lambda s: '_'.join(s.split('_')[:-1]))

    meta_df = pd.read_csv('processed/stage_1_train_meta.csv')
    meta_df['ID'] = meta_df['SOPInstanceUID']
    test_meta_df = pd.read_csv('processed/stage_1_test_meta.csv')
    df = pd.merge(df, meta_df, how='left')

    patient_ids = meta_df['PatientID'].unique()
    test_patient_ids = test_meta_df['PatientID'].unique()
    if args.remove_test_patient_ids:
        patient_ids = np.array([s for s in patient_ids if not s in test_patient_ids])

    train_img_paths = np.hstack(df[['img_path', 'PatientID']].groupby(['PatientID'])['img_path'].apply(np.array).loc[patient_ids].to_list()).astype('str')
    train_labels = []
    for c in range(6):
        train_labels.append(np.hstack(df[['label_%d' %c, 'PatientID']].groupby(['PatientID'])['label_%d' %c].apply(np.array).loc[patient_ids].to_list()))
    train_labels = np.array(train_labels).T

    if args.resume:
        checkpoint = torch.load('models/%s/checkpoint.pth.tar' % args.name)

    # train
    train_set = Dataset(
        train_img_paths,
        train_labels,
        transform=train_transform)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        # pin_memory=True,
    )

    if args.optimizer == 'Adam':
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamW':
        optimizer = optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'RAdam':
        optimizer = RAdam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                              momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov)
    else:
        raise NotImplementedError

    if args.apex:
        amp.initialize(model, optimizer, opt_level='O1')

    if args.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=args.epochs, eta_min=args.min_lr)
    elif args.scheduler == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer,
            milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma)
    else:
        raise NotImplementedError

    log = {
        'epoch': [],
        'loss': [],
    }

    start_epoch = 0

    if args.resume:
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        start_epoch = checkpoint['epoch']
        log = pd.read_csv('models/%s/log.csv' % args.name).to_dict(orient='list')

    for epoch in range(start_epoch, args.epochs):
        print('Epoch [%d/%d]' % (epoch + 1, args.epochs))

        # train for one epoch
        train_loss = train(args, train_loader, model, criterion, optimizer, epoch)

        if args.scheduler == 'CosineAnnealingLR':
            scheduler.step()

        print('loss %.4f' % (train_loss))

        log['epoch'].append(epoch)
        log['loss'].append(train_loss)

        pd.DataFrame(log).to_csv('models/%s/log.csv' % args.name, index=False)

        torch.save(model.state_dict(), 'models/%s/model.pth' % args.name)
        print("=> saved model")

        state = {
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }
        torch.save(state, 'models/%s/checkpoint.pth.tar' % args.name)
def main():
    config = vars(parse_args())

    if config['name'] is None:
        config['name'] = '%s_%s' % (config['arch'],
                                    datetime.now().strftime('%m%d%H'))

    config['num_filters'] = [int(n) for n in config['num_filters'].split(',')]

    if not os.path.exists('models/detection/%s' % config['name']):
        os.makedirs('models/detection/%s' % config['name'])

    if config['resume']:
        with open('models/detection/%s/config.yml' % config['name'], 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        config['resume'] = True

    with open('models/detection/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    print('-' * 20)
    for key in config.keys():
        print('- %s: %s' % (key, str(config[key])))
    print('-' * 20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_paths = np.array('inputs/train_images/' + df['ImageId'].values +
                         '.jpg')
    mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values +
                          '.jpg')
    labels = np.array(
        [convert_str_to_labels(s) for s in df['PredictionString']])

    test_img_paths = None
    test_mask_paths = None
    test_outputs = None
    if config['pseudo_label'] is not None:
        test_df = pd.read_csv('inputs/sample_submission.csv')
        test_img_paths = np.array('inputs/test_images/' +
                                  test_df['ImageId'].values + '.jpg')
        test_mask_paths = np.array('inputs/test_masks/' +
                                   test_df['ImageId'].values + '.jpg')
        ext = os.path.splitext(config['pseudo_label'])[1]
        if ext == '.pth':
            test_outputs = torch.load('outputs/raw/test/%s' %
                                      config['pseudo_label'])
        elif ext == '.csv':
            test_labels = pd.read_csv('outputs/submissions/test/%s' %
                                      config['pseudo_label'])
            null_idx = test_labels.isnull().any(axis=1)
            test_img_paths = test_img_paths[~null_idx]
            test_mask_paths = test_mask_paths[~null_idx]
            test_labels = test_labels.dropna()
            test_labels = np.array([
                convert_str_to_labels(
                    s, names=['pitch', 'yaw', 'roll', 'x', 'y', 'z', 'score'])
                for s in test_labels['PredictionString']
            ])
            print(test_labels)
        else:
            raise NotImplementedError

    if config['resume']:
        checkpoint = torch.load('models/detection/%s/checkpoint.pth.tar' %
                                config['name'])

    heads = OrderedDict([
        ('hm', 1),
        ('reg', 2),
        ('depth', 1),
    ])

    if config['rot'] == 'eular':
        heads['eular'] = 3
    elif config['rot'] == 'trig':
        heads['trig'] = 6
    elif config['rot'] == 'quat':
        heads['quat'] = 4
    else:
        raise NotImplementedError

    if config['wh']:
        heads['wh'] = 2

    criterion = OrderedDict()
    for head in heads.keys():
        criterion[head] = losses.__dict__[config[head + '_loss']]().cuda()

    train_transform = Compose([
        transforms.ShiftScaleRotate(shift_limit=config['shift_limit'],
                                    scale_limit=0,
                                    rotate_limit=0,
                                    border_mode=cv2.BORDER_CONSTANT,
                                    value=0,
                                    p=config['shift_p'])
        if config['shift'] else NoOp(),
        OneOf([
            transforms.HueSaturationValue(hue_shift_limit=config['hue_limit'],
                                          sat_shift_limit=config['sat_limit'],
                                          val_shift_limit=config['val_limit'],
                                          p=config['hsv_p'])
            if config['hsv'] else NoOp(),
            transforms.RandomBrightness(
                limit=config['brightness_limit'],
                p=config['brightness_p'],
            ) if config['brightness'] else NoOp(),
            transforms.RandomContrast(
                limit=config['contrast_limit'],
                p=config['contrast_p'],
            ) if config['contrast'] else NoOp(),
        ],
              p=1),
        transforms.ISONoise(p=config['iso_noise_p'], )
        if config['iso_noise'] else NoOp(),
        transforms.CLAHE(p=config['clahe_p'], ) if config['clahe'] else NoOp(),
    ],
                              keypoint_params=KeypointParams(
                                  format='xy', remove_invisible=False))

    val_transform = None

    folds = []
    best_losses = []
    # best_scores = []

    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)):
        print('Fold [%d/%d]' % (fold + 1, config['n_splits']))

        if (config['resume'] and fold < checkpoint['fold'] - 1) or (
                not config['resume']
                and os.path.exists('models/%s/model_%d.pth' %
                                   (config['name'], fold + 1))):
            log = pd.read_csv('models/detection/%s/log_%d.csv' %
                              (config['name'], fold + 1))
            best_loss = log.loc[log['val_loss'].values.argmin(), 'val_loss']
            # best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values
            folds.append(str(fold + 1))
            best_losses.append(best_loss)
            # best_scores.append(best_score)
            continue

        train_img_paths, val_img_paths = img_paths[train_idx], img_paths[
            val_idx]
        train_mask_paths, val_mask_paths = mask_paths[train_idx], mask_paths[
            val_idx]
        train_labels, val_labels = labels[train_idx], labels[val_idx]

        if config['pseudo_label'] is not None:
            train_img_paths = np.hstack((train_img_paths, test_img_paths))
            train_mask_paths = np.hstack((train_mask_paths, test_mask_paths))
            train_labels = np.hstack((train_labels, test_labels))

        # train
        train_set = Dataset(
            train_img_paths,
            train_mask_paths,
            train_labels,
            input_w=config['input_w'],
            input_h=config['input_h'],
            transform=train_transform,
            lhalf=config['lhalf'],
            hflip=config['hflip_p'] if config['hflip'] else 0,
            scale=config['scale_p'] if config['scale'] else 0,
            scale_limit=config['scale_limit'],
            # test_img_paths=test_img_paths,
            # test_mask_paths=test_mask_paths,
            # test_outputs=test_outputs,
        )
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        val_set = Dataset(val_img_paths,
                          val_mask_paths,
                          val_labels,
                          input_w=config['input_w'],
                          input_h=config['input_h'],
                          transform=val_transform,
                          lhalf=config['lhalf'])
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        # create model
        model = get_model(config['arch'],
                          heads=heads,
                          head_conv=config['head_conv'],
                          num_filters=config['num_filters'],
                          dcn=config['dcn'],
                          gn=config['gn'],
                          ws=config['ws'],
                          freeze_bn=config['freeze_bn'])
        model = model.cuda()

        if config['load_model'] is not None:
            model.load_state_dict(
                torch.load('models/detection/%s/model_%d.pth' %
                           (config['load_model'], fold + 1)))

        params = filter(lambda p: p.requires_grad, model.parameters())
        if config['optimizer'] == 'Adam':
            optimizer = optim.Adam(params,
                                   lr=config['lr'],
                                   weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'AdamW':
            optimizer = optim.AdamW(params,
                                    lr=config['lr'],
                                    weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'RAdam':
            optimizer = RAdam(params,
                              lr=config['lr'],
                              weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'SGD':
            optimizer = optim.SGD(params,
                                  lr=config['lr'],
                                  momentum=config['momentum'],
                                  nesterov=config['nesterov'],
                                  weight_decay=config['weight_decay'])
        else:
            raise NotImplementedError

        if config['apex']:
            amp.initialize(model, optimizer, opt_level='O1')

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(
                optimizer,
                factor=config['factor'],
                patience=config['patience'],
                verbose=1,
                min_lr=config['min_lr'])
        elif config['scheduler'] == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(
                optimizer,
                milestones=[int(e) for e in config['milestones'].split(',')],
                gamma=config['gamma'])
        else:
            raise NotImplementedError

        log = {
            'epoch': [],
            'loss': [],
            # 'score': [],
            'val_loss': [],
            # 'val_score': [],
        }

        best_loss = float('inf')
        # best_score = float('inf')

        start_epoch = 0

        if config['resume'] and fold == checkpoint['fold'] - 1:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            start_epoch = checkpoint['epoch']
            log = pd.read_csv(
                'models/detection/%s/log_%d.csv' %
                (config['name'], fold + 1)).to_dict(orient='list')
            best_loss = checkpoint['best_loss']

        for epoch in range(start_epoch, config['epochs']):
            print('Epoch [%d/%d]' % (epoch + 1, config['epochs']))

            # train for one epoch
            train_loss = train(config, heads, train_loader, model, criterion,
                               optimizer, epoch)
            # evaluate on validation set
            val_loss = validate(config, heads, val_loader, model, criterion)

            if config['scheduler'] == 'CosineAnnealingLR':
                scheduler.step()
            elif config['scheduler'] == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            print('loss %.4f - val_loss %.4f' % (train_loss, val_loss))
            # print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f'
            #       % (train_loss, train_score, val_loss, val_score))

            log['epoch'].append(epoch)
            log['loss'].append(train_loss)
            # log['score'].append(train_score)
            log['val_loss'].append(val_loss)
            # log['val_score'].append(val_score)

            pd.DataFrame(log).to_csv('models/detection/%s/log_%d.csv' %
                                     (config['name'], fold + 1),
                                     index=False)

            if val_loss < best_loss:
                torch.save(
                    model.state_dict(), 'models/detection/%s/model_%d.pth' %
                    (config['name'], fold + 1))
                best_loss = val_loss
                # best_score = val_score
                print("=> saved best model")

            state = {
                'fold': fold + 1,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            torch.save(
                state,
                'models/detection/%s/checkpoint.pth.tar' % config['name'])

        print('val_loss:  %f' % best_loss)
        # print('val_score: %f' % best_score)

        folds.append(str(fold + 1))
        best_losses.append(best_loss)
        # best_scores.append(best_score)

        results = pd.DataFrame({
            'fold':
            folds + ['mean'],
            'best_loss':
            best_losses + [np.mean(best_losses)],
            # 'best_score': best_scores + [np.mean(best_scores)],
        })

        print(results)
        results.to_csv('models/detection/%s/results.csv' % config['name'],
                       index=False)

        del model
        torch.cuda.empty_cache()

        del train_set, train_loader
        del val_set, val_loader
        gc.collect()

        if not config['cv']:
            break
Beispiel #9
0
def main():
    config = vars(parse_args())

    if config['name'] is None:
        config['name'] = '%s_%s' % (config['arch'], datetime.now().strftime('%m%d%H'))

    if not os.path.exists('models/pose/%s' % config['name']):
        os.makedirs('models/pose/%s' % config['name'])

    if config['resume']:
        with open('models/pose/%s/config.yml' % config['name'], 'r') as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        config['resume'] = True

    with open('models/pose/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    print('-'*20)
    for key in config.keys():
        print('- %s: %s' % (key, str(config[key])))
    print('-'*20)

    cudnn.benchmark = True

    df = pd.read_csv('inputs/train.csv')
    img_ids = df['ImageId'].values
    pose_df = pd.read_csv('processed/pose_train.csv')
    pose_df['img_path'] = 'processed/pose_images/train/' + pose_df['img_path']

    if config['resume']:
        checkpoint = torch.load('models/pose/%s/checkpoint.pth.tar' % config['name'])

    if config['rot'] == 'eular':
        num_outputs = 3
    elif config['rot'] == 'trig':
        num_outputs = 6
    elif config['rot'] == 'quat':
        num_outputs = 4
    else:
        raise NotImplementedError

    if config['loss'] == 'L1Loss':
        criterion = nn.L1Loss().cuda()
    elif config['loss'] == 'MSELoss':
        criterion = nn.MSELoss().cuda()
    else:
        raise NotImplementedError

    train_transform = Compose([
        transforms.ShiftScaleRotate(
            shift_limit=config['shift_limit'],
            scale_limit=0,
            rotate_limit=0,
            border_mode=cv2.BORDER_CONSTANT,
            value=0,
            p=config['shift_p']
        ) if config['shift'] else NoOp(),
        OneOf([
            transforms.HueSaturationValue(
                hue_shift_limit=config['hue_limit'],
                sat_shift_limit=config['sat_limit'],
                val_shift_limit=config['val_limit'],
                p=config['hsv_p']
            ) if config['hsv'] else NoOp(),
            transforms.RandomBrightness(
                limit=config['brightness_limit'],
                p=config['brightness_p'],
            ) if config['brightness'] else NoOp(),
            transforms.RandomContrast(
                limit=config['contrast_limit'],
                p=config['contrast_p'],
            ) if config['contrast'] else NoOp(),
        ], p=1),
        transforms.ISONoise(
            p=config['iso_noise_p'],
        ) if config['iso_noise'] else NoOp(),
        transforms.CLAHE(
            p=config['clahe_p'],
        ) if config['clahe'] else NoOp(),
        transforms.Resize(config['input_w'], config['input_h']),
        transforms.Normalize(),
        ToTensor(),
    ])

    val_transform = Compose([
        transforms.Resize(config['input_w'], config['input_h']),
        transforms.Normalize(),
        ToTensor(),
    ])

    folds = []
    best_losses = []

    kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41)
    for fold, (train_idx, val_idx) in enumerate(kf.split(img_ids)):
        print('Fold [%d/%d]' %(fold + 1, config['n_splits']))

        if (config['resume'] and fold < checkpoint['fold'] - 1) or (not config['resume'] and os.path.exists('pose_models/%s/model_%d.pth' % (config['name'], fold+1))):
            log = pd.read_csv('models/pose/%s/log_%d.csv' %(config['name'], fold+1))
            best_loss = log.loc[log['val_loss'].values.argmin(), 'val_loss']
            # best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values
            folds.append(str(fold + 1))
            best_losses.append(best_loss)
            # best_scores.append(best_score)
            continue

        train_img_ids, val_img_ids = img_ids[train_idx], img_ids[val_idx]

        train_img_paths = []
        train_labels = []
        for img_id in train_img_ids:
            tmp = pose_df.loc[pose_df.ImageId == img_id]

            img_path = tmp['img_path'].values
            train_img_paths.append(img_path)

            yaw = tmp['yaw'].values
            pitch = tmp['pitch'].values
            roll = tmp['roll'].values
            roll = rotate(roll, np.pi)

            if config['rot'] == 'eular':
                label = np.array([
                    yaw,
                    pitch,
                    roll
                ]).T
            elif config['rot'] == 'trig':
                label = np.array([
                    np.cos(yaw),
                    np.sin(yaw),
                    np.cos(pitch),
                    np.sin(pitch),
                    np.cos(roll),
                    np.sin(roll),
                ]).T
            elif config['rot'] == 'quat':
                raise NotImplementedError
            else:
                raise NotImplementedError

            train_labels.append(label)
        train_img_paths = np.hstack(train_img_paths)
        train_labels = np.vstack(train_labels)

        val_img_paths = []
        val_labels = []
        for img_id in val_img_ids:
            tmp = pose_df.loc[pose_df.ImageId == img_id]

            img_path = tmp['img_path'].values
            val_img_paths.append(img_path)

            yaw = tmp['yaw'].values
            pitch = tmp['pitch'].values
            roll = tmp['roll'].values
            roll = rotate(roll, np.pi)

            if config['rot'] == 'eular':
                label = np.array([
                    yaw,
                    pitch,
                    roll
                ]).T
            elif config['rot'] == 'trig':
                label = np.array([
                    np.cos(yaw),
                    np.sin(yaw),
                    np.cos(pitch),
                    np.sin(pitch),
                    np.cos(roll),
                    np.sin(roll),
                ]).T
            elif config['rot'] == 'quat':
                raise NotImplementedError
            else:
                raise NotImplementedError

            val_labels.append(label)
        val_img_paths = np.hstack(val_img_paths)
        val_labels = np.vstack(val_labels)

        # train
        train_set = PoseDataset(
            train_img_paths,
            train_labels,
            transform=train_transform,
        )
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=config['batch_size'],
            shuffle=True,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        val_set = PoseDataset(
            val_img_paths,
            val_labels,
            transform=val_transform,
        )
        val_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=config['batch_size'],
            shuffle=False,
            num_workers=config['num_workers'],
            # pin_memory=True,
        )

        # create model
        model = get_pose_model(config['arch'],
                          num_outputs=num_outputs,
                          freeze_bn=config['freeze_bn'])
        model = model.cuda()

        params = filter(lambda p: p.requires_grad, model.parameters())
        if config['optimizer'] == 'Adam':
            optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'AdamW':
            optimizer = optim.AdamW(params, lr=config['lr'], weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'RAdam':
            optimizer = RAdam(params, lr=config['lr'], weight_decay=config['weight_decay'])
        elif config['optimizer'] == 'SGD':
            optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                                  nesterov=config['nesterov'], weight_decay=config['weight_decay'])
        else:
            raise NotImplementedError

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
                                                       verbose=1, min_lr=config['min_lr'])
        elif config['scheduler'] == 'MultiStepLR':
            scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
        else:
            raise NotImplementedError

        log = {
            'epoch': [],
            'loss': [],
            # 'score': [],
            'val_loss': [],
            # 'val_score': [],
        }

        best_loss = float('inf')
        # best_score = float('inf')

        start_epoch = 0

        if config['resume'] and fold == checkpoint['fold'] - 1:
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            start_epoch = checkpoint['epoch']
            log = pd.read_csv('models/pose/%s/log_%d.csv' % (config['name'], fold+1)).to_dict(orient='list')
            best_loss = checkpoint['best_loss']

        for epoch in range(start_epoch, config['epochs']):
            print('Epoch [%d/%d]' % (epoch + 1, config['epochs']))

            # train for one epoch
            train_loss = train(config, train_loader, model, criterion, optimizer, epoch)
            # evaluate on validation set
            val_loss = validate(config, val_loader, model, criterion)

            if config['scheduler'] == 'CosineAnnealingLR':
                scheduler.step()
            elif config['scheduler'] == 'ReduceLROnPlateau':
                scheduler.step(val_loss)

            print('loss %.4f - val_loss %.4f' % (train_loss, val_loss))
            # print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f'
            #       % (train_loss, train_score, val_loss, val_score))

            log['epoch'].append(epoch)
            log['loss'].append(train_loss)
            # log['score'].append(train_score)
            log['val_loss'].append(val_loss)
            # log['val_score'].append(val_score)

            pd.DataFrame(log).to_csv('models/pose/%s/log_%d.csv' % (config['name'], fold+1), index=False)

            if val_loss < best_loss:
                torch.save(model.state_dict(), 'models/pose/%s/model_%d.pth' % (config['name'], fold+1))
                best_loss = val_loss
                # best_score = val_score
                print("=> saved best model")

            state = {
                'fold': fold + 1,
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
            }
            torch.save(state, 'models/pose/%s/checkpoint.pth.tar' % config['name'])

        print('val_loss:  %f' % best_loss)
        # print('val_score: %f' % best_score)

        folds.append(str(fold + 1))
        best_losses.append(best_loss)
        # best_scores.append(best_score)

        results = pd.DataFrame({
            'fold': folds + ['mean'],
            'best_loss': best_losses + [np.mean(best_losses)],
            # 'best_score': best_scores + [np.mean(best_scores)],
        })

        print(results)
        results.to_csv('models/pose/%s/results.csv' % config['name'], index=False)

        del model
        torch.cuda.empty_cache()

        del train_set, train_loader
        del val_set, val_loader
        gc.collect()

        if not config['cv']:
            break
Beispiel #10
0
    def get_transforms(phase: str, cli_args) -> Dict[str, Compose]:
        """Get composed albumentations augmentations

        Parameters
        ----------
        phase : str
            Phase of learning
            In ['train', 'val']
        cli_args
            Arguments coming all the way from `main.py`

        Returns
        -------
        transforms: dict[str, albumentations.core.composition.Compose]
            Composed list of transforms
        """
        aug_transforms = []
        im_sz = (cli_args.image_size, cli_args.image_size)

        if phase == "train":
            # Data augmentation for training only
            aug_transforms.extend([
                tf.ShiftScaleRotate(
                    shift_limit=0,
                    scale_limit=0.1,
                    rotate_limit=15,
                    p=0.5),
                tf.Flip(p=0.5),
                tf.RandomRotate90(p=0.5),
            ])
            # Exotic Augmentations for train only 🤤
            aug_transforms.extend([
                tf.RandomBrightnessContrast(p=0.5),
                tf.ElasticTransform(p=0.5),
                tf.MultiplicativeNoise(multiplier=(0.5, 1.5),
                                       per_channel=True, p=0.2),
            ])
        aug_transforms.extend([
            tf.RandomSizedCrop(min_max_height=im_sz,
                               height=im_sz[0],
                               width=im_sz[1],
                               w2h_ratio=1.0,
                               interpolation=cv2.INTER_LINEAR,
                               p=1.0),
        ])
        aug_transforms = Compose(aug_transforms)

        mask_only_transforms = Compose([
            tf.Normalize(mean=0, std=1, always_apply=True)
        ])
        image_only_transforms = Compose([
            tf.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0),
                         always_apply=True)
        ])
        final_transforms = Compose([
            ToTensorV2()
        ])

        transforms = {
            'aug': aug_transforms,
            'img_only': image_only_transforms,
            'mask_only': mask_only_transforms,
            'final': final_transforms
        }
        return transforms
Beispiel #11
0
def main_func(train_idx, val_set, test_set, modelName, fileName):
    config = vars(parse_args())
    config['name'] = modelName
    fw = open('batch_results_train/' + fileName, 'w')
    print('config of dataset is ' + str(config['dataset']))
    fw.write('config of dataset is ' + str(config['dataset']) + '\n')
    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    os.makedirs('models/%s' % config['name'], exist_ok=True)

    print('-' * 20)
    fw.write('-' * 20 + '\n')
    for key in config:
        print('%s: %s' % (key, config[key]))
        fw.write('%s: %s' % (key, config[key]) + '\n')
    print('-' * 20)
    fw.write('-' * 20 + '\n')
    #TODO print parameters manually i think, all imports to function

    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    fw.write("=> creating model %s" % config['arch'] + '\n')
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(params,
                               lr=config['lr'],
                               weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params,
                              lr=config['lr'],
                              momentum=config['momentum'],
                              nesterov=config['nesterov'],
                              weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=config['epochs'],
                                                   eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   factor=config['factor'],
                                                   patience=config['patience'],
                                                   verbose=1,
                                                   min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(
            optimizer,
            milestones=[int(e) for e in config['milestones'].split(',')],
            gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    img_ids = glob(
        os.path.join('inputs', config['dataset'], 'images',
                     '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    #train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41
    val_idx = [val_set, val_set + 1]
    #train_idx = [2, 3, 6, 7]
    val_img_ids = []
    train_img_ids = []

    for image in img_ids:
        im_begin = image.split('.')[0]
        if int(im_begin[-1]) in val_idx:
            val_img_ids.append(image)
        elif int(im_begin[-1]) in train_idx:
            train_img_ids.append(image)
    #print("train img ids size is " + str(len(train_img_ids)))
    '''train_transform = Compose([
        transforms.RandomRotate90(),
        transforms.Flip(),
        OneOf([
            transforms.HueSaturationValue(),
            transforms.RandomBrightness(),
            transforms.RandomContrast(),
        ], p=1),
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    '''
    train_transform = Compose([
        #transforms.RandomRotate90(),
        #transforms.Flip(),
        #OneOf([
        #    transforms.HueSaturationValue(),
        #    transforms.RandomBrightness(),
        #    transforms.RandomContrast(),
        #], p=1),
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    train_transform2 = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
        transforms.ShiftScaleRotate(
            shift_limit=0.1, scale_limit=0, rotate_limit=0
        ),  # shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1),
    ])

    val_transform2 = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
        #transforms.RandomAffine(degrees = 0, translate = (10, 10)),
        transforms.ShiftScaleRotate(
            shift_limit=0.1, scale_limit=0, rotate_limit=0
        ),  # shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1), ##TODO remove from validation
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])

    train_dataset = Dataset(img_ids=train_img_ids,
                            img_dir=os.path.join('inputs', config['dataset'],
                                                 'images'),
                            mask_dir=os.path.join('inputs', config['dataset'],
                                                  'masks'),
                            img_ext=config['img_ext'],
                            mask_ext=config['mask_ext'],
                            num_classes=config['num_classes'],
                            transform=train_transform2)
    val_dataset = Dataset(img_ids=val_img_ids,
                          img_dir=os.path.join('inputs', config['dataset'],
                                               'images'),
                          mask_dir=os.path.join('inputs', config['dataset'],
                                                'masks'),
                          img_ext=config['img_ext'],
                          mask_ext=config['mask_ext'],
                          num_classes=config['num_classes'],
                          transform=val_transform2)

    #print("length of train dataset is " + str(len(train_dataset)))
    #print("length of val dataset is " + str(len(val_dataset)))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'],
                                             drop_last=False)

    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
        ('dice', []),
    ])

    best_iou = 0
    trigger = 0
    best_dice = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))
        fw.write('Epoch [%d/%d]' % (epoch, config['epochs']) + '\n')

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' %
              (train_log['loss'], train_log['iou'], val_log['loss'],
               val_log['iou'], val_log['dice']))
        fw.write(
            'loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' %
            (train_log['loss'], train_log['iou'], val_log['loss'],
             val_log['iou'], val_log['dice']) + '\n')

        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'],
                                 index=False)

        trigger += 1
        '''
        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), 'models/%s/model.pth' %
                       config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0
        '''
        if val_log['dice'] > best_dice:
            torch.save(model.state_dict(),
                       'models/%s/model.pth' % config['name'])
            best_dice = val_log['dice']
            print("=> saved best model")
            fw.write("=> saved best model" + '\n')
            trigger = 0

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config[
                'early_stopping']:
            print("=> early stopping")
            fw.write("=> early stopping" + '\n')
            break

        torch.cuda.empty_cache()
Beispiel #12
0
def main_func(train_idx, val_set, test_set, modelName, fileName):
    '''
        params: train_idx, val_set, test_set => patient ids in train, val and test set.
                modelName, fileName => modelname for model directory storing models, 
                configurations and filename to store results, both generated as per 
                patient indices in train, test and val set. (For identification later)
        New model trained, tested and stored in corresponding modelName and fileName files.
        No objects returned.
    '''
    
    # Read configurations and create model directory
    config = vars(parse_args())
    config['name'] = modelName
    fw = open('batch_results_train/'+ fileName, 'w')
    print('config of dataset is ' + str(config['dataset']))
    fw.write('config of dataset is ' + str(config['dataset']) + '\n')    
    if config['name'] is None:
        if config['deep_supervision']:
            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
        else:
            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
    os.makedirs('models/%s' % config['name'], exist_ok=True)

    print('-' * 20)
    fw.write('-' * 20 + '\n')
    for key in config:
        print('%s: %s' % (key, config[key]))
        fw.write('%s: %s' % (key, config[key]) + '\n')
    print('-' * 20)
    fw.write('-' * 20 + '\n')

    with open('models/%s/config.yml' % config['name'], 'w') as f:
        yaml.dump(config, f)

    # define loss function (criterion)
    if config['loss'] == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss().cuda()
    else:
        criterion = losses.__dict__[config['loss']]().cuda()

    cudnn.benchmark = True

    # create model
    print("=> creating model %s" % config['arch'])
    fw.write("=> creating model %s" % config['arch'] + '\n')   
    model = archs.__dict__[config['arch']](config['num_classes'],
                                           config['input_channels'],
                                           config['deep_supervision'])

    model = model.cuda()

    params = filter(lambda p: p.requires_grad, model.parameters())
    if config['optimizer'] == 'Adam':
        optimizer = optim.Adam(
            params, lr=config['lr'], weight_decay=config['weight_decay'])
    elif config['optimizer'] == 'SGD':
        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                              nesterov=config['nesterov'], weight_decay=config['weight_decay'])
    else:
        raise NotImplementedError

    if config['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    elif config['scheduler'] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],
                                                   verbose=1, min_lr=config['min_lr'])
    elif config['scheduler'] == 'MultiStepLR':
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
    elif config['scheduler'] == 'ConstantLR':
        scheduler = None
    else:
        raise NotImplementedError

    # Data loading code
    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]

    # Patient IDs in validation set
    val_idx = [val_set]

    # Creating lists for noting images in train set or validation set. Iterate 
    # through all images and insert into train or validation set as per patient
    # found in the name of the image.
    val_img_ids = []
    train_img_ids = []

    for image in img_ids:
        im_begin = image.split('.')[0]
        if int(im_begin[-1]) in val_idx:
            val_img_ids.append(image)
        elif int(im_begin[-1]) in train_idx:
            train_img_ids.append(image)

    # Transformations that could be applied to the images.
    # Note: Same transformation must be applied to both train and validation set.
    train_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    train_transform2 = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    
    val_transform2 = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
        transforms.ShiftScaleRotate(shift_limit = 0.1, scale_limit = 0, rotate_limit = 0),# shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1), ##TODO remove from validation
    ])

    val_transform = Compose([
        transforms.Resize(config['input_h'], config['input_w']),
        transforms.Normalize(),
    ])
    
    # Creating PyTorch dataset object.
    train_dataset = Dataset(
        img_ids=train_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=train_transform2)
    val_dataset = Dataset(
        img_ids=val_img_ids,
        img_dir=os.path.join('inputs', config['dataset'], 'images'),
        mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
        img_ext=config['img_ext'],
        mask_ext=config['mask_ext'],
        num_classes=config['num_classes'],
        transform=val_transform)
    
    # creating the pytorch dataloader for train and validation sets.
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        drop_last=False)

    # Results dictionary
    log = OrderedDict([
        ('epoch', []),
        ('lr', []),
        ('loss', []),
        ('iou', []),
        ('val_loss', []),
        ('val_iou', []),
        ('dice', []),
    ])

    best_iou = 0
    trigger = 0
    best_dice = 0
    for epoch in range(config['epochs']):
        print('Epoch [%d/%d]' % (epoch, config['epochs']))
        fw.write('Epoch [%d/%d]' % (epoch, config['epochs']) + '\n')    

        # train for one epoch
        train_log = train(config, train_loader, model, criterion, optimizer)
        # evaluate on validation set
        val_log = validate(config, val_loader, model, criterion)

        if config['scheduler'] == 'CosineAnnealingLR':
            scheduler.step()
        elif config['scheduler'] == 'ReduceLROnPlateau':
            scheduler.step(val_log['loss'])

        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice']))
        fw.write('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f'
              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice']) + '\n')

        # Appending result to log dictionary
        log['epoch'].append(epoch)
        log['lr'].append(config['lr'])
        log['loss'].append(train_log['loss'])
        log['iou'].append(train_log['iou'])
        log['val_loss'].append(val_log['loss'])
        log['val_iou'].append(val_log['iou'])
        log['dice'].append(val_log['dice'])

        pd.DataFrame(log).to_csv('models/%s/log.csv' %
                                 config['name'], index=False)

        trigger += 1

        # Determine if new updated model gives best performance and accordingly save.
        # Multiple ways to determine better performance, dice score is used here, can also use IoU.
        if val_log['dice'] > best_dice:
            torch.save(model.state_dict(), 'models/%s/model.pth' %
                       config['name'])
            best_dice = val_log['dice']
            print("=> saved best model")
            fw.write("=> saved best model" + '\n')
            trigger = 0

        '''
        # can be used if best model picked using IOU
        if val_log['iou'] > best_iou:
            torch.save(model.state_dict(), 'models/%s/model.pth' %
                       config['name'])
            best_iou = val_log['iou']
            print("=> saved best model")
            trigger = 0
        '''

        # early stopping
        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
            print("=> early stopping")
            fw.write("=> early stopping" + '\n')
            break

        torch.cuda.empty_cache()