def augmentation(mode, target_size, prob=0.5, aug_m=2): ''' description: augmentation mode: 'train' 'test' target_size: int or list, the shape of image , aug_m: Strength of transform ''' high_p = prob low_p = high_p / 2.0 M = aug_m first_size = [int(x / 0.7) for x in target_size] if mode == 'train': return composition.Compose([ transforms.Resize(first_size[0], first_size[1], interpolation=3), transforms.Flip(p=0.5), composition.OneOf([ RandomCenterCut(scale=0.1 * M), transforms.ShiftScaleRotate(shift_limit=0.05 * M, scale_limit=0.1 * M, rotate_limit=180, border_mode=cv2.BORDER_CONSTANT, value=0), albumentations.imgaug.transforms.IAAAffine( shear=(-10 * M, 10 * M), mode='constant') ], p=high_p), transforms.RandomBrightnessContrast( brightness_limit=0.1 * M, contrast_limit=0.03 * M, p=high_p), transforms.HueSaturationValue(hue_shift_limit=5 * M, sat_shift_limit=15 * M, val_shift_limit=10 * M, p=high_p), transforms.OpticalDistortion(distort_limit=0.03 * M, shift_limit=0, border_mode=cv2.BORDER_CONSTANT, value=0, p=low_p), composition.OneOf([ transforms.Blur(blur_limit=7), albumentations.imgaug.transforms.IAASharpen(), transforms.GaussNoise(var_limit=(2.0, 10.0), mean=0), transforms.ISONoise() ], p=low_p), transforms.Resize(target_size[0], target_size[1], interpolation=3), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0) ], p=1) else: return composition.Compose([ transforms.Resize(target_size[0], target_size[1], interpolation=3), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0) ], p=1)
def ben_augmentation(): # !! Need to do something to change color _probably_ return ACompose([ atransforms.HorizontalFlip(p=0.5), atransforms.RandomRotate90(p=1.0), atransforms.ShiftScaleRotate(shift_limit=0, scale_limit=0, p=1.0), atransforms.RandomSizedCrop((60, 120), height=128, width=128, interpolation=3), # atransforms.GridDistortion(num_steps=5, p=0.5), # !! Maybe too much noise? atransforms.Normalize(mean=BEN_BAND_STATS['mean'], std=BEN_BAND_STATS['std']), # atransforms.ChannelDropout(channel_drop_range=(1, 2), p=0.5), AToTensor(), ])
def get_augmentations(): """Get a list of 'major' and 'minor' augmentation functions for the pipeline in a dictionary.""" return { "major": { "shift-scale-rot": trans.ShiftScaleRotate( shift_limit=0.05, rotate_limit=35, border_mode=cv2.BORDER_REPLICATE, always_apply=True, ), "crop": trans.RandomResizedCrop(100, 100, scale=(0.8, 0.95), ratio=(0.8, 1.2), always_apply=True), # "elastic": trans.ElasticTransform( # alpha=0.8, # alpha_affine=10, # sigma=40, # border_mode=cv2.BORDER_REPLICATE, # always_apply=True, # ), "distort": trans.OpticalDistortion(0.2, always_apply=True), }, "minor": { "blur": trans.GaussianBlur(7, always_apply=True), "noise": trans.GaussNoise((20.0, 40.0), always_apply=True), "bright-contrast": trans.RandomBrightnessContrast(0.4, 0.4, always_apply=True), "hsv": trans.HueSaturationValue(30, 40, 50, always_apply=True), "rgb": trans.RGBShift(always_apply=True), "flip": trans.HorizontalFlip(always_apply=True), }, }
def train_dataloader(self): augmentations = Compose( [ A.RandomResizedCrop( height=self.hparams.sz, width=self.hparams.sz, scale=(0.7, 1.0), ), # AdvancedHairAugmentation(), A.GridDistortion(), A.RandomBrightnessContrast(), A.ShiftScaleRotate(), A.Flip(p=0.5), A.CoarseDropout( max_height=int(self.hparams.sz / 10), max_width=int(self.hparams.sz / 10), ), # A.HueSaturationValue(), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255, ), ToTensorV2(), ] ) train_ds = MelanomaDataset( df=self.train_df, images_path=self.train_images_path, augmentations=augmentations, train_or_valid=True, ) return DataLoader( train_ds, # sampler=sampler, batch_size=self.hparams.bs, shuffle=True, num_workers=os.cpu_count(), pin_memory=True, )
def get_tta_transforms(): return Compose([ A.RandomResizedCrop( height=hparams.sz, width=hparams.sz, scale=(0.7, 1.0), ), # AdvancedHairAugmentation(), A.GridDistortion(), A.RandomBrightnessContrast(), A.ShiftScaleRotate(), A.Flip(p=0.5), A.CoarseDropout( max_height=int(hparams.sz / 10), max_width=int(hparams.sz / 10), ), # A.HueSaturationValue(), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255, ), ToTensorV2(), ])
def create_transform(self, args, is_train): """ Convert numpy array into Tensor if dataset is for validation. Apply data augmentation method to train dataset while cv or test if args.use_aug is 1. is_train: boolean flg that dataset is for validation in cv or test return: Compose of albumentations """ if is_train and args.use_aug == 1: transform = A.Compose([ trans.Normalize(mean=self.cifar_10_mean, std=self.cifar_10_std, max_pixel_value=1.0), trans.HorizontalFlip(p=0.5), trans.ShiftScaleRotate(shift_limit=0, scale_limit=0.25, rotate_limit=30, p=1), trans.CoarseDropout(max_holes=1, min_holes=1, min_width=12, min_height=12, max_height=12, max_width=12, p=0.5), ToTensorV2() ]) else: transform = A.Compose([ trans.Normalize(mean=self.cifar_10_mean, std=self.cifar_10_std, max_pixel_value=1.0), ToTensorV2() ]) return transform
def main(): args = parse_args() if args.name is None: args.name = '%s_%s' % (args.arch, datetime.now().strftime('%m%d%H')) if not os.path.exists('models/%s' % args.name): os.makedirs('models/%s' % args.name) if args.resume: args = joblib.load('models/%s/args.pkl' % args.name) args.resume = True print('Config -----') for arg in vars(args): print('- %s: %s' % (arg, getattr(args, arg))) print('------------') with open('models/%s/args.txt' % args.name, 'w') as f: for arg in vars(args): print('- %s: %s' % (arg, getattr(args, arg)), file=f) joblib.dump(args, 'models/%s/args.pkl' % args.name) if args.seed is not None and not args.resume: print('set random seed') random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if args.loss == 'BCEWithLogitsLoss': criterion = BCEWithLogitsLoss().cuda() elif args.loss == 'WeightedBCEWithLogitsLoss': criterion = BCEWithLogitsLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.]), smooth=args.label_smooth).cuda() elif args.loss == 'FocalLoss': criterion = FocalLoss().cuda() elif args.loss == 'WeightedFocalLoss': criterion = FocalLoss(weight=torch.Tensor([1., 1., 1., 1., 1., 2.])).cuda() else: raise NotImplementedError if args.pred_type == 'all': num_outputs = 6 elif args.pred_type == 'except_any': num_outputs = 5 else: raise NotImplementedError cudnn.benchmark = True # create model model = get_model(model_name=args.arch, num_outputs=num_outputs, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, pooling=args.pooling, lp_p=args.lp_p) model = model.cuda() train_transform = Compose([ transforms.Resize(args.img_size, args.img_size), transforms.HorizontalFlip() if args.hflip else NoOp(), transforms.VerticalFlip() if args.vflip else NoOp(), transforms.ShiftScaleRotate( shift_limit=args.shift_limit, scale_limit=args.scale_limit, rotate_limit=args.rotate_limit, border_mode=cv2.BORDER_CONSTANT, value=0, p=args.shift_scale_rotate_p ) if args.shift_scale_rotate else NoOp(), transforms.RandomContrast( limit=args.contrast_limit, p=args.contrast_p ) if args.contrast else NoOp(), RandomErase() if args.random_erase else NoOp(), transforms.CenterCrop(args.crop_size, args.crop_size) if args.center_crop else NoOp(), ForegroundCenterCrop(args.crop_size) if args.foreground_center_crop else NoOp(), transforms.RandomCrop(args.crop_size, args.crop_size) if args.random_crop else NoOp(), transforms.Normalize(mean=model.mean, std=model.std), ToTensor(), ]) if args.img_type: stage_1_train_dir = 'processed/stage_1_train_%s' %args.img_type else: stage_1_train_dir = 'processed/stage_1_train' df = pd.read_csv('inputs/stage_1_train.csv') img_paths = np.array([stage_1_train_dir + '/' + '_'.join(s.split('_')[:-1]) + '.png' for s in df['ID']][::6]) labels = np.array([df.loc[c::6, 'Label'].values for c in range(6)]).T.astype('float32') df = df[::6] df['img_path'] = img_paths for c in range(6): df['label_%d' %c] = labels[:, c] df['ID'] = df['ID'].apply(lambda s: '_'.join(s.split('_')[:-1])) meta_df = pd.read_csv('processed/stage_1_train_meta.csv') meta_df['ID'] = meta_df['SOPInstanceUID'] test_meta_df = pd.read_csv('processed/stage_1_test_meta.csv') df = pd.merge(df, meta_df, how='left') patient_ids = meta_df['PatientID'].unique() test_patient_ids = test_meta_df['PatientID'].unique() if args.remove_test_patient_ids: patient_ids = np.array([s for s in patient_ids if not s in test_patient_ids]) train_img_paths = np.hstack(df[['img_path', 'PatientID']].groupby(['PatientID'])['img_path'].apply(np.array).loc[patient_ids].to_list()).astype('str') train_labels = [] for c in range(6): train_labels.append(np.hstack(df[['label_%d' %c, 'PatientID']].groupby(['PatientID'])['label_%d' %c].apply(np.array).loc[patient_ids].to_list())) train_labels = np.array(train_labels).T if args.resume: checkpoint = torch.load('models/%s/checkpoint.pth.tar' % args.name) # train train_set = Dataset( train_img_paths, train_labels, transform=train_transform) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, # pin_memory=True, ) if args.optimizer == 'Adam': optimizer = optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'AdamW': optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'RAdam': optimizer = RAdam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'SGD': optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) else: raise NotImplementedError if args.apex: amp.initialize(model, optimizer, opt_level='O1') if args.scheduler == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=args.min_lr) elif args.scheduler == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in args.milestones.split(',')], gamma=args.gamma) else: raise NotImplementedError log = { 'epoch': [], 'loss': [], } start_epoch = 0 if args.resume: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) start_epoch = checkpoint['epoch'] log = pd.read_csv('models/%s/log.csv' % args.name).to_dict(orient='list') for epoch in range(start_epoch, args.epochs): print('Epoch [%d/%d]' % (epoch + 1, args.epochs)) # train for one epoch train_loss = train(args, train_loader, model, criterion, optimizer, epoch) if args.scheduler == 'CosineAnnealingLR': scheduler.step() print('loss %.4f' % (train_loss)) log['epoch'].append(epoch) log['loss'].append(train_loss) pd.DataFrame(log).to_csv('models/%s/log.csv' % args.name, index=False) torch.save(model.state_dict(), 'models/%s/model.pth' % args.name) print("=> saved model") state = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), } torch.save(state, 'models/%s/checkpoint.pth.tar' % args.name)
def main(): config = vars(parse_args()) if config['name'] is None: config['name'] = '%s_%s' % (config['arch'], datetime.now().strftime('%m%d%H')) config['num_filters'] = [int(n) for n in config['num_filters'].split(',')] if not os.path.exists('models/detection/%s' % config['name']): os.makedirs('models/detection/%s' % config['name']) if config['resume']: with open('models/detection/%s/config.yml' % config['name'], 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) config['resume'] = True with open('models/detection/%s/config.yml' % config['name'], 'w') as f: yaml.dump(config, f) print('-' * 20) for key in config.keys(): print('- %s: %s' % (key, str(config[key]))) print('-' * 20) cudnn.benchmark = True df = pd.read_csv('inputs/train.csv') img_paths = np.array('inputs/train_images/' + df['ImageId'].values + '.jpg') mask_paths = np.array('inputs/train_masks/' + df['ImageId'].values + '.jpg') labels = np.array( [convert_str_to_labels(s) for s in df['PredictionString']]) test_img_paths = None test_mask_paths = None test_outputs = None if config['pseudo_label'] is not None: test_df = pd.read_csv('inputs/sample_submission.csv') test_img_paths = np.array('inputs/test_images/' + test_df['ImageId'].values + '.jpg') test_mask_paths = np.array('inputs/test_masks/' + test_df['ImageId'].values + '.jpg') ext = os.path.splitext(config['pseudo_label'])[1] if ext == '.pth': test_outputs = torch.load('outputs/raw/test/%s' % config['pseudo_label']) elif ext == '.csv': test_labels = pd.read_csv('outputs/submissions/test/%s' % config['pseudo_label']) null_idx = test_labels.isnull().any(axis=1) test_img_paths = test_img_paths[~null_idx] test_mask_paths = test_mask_paths[~null_idx] test_labels = test_labels.dropna() test_labels = np.array([ convert_str_to_labels( s, names=['pitch', 'yaw', 'roll', 'x', 'y', 'z', 'score']) for s in test_labels['PredictionString'] ]) print(test_labels) else: raise NotImplementedError if config['resume']: checkpoint = torch.load('models/detection/%s/checkpoint.pth.tar' % config['name']) heads = OrderedDict([ ('hm', 1), ('reg', 2), ('depth', 1), ]) if config['rot'] == 'eular': heads['eular'] = 3 elif config['rot'] == 'trig': heads['trig'] = 6 elif config['rot'] == 'quat': heads['quat'] = 4 else: raise NotImplementedError if config['wh']: heads['wh'] = 2 criterion = OrderedDict() for head in heads.keys(): criterion[head] = losses.__dict__[config[head + '_loss']]().cuda() train_transform = Compose([ transforms.ShiftScaleRotate(shift_limit=config['shift_limit'], scale_limit=0, rotate_limit=0, border_mode=cv2.BORDER_CONSTANT, value=0, p=config['shift_p']) if config['shift'] else NoOp(), OneOf([ transforms.HueSaturationValue(hue_shift_limit=config['hue_limit'], sat_shift_limit=config['sat_limit'], val_shift_limit=config['val_limit'], p=config['hsv_p']) if config['hsv'] else NoOp(), transforms.RandomBrightness( limit=config['brightness_limit'], p=config['brightness_p'], ) if config['brightness'] else NoOp(), transforms.RandomContrast( limit=config['contrast_limit'], p=config['contrast_p'], ) if config['contrast'] else NoOp(), ], p=1), transforms.ISONoise(p=config['iso_noise_p'], ) if config['iso_noise'] else NoOp(), transforms.CLAHE(p=config['clahe_p'], ) if config['clahe'] else NoOp(), ], keypoint_params=KeypointParams( format='xy', remove_invisible=False)) val_transform = None folds = [] best_losses = [] # best_scores = [] kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41) for fold, (train_idx, val_idx) in enumerate(kf.split(img_paths)): print('Fold [%d/%d]' % (fold + 1, config['n_splits'])) if (config['resume'] and fold < checkpoint['fold'] - 1) or ( not config['resume'] and os.path.exists('models/%s/model_%d.pth' % (config['name'], fold + 1))): log = pd.read_csv('models/detection/%s/log_%d.csv' % (config['name'], fold + 1)) best_loss = log.loc[log['val_loss'].values.argmin(), 'val_loss'] # best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values folds.append(str(fold + 1)) best_losses.append(best_loss) # best_scores.append(best_score) continue train_img_paths, val_img_paths = img_paths[train_idx], img_paths[ val_idx] train_mask_paths, val_mask_paths = mask_paths[train_idx], mask_paths[ val_idx] train_labels, val_labels = labels[train_idx], labels[val_idx] if config['pseudo_label'] is not None: train_img_paths = np.hstack((train_img_paths, test_img_paths)) train_mask_paths = np.hstack((train_mask_paths, test_mask_paths)) train_labels = np.hstack((train_labels, test_labels)) # train train_set = Dataset( train_img_paths, train_mask_paths, train_labels, input_w=config['input_w'], input_h=config['input_h'], transform=train_transform, lhalf=config['lhalf'], hflip=config['hflip_p'] if config['hflip'] else 0, scale=config['scale_p'] if config['scale'] else 0, scale_limit=config['scale_limit'], # test_img_paths=test_img_paths, # test_mask_paths=test_mask_paths, # test_outputs=test_outputs, ) train_loader = torch.utils.data.DataLoader( train_set, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], # pin_memory=True, ) val_set = Dataset(val_img_paths, val_mask_paths, val_labels, input_w=config['input_w'], input_h=config['input_h'], transform=val_transform, lhalf=config['lhalf']) val_loader = torch.utils.data.DataLoader( val_set, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], # pin_memory=True, ) # create model model = get_model(config['arch'], heads=heads, head_conv=config['head_conv'], num_filters=config['num_filters'], dcn=config['dcn'], gn=config['gn'], ws=config['ws'], freeze_bn=config['freeze_bn']) model = model.cuda() if config['load_model'] is not None: model.load_state_dict( torch.load('models/detection/%s/model_%d.pth' % (config['load_model'], fold + 1))) params = filter(lambda p: p.requires_grad, model.parameters()) if config['optimizer'] == 'Adam': optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'AdamW': optimizer = optim.AdamW(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'RAdam': optimizer = RAdam(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'SGD': optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay']) else: raise NotImplementedError if config['apex']: amp.initialize(model, optimizer, opt_level='O1') if config['scheduler'] == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=config['epochs'], eta_min=config['min_lr']) elif config['scheduler'] == 'ReduceLROnPlateau': scheduler = lr_scheduler.ReduceLROnPlateau( optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr']) elif config['scheduler'] == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) else: raise NotImplementedError log = { 'epoch': [], 'loss': [], # 'score': [], 'val_loss': [], # 'val_score': [], } best_loss = float('inf') # best_score = float('inf') start_epoch = 0 if config['resume'] and fold == checkpoint['fold'] - 1: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) start_epoch = checkpoint['epoch'] log = pd.read_csv( 'models/detection/%s/log_%d.csv' % (config['name'], fold + 1)).to_dict(orient='list') best_loss = checkpoint['best_loss'] for epoch in range(start_epoch, config['epochs']): print('Epoch [%d/%d]' % (epoch + 1, config['epochs'])) # train for one epoch train_loss = train(config, heads, train_loader, model, criterion, optimizer, epoch) # evaluate on validation set val_loss = validate(config, heads, val_loader, model, criterion) if config['scheduler'] == 'CosineAnnealingLR': scheduler.step() elif config['scheduler'] == 'ReduceLROnPlateau': scheduler.step(val_loss) print('loss %.4f - val_loss %.4f' % (train_loss, val_loss)) # print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f' # % (train_loss, train_score, val_loss, val_score)) log['epoch'].append(epoch) log['loss'].append(train_loss) # log['score'].append(train_score) log['val_loss'].append(val_loss) # log['val_score'].append(val_score) pd.DataFrame(log).to_csv('models/detection/%s/log_%d.csv' % (config['name'], fold + 1), index=False) if val_loss < best_loss: torch.save( model.state_dict(), 'models/detection/%s/model_%d.pth' % (config['name'], fold + 1)) best_loss = val_loss # best_score = val_score print("=> saved best model") state = { 'fold': fold + 1, 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), } torch.save( state, 'models/detection/%s/checkpoint.pth.tar' % config['name']) print('val_loss: %f' % best_loss) # print('val_score: %f' % best_score) folds.append(str(fold + 1)) best_losses.append(best_loss) # best_scores.append(best_score) results = pd.DataFrame({ 'fold': folds + ['mean'], 'best_loss': best_losses + [np.mean(best_losses)], # 'best_score': best_scores + [np.mean(best_scores)], }) print(results) results.to_csv('models/detection/%s/results.csv' % config['name'], index=False) del model torch.cuda.empty_cache() del train_set, train_loader del val_set, val_loader gc.collect() if not config['cv']: break
def main(): config = vars(parse_args()) if config['name'] is None: config['name'] = '%s_%s' % (config['arch'], datetime.now().strftime('%m%d%H')) if not os.path.exists('models/pose/%s' % config['name']): os.makedirs('models/pose/%s' % config['name']) if config['resume']: with open('models/pose/%s/config.yml' % config['name'], 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) config['resume'] = True with open('models/pose/%s/config.yml' % config['name'], 'w') as f: yaml.dump(config, f) print('-'*20) for key in config.keys(): print('- %s: %s' % (key, str(config[key]))) print('-'*20) cudnn.benchmark = True df = pd.read_csv('inputs/train.csv') img_ids = df['ImageId'].values pose_df = pd.read_csv('processed/pose_train.csv') pose_df['img_path'] = 'processed/pose_images/train/' + pose_df['img_path'] if config['resume']: checkpoint = torch.load('models/pose/%s/checkpoint.pth.tar' % config['name']) if config['rot'] == 'eular': num_outputs = 3 elif config['rot'] == 'trig': num_outputs = 6 elif config['rot'] == 'quat': num_outputs = 4 else: raise NotImplementedError if config['loss'] == 'L1Loss': criterion = nn.L1Loss().cuda() elif config['loss'] == 'MSELoss': criterion = nn.MSELoss().cuda() else: raise NotImplementedError train_transform = Compose([ transforms.ShiftScaleRotate( shift_limit=config['shift_limit'], scale_limit=0, rotate_limit=0, border_mode=cv2.BORDER_CONSTANT, value=0, p=config['shift_p'] ) if config['shift'] else NoOp(), OneOf([ transforms.HueSaturationValue( hue_shift_limit=config['hue_limit'], sat_shift_limit=config['sat_limit'], val_shift_limit=config['val_limit'], p=config['hsv_p'] ) if config['hsv'] else NoOp(), transforms.RandomBrightness( limit=config['brightness_limit'], p=config['brightness_p'], ) if config['brightness'] else NoOp(), transforms.RandomContrast( limit=config['contrast_limit'], p=config['contrast_p'], ) if config['contrast'] else NoOp(), ], p=1), transforms.ISONoise( p=config['iso_noise_p'], ) if config['iso_noise'] else NoOp(), transforms.CLAHE( p=config['clahe_p'], ) if config['clahe'] else NoOp(), transforms.Resize(config['input_w'], config['input_h']), transforms.Normalize(), ToTensor(), ]) val_transform = Compose([ transforms.Resize(config['input_w'], config['input_h']), transforms.Normalize(), ToTensor(), ]) folds = [] best_losses = [] kf = KFold(n_splits=config['n_splits'], shuffle=True, random_state=41) for fold, (train_idx, val_idx) in enumerate(kf.split(img_ids)): print('Fold [%d/%d]' %(fold + 1, config['n_splits'])) if (config['resume'] and fold < checkpoint['fold'] - 1) or (not config['resume'] and os.path.exists('pose_models/%s/model_%d.pth' % (config['name'], fold+1))): log = pd.read_csv('models/pose/%s/log_%d.csv' %(config['name'], fold+1)) best_loss = log.loc[log['val_loss'].values.argmin(), 'val_loss'] # best_loss, best_score = log.loc[log['val_loss'].values.argmin(), ['val_loss', 'val_score']].values folds.append(str(fold + 1)) best_losses.append(best_loss) # best_scores.append(best_score) continue train_img_ids, val_img_ids = img_ids[train_idx], img_ids[val_idx] train_img_paths = [] train_labels = [] for img_id in train_img_ids: tmp = pose_df.loc[pose_df.ImageId == img_id] img_path = tmp['img_path'].values train_img_paths.append(img_path) yaw = tmp['yaw'].values pitch = tmp['pitch'].values roll = tmp['roll'].values roll = rotate(roll, np.pi) if config['rot'] == 'eular': label = np.array([ yaw, pitch, roll ]).T elif config['rot'] == 'trig': label = np.array([ np.cos(yaw), np.sin(yaw), np.cos(pitch), np.sin(pitch), np.cos(roll), np.sin(roll), ]).T elif config['rot'] == 'quat': raise NotImplementedError else: raise NotImplementedError train_labels.append(label) train_img_paths = np.hstack(train_img_paths) train_labels = np.vstack(train_labels) val_img_paths = [] val_labels = [] for img_id in val_img_ids: tmp = pose_df.loc[pose_df.ImageId == img_id] img_path = tmp['img_path'].values val_img_paths.append(img_path) yaw = tmp['yaw'].values pitch = tmp['pitch'].values roll = tmp['roll'].values roll = rotate(roll, np.pi) if config['rot'] == 'eular': label = np.array([ yaw, pitch, roll ]).T elif config['rot'] == 'trig': label = np.array([ np.cos(yaw), np.sin(yaw), np.cos(pitch), np.sin(pitch), np.cos(roll), np.sin(roll), ]).T elif config['rot'] == 'quat': raise NotImplementedError else: raise NotImplementedError val_labels.append(label) val_img_paths = np.hstack(val_img_paths) val_labels = np.vstack(val_labels) # train train_set = PoseDataset( train_img_paths, train_labels, transform=train_transform, ) train_loader = torch.utils.data.DataLoader( train_set, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], # pin_memory=True, ) val_set = PoseDataset( val_img_paths, val_labels, transform=val_transform, ) val_loader = torch.utils.data.DataLoader( val_set, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], # pin_memory=True, ) # create model model = get_pose_model(config['arch'], num_outputs=num_outputs, freeze_bn=config['freeze_bn']) model = model.cuda() params = filter(lambda p: p.requires_grad, model.parameters()) if config['optimizer'] == 'Adam': optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'AdamW': optimizer = optim.AdamW(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'RAdam': optimizer = RAdam(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'SGD': optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay']) else: raise NotImplementedError if config['scheduler'] == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr']) elif config['scheduler'] == 'ReduceLROnPlateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr']) elif config['scheduler'] == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) else: raise NotImplementedError log = { 'epoch': [], 'loss': [], # 'score': [], 'val_loss': [], # 'val_score': [], } best_loss = float('inf') # best_score = float('inf') start_epoch = 0 if config['resume'] and fold == checkpoint['fold'] - 1: model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) start_epoch = checkpoint['epoch'] log = pd.read_csv('models/pose/%s/log_%d.csv' % (config['name'], fold+1)).to_dict(orient='list') best_loss = checkpoint['best_loss'] for epoch in range(start_epoch, config['epochs']): print('Epoch [%d/%d]' % (epoch + 1, config['epochs'])) # train for one epoch train_loss = train(config, train_loader, model, criterion, optimizer, epoch) # evaluate on validation set val_loss = validate(config, val_loader, model, criterion) if config['scheduler'] == 'CosineAnnealingLR': scheduler.step() elif config['scheduler'] == 'ReduceLROnPlateau': scheduler.step(val_loss) print('loss %.4f - val_loss %.4f' % (train_loss, val_loss)) # print('loss %.4f - score %.4f - val_loss %.4f - val_score %.4f' # % (train_loss, train_score, val_loss, val_score)) log['epoch'].append(epoch) log['loss'].append(train_loss) # log['score'].append(train_score) log['val_loss'].append(val_loss) # log['val_score'].append(val_score) pd.DataFrame(log).to_csv('models/pose/%s/log_%d.csv' % (config['name'], fold+1), index=False) if val_loss < best_loss: torch.save(model.state_dict(), 'models/pose/%s/model_%d.pth' % (config['name'], fold+1)) best_loss = val_loss # best_score = val_score print("=> saved best model") state = { 'fold': fold + 1, 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_loss': best_loss, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), } torch.save(state, 'models/pose/%s/checkpoint.pth.tar' % config['name']) print('val_loss: %f' % best_loss) # print('val_score: %f' % best_score) folds.append(str(fold + 1)) best_losses.append(best_loss) # best_scores.append(best_score) results = pd.DataFrame({ 'fold': folds + ['mean'], 'best_loss': best_losses + [np.mean(best_losses)], # 'best_score': best_scores + [np.mean(best_scores)], }) print(results) results.to_csv('models/pose/%s/results.csv' % config['name'], index=False) del model torch.cuda.empty_cache() del train_set, train_loader del val_set, val_loader gc.collect() if not config['cv']: break
def get_transforms(phase: str, cli_args) -> Dict[str, Compose]: """Get composed albumentations augmentations Parameters ---------- phase : str Phase of learning In ['train', 'val'] cli_args Arguments coming all the way from `main.py` Returns ------- transforms: dict[str, albumentations.core.composition.Compose] Composed list of transforms """ aug_transforms = [] im_sz = (cli_args.image_size, cli_args.image_size) if phase == "train": # Data augmentation for training only aug_transforms.extend([ tf.ShiftScaleRotate( shift_limit=0, scale_limit=0.1, rotate_limit=15, p=0.5), tf.Flip(p=0.5), tf.RandomRotate90(p=0.5), ]) # Exotic Augmentations for train only 🤤 aug_transforms.extend([ tf.RandomBrightnessContrast(p=0.5), tf.ElasticTransform(p=0.5), tf.MultiplicativeNoise(multiplier=(0.5, 1.5), per_channel=True, p=0.2), ]) aug_transforms.extend([ tf.RandomSizedCrop(min_max_height=im_sz, height=im_sz[0], width=im_sz[1], w2h_ratio=1.0, interpolation=cv2.INTER_LINEAR, p=1.0), ]) aug_transforms = Compose(aug_transforms) mask_only_transforms = Compose([ tf.Normalize(mean=0, std=1, always_apply=True) ]) image_only_transforms = Compose([ tf.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0), always_apply=True) ]) final_transforms = Compose([ ToTensorV2() ]) transforms = { 'aug': aug_transforms, 'img_only': image_only_transforms, 'mask_only': mask_only_transforms, 'final': final_transforms } return transforms
def main_func(train_idx, val_set, test_set, modelName, fileName): config = vars(parse_args()) config['name'] = modelName fw = open('batch_results_train/' + fileName, 'w') print('config of dataset is ' + str(config['dataset'])) fw.write('config of dataset is ' + str(config['dataset']) + '\n') if config['name'] is None: if config['deep_supervision']: config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch']) else: config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch']) os.makedirs('models/%s' % config['name'], exist_ok=True) print('-' * 20) fw.write('-' * 20 + '\n') for key in config: print('%s: %s' % (key, config[key])) fw.write('%s: %s' % (key, config[key]) + '\n') print('-' * 20) fw.write('-' * 20 + '\n') #TODO print parameters manually i think, all imports to function with open('models/%s/config.yml' % config['name'], 'w') as f: yaml.dump(config, f) # define loss function (criterion) if config['loss'] == 'BCEWithLogitsLoss': criterion = nn.BCEWithLogitsLoss().cuda() else: criterion = losses.__dict__[config['loss']]().cuda() cudnn.benchmark = True # create model print("=> creating model %s" % config['arch']) fw.write("=> creating model %s" % config['arch'] + '\n') model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision']) model = model.cuda() params = filter(lambda p: p.requires_grad, model.parameters()) if config['optimizer'] == 'Adam': optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'SGD': optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay']) else: raise NotImplementedError if config['scheduler'] == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr']) elif config['scheduler'] == 'ReduceLROnPlateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr']) elif config['scheduler'] == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR( optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) elif config['scheduler'] == 'ConstantLR': scheduler = None else: raise NotImplementedError # Data loading code img_ids = glob( os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext'])) img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] #train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41 val_idx = [val_set, val_set + 1] #train_idx = [2, 3, 6, 7] val_img_ids = [] train_img_ids = [] for image in img_ids: im_begin = image.split('.')[0] if int(im_begin[-1]) in val_idx: val_img_ids.append(image) elif int(im_begin[-1]) in train_idx: train_img_ids.append(image) #print("train img ids size is " + str(len(train_img_ids))) '''train_transform = Compose([ transforms.RandomRotate90(), transforms.Flip(), OneOf([ transforms.HueSaturationValue(), transforms.RandomBrightness(), transforms.RandomContrast(), ], p=1), transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) val_transform = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) ''' train_transform = Compose([ #transforms.RandomRotate90(), #transforms.Flip(), #OneOf([ # transforms.HueSaturationValue(), # transforms.RandomBrightness(), # transforms.RandomContrast(), #], p=1), transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) train_transform2 = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), transforms.ShiftScaleRotate( shift_limit=0.1, scale_limit=0, rotate_limit=0 ), # shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1), ]) val_transform2 = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), #transforms.RandomAffine(degrees = 0, translate = (10, 10)), transforms.ShiftScaleRotate( shift_limit=0.1, scale_limit=0, rotate_limit=0 ), # shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1), ##TODO remove from validation ]) val_transform = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) train_dataset = Dataset(img_ids=train_img_ids, img_dir=os.path.join('inputs', config['dataset'], 'images'), mask_dir=os.path.join('inputs', config['dataset'], 'masks'), img_ext=config['img_ext'], mask_ext=config['mask_ext'], num_classes=config['num_classes'], transform=train_transform2) val_dataset = Dataset(img_ids=val_img_ids, img_dir=os.path.join('inputs', config['dataset'], 'images'), mask_dir=os.path.join('inputs', config['dataset'], 'masks'), img_ext=config['img_ext'], mask_ext=config['mask_ext'], num_classes=config['num_classes'], transform=val_transform2) #print("length of train dataset is " + str(len(train_dataset))) #print("length of val dataset is " + str(len(val_dataset))) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], drop_last=True) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], drop_last=False) log = OrderedDict([ ('epoch', []), ('lr', []), ('loss', []), ('iou', []), ('val_loss', []), ('val_iou', []), ('dice', []), ]) best_iou = 0 trigger = 0 best_dice = 0 for epoch in range(config['epochs']): print('Epoch [%d/%d]' % (epoch, config['epochs'])) fw.write('Epoch [%d/%d]' % (epoch, config['epochs']) + '\n') # train for one epoch train_log = train(config, train_loader, model, criterion, optimizer) # evaluate on validation set val_log = validate(config, val_loader, model, criterion) if config['scheduler'] == 'CosineAnnealingLR': scheduler.step() elif config['scheduler'] == 'ReduceLROnPlateau': scheduler.step(val_log['loss']) print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice'])) fw.write( 'loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice']) + '\n') log['epoch'].append(epoch) log['lr'].append(config['lr']) log['loss'].append(train_log['loss']) log['iou'].append(train_log['iou']) log['val_loss'].append(val_log['loss']) log['val_iou'].append(val_log['iou']) log['dice'].append(val_log['dice']) pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False) trigger += 1 ''' if val_log['iou'] > best_iou: torch.save(model.state_dict(), 'models/%s/model.pth' % config['name']) best_iou = val_log['iou'] print("=> saved best model") trigger = 0 ''' if val_log['dice'] > best_dice: torch.save(model.state_dict(), 'models/%s/model.pth' % config['name']) best_dice = val_log['dice'] print("=> saved best model") fw.write("=> saved best model" + '\n') trigger = 0 # early stopping if config['early_stopping'] >= 0 and trigger >= config[ 'early_stopping']: print("=> early stopping") fw.write("=> early stopping" + '\n') break torch.cuda.empty_cache()
def main_func(train_idx, val_set, test_set, modelName, fileName): ''' params: train_idx, val_set, test_set => patient ids in train, val and test set. modelName, fileName => modelname for model directory storing models, configurations and filename to store results, both generated as per patient indices in train, test and val set. (For identification later) New model trained, tested and stored in corresponding modelName and fileName files. No objects returned. ''' # Read configurations and create model directory config = vars(parse_args()) config['name'] = modelName fw = open('batch_results_train/'+ fileName, 'w') print('config of dataset is ' + str(config['dataset'])) fw.write('config of dataset is ' + str(config['dataset']) + '\n') if config['name'] is None: if config['deep_supervision']: config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch']) else: config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch']) os.makedirs('models/%s' % config['name'], exist_ok=True) print('-' * 20) fw.write('-' * 20 + '\n') for key in config: print('%s: %s' % (key, config[key])) fw.write('%s: %s' % (key, config[key]) + '\n') print('-' * 20) fw.write('-' * 20 + '\n') with open('models/%s/config.yml' % config['name'], 'w') as f: yaml.dump(config, f) # define loss function (criterion) if config['loss'] == 'BCEWithLogitsLoss': criterion = nn.BCEWithLogitsLoss().cuda() else: criterion = losses.__dict__[config['loss']]().cuda() cudnn.benchmark = True # create model print("=> creating model %s" % config['arch']) fw.write("=> creating model %s" % config['arch'] + '\n') model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision']) model = model.cuda() params = filter(lambda p: p.requires_grad, model.parameters()) if config['optimizer'] == 'Adam': optimizer = optim.Adam( params, lr=config['lr'], weight_decay=config['weight_decay']) elif config['optimizer'] == 'SGD': optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay']) else: raise NotImplementedError if config['scheduler'] == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=config['epochs'], eta_min=config['min_lr']) elif config['scheduler'] == 'ReduceLROnPlateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr']) elif config['scheduler'] == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) elif config['scheduler'] == 'ConstantLR': scheduler = None else: raise NotImplementedError # Data loading code img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext'])) img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] # Patient IDs in validation set val_idx = [val_set] # Creating lists for noting images in train set or validation set. Iterate # through all images and insert into train or validation set as per patient # found in the name of the image. val_img_ids = [] train_img_ids = [] for image in img_ids: im_begin = image.split('.')[0] if int(im_begin[-1]) in val_idx: val_img_ids.append(image) elif int(im_begin[-1]) in train_idx: train_img_ids.append(image) # Transformations that could be applied to the images. # Note: Same transformation must be applied to both train and validation set. train_transform = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) train_transform2 = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) val_transform2 = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), transforms.ShiftScaleRotate(shift_limit = 0.1, scale_limit = 0, rotate_limit = 0),# shift_limit_x = 0.1, shift_limit_y = 0.1, p = 1), ##TODO remove from validation ]) val_transform = Compose([ transforms.Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) # Creating PyTorch dataset object. train_dataset = Dataset( img_ids=train_img_ids, img_dir=os.path.join('inputs', config['dataset'], 'images'), mask_dir=os.path.join('inputs', config['dataset'], 'masks'), img_ext=config['img_ext'], mask_ext=config['mask_ext'], num_classes=config['num_classes'], transform=train_transform2) val_dataset = Dataset( img_ids=val_img_ids, img_dir=os.path.join('inputs', config['dataset'], 'images'), mask_dir=os.path.join('inputs', config['dataset'], 'masks'), img_ext=config['img_ext'], mask_ext=config['mask_ext'], num_classes=config['num_classes'], transform=val_transform) # creating the pytorch dataloader for train and validation sets. train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], drop_last=False) # Results dictionary log = OrderedDict([ ('epoch', []), ('lr', []), ('loss', []), ('iou', []), ('val_loss', []), ('val_iou', []), ('dice', []), ]) best_iou = 0 trigger = 0 best_dice = 0 for epoch in range(config['epochs']): print('Epoch [%d/%d]' % (epoch, config['epochs'])) fw.write('Epoch [%d/%d]' % (epoch, config['epochs']) + '\n') # train for one epoch train_log = train(config, train_loader, model, criterion, optimizer) # evaluate on validation set val_log = validate(config, val_loader, model, criterion) if config['scheduler'] == 'CosineAnnealingLR': scheduler.step() elif config['scheduler'] == 'ReduceLROnPlateau': scheduler.step(val_log['loss']) print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice'])) fw.write('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f dice %.4f' % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'], val_log['dice']) + '\n') # Appending result to log dictionary log['epoch'].append(epoch) log['lr'].append(config['lr']) log['loss'].append(train_log['loss']) log['iou'].append(train_log['iou']) log['val_loss'].append(val_log['loss']) log['val_iou'].append(val_log['iou']) log['dice'].append(val_log['dice']) pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False) trigger += 1 # Determine if new updated model gives best performance and accordingly save. # Multiple ways to determine better performance, dice score is used here, can also use IoU. if val_log['dice'] > best_dice: torch.save(model.state_dict(), 'models/%s/model.pth' % config['name']) best_dice = val_log['dice'] print("=> saved best model") fw.write("=> saved best model" + '\n') trigger = 0 ''' # can be used if best model picked using IOU if val_log['iou'] > best_iou: torch.save(model.state_dict(), 'models/%s/model.pth' % config['name']) best_iou = val_log['iou'] print("=> saved best model") trigger = 0 ''' # early stopping if config['early_stopping'] >= 0 and trigger >= config['early_stopping']: print("=> early stopping") fw.write("=> early stopping" + '\n') break torch.cuda.empty_cache()