def _eval(): args.val_save_pth = 'data/val/out_val' args.batch_size = 12 ' model setup ' def activation(x): x model = eval('smp.' + args.model_name)( args.arch_encoder, encoder_weights='imagenet', classes=args.num_classes, activation=activation, ) optimizer = optimizers.optimfn(args.optim, model) # unused model, optimizer, start_epoch = networktools.continue_train( model, optimizer, args.eval_model_pth, True) ' datasets ' validation_params = { 'ph': args.tile_h, # patch height (y) 'pw': args.tile_w, # patch width (x) 'sh': args.tile_stride_h, # slide step (dy) 'sw': args.tile_stride_w, # slide step (dx) } iterator_test = ds.Dataset_wsis(args.raw_val1_pth, validation_params) model = model.cuda() 'zoom in stage' model.eval() with torch.no_grad(): val.predict_wsi(model, iterator_test, 0) 'save preds'
def _eval(): args.val_save_pth = '/home/ozan/remoteDir/Tumor Bed Detection Results/Ynet_segmentation_ozan' args.raw_val_pth = '/home/ozan/remoteDir/' ' model setup ' def activation(x): x model = eval('smp.'+args.model_name)( args.arch_encoder, encoder_weights='imagenet', classes=args.num_classes, activation=activation, ) model.classifier = Classifier(model.encoder.out_shapes[0], args.num_classes) model.regressor = Regressor(model.encoder.out_shapes[0], 1) model, _, _ = networktools.continue_train( model, optimizers.optimfn(args.optim, model), args.eval_model_pth, True ) ' datasets ' validation_params = { 'ph': args.tile_h * args.scan_resize, # patch height (y) 'pw': args.tile_w * args.scan_resize, # patch width (x) 'sh': args.tile_stride_h, # slide step (dy) 'sw': args.tile_stride_w, # slide step (dx) } iterator_test = ds.Dataset_wsis(args.raw_val_pth, validation_params) model = model.cuda() val.predict_tumorbed(model, iterator_test, 0)
def train(): ' model setup ' def activation(x): x model = eval('smp.' + args.model_name)( args.arch_encoder, encoder_weights='imagenet', classes=args.num_classes, activation=activation, ) optimizer = optimizers.optimfn(args.optim, model) model, optimizer, start_epoch = networktools.continue_train( model, optimizer, args.train_model_pth, args.continue_train) ' losses ' args.cls_ratios = preprocessing.cls_ratios_ssr( 'data/same_sized_regions/train') cls_weights = 1.0 / args.cls_ratios cls_weights /= cls_weights.max() params = { 'reduction': 'mean', 'alpha': torch.Tensor(cls_weights), 'gamma': 2, 'scale_factor': 1 / 8, 'ratio': 0.25, 'ignore_index': 0, } lossfn = losses.lossfn(args.loss, params=params).cuda() lossfn_dice = losses.lossfn('dice', params=params).cuda() ' datasets ' iterator_train = ds.GenerateIterator('data/same_sized_regions/train') iterator_val = ds.GenerateIterator('data/same_sized_regions/val', eval=True) validation_params = { 'ph': args.tile_h * args.scan_resize, # patch height (y) 'pw': args.tile_w * args.scan_resize, # patch width (x) 'sh': args.tile_stride_h, # slide step (dy) 'sw': args.tile_stride_w, # slide step (dx) } #iterator_val = ds_tr.Dataset_wsis(args.raw_val_pth, validation_params) cuda = torch.cuda.is_available() if cuda: model = model.cuda() lossfn = lossfn.cuda() rev_norm = preprocessing.NormalizeInverse(args.dataset_mean, args.dataset_std) from torchvision.utils import make_grid def show(img, ep, batch_it): import matplotlib.pyplot as plt from PIL import Image import numpy as np npimg = img.numpy() npimg = np.transpose(npimg, (1, 2, 0)) npimg = Image.fromarray((255 * npimg).astype(np.uint8)) os.makedirs('data/res/{}/'.format(ep), exist_ok=True) npimg.save('data/res/{}/{}.png'.format(ep, batch_it)) ' current run train parameters ' print(args) for epoch in range(start_epoch, 1 + args.num_epoch): sum_loss_cls = 0 progress_bar = tqdm(iterator_train, disable=False) for batch_it, (image, label) in enumerate(progress_bar): if cuda: image = image.cuda() label = label.cuda() # pass images through the network (cls) pred_src = model(image) loss_cls = lossfn(pred_src, label) #+lossfn_dice(pred_src, label) sum_loss_cls = sum_loss_cls + loss_cls.item() optimizer.zero_grad() loss_cls.backward() optimizer.step() progress_bar.set_description('ep. {}, cls loss: {:.3f}'.format( epoch, sum_loss_cls / (batch_it + args.epsilon))) ' test model accuracy ' if 0 and epoch >= 1: #args.validate_model > 0 and epoch % args.validate_model == 0: val.predict_wsis(model, iterator_val, epoch) if epoch >= 1: model.eval() with torch.no_grad(): total_acc = 0 binary_acc = 0 for batch_it, (image, label) in enumerate(iterator_val): image = image.cuda() label = label.cuda() pred_src = model(image) pred_ = torch.argmax(pred_src, 1) m = torch.cat((label, pred_), dim=-2) m = torch.eye(args.num_classes)[m][..., 1:].permute(0, 3, 1, 2) for ij in range(image.size(0)): image[ij, ...] = rev_norm(image[ij, ...]) m = torch.cat((image, m.cuda()), dim=-2) show(make_grid(m.cpu()), epoch, batch_it) total_acc += torch.mean((pred_ == label).float()) binary_acc += torch.mean( ((pred_ > 0) == (label > 0)).float()) print('Acc {:.2f}, binary acc {:.2f}'.format( total_acc / batch_it, binary_acc / batch_it)) model.train() if args.save_models > 0 and epoch % args.save_models == 0: state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': args } torch.save( state, '{}/model_{}_{}.pt'.format(args.model_save_pth, args.arch_encoder, epoch))
metadata[tile_id] = { 'cnt_xy': center_pts, 'perim_xy': perim_coords, 'wsipath': svspth, 'scan_level': scan_level, 'foreground_indices': foreground_indices, 'tile_id': tile_id, } ''' evaluation stage ''' # load model model = resnets_shift.resnet18(True).cuda() optimizer = optimizers.optimfn(args.optim, model) # unused model, _, _ = networktools.continue_train(model, optimizer, args.eval_model_pth, True) model.eval() # generate dataset from points iterator_val = GenerateIterator_eval(metadata) # pass through dataset pred_mask = np.zeros_like(labels) with torch.no_grad(): for batch_it, (images, tile_ids) in enumerate(iterator_val): images = images.cuda() pred_ensemble = model(images) pred_ensemble = torch.argmax(pred_ensemble, 1).cpu().numpy() for tj, tile_id in enumerate(tile_ids.numpy()): pred_mask[metadata[tile_id] ['foreground_indices']] = pred_ensemble[tj]
def train(): args.val_save_pth = 'data/val/out2' ' model setup ' def activation(x): x model = eval('smp.' + args.model_name)( args.arch_encoder, encoder_weights='imagenet', classes=args.num_classes, activation=activation, ) model.classifier = Classifier(model.encoder.out_shapes[0], args.num_classes) optimizer = optimizers.optimfn(args.optim, model) model, optimizer, start_epoch = networktools.continue_train( model, optimizer, args.train_model_pth, args.continue_train) ' losses ' cls_weights_cls, cls_weights_seg = preprocessing.cls_weights( args.train_image_pth) params = { 'reduction': 'mean', 'alpha': torch.Tensor(cls_weights_cls), 'xent_ignore': -1, } lossfn_cls = losses.lossfn(args.loss, params).cuda() params = { 'reduction': 'mean', 'alpha': torch.Tensor(cls_weights_seg), 'xent_ignore': -1, } lossfn_seg = losses.lossfn(args.loss, params).cuda() ' datasets ' validation_params = { 'ph': args.tile_h * args.scan_resize, # patch height (y) 'pw': args.tile_w * args.scan_resize, # patch width (x) 'sh': args.tile_stride_h, # slide step (dy) 'sw': args.tile_stride_w, # slide step (dx) } iterator_train = ds.GenerateIterator(args.train_image_pth, duplicate_dataset=1) iterator_val = ds.Dataset_wsis(args.raw_val_pth, validation_params) model = model.cuda() ' current run train parameters ' print(args) for epoch in range(start_epoch, 1 + args.num_epoch): sum_loss = 0 progress_bar = tqdm(iterator_train, disable=False) for batch_it, (image, label, is_cls, cls_code) in enumerate(progress_bar): image = image.cuda() label = label.cuda() is_cls = is_cls.type(torch.bool).cuda() cls_code = cls_code.cuda() # pass images through the network (cls) encoding = model.encoder(image) loss = 0 if torch.nonzero(is_cls).size(0) > 0: pred_cls = model.classifier(encoding[0][is_cls, ...]) loss = loss + lossfn_cls(pred_cls, cls_code[is_cls]) if torch.nonzero(~is_cls).size(0) > 0: pred_seg = model.decoder([x[~is_cls, ...] for x in encoding]) loss = loss + lossfn_seg(pred_seg, label[~is_cls]) sum_loss = sum_loss + loss.item() optimizer.zero_grad() loss.backward() #with amp_handle.scale_loss(loss, optimizer) as scaled_loss: # scaled_loss.backward() optimizer.step() progress_bar.set_description('ep. {}, cls loss: {:.3f}'.format( epoch, sum_loss / (batch_it + args.epsilon))) ' test model accuracy ' if args.validate_model > 0 and epoch % args.validate_model == 0: val.predict_wsis(model, iterator_val, epoch) if args.save_models > 0 and epoch % args.save_models == 0: state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': args } torch.save( state, '{}/model_{}_{}.pt'.format(args.model_save_pth, args.arch_encoder, epoch))
def train(): ' model setup ' model = pretrainedmodels.__dict__[args.arch_encoder](num_classes=1000, pretrained='imagenet') model.last_linear = torch.nn.Linear(model.last_linear.in_features, args.num_classes) optimizer = optimizers.optimfn(args.optim, model) model, optimizer, start_epoch = networktools.continue_train( model, optimizer, args.train_model_pth, args.continue_train) ' losses ' args.cls_ratios = preprocessing.cls_ratios_ssr('data/ssr/train', option='classification') cls_weights = 1.0 / args.cls_ratios cls_weights /= cls_weights.max() params = { 'reduction': 'mean', 'alpha': torch.Tensor(cls_weights), 'gamma': 2, 'scale_factor': 1 / 8, 'ratio': 0.25, 'ignore_index': 0, } lossfn = losses.lossfn(args.loss, params=params).cuda() ' datasets ' iterator_train = ds.GenerateIterator_cls('data/ssr/train') iterator_val = ds.GenerateIterator_cls('data/ssr/val', eval=True) cuda = torch.cuda.is_available() if cuda: model = model.cuda() lossfn = lossfn.cuda() ' current run train parameters ' print(args) for epoch in range(start_epoch, 1 + args.num_epoch): sum_loss_cls = 0 progress_bar = tqdm(iterator_train, disable=False) for batch_it, (image, label) in enumerate(progress_bar): if cuda: image = image.cuda() label = label.cuda() # pass images through the network (cls) pred_src = model(image) loss_cls = lossfn(pred_src, label) sum_loss_cls = sum_loss_cls + loss_cls.item() optimizer.zero_grad() loss_cls.backward() optimizer.step() progress_bar.set_description('ep. {}, cls loss: {:.3f}'.format( epoch, sum_loss_cls / (batch_it + args.epsilon))) ' test model accuracy ' if epoch >= 1: model.eval() with torch.no_grad(): preds, gts = [], [] for batch_it, (image, label) in enumerate(iterator_val): image = image.cuda() pred_src = model(image) pred_ = torch.argmax(pred_src, 1) preds.extend(pred_.cpu().numpy()) gts.extend(label.numpy()) preds = np.asarray(preds) gts = np.asarray(gts) total_acc = np.mean(gts == preds) cfs = confusion_matrix(gts, preds) cls_acc = np.diag(cfs / cfs.sum(1)) cls_acc = ['{:.2f}'.format(el) for el in cls_acc] print('Ep. {},' ' Acc {:.2f},' 'Classwise acc. {}'.format(epoch, total_acc, cls_acc)) model.train() if args.save_models > 0 and epoch % args.save_models == 0: state = { 'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': args } torch.save( state, '{}/model_{}_{}.pt'.format(args.model_save_pth, args.arch_encoder, epoch))