def train_fold(save_dir, train_folds, val_folds, model_path): depth_trns = SimpleDepthTransform() train_trns = SaltTransform(IMAGE_SIZE, True, 'crop') val_trns = SaltTransform(IMAGE_SIZE, False, 'crop') train_dataset = SaltDataset(TRAIN_FOLDS_PATH, train_folds, train_trns, depth_trns) val_dataset = SaltDataset(TRAIN_FOLDS_PATH, val_folds, val_trns, depth_trns) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=8) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8) model = load_model(model_path) model.loss.lovasz_weight = 0.5 model.loss.prob_weight = 0.5 callbacks = [ MonitorCheckpoint(save_dir, monitor='val_crop_iout', max_saves=3, copy_last=False), LoggingToFile(os.path.join(save_dir, 'log.txt')), update_lr ] model.fit(train_loader, val_loader=val_loader, max_epochs=500, callbacks=callbacks, metrics=['crop_iout'])
def __init__(self, model_path): self.model = load_model(model_path) self.model.nn_module.final = torch.nn.Sigmoid() # self.model.nn_module.eval() self.depth_trns = SimpleDepthTransform() self.crop_trns = CenterCrop(ORIG_IMAGE_SIZE) self.trns = SaltTransform(PRED_IMAGE_SIZE, False, TRANSFORM_MODE)
def __init__(self, model_path): self.model = load_model(model_path) self.model.nn_module.eval() self.depth_trns = SimpleDepthTransform() self.crop_trns = CenterCrop(ORIG_IMAGE_SIZE) self.trns = SaltTransform(PRED_IMAGE_SIZE, False, TRANSFORM_MODE) self.flip = HorizontalFlip()
def __init__(self, test_dir, transform=None, depth_transform=None): super().__init__() self.test_dir = test_dir self.transform = transform if depth_transform is None: self.depth_transform = SimpleDepthTransform() else: self.depth_transform = depth_transform self.images_lst, self.depth_lst = \ get_test_samples(test_dir)
def __init__(self, train_folds_path, folds, transform=None, depth_transform=None): super().__init__() self.train_folds_path = train_folds_path self.folds = folds self.transform = transform if depth_transform is None: self.depth_transform = SimpleDepthTransform() else: self.depth_transform = depth_transform self.images_lst, self.target_lst, self.depth_lst = \ get_samples(train_folds_path, folds)
def train_fold(save_dir, train_folds, val_folds): depth_trns = SimpleDepthTransform() train_trns = SaltTransform(IMAGE_SIZE, True, 'crop') val_trns = SaltTransform(IMAGE_SIZE, False, 'crop') train_dataset = SaltDataset(TRAIN_FOLDS_PATH, train_folds, train_trns, depth_trns) val_dataset = SaltDataset(TRAIN_FOLDS_PATH, val_folds, val_trns, depth_trns) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4) model = SaltMetaModel(PARAMS) callbacks = [ MonitorCheckpoint(save_dir, monitor='val_crop_iout', max_saves=3, copy_last=False), EarlyStopping(monitor='val_crop_iout', patience=100), ReduceLROnPlateau(monitor='val_crop_iout', patience=30, factor=0.64, min_lr=1e-8), LoggingToFile(os.path.join(save_dir, 'log.txt')), ] model.fit(train_loader, val_loader=val_loader, max_epochs=700, callbacks=callbacks, metrics=['crop_iout'])
'bce_weight': random_params['bce_weight'], 'prob_weight': random_params['prob_weight'] }), 'prediction_transform': ('ProbOutputTransform', { 'segm_thresh': 0.5, 'prob_thresh': 0.5, }), 'optimizer': ('Adam', { 'lr': 0.0001 }), 'device': 'cuda' } pprint(params) depth_trns = SimpleDepthTransform() train_trns = SaltTransform(IMAGE_SIZE, True, 'crop') val_trns = SaltTransform(IMAGE_SIZE, False, 'crop') train_dataset = SaltDataset(TRAIN_FOLDS_PATH, TRAIN_FOLDS, train_trns, depth_trns) val_dataset = SaltDataset(TRAIN_FOLDS_PATH, VAL_FOLDS, val_trns, depth_trns) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)