def get_model(name, classification_head, model_weights_path=None): if name == 'unet34': return smp.Unet('resnet34', encoder_weights='imagenet') elif name == 'unet18': print('classification_head:', classification_head) if classification_head: aux_params = dict( pooling='max', # one of 'avg', 'max' dropout=0.1, # dropout ratio, default is None activation='sigmoid', # activation function, default is None classes=1, # define number of output labels ) return smp.Unet('resnet18', aux_params=aux_params, encoder_weights=None, encoder_depth=2, decoder_channels=(256, 128)) else: return smp.Unet('resnet18', encoder_weights='imagenet', encoder_depth=2, decoder_channels=(256, 128)) elif name == 'unet50': return smp.Unet('resnet50', encoder_weights='imagenet') elif name == 'unet101': return smp.Unet('resnet101', encoder_weights='imagenet') elif name == 'linknet34': return smp.Linknet('resnet34', encoder_weights='imagenet') elif name == 'linknet50': return smp.Linknet('resnet50', encoder_weights='imagenet') elif name == 'fpn34': return smp.FPN('resnet34', encoder_weights='imagenet') elif name == 'fpn50': return smp.FPN('resnet50', encoder_weights='imagenet') elif name == 'fpn101': return smp.FPN('resnet101', encoder_weights='imagenet') elif name == 'pspnet34': return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1) elif name == 'pspnet50': return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1) elif name == 'fpn50_season': from clearcut_research.pytorch import FPN_double_output return FPN_double_output('resnet50', encoder_weights='imagenet') elif name == 'fpn50_satellite': fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None) fpn_resnet50.encoder = get_satellite_pretrained_resnet( model_weights_path) return fpn_resnet50 elif name == 'fpn50_multiclass': return smp.FPN('resnet50', encoder_weights='imagenet', classes=3, activation='softmax') else: raise ValueError("Unknown network")
def init(config): # ---- Model Initialization ---- if config["model"] == "UNet": model = smp.Unet( activation=None ) #UNet2D(n_channels=3, n_classes=1) # #UNet2D(n_channels=1, n_classes=1) #smp.Unet(activation=None) elif config["model"] == "PSPNet": model = smp.PSPNet(activation=None) elif config["model"] == "FPN": model = smp.FPN(activation=None) elif config["model"] == "Linknet": model = smp.Linknet(activation=None) else: raise Exception('Incorrect model name!') # ---- Loss Initialization ---- if config["mode"] == 'train': if config["loss"] == "DiceBCE": loss = LossBinaryDice(dice_weight=config["dice_weight"]) elif config["loss"] == "FocalTversky": loss = FocalTverskyLoss() elif config["loss"] == "Focal": loss = FocalLoss() elif config["loss"] == "Tversky": loss = TverskyLoss() else: raise Exception('Incorrect loss name!') return model, loss else: return model
def make_model(model_name='unet_resnet34', weights='imagenet', n_classes=2, input_channels=4): if model_name.split('_')[0] == 'unet': model = smp.Unet('_'.join(model_name.split('_')[1:]), classes=n_classes, activation=None, encoder_weights=weights, in_channels=input_channels) elif model_name.split('_')[0] == 'fpn': model = smp.FPN('_'.join(model_name.split('_')[1:]), classes=n_classes, activation=None, encoder_weights=weights, in_channels=input_channels) elif model_name.split('_')[0] == 'linknet': model = smp.Linknet('_'.join(model_name.split('_')[1:]), classes=n_classes, activation=None, encoder_weights=weights, in_channels=input_channels) else: raise ValueError('Model not implemented') return model
def create_model(self): kwargs = { 'encoder_name': self.encoder_name, 'encoder_weights': self.encoder_weights, 'classes': self.num_classes } if self.model_architecture == 'Unet': model = smp.Unet(**kwargs) elif self.model_architecture == 'FPN': model = smp.FPN(**kwargs) elif self.model_architecture == 'Linknet': model = smp.Linknet(**kwargs) elif self.model_architecture == 'PSPNet': model = smp.Linknet(**kwargs) return model
def Linknet(self, img_ch, output_ch): return smp.Linknet(encoder_name=self.encoder, encoder_depth=self.en_depth, encoder_weights=self.en_weights, decoder_use_batchnorm=False, in_channels=img_ch, classes=output_ch, activation=None, aux_params=None)
def resnet50_Linknet_noclassification(**kwargs): model = smp.Linknet('resnet50', in_channels=in_channels, classes=classes, activation=activation, **kwargs) print("Just segmentation Model args:") print("in_channels:%d,classes:%d,activation:%s" % (in_channels, classes, activation)) print("kwargs", kwargs) return model
def build_model(configuration): model_list = ['UNet', 'LinkNet', 'PSPNet', 'FPN', 'PAN', 'Deeplab_v3', 'Deeplab_v3+'] if configuration.Model.model_name.lower() == 'unet': return smp.Unet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes, decoder_attention_type=None, ) if configuration.Model.model_name.lower() == 'linknet': return smp.Linknet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pspnet': return smp.PSPNet( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'fpn': return smp.FPN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'pan': return smp.PAN( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3+': return smp.DeepLabV3Plus( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) if configuration.Model.model_name.lower() == 'deeplab_v3': return smp.DeepLabV3( encoder_name=configuration.Model.encoder, encoder_weights=configuration.Model.encoder_weights, activation=None, classes=configuration.DataSet.number_of_classes ) raise KeyError(f'Model should be one of {model_list}')
def create_segmentation_models(encoder, arch, num_classes=4, encoder_weights=None, activation=None): ''' segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch has following architectures: - Unet - Linknet - FPN - PSPNet encoders: A lot! see the above github page. Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception has for encoders: - resnet (resnet101) - mobilenet - xception - drn ''' if arch == "Unet": return smp.Unet(encoder, encoder_weights=encoder_weights, classes=num_classes, activation=activation) elif arch == "Linknet": return smp.Linknet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "FPN": return smp.FPN(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "PSPNet": return smp.PSPNet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "deeplabv3plus": if deeplabv3plus_PATH in os.environ: sys.path.append(os.environ[deeplabv3plus_PATH]) from modeling.deeplab import DeepLab return DeepLab(encoder, num_classes=4) else: raise ValueError('Set deeplabv3plus path by environment variable.') else: raise ValueError( 'arch {} is not found, set the correct arch'.format(arch)) sys.exit()
def get_model(name='fpn50', model_weights_path=None): if name == 'unet34': return smp.Unet('resnet34', encoder_weights='imagenet') elif name == 'unet50': return smp.Unet('resnet50', encoder_weights='imagenet') elif name == 'unet101': return smp.Unet('resnet101', encoder_weights='imagenet') elif name == 'linknet34': return smp.Linknet('resnet34', encoder_weights='imagenet') elif name == 'linknet50': return smp.Linknet('resnet50', encoder_weights='imagenet') elif name == 'fpn34': return smp.FPN('resnet34', encoder_weights='imagenet') elif name == 'fpn50': return smp.FPN('resnet50', encoder_weights='imagenet') elif name == 'fpn101': return smp.FPN('resnet101', encoder_weights='imagenet') elif name == 'pspnet34': return smp.PSPNet('resnet34', encoder_weights='imagenet', classes=1) elif name == 'pspnet50': return smp.PSPNet('resnet50', encoder_weights='imagenet', classes=1) elif name == 'fpn50_season': from clearcut_research.pytorch import FPN_double_output return FPN_double_output('resnet50', encoder_weights='imagenet') elif name == 'fpn50_satellite': fpn_resnet50 = smp.FPN('resnet50', encoder_weights=None) fpn_resnet50.encoder = get_satellite_pretrained_resnet( model_weights_path) return fpn_resnet50 elif name == 'fpn50_multiclass': return smp.FPN('resnet50', encoder_weights='imagenet', classes=3, activation='softmax') else: raise ValueError("Unknown network")
def main(): args = parseArgs() data = DataTest(rootPth=args.rootPth) dataLoader = DataLoader(data, batch_size=args.batchSize, shuffle=False, pin_memory=False, num_workers=args.numWorkers ) model = smp.Linknet(classes=1, encoder_name='se_resnext101_32x4d').to(device) model.load_state_dict(torch.load(args.modelPth)) if not osp.exists(args.savePth): os.makedirs(args.savePth) inference(model, dataLoader, args) print('--Done--')
def create_smp_model(arch, **kwargs): 'Create segmentation_models_pytorch model' assert arch in ARCHITECTURES, f'Select one of {ARCHITECTURES}' if arch == "Unet": model = smp.Unet(**kwargs) elif arch == "UnetPlusPlus": model = smp.UnetPlusPlus(**kwargs) elif arch == "MAnet": model = smp.MAnet(**kwargs) elif arch == "FPN": model = smp.FPN(**kwargs) elif arch == "PAN": model = smp.PAN(**kwargs) elif arch == "PSPNet": model = smp.PSPNet(**kwargs) elif arch == "Linknet": model = smp.Linknet(**kwargs) elif arch == "DeepLabV3": model = smp.DeepLabV3(**kwargs) elif arch == "DeepLabV3Plus": model = smp.DeepLabV3Plus(**kwargs) else: raise NotImplementedError setattr(model, 'kwargs', kwargs) return model
def evaluate(self, architecture="FPN", encoder="resnet34", encoder_weights="imagenet", activation="sigmoid", resize_width=224, resize_height=224): """ Args: architecture (str): One of the followings: "FPN", "UNET". encoder_weights (str): Any encoder supported by SMP. activation (str): Any SMP activation. resize_width (int): Preprocessing image resize width. resize_height (int): Preprocessing image resize height. """ import segmentation_models_pytorch as smp from plantseg.inference import Preprocessing if architecture == "FPN": model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=1, activation=activation) elif architecture == "UNet": model = smp.Unet( encoder_name=encoder, encoder_weights=encoder_weights, classes=1, activation=activation) elif architecture == "Linknet": model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=1, activation=activation) else: raise RuntimeError(f"Undefined architecture {architecture}") preproc_fun = smp.encoders.get_preprocessing_fn( encoder, encoder_weights) return model, Preprocessing(preproc_fun, resize_width, resize_height)
def get_model(encoder='resnet18', type='unet', encoder_weights='imagenet', classes=4): # My own simple wrapper around smp if type == 'unet': model = smp.Unet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'fpn': model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'pspnet': model = smp.PSPNet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) elif type == 'linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=classes, activation=None, ) else: raise "weird architecture" print(f"Training on {type} architecture with {encoder} encoder") preprocessing_fn = smp.encoders.get_preprocessing_fn( encoder, encoder_weights) return model, preprocessing_fn
def get_model( model_type: str = "Unet", encoder: str = "Resnet18", encoder_weights: str = "imagenet", activation: str = None, n_classes: int = 4, task: str = "segmentation", source: str = "pretrainedmodels", head: str = "simple", ): """ Get model for training or inference. Returns loaded models, which is ready to be used. Args: model_type: segmentation model architecture encoder: encoder of the model encoder_weights: pre-trained weights to use activation: activation function for the output layer n_classes: number of classes in the output layer task: segmentation or classification source: source of model for classification head: simply change number of outputs or use better output head Returns: """ if task == "segmentation": if model_type == "Unet": model = smp.Unet( # attention_type='scse', encoder_name=encoder, encoder_weights=encoder_weights, classes=n_classes, activation=activation, ) elif model_type == "Linknet": model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=n_classes, activation=activation, ) elif model_type == "FPN": model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=n_classes, activation=activation, ) elif model_type == "resnet34_fpn": model = resnet34_fpn(num_classes=n_classes, fpn_features=128) elif model_type == "effnetB4_fpn": model = effnetB4_fpn(num_classes=n_classes, fpn_features=128) else: model = None elif task == "classification": if source == "pretrainedmodels": model_fn = pretrainedmodels.__dict__[encoder] model = model_fn(num_classes=1000, pretrained=encoder_weights) elif source == "torchvision": model = torchvision.models.__dict__[encoder]( pretrained=encoder_weights) if head == "simple": model.last_linear = nn.Linear(model.last_linear.in_features, n_classes) else: model = Net(net=model) return model
def main(): fold_path = args.fold_path fold_num = args.fold_num model_name = args.model_name train_csv = args.train_csv sub_csv = args.sub_csv encoder = args.encoder num_workers = args.num_workers batch_size = args.batch_size num_epochs = args.num_epochs learn_late = args.learn_late attention_type = args.attention_type train = pd.read_csv(train_csv) sub = pd.read_csv(sub_csv) train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1]) train['im_id'] = train['Image_Label'].apply( lambda x: x.replace('_' + x.split('_')[-1], '')) sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1]) sub['im_id'] = sub['Image_Label'].apply( lambda x: x.replace('_' + x.split('_')[-1], '')) train_fold = pd.read_csv(f'{fold_path}/train_file_fold_{fold_num}.csv') val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv') train_ids = np.array(train_fold.file_name) valid_ids = np.array(val_fold.file_name) encoder_weights = 'imagenet' attention_type = None if attention_type == 'None' else attention_type if model_name == 'Unet': model = smp.Unet( encoder_name=encoder, encoder_weights=encoder_weights, classes=4, activation='softmax', attention_type=attention_type, ) if model_name == 'Linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=4, activation='softmax', ) if model_name == 'FPN': model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=4, activation='softmax', ) if model_name == 'ORG': model = Linknet_resnet18_ASPP() preprocessing_fn = smp.encoders.get_preprocessing_fn( encoder, encoder_weights) train_dataset = CloudDataset( df=train, datatype='train', img_ids=train_ids, transforms=get_training_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) valid_dataset = CloudDataset( df=train, datatype='valid', img_ids=valid_ids, transforms=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True, ) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) loaders = {"train": train_loader, "valid": valid_loader} logdir = f"./log/logs_{model_name}_fold_{fold_num}_{encoder}/segmentation" #for batch_idx, (data, target) in enumerate(loaders['train']): # print(batch_idx) print(logdir) if model_name == 'ORG': optimizer = NAdam([ { 'params': model.parameters(), 'lr': learn_late }, ]) else: optimizer = NAdam([ { 'params': model.decoder.parameters(), 'lr': learn_late }, { 'params': model.encoder.parameters(), 'lr': learn_late }, ]) scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=0) criterion = smp.utils.losses.BCEDiceLoss() runner = SupervisedRunner() runner.train(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=[ DiceCallback(), EarlyStoppingCallback(patience=5, min_delta=1e-7) ], logdir=logdir, num_epochs=num_epochs, verbose=1)
cuda_id = 2 DEVICE = 'cuda' NUM_EPOCH = 20 SAVE_PRE = 1 EVAL_PRE = 1 PRINT_PRE = 1 NUM_STEPS_STOP = 100000 SAVE_DIR = project_path + r'/../checkpoints/' + str(localtime) if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) # 数据集 train_loader, valid_loader = create_valley_data_loader() # 模型 net = smp.Linknet('resnet50', in_channels=1, classes=1).cuda(cuda_id) # 损失函数 loss = torch.nn.CrossEntropyLoss().cuda(cuda_id) # 验证指标 metrics = [ smp.utils.metrics.IoU(threshold=0.5), ] # 优化器 optimizer = torch.optim.Adam(params=net.parameters(), lr=0.0001) # tensorboardX writer = SummaryWriter(project_path + r'/../runs/' + str(localtime)) # 数据存放在这个文件夹
def main(): parser = argparse.ArgumentParser() parser.add_argument('--encoder', type=str, default='efficientnet-b0') parser.add_argument('--model', type=str, default='unet') parser.add_argument('--loc', type=str) parser.add_argument('--data_folder', type=str, default='../input/') parser.add_argument('--batch_size', type=int, default=2) parser.add_argument('--optimize', type=bool, default=False) parser.add_argument('--tta_pre', type=bool, default=False) parser.add_argument('--tta_post', type=bool, default=False) parser.add_argument('--merge', type=str, default='mean') parser.add_argument('--min_size', type=int, default=10000) parser.add_argument('--thresh', type=float, default=0.5) parser.add_argument('--name', type=str) args = parser.parse_args() encoder = args.encoder model = args.model loc = args.loc data_folder = args.data_folder bs = args.batch_size optimize = args.optimize tta_pre = args.tta_pre tta_post = args.tta_post merge = args.merge min_size = args.min_size thresh = args.thresh name = args.name if model == 'unet': model = smp.Unet(encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None) if model == 'fpn': model = smp.FPN( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) if model == 'pspnet': model = smp.PSPNet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) if model == 'linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights='imagenet', classes=4, activation=None, ) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, 'imagenet') test_df = get_dataset(train=False) test_df = prepare_dataset(test_df) test_ids = test_df['Image_Label'].apply( lambda x: x.split('_')[0]).drop_duplicates().values test_dataset = CloudDataset( df=test_df, datatype='test', img_ids=test_ids, transforms=valid1(), preprocessing=get_preprocessing(preprocessing_fn)) test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=False) val_df = get_dataset(train=True) val_df = prepare_dataset(val_df) _, val_ids = get_train_test(val_df) valid_dataset = CloudDataset( df=val_df, datatype='train', img_ids=val_ids, transforms=valid1(), preprocessing=get_preprocessing(preprocessing_fn)) valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False) model.load_state_dict(torch.load(loc)['model_state_dict']) class_params = { 0: (thresh, min_size), 1: (thresh, min_size), 2: (thresh, min_size), 3: (thresh, min_size) } if optimize: print("OPTIMIZING") print(tta_pre) if tta_pre: opt_model = tta.SegmentationTTAWrapper( model, tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 180]) ]), merge_mode=merge) else: opt_model = model tta_runner = SupervisedRunner() print("INFERRING ON VALID") tta_runner.infer( model=opt_model, loaders={'valid': valid_loader}, callbacks=[InferCallback()], verbose=True, ) valid_masks = [] probabilities = np.zeros((4 * len(valid_dataset), 350, 525)) for i, (batch, output) in enumerate( tqdm( zip(valid_dataset, tta_runner.callbacks[0].predictions["logits"]))): _, mask = batch for m in mask: if m.shape != (350, 525): m = cv2.resize(m, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) valid_masks.append(m) for j, probability in enumerate(output): if probability.shape != (350, 525): probability = cv2.resize(probability, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) probabilities[(i * 4) + j, :, :] = probability print("RUNNING GRID SEARCH") for class_id in range(4): print(class_id) attempts = [] for t in range(30, 70, 5): t /= 100 for ms in [7500, 10000, 12500, 15000, 175000]: masks = [] for i in range(class_id, len(probabilities), 4): probability = probabilities[i] predict, num_predict = post_process( sigmoid(probability), t, ms) masks.append(predict) d = [] for i, j in zip(masks, valid_masks[class_id::4]): if (i.sum() == 0) & (j.sum() == 0): d.append(1) else: d.append(dice(i, j)) attempts.append((t, ms, np.mean(d))) attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice']) attempts_df = attempts_df.sort_values('dice', ascending=False) print(attempts_df.head()) best_threshold = attempts_df['threshold'].values[0] best_size = attempts_df['size'].values[0] class_params[class_id] = (best_threshold, best_size) del opt_model del tta_runner del valid_masks del probabilities gc.collect() if tta_post: model = tta.SegmentationTTAWrapper(model, tta.Compose([ tta.HorizontalFlip(), tta.VerticalFlip(), tta.Rotate90(angles=[0, 180]) ]), merge_mode=merge) else: model = model print(tta_post) runner = SupervisedRunner() runner.infer( model=model, loaders={'test': test_loader}, callbacks=[InferCallback()], verbose=True, ) encoded_pixels = [] image_id = 0 for i, image in enumerate(tqdm(runner.callbacks[0].predictions['logits'])): for i, prob in enumerate(image): if prob.shape != (350, 525): prob = cv2.resize(prob, dsize=(525, 350), interpolation=cv2.INTER_LINEAR) predict, num_predict = post_process(sigmoid(prob), class_params[image_id % 4][0], class_params[image_id % 4][1]) if num_predict == 0: encoded_pixels.append('') else: r = mask2rle(predict) encoded_pixels.append(r) image_id += 1 test_df['EncodedPixels'] = encoded_pixels test_df.to_csv(name, columns=['Image_Label', 'EncodedPixels'], index=False)
checkpoint = torch.load(args.checkpoint) arch_dict = { "unet": smp.Unet( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], decoder_attention_type="scse", decoder_use_batchnorm=True, ), "linknet": smp.Linknet( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], ), "fpn": smp.FPN( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"], ), "pspnet": smp.PSPNet( encoder_name=checkpoint["encoder"], encoder_weights=checkpoint["encoder_weight"], classes=8, activation=checkpoint["activation"],
print("fold: {} ----------------------------------------".format( fold)) best = 0 trainloader, validloader = prepare_train_valid_dataloader( dir_df, [fold]) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') path = f'/mnt/result1/{CFG.data}/{CFG.encoder}-{CFG.base_model}-{CFG.criterion}-{CFG.data}-FOLD-{fold}-model.pth' state_dict = torch.load(path, map_location=torch.device('cpu')) if CFG.base_model == 'unet': model = smp.Unet(CFG.encoder, encoder_weights=None, classes=1) elif CFG.base_model == 'linknet': model = smp.Linknet(CFG.encoder, encoder_weights='imagenet', classes=1) model.load_state_dict(state_dict) del state_dict scaler = GradScaler() for epoch in range(CFG.epoch): if epoch < CFG.freeze_epoch: for p in model.encoder.parameters(): p.requires_grad = False if args.op == 'adam': optimizer1 = Adam(filter(lambda p: p.requires_grad, model.parameters()),
def get_model(config): """ """ arch = config.MODEL.ARCHITECTURE backbone = config.MODEL.BACKBONE encoder_weights = config.MODEL.ENCODER_PRETRAINED_FROM in_channels = config.MODEL.IN_CHANNELS n_classes = len(config.INPUT.CLASSES) activation = config.MODEL.ACTIVATION # unet specific decoder_attention_type = 'scse' if config.MODEL.UNET_ENABLE_DECODER_SCSE else None if arch == 'unet': model = smp.Unet( encoder_name=backbone, encoder_weights=encoder_weights, decoder_channels=config.MODEL.UNET_DECODER_CHANNELS, decoder_attention_type=decoder_attention_type, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'fpn': model = smp.FPN( encoder_name=backbone, encoder_weights=encoder_weights, decoder_dropout=config.MODEL.FPN_DECODER_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pan': model = smp.PAN( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'pspnet': model = smp.PSPNet( encoder_name=backbone, encoder_weights=encoder_weights, psp_dropout=config.MODEL.PSPNET_DROPOUT, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'deeplabv3': model = smp.DeepLabV3( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) elif arch == 'linknet': model = smp.Linknet( encoder_name=backbone, encoder_weights=encoder_weights, in_channels=in_channels, classes=n_classes, activation=activation ) else: raise ValueError() model = torch.nn.DataParallel(model) if config.MODEL.WEIGHT and config.MODEL.WEIGHT != 'none': # load weight from file model.load_state_dict( torch.load( config.MODEL.WEIGHT, map_location=torch.device('cpu') ) ) model = model.to(config.MODEL.DEVICE) return model
def main(): fold_path = args.fold_path fold_num = args.fold_num model_name = args.model_name train_csv = args.train_csv sub_csv = args.sub_csv encoder = args.encoder num_workers = args.num_workers batch_size = args.batch_size log_path = args.log_path is_tta = args.is_tta test_batch_size = args.test_batch_size attention_type = args.attention_type print(log_path) train = pd.read_csv(train_csv) train['label'] = train['Image_Label'].apply(lambda x: x.split('_')[-1]) train['im_id'] = train['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], '')) val_fold = pd.read_csv(f'{fold_path}/valid_file_fold_{fold_num}.csv') valid_ids = np.array(val_fold.file_name) attention_type = None if attention_type == 'None' else attention_type encoder_weights = 'imagenet' if model_name == 'Unet': model = smp.Unet( encoder_name=encoder, encoder_weights=encoder_weights, classes=CLASS, activation='softmax', attention_type=attention_type, ) if model_name == 'Linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights=encoder_weights, classes=CLASS, activation='softmax', ) if model_name == 'FPN': model = smp.FPN( encoder_name=encoder, encoder_weights=encoder_weights, classes=CLASS, activation='softmax', ) if model_name == 'ORG': model = Linknet_resnet18_ASPP( ) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights) valid_dataset = CloudDataset(df=train, datatype='valid', img_ids=valid_ids, transforms=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) loaders = {"infer": valid_loader} runner = SupervisedRunner() checkpoint = torch.load(f"{log_path}/checkpoints/best.pth") model.load_state_dict(checkpoint['model_state_dict']) model.eval() transforms = tta.Compose( [ tta.HorizontalFlip(), tta.VerticalFlip(), ] ) model = tta.SegmentationTTAWrapper(model, transforms) runner.infer( model=model, loaders=loaders, callbacks=[InferCallback()], ) callbacks_num = 0 valid_masks = [] probabilities = np.zeros((valid_dataset.__len__() * CLASS, IMG_SIZE[0], IMG_SIZE[1])) # ======== # val predict # for batch in tqdm(valid_dataset): # クラスごとの予測値 _, mask = batch for m in mask: m = resize_img(m) valid_masks.append(m) for i, output in enumerate(tqdm(runner.callbacks[callbacks_num].predictions["logits"])): for j, probability in enumerate(output): probability = resize_img(probability) # 各クラスごとにprobability(予測値)が取り出されている。jは0~3だと思う。 probabilities[i * CLASS + j, :, :] = probability # ======== # search best size and threshold # class_params = {} for class_id in range(CLASS): attempts = [] for threshold in range(20, 90, 5): threshold /= 100 for min_size in [10000, 15000, 20000]: masks = class_masks(class_id, probabilities, threshold, min_size) dices = class_dices(class_id, masks, valid_masks) attempts.append((threshold, min_size, np.mean(dices))) attempts_df = pd.DataFrame(attempts, columns=['threshold', 'size', 'dice']) attempts_df = attempts_df.sort_values('dice', ascending=False) print(attempts_df.head()) best_threshold = attempts_df['threshold'].values[0] best_size = attempts_df['size'].values[0] class_params[class_id] = (best_threshold, best_size) # ======== # gc # torch.cuda.empty_cache() gc.collect() # ======== # predict # sub = pd.read_csv(sub_csv) sub['label'] = sub['Image_Label'].apply(lambda x: x.split('_')[-1]) sub['im_id'] = sub['Image_Label'].apply(lambda x: x.replace('_' + x.split('_')[-1], '')) test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values test_dataset = CloudDataset(df=sub, datatype='test', img_ids=test_ids, transforms=get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn)) encoded_pixels = get_test_encoded_pixels(test_dataset, runner, class_params, test_batch_size) sub['EncodedPixels'] = encoded_pixels # ======== # val dice # val_Image_Label = [] for i, row in val_fold.iterrows(): val_Image_Label.append(row.file_name + '_Fish') val_Image_Label.append(row.file_name + '_Flower') val_Image_Label.append(row.file_name + '_Gravel') val_Image_Label.append(row.file_name + '_Sugar') val_encoded_pixels = get_test_encoded_pixels(valid_dataset, runner, class_params, test_batch_size) val = pd.DataFrame(val_encoded_pixels, columns=['EncodedPixels']) val['Image_Label'] = val_Image_Label sub.to_csv(f'./sub/sub_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False) val.to_csv(f'./val/val_{model_name}_fold_{fold_num}_{encoder}.csv', columns=['Image_Label', 'EncodedPixels'], index=False)
def get_segmentation_model( arch: str, encoder_name: str, encoder_weights: Optional[str] = "imagenet", pretrained_checkpoint_path: Optional[str] = None, checkpoint_path: Optional[Union[str, List[str]]] = None, convert_bn: Optional[str] = None, convert_bottleneck: Tuple[int, int, int] = (0, 0, 0), **kwargs: Any, ) -> nn.Module: """ Fetch segmentation model by its name :param arch: :param encoder_name: :param encoder_weights: :param checkpoint_path: :param pretrained_checkpoint_path: :param convert_bn: :param convert_bottleneck: :param kwargs: :return: """ arch = arch.lower() if (encoder_name == "en_resnet34" or checkpoint_path is not None or pretrained_checkpoint_path is not None): encoder_weights = None if arch == "unet": model = smp.Unet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "unetplusplus" or arch == "unet++": model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "linknet": model = smp.Linknet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "pspnet": model = smp.PSPNet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "pan": model = smp.PAN(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "deeplabv3": model = smp.DeepLabV3(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "deeplabv3plus" or arch == "deeplabv3+": model = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) elif arch == "manet": model = smp.MAnet(encoder_name=encoder_name, encoder_weights=encoder_weights, **kwargs) else: raise ValueError if pretrained_checkpoint_path is not None: print(f"Loading pretrained checkpoint {pretrained_checkpoint_path}") state_dict = torch.load(pretrained_checkpoint_path, map_location=torch.device("cpu")) model.encoder.load_state_dict(state_dict) del state_dict # TODO fmap_size=16 hardcoded for input 256 (matters for positional encoding) botnet.convert_resnet( model.encoder, replacement=convert_bottleneck, fmap_size=16, position_encoding=None, ) # TODO parametrize conversion print(f"Convert BN to {convert_bn}") if convert_bn == "instance": print("Converting BatchNorm2d to InstanceNorm2d") model = batch_norm2instance(model) elif convert_bn == "group": print("Converting BatchNorm2d to GroupNorm") model = batch_norm2group(model, channels_per_group=1) elif convert_bn == "bnet": print("Converting BatchNorm2d to BNet2d") model = batch_norm2bnet(model) elif convert_bn == "gnet": print("Converting BatchNorm2d to GNet2d") model = batch_norm2gnet(model, channels_per_group=1) elif not convert_bn: print("Do not convert BatchNorm2d") else: raise ValueError if checkpoint_path is not None: if not isinstance(checkpoint_path, list): checkpoint_path = [checkpoint_path] states = [] for cp in checkpoint_path: # Load checkpoint print(f"\nLoading checkpoint {str(cp)}") state_dict = torch.load( cp, map_location=torch.device("cpu"))["model_state_dict"] states.append(state_dict) state_dict = average_weights(states) model.load_state_dict(state_dict) del state_dict return model
def __init__(self, architecture='Unet', encoder='resnet34', depth=5, in_channels=3, classes=2, activation='softmax'): super(SegmentationModels, self).__init__() self.architecture = architecture self.encoder = encoder self.depth = depth self.in_channels = in_channels self.classes = classes self.activation = activation # define model _ARCHITECTURES = ['Unet', 'Linknet', 'FPN', 'PSPNet', 'PAN', 'DeepLabV3', 'DeepLabV3Plus'] assert self.architecture in _ARCHITECTURES, 'architecture=={0}, actual \'{1}\''.format(_ARCHITECTURES, self.architecture) if self.architecture == 'Unet': self.model = smp.Unet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'Linknet': self.model = smp.Linknet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'FPN': self.model = smp.FPN(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'PSPNet': self.model = smp.PSPNet(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'PAN': self.model = smp.PAN(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'DeepLabV3': self.model = smp.DeepLabV3(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth elif self.architecture == 'DeepLabV3Plus': self.model = smp.DeepLabV3Plus(encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation) self.pad_unit = 2 ** self.depth
test_dataset = SeverstalSteelData(img_dir=DATASET_PATH + '/train_images', split_csv=DATASET_PATH + '/steel_valid.csv', rle_csv=DATASET_PATH + '/train.csv', device=device) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # print(train_dataset.__getitem__(2)) ##===== Model Configuration ==================================================== model = smp.Linknet('se_resnext101_32x4d', classes=4, activation=None, encoder_weights=None) model = model.to(device) TEST_MODEL(model, (3, 1600, 256)) ##====== Optimizer Zone ======================================================== optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5) criterionDBCE = metrics.DiceBCELoss() criterionFTversky = metrics.FocalTverskyLoss() criterionFocal = metrics.FocalLoss()
def create_model(model_name, encoder_name, pretrained=False, num_classes=6, in_chans=3, checkpoint_path='', **kwargs): """Create a model Args: model_name (str): name of model to instantiate encoder_name (str): name of encoder to instantiate pretrained (bool): load pretrained ImageNet-1k weights if true num_classes (int): number of classes for final layer (default 6) in_chans (int): number of input channels / colors (default: 3) checkpoint_path (str): path of checkpoint to load after model is initialized Keyword Args: **: other kwargs are model specific """ # I should probably rewrite it weights = None if pretrained: weights = 'imagenet' _logger.info('Using pre-trained imagenet weights') if model_name == 'unetplusplus': model = smp.UnetPlusPlus(encoder_name=encoder_name, encoder_weights=weights, classes=num_classes, in_channels=in_chans, **kwargs) elif model_name == 'unet': model = smp.Unet(encoder_name=encoder_name, encoder_weights=weights, classes=num_classes, in_channels=in_chans, **kwargs) elif model_name == 'fpn': model = smp.FPN(encoder_name=encoder_name, encoder_weights=weights, classes=num_classes, in_channels=in_chans, **kwargs) elif model_name == 'linknet': model = smp.Linknet(encoder_name=encoder_name, encoder_weights=weights, classes=num_classes, in_channels=in_chans, **kwargs) elif model_name == 'pspnet': model = smp.PSPNet(encoder_name=encoder_name, encoder_weights=weights, classes=num_classes, in_channels=in_chans, **kwargs) else: raise NotImplementedError() if checkpoint_path: load_checkpoint(model, checkpoint_path) return model
def main(): with timer('load data'): df = pd.read_csv(FOLD_PATH) with timer('preprocessing'): train_df, val_df = df[df.fold_id != FOLD_ID], df[df.fold_id == FOLD_ID] train_augmentation = Compose([ Flip(p=0.5), OneOf([ #ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03), GridDistortion(p=0.5), OpticalDistortion(p=0.5, distort_limit=2, shift_limit=0.5) ], p=0.5), #OneOf([ # ShiftScaleRotate(p=0.5), ## RandomRotate90(p=0.5), # Rotate(p=0.5) #], p=0.5), OneOf([ Blur(blur_limit=8, p=0.5), MotionBlur(blur_limit=8,p=0.5), MedianBlur(blur_limit=8,p=0.5), GaussianBlur(blur_limit=8,p=0.5) ], p=0.5), OneOf([ #CLAHE(clip_limit=4, tile_grid_size=(4, 4), p=0.5), RandomGamma(gamma_limit=(100,140), p=0.5), RandomBrightnessContrast(p=0.5), RandomBrightness(p=0.5), RandomContrast(p=0.5) ], p=0.5), OneOf([ GaussNoise(p=0.5), Cutout(num_holes=10, max_h_size=10, max_w_size=20, p=0.5) ], p=0.5) ]) train_augmentation = Compose([ Flip(p=0.5) ]) val_augmentation = None train_dataset = SeverDataset(train_df, IMG_DIR, IMG_SIZE, N_CLASSES, id_colname=ID_COLUMNS, transforms=train_augmentation) val_dataset = SeverDataset(val_df, IMG_DIR, IMG_SIZE, N_CLASSES, id_colname=ID_COLUMNS, transforms=val_augmentation) train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2) val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2) del train_df, val_df, df, train_dataset, val_dataset gc.collect() with timer('create model'): model = smp.Linknet('se_resnext101_32x4d', encoder_weights='imagenet', classes=N_CLASSES, encoder_se_module=True, decoder_semodule=True, h_columns=False) model.load_state_dict(torch.load(model_path)) model.to(device) #criterion = torch.nn.BCEWithLogitsLoss() criterion = FocalLovaszLoss() optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) scheduler = CosineAnnealingLR(optimizer, T_max=CLR_CYCLE, eta_min=3e-5) #scheduler = GradualWarmupScheduler(optimizer, multiplier=1.1, total_epoch=CLR_CYCLE*2, after_scheduler=scheduler_cosine) model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) with timer('train'): train_losses = [] valid_losses = [] best_model_loss = 999 best_model_ep = 0 checkpoint = 0 for epoch in range(1, EPOCHS + 1): if epoch % (CLR_CYCLE * 2) == 0: if epoch != 0: y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1]) best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1]) for i in range(N_CLASSES): th, score, _, _ = search_threshold(y_val[:, i, :, :], best_pred[:, i, :, :]) LOGGER.info('Best loss: {} Best Dice: {} on epoch {} th {} class {}'.format( round(best_model_loss, 5), round(score, 5), best_model_ep, th, i)) checkpoint += 1 best_model_loss = 999 LOGGER.info("Starting {} epoch...".format(epoch)) tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device) train_losses.append(tr_loss) LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5))) valid_loss, val_pred, y_val = validate(model, val_loader, criterion, device) valid_losses.append(valid_loss) LOGGER.info('Mean valid loss: {}'.format(round(valid_loss, 5))) scheduler.step() if valid_loss < best_model_loss: torch.save(model.state_dict(), '{}_fold{}_ckpt{}.pth'.format(EXP_ID, FOLD_ID, checkpoint)) best_model_loss = valid_loss best_model_ep = epoch best_pred = val_pred del val_pred gc.collect() with timer('eval'): y_val = y_val.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1]) best_pred = best_pred.reshape(-1, N_CLASSES, IMG_SIZE[0], IMG_SIZE[1]) for i in range(N_CLASSES): th, score, _, _ = search_threshold(y_val[:, i, :, :], best_pred[:, i, :, :]) LOGGER.info('Best loss: {} Best Dice: {} on epoch {} th {} class {}'.format( round(best_model_loss, 5), round(score, 5), best_model_ep, th, i)) xs = list(range(1, len(train_losses) + 1)) plt.plot(xs, train_losses, label='Train loss') plt.plot(xs, valid_losses, label='Val loss') plt.legend() plt.xticks(xs) plt.xlabel('Epochs') plt.savefig("loss.png")
def __init__(self, architecture="Unet", encoder="resnet34", depth=5, in_channels=3, classes=2, activation="softmax"): super(SegmentationModels, self).__init__() self.architecture = architecture self.encoder = encoder self.depth = depth self.in_channels = in_channels self.classes = classes self.activation = activation # define model _ARCHITECTURES = [ "Unet", "UnetPlusPlus", "Linknet", "MAnet", "FPN", "PSPNet", "PAN", "DeepLabV3", "DeepLabV3Plus" ] assert self.architecture in _ARCHITECTURES, "architecture=={0}, actual '{1}'".format( _ARCHITECTURES, self.architecture) if self.architecture == "Unet": self.model = smp.Unet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "UnetPlusPlus": self.model = smp.UnetPlusPlus( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "MAnet": self.model = smp.MAnet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "Linknet": self.model = smp.Linknet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "FPN": self.model = smp.FPN( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "PSPNet": self.model = smp.PSPNet( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "PAN": self.model = smp.PAN( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "DeepLabV3": self.model = smp.DeepLabV3( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth elif self.architecture == "DeepLabV3Plus": self.model = smp.DeepLabV3Plus( encoder_name=self.encoder, encoder_weights=None, encoder_depth=self.depth, in_channels=self.in_channels, classes=self.classes, activation=self.activation, ) self.pad_unit = 2**self.depth
arch_dict = { "unet": smp.Unet( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation, decoder_attention_type="scse", decoder_use_batchnorm=True, aux_params=aux_params_dict, ), "linknet": smp.Linknet( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation, ), "fpn": smp.FPN( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation, ), "pspnet": smp.PSPNet( encoder_name=args.encoder, encoder_weights=args.weight, classes=8, activation=args.activation,
def main(): parser = argparse.ArgumentParser() parser.add_argument('--encoder', type=str, default='efficientnet-b0') parser.add_argument('--model', type=str, default='unet') parser.add_argument('--pretrained', type=str, default='imagenet') parser.add_argument('--logdir', type=str, default='../logs/') parser.add_argument('--exp_name', type=str) parser.add_argument('--data_folder', type=str, default='../input/') parser.add_argument('--height', type=int, default=320) parser.add_argument('--width', type=int, default=640) parser.add_argument('--batch_size', type=int, default=2) parser.add_argument('--accumulate', type=int, default=8) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--enc_lr', type=float, default=1e-2) parser.add_argument('--dec_lr', type=float, default=1e-3) parser.add_argument('--optim', type=str, default="radam") parser.add_argument('--loss', type=str, default="bcedice") parser.add_argument('--schedule', type=str, default="rlop") parser.add_argument('--early_stopping', type=bool, default=True) args = parser.parse_args() encoder = args.encoder model = args.model pretrained = args.pretrained logdir = args.logdir name = args.exp_name data_folder = args.data_folder height = args.height width = args.width bs = args.batch_size accumulate = args.accumulate epochs = args.epochs enc_lr = args.enc_lr dec_lr = args.dec_lr optim = args.optim loss = args.loss schedule = args.schedule early_stopping = args.early_stopping if model == 'unet': model = smp.Unet(encoder_name=encoder, encoder_weights=pretrained, classes=4, activation=None) if model == 'fpn': model = smp.FPN( encoder_name=encoder, encoder_weights=pretrained, classes=4, activation=None, ) if model == 'pspnet': model = smp.PSPNet( encoder_name=encoder, encoder_weights=pretrained, classes=4, activation=None, ) if model == 'linknet': model = smp.Linknet( encoder_name=encoder, encoder_weights=pretrained, classes=4, activation=None, ) if model == 'aspp': print('aspp can only be used with resnet34') model = aspp(num_class=4) preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, pretrained) log = os.path.join(logdir, name) ds = get_dataset(path=data_folder) prepared_ds = prepare_dataset(ds) train_set, valid_set = get_train_test(ds) train_ds = CloudDataset(df=prepared_ds, datatype='train', img_ids=train_set, transforms=training1(h=height, w=width), preprocessing=get_preprocessing(preprocessing_fn), folder=data_folder) valid_ds = CloudDataset(df=prepared_ds, datatype='train', img_ids=valid_set, transforms=valid1(h=height, w=width), preprocessing=get_preprocessing(preprocessing_fn), folder=data_folder) train_loader = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=multiprocessing.cpu_count()) valid_loader = DataLoader(valid_ds, batch_size=bs, shuffle=False, num_workers=multiprocessing.cpu_count()) loaders = { 'train': train_loader, 'valid': valid_loader, } num_epochs = epochs if args.model != "aspp": if optim == "radam": optimizer = RAdam([ { 'params': model.encoder.parameters(), 'lr': enc_lr }, { 'params': model.decoder.parameters(), 'lr': dec_lr }, ]) if optim == "adam": optimizer = Adam([ { 'params': model.encoder.parameters(), 'lr': enc_lr }, { 'params': model.decoder.parameters(), 'lr': dec_lr }, ]) if optim == "adamw": optimizer = AdamW([ { 'params': model.encoder.parameters(), 'lr': enc_lr }, { 'params': model.decoder.parameters(), 'lr': dec_lr }, ]) if optim == "sgd": optimizer = SGD([ { 'params': model.encoder.parameters(), 'lr': enc_lr }, { 'params': model.decoder.parameters(), 'lr': dec_lr }, ]) elif args.model == 'aspp': if optim == "radam": optimizer = RAdam([ { 'params': model.parameters(), 'lr': enc_lr }, ]) if optim == "adam": optimizer = Adam([ { 'params': model.parameters(), 'lr': enc_lr }, ]) if optim == "adamw": optimizer = AdamW([ { 'params': model.parameters(), 'lr': enc_lr }, ]) if optim == "sgd": optimizer = SGD([ { 'params': model.parameters(), 'lr': enc_lr }, ]) scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=5) if schedule == "rlop": scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=3) if schedule == "noam": scheduler = NoamLR(optimizer, 10) if loss == "bcedice": criterion = smp.utils.losses.BCEDiceLoss(eps=1.) if loss == "dice": criterion = smp.utils.losses.DiceLoss(eps=1.) if loss == "bcejaccard": criterion = smp.utils.losses.BCEJaccardLoss(eps=1.) if loss == "jaccard": criterion == smp.utils.losses.JaccardLoss(eps=1.) if loss == 'bce': criterion = NewBCELoss() callbacks = [NewDiceCallback(), CriterionCallback()] callbacks.append(OptimizerCallback(accumulation_steps=accumulate)) if early_stopping: callbacks.append(EarlyStoppingCallback(patience=5, min_delta=0.001)) runner = SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=callbacks, logdir=log, num_epochs=num_epochs, verbose=True, )
def main(args, logger): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') writer = SummaryWriter(logdir=args.subTensorboardDir) trainSet = Data(mode='train') trainLoader = DataLoader(trainSet, batch_size=args.batchSizeTrain, shuffle=True, pin_memory=False, drop_last=False, num_workers=args.numWorkers) testSet = Data(mode='test') testLoader = DataLoader(testSet, batch_size=args.batchSizeTest, shuffle=False, pin_memory=False, num_workers=args.numWorkers) # net = smp.Unet(classes=2).to(device) net = smp.Linknet(classes=1, activation='sigmoid', encoder_name='se_resnext101_32x4d').to(device) # criterion = nn.CrossEntropyLoss().to(device) criterion = smploss.DiceLoss(eps=sys.float_info.min).to(device) # criterion = DiceLoss().to(device) optimizer = optim.SGD(net.parameters(), lr=.1, momentum=.9) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduleStep, gamma=0.1) runningLoss = [] st = stGloble = time.time() totalIter = len(trainLoader) * args.epoch iter = 0 for epoch in range(args.epoch): if epoch != 0 and epoch % args.evalFrequency == 0: pass if epoch != 0 and epoch % args.saveFrequency == 0: modelName = osp.join(args.subModelDir, 'out_{}.pth'.format(epoch)) state_dict = net.modules.state_dict() if hasattr( net, 'module') else net.state_dict() torch.save(state_dict, modelName) for img, mask in trainLoader: iter += 1 img = img.to(device) mask = mask.to(device, dtype=torch.int64).unsqueeze(1) optimizer.zero_grad() outputs = net(img) # print(outputs.shape, mask.shape) # break loss = criterion(outputs, mask) loss.backward() optimizer.step() runningLoss.append(loss.item()) if iter % args.msgFrequency == 0: # writer.add_images('img', img, iter) # writer.add_images('mask', mask.unsqueeze(1), iter) ed = time.time() spend = ed - st spendGloable = ed - stGloble st = ed eta = int((totalIter - iter) * (spendGloable / iter)) spendGloable = str(datetime.timedelta(seconds=spendGloable)) eta = str(datetime.timedelta(seconds=eta)) avgLoss = np.mean(runningLoss) runningLoss = [] lr = optimizer.param_groups[0]['lr'] msg = '. '.join([ 'epoch:{epoch}', 'iter/total_iter:{iter}/{totalIter}', 'lr:{lr:.5f}', 'loss:{loss:.4f}', 'spend/gloable_spend:{spend:.4f}/{gloable_spend}', 'eta:{eta}' ]).format(epoch=epoch, loss=avgLoss, iter=iter, totalIter=totalIter, spend=spend, gloable_spend=spendGloable, lr=lr, eta=eta) logger.info(msg) writer.add_scalar('loss', avgLoss, iter) writer.add_scalar('lr', lr, iter) scheduler.step() outName = osp.join(args.subModelDir, 'final.pth') torch.save(net.cpu().state_dict(), outName)