def train(args): set_random_seed(42) model = get_model(args.network) print('Loading model') model.encoder.conv1 = nn.Conv2d( count_channels(args.channels), 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) model, device = UtilsFactory.prepare_model(model) train_df = pd.read_csv(args.train_df).to_dict('records') val_df = pd.read_csv(args.val_df).to_dict('records') ds = Dataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers) loaders = ds.create_loaders(train_df, val_df) print(loaders['train'].dataset.data) criterion = BCE_Dice_Loss(bce_weight=0.2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[10, 20, 40], gamma=0.3 ) save_path = os.path.join( args.logdir, '_'.join([args.network, *args.channels]) ) # model runner runner = SupervisedRunner() # model training runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=[ DiceCallback() ], logdir=save_path, num_epochs=args.epochs, verbose=True ) infer_loader = collections.OrderedDict([('infer', loaders['valid'])]) runner.infer( model=model, loaders=infer_loader, callbacks=[ CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'), InferCallback() ], )
def train(args): set_random_seed(42) for fold in range(args.folds): model = get_model(args.network) print("Loading model") model, device = UtilsFactory.prepare_model(model) train_df = pd.read_csv( os.path.join(args.dataset_path, f'train{fold}.csv')).to_dict('records') val_df = pd.read_csv(os.path.join(args.dataset_path, f'val{fold}.csv')).to_dict('records') ds = Dataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers) loaders = ds.create_loaders(train_df, val_df) criterion = BCE_Dice_Loss(bce_weight=0.2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[10, 20, 40], gamma=0.3) # model runner runner = SupervisedRunner() save_path = os.path.join(args.logdir, f'fold{fold}') # model training runner.train(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=[DiceCallback()], logdir=save_path, num_epochs=args.epochs, verbose=True) infer_loader = collections.OrderedDict([("infer", loaders["valid"])]) runner.infer( model=model, loaders=infer_loader, callbacks=[ CheckpointCallback(resume=f'{save_path}/checkpoints/best.pth'), InferCallback() ], ) print(f'Fold {fold} ended')
def train(args): set_random_seed(42) model = get_model('fpn50_season') print("Loading model") model, device = UtilsFactory.prepare_model(model) train_df = pd.read_csv(args.train_df).to_dict('records') val_df = pd.read_csv(args.val_df).to_dict('records') ds = SeasonDataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers) loaders = ds.create_loaders(train_df, val_df) criterion = BCE_Dice_Loss(bce_weight=0.2) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3) best_valid_dice = -1 best_epoch = -1 best_accuracy = -1 for epoch in range(args.epochs): segmentation_weight = 0.8 train_iter(loaders['train'], model, device, criterion, optimizer, segmentation_weight) dice_mean, valid_accuracy = valid_iter(loaders['valid'], model, device, criterion, segmentation_weight) if dice_mean > best_valid_dice: best_valid_dice = dice_mean best_epoch = epoch best_accuracy = valid_accuracy os.makedirs(f'{args.logdir}/weights', exist_ok=True) torch.save(model.state_dict(), f'{args.logdir}/weights/epoch{epoch}.pth') scheduler.step() print("Epoch {0} ended".format(epoch)) print("Best epoch: ", best_epoch, "with dice ", best_valid_dice, "and season prediction accuracy", best_accuracy)
def train(args): set_random_seed(42) model = get_model('fpn50_multiclass') print("Loading model") model, device = UtilsFactory.prepare_model(model) train_df = pd.read_csv(args.train_df).to_dict('records') val_df = pd.read_csv(args.val_df).to_dict('records') ds = MulticlassDataset(args.channels, args.dataset_path, args.image_size, args.batch_size, args.num_workers) loaders = ds.create_loaders(train_df, val_df) criterion = MultiClass_Dice_Loss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 40], gamma=0.3) # model runner runner = SupervisedRunner() # model training runner.train(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, callbacks=[MultiClassDiceCallback()], logdir=args.logdir, num_epochs=args.epochs, verbose=True) infer_loader = collections.OrderedDict([("infer", loaders["valid"])]) runner.infer( model=model, loaders=infer_loader, callbacks=[ CheckpointCallback(resume=f"{args.logdir}/checkpoints/best.pth"), InferCallback() ], )
def predict(data_path, model_weights_path, network, test_df_path, save_path, size, channels): model = get_model(network) model.encoder.conv1 = nn.Conv2d(count_channels(args.channels), 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) checkpoint = torch.load(model_weights_path, map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) test_df = pd.read_csv(test_df_path) predictions_path = os.path.join(save_path, "predictions") if not os.path.exists(predictions_path): os.makedirs(predictions_path, exist_ok=True) print("Prediction directory created.") for _, image_info in tqdm(test_df.iterrows()): filename = '_'.join([image_info['name'], image_info['position']]) image_path = get_filepath(data_path, image_info['dataset_folder'], 'images', filename, file_type='tiff') image_tensor = filter_by_channels(read_tensor(image_path), channels) if image_tensor.ndim == 2: image_tensor = np.expand_dims(image_tensor, -1) image = transforms.ToTensor()(image_tensor) prediction = model.predict( image.view(1, count_channels(channels), size, size)) result = prediction.view(size, size).detach().numpy() cv.imwrite(get_filepath(predictions_path, filename, file_type='png'), result * 255)
def load_model(network, model_weights_path): model = get_model(network) model, device = UtilsFactory.prepare_model(model) checkpoint = torch.load(model_weights_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) return model