steps = 0 for batch_idx, (images, masks) in enumerate(train_loader): images = images.to(device) masks = masks.to(device) model.train() opt.zero_grad() preds = model(images) loss = losses.dice_loss(preds, masks) loss.backward() #opt.step() train_losses.append(loss.item()) train_dsc.append(losses.dice_score(preds, masks).item()) else: val_loss = 0 val_acc = 0 model.eval() with torch.no_grad(): for inputs, masks in val_loader: inputs, masks = inputs.to(device), masks.to(device) preds = model.forward(inputs) loss = losses.dice_loss(preds, masks) val_losses.append(loss.item()) val_dsc.append(losses.dice_score(preds,masks).item()) scheduler.step(loss)
test_dice_score = AverageMeter() with torch.no_grad(): for i, data in enumerate(test_loader, 0): inputs, labels = data if torch.cuda.is_available(): inputs = inputs.cuda(non_blocking=True) labels = labels.cuda(non_blocking=True) outputs = net(inputs) loss_test_tmp = loss_fct(outputs, labels) test_loss.append(loss_test_tmp.item()) if torch.cuda.is_available(): res = np.round( outputs[0, 1, :, :, :].cpu().numpy()).astype(int) test_dice_score.append( losses.dice_score(res, labels[0, 0, :, :, :].cpu().numpy())) else: res = np.round(outputs[0, 1, :, :, :].numpy()).astype(int) test_dice_score.append( losses.dice_score(res, labels[0, 0, :, :, :].numpy())) if epoch == params.N_EPOCHS - 1: np.save("./last_epoch_results/test_" + str(i) + ".npy", res) print("epoch " + str(epoch + 1) + ": %.3f, %.3f, %.3f" % (train_loss.avrg, test_loss.avrg, test_dice_score.avrg)) print('Finished Training')
def main(): if args.restart_training == 'true': if use_multiinput_architecture is False: if modeltype == 'unet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, wf=wf, up_mode='upconv', batch_norm=True, residual=False).double().to(device) elif modeltype == 'resunet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, wf=wf, up_mode='upconv', batch_norm=True, residual=True).double().to(device) elif use_multiinput_architecture is True: if modeltype == 'unet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=False, wf=wf, use_attention=use_attention).double().to(device) elif modeltype == 'resunet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=True, wf=wf, use_attention=use_attention).double().to(device) else: if use_multiinput_architecture is False: if modeltype == 'unet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, wf=wf, up_mode='upconv', batch_norm=True, residual=False).double().to(device) elif modeltype == 'resunet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, wf=wf, up_mode='upconv', batch_norm=True, residual=True).double().to(device) # checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) # pretrained_dict = checkpoint['model_state_dict'] # model_dict = model.state_dict() # # 1. filter out unnecessary keys # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ['last.weight', 'last.bias']} # # 2. overwrite entries in the existing state dict # model_dict.update(pretrained_dict) # # 3. load the new state dict # model.load_state_dict(model_dict) elif use_multiinput_architecture is True: if modeltype == 'unet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=False, wf=wf, use_attention=use_attention).double().to(device) elif modeltype == 'resunet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=True, wf=wf, use_attention=use_attention).double().to(device) # checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) # model.load_state_dict(checkpoint['model_state_dict']) checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['model_state_dict']) train_loader = dataloader_cxr.DataLoader(data_path, dataloader_type='train', batchsize=batch_size, device=device, image_resolution=image_resolution) print('trainloader loaded') valid_loader = dataloader_cxr.DataLoader(data_path, dataloader_type='valid', batchsize=batch_size, device=device, image_resolution=image_resolution) print('validloader loaded') loss_list_train_epoch = [None] dice_score_list_train_epoch = [None] epoch_data_list = [None] loss_list_validation = [None] loss_list_validation_index = [None] dice_score_list_validation = [None] dice_score_list_validation_0 = [None] dice_score_list_validation_1 = [None] epoch_old = 0 if load_old_lists == True: if args.restart_training == 'false': epoch_old = checkpoint['epochs'] if checkpoint['train_loss_list_epoch'][-1] == None: dice_score_list_train_epoch = [None] loss_list_train_epoch = [None] epoch_data_list = [None] else: dice_score_list_train_epoch = checkpoint[ 'train_dice_score_list_epoch'] loss_list_train_epoch = checkpoint['train_loss_list_epoch'] epoch_data_list = checkpoint['train_loss_index_epoch'] if checkpoint['valid_loss_list'][-1] == None: loss_list_validation = [None] loss_list_validation_index = [None] dice_score_list_validation = [None] dice_score_list_validation_0 = [None] dice_score_list_validation_1 = [None] else: loss_list_validation = checkpoint['valid_loss_list'] loss_list_validation_index = checkpoint['valid_loss_index'] dice_score_list_validation = checkpoint[ 'valid_dice_score_list'] dice_score_list_validation_0 = checkpoint[ 'valid_dice_score_list_0'] dice_score_list_validation_1 = checkpoint[ 'valid_dice_score_list_1'] best_model_accuracy = np.max(dice_score_list_validation[1:]) if len(train_loader.data_list) % batch_size == 0: total_idx_train = len(train_loader.data_list) // batch_size else: total_idx_train = len(train_loader.data_list) // batch_size + 1 if len(valid_loader.data_list) % batch_size == 0: total_idx_valid = len(valid_loader.data_list) // batch_size else: total_idx_valid = len(valid_loader.data_list) // batch_size + 1 if epoch_old != 0: power_factor = epoch_old // scheduler_step_size LR_ = LR * (scheduler_gamma**power_factor) else: LR_ = LR LR_ = LR optimizer = optim.Adam(model.parameters(), lr=LR_) # optimizer = optim.SGD(model.parameters(), lr=LR_, momentum=0.9) scheduler = StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma) for epoch in range(epoch_old, train_epoch): if (epoch + 1) % 10 == 0: scheduler.step() epoch_loss = 0.0 epoch_dice_score = 0.0 train_count = 0 model.train() for idx in range(total_idx_train): optimizer.zero_grad() batch_images_input, batch_label_input = train_loader[idx] output = model(batch_images_input) if use_multiinput_architecture is False: loss = losses.dice_loss( output, batch_label_input, weights=torch.Tensor([gamma0, gamma1]).double().to(device)) elif use_multiinput_architecture is True: loss = losses.dice_loss_deep_supervised( output, batch_label_input, weights=torch.Tensor([gamma0, gamma1]).double().to(device)) loss.backward() optimizer.step() if use_multiinput_architecture is False: score = losses.dice_score(output, batch_label_input) else: score = losses.dice_score(output[-1], batch_label_input) epoch_dice_score += (score.sum().item() / score.size(0)) * batch_images_input.shape[0] epoch_loss += loss.item() * batch_images_input.shape[0] train_count += batch_images_input.shape[0] loss_list_train_epoch.append(epoch_loss / train_count) epoch_data_list.append(epoch + 1) dice_score_list_train_epoch.append(epoch_dice_score / train_count) print( 'Epoch %d Training Loss: %.3f Dice Score: %.3f' % (epoch + 1, loss_list_train_epoch[-1], dice_score_list_train_epoch[-1]), ' Time:', datetime.datetime.now()) plt.plot(epoch_data_list[1:], loss_list_train_epoch[1:], label="Training", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Training Loss') plt.savefig(plots_dir + '/train_loss_plot.png') plt.clf() plt.plot(epoch_data_list[1:], dice_score_list_train_epoch[1:], label="Training", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Training Dice Score') plt.savefig(plots_dir + '/train_dice_score_plot.png') plt.clf() training_pickle = open(plots_pickle_dir + "/loss_list_train.npy", 'wb') pickle.dump(loss_list_train_epoch, training_pickle) training_pickle.close() training_pickle = open(plots_pickle_dir + "/epoch_list_train.npy", 'wb') pickle.dump(epoch_data_list, training_pickle) training_pickle.close() training_pickle = open( plots_pickle_dir + "/dice_score_list_train_epoch.npy", 'wb') pickle.dump(dice_score_list_train_epoch, training_pickle) training_pickle.close() if (epoch + 1) % save_every == 0: print('Saving model at %d epoch' % (epoch + 1), ' Time:', datetime.datetime.now() ) # save every save_every mini_batch of data torch.save( { 'epochs': epoch + 1, 'batchsize': batch_size, 'train_loss_list_epoch': loss_list_train_epoch, 'train_dice_score_list_epoch': dice_score_list_train_epoch, 'train_loss_index_epoch': epoch_data_list, 'valid_loss_list': loss_list_validation, 'valid_dice_score_list': dice_score_list_validation, 'valid_dice_score_list_0': dice_score_list_validation_0, 'valid_dice_score_list_1': dice_score_list_validation_1, 'valid_loss_index': loss_list_validation_index, 'model_state_dict': model.state_dict(), }, model_checkpoint_dir + '/model_%d.pth' % (epoch + 1)) if (epoch + 1) % valid_every == 0: model.eval() optimizer.zero_grad() valid_count = 0 total_loss_valid = 0.0 valid_dice_score = 0.0 valid_dice_score_0 = 0.0 valid_dice_score_1 = 0.0 for idx in range(total_idx_valid): with torch.no_grad(): batch_images_input, batch_label_input = valid_loader[idx] output = model(batch_images_input) if use_multiinput_architecture is False: loss = losses.dice_loss(output, batch_label_input) else: loss = losses.dice_loss(output[-1], batch_label_input) total_loss_valid += loss.item( ) * batch_images_input.shape[0] valid_count += batch_images_input.shape[0] if use_multiinput_architecture is False: score = losses.dice_score(output, batch_label_input) else: score = losses.dice_score(output[-1], batch_label_input) valid_dice_score += (score.sum().item() / score.size(0) ) * batch_images_input.shape[0] valid_dice_score_0 += score[0].item( ) * batch_images_input.shape[0] valid_dice_score_1 += score[1].item( ) * batch_images_input.shape[0] loss_list_validation.append(total_loss_valid / valid_count) dice_score_list_validation.append(valid_dice_score / valid_count) dice_score_list_validation_0.append(valid_dice_score_0 / valid_count) dice_score_list_validation_1.append(valid_dice_score_1 / valid_count) loss_list_validation_index.append(epoch + 1) print( 'Epoch %d Valid Loss: %.3f' % (epoch + 1, loss_list_validation[-1]), ' Time:', datetime.datetime.now()) print('Valid Dice Score: ', dice_score_list_validation[-1], ' Valid Dice Score 0: ', dice_score_list_validation_0[-1], ' Valid Dice Score 1: ', dice_score_list_validation_1[-1]) plt.plot(loss_list_validation_index[1:], loss_list_validation[1:], label="Validation", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Validation Loss') plt.savefig(plots_dir + '/valid_loss_plot.png') plt.clf() plt.plot(loss_list_validation_index[1:], dice_score_list_validation[1:], label="Validation", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Validation Dice Score') plt.savefig(plots_dir + '/valid_dice_score_plot.png') plt.clf() plt.plot(loss_list_validation_index[1:], dice_score_list_validation_0[1:], label="Validation", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Validation Dice Score') plt.savefig(plots_dir + '/valid_dice_score_0_plot.png') plt.clf() plt.plot(loss_list_validation_index[1:], dice_score_list_validation_1[1:], label="Validation", color='red', marker='o', markerfacecolor='yellow', markersize=5) plt.xlabel('Epoch') plt.ylabel('Validation Dice Score') plt.savefig(plots_dir + '/valid_dice_score_1_plot.png') plt.clf() validation_pickle = open( plots_pickle_dir + "/loss_list_validation.npy", 'wb') pickle.dump(loss_list_validation, validation_pickle) validation_pickle.close() validation_pickle = open( plots_pickle_dir + "/index_list_validation.npy", 'wb') pickle.dump(loss_list_validation_index, validation_pickle) validation_pickle.close() validation_pickle = open( plots_pickle_dir + "/dice_score_list_validation.npy", 'wb') pickle.dump(dice_score_list_validation, validation_pickle) validation_pickle.close() if len(loss_list_validation) >= 3: if dice_score_list_validation[-1] > best_model_accuracy: best_model_accuracy = dice_score_list_validation[-1] torch.save( { 'epochs': epoch + 1, 'batchsize': batch_size, 'train_loss_list_epoch': loss_list_train_epoch, 'train_dice_score_list_epoch': dice_score_list_train_epoch, 'train_loss_index_epoch': epoch_data_list, 'valid_loss_list': loss_list_validation, 'valid_dice_score_list': dice_score_list_validation, 'valid_dice_score_list_0': dice_score_list_validation_0, 'valid_dice_score_list_1': dice_score_list_validation_1, 'valid_loss_index': loss_list_validation_index, 'model_state_dict': model.state_dict(), }, model_checkpoint_dir + '/model_best.pth') else: best_model_accuracy = dice_score_list_validation[-1] torch.save( { 'epochs': epoch + 1, 'batchsize': batch_size, 'train_loss_list_epoch': loss_list_train_epoch, 'train_dice_score_list_epoch': dice_score_list_train_epoch, 'train_loss_index_epoch': epoch_data_list, 'valid_loss_list': loss_list_validation, 'valid_dice_score_list': dice_score_list_validation, 'valid_dice_score_list_0': dice_score_list_validation_0, 'valid_dice_score_list_1': dice_score_list_validation_1, 'valid_loss_index': loss_list_validation_index, 'model_state_dict': model.state_dict(), }, model_checkpoint_dir + '/model_best.pth')
test_dataset = dataset_generator.getTestDataset(0, 1) test_size = test_dataset.__len__() test_results = np.zeros(tuple([test_size] + basic_image_size), dtype=int) for i in range(test_size): test_results[i, :, :, :] = np.load( os.path.join(res_dir, "test_" + str(i) + ".npy")) images = np.zeros(tuple([test_size] + basic_image_size), dtype=int) GT_labels = np.zeros(tuple([test_size] + basic_image_size), dtype=int) for i in range(test_size): images[i, :, :, :] = test_dataset.__getitem__(i)[0].numpy()[0] GT_labels[i, :, :, :] = test_dataset.__getitem__(i)[1].numpy()[0] dice_scores = np.array([ losses.dice_score(test_results[i], GT_labels[i]) for i in range(test_size) ]) img_id = 9 slice_id_list = [7, 10, 13] print("AVERAGE DICE:", np.mean(dice_scores)) print() print("Original image index: ", test_dataset.indices[img_id]) print("dice = ", dice_scores[img_id]) comb = [] for slice_id in slice_id_list: margin_size = 5 color_scale_GT = [0.3, 0.0, 0.0] image_GT = np.zeros((basic_image_size[1], basic_image_size[2], 3))
def main(): if use_multiinput_architecture is False: if modeltype == 'unet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, up_mode='upconv', batch_norm=True, residual=False).double().to(device) elif modeltype == 'resunet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, up_mode='upconv', batch_norm=True, residual=True).double().to(device) elif use_multiinput_architecture is True: if modeltype == 'unet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=False, wf=wf, use_attention=use_attention).double().to(device) elif modeltype == 'resunet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=True, wf=wf, use_attention=use_attention).double().to(device) checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['model_state_dict']) valid_loader = dataloader_cxr.DataLoader(data_path, dataloader_type=dataloader_type, batchsize=batch_size, device=device, image_resolution=image_resolution) if len(valid_loader.data_list) % batch_size == 0: total_idx_valid = len(valid_loader.data_list) // batch_size else: total_idx_valid = len(valid_loader.data_list) // batch_size + 1 model.eval() prediction_array = np.zeros((len(valid_loader.data_list), image_resolution[0], image_resolution[1])) if valid_loader.dataloader_type != "test": input_mask_array = np.zeros((len(valid_loader.data_list), image_resolution[0], image_resolution[1])) valid_count = 0 valid_dice_score = 0.0 valid_dice_score_0 = 0.0 valid_dice_score_1 = 0.0 for idx in range(total_idx_valid): with torch.no_grad(): if valid_loader.dataloader_type != "test": batch_images_input, batch_label_input = valid_loader[idx] else: batch_images_input = valid_loader[idx] output = model(batch_images_input) if use_multiinput_architecture is False: if len(valid_loader.data_list) % batch_size == 0: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if idx == len(valid_loader.data_list) // batch_size: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[ idx * batch_size:] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if len(valid_loader.data_list) % batch_size == 0: temp_image = torch.max(output[-1], 1)[1].detach().cpu().numpy().astype( np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if idx == len(valid_loader.data_list) // batch_size: temp_image = torch.max( output[-1], 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[ idx * batch_size:] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: temp_image = torch.max( output[-1], 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) if valid_loader.dataloader_type != "test": if use_multiinput_architecture is False: loss = losses.dice_loss(output, batch_label_input) else: loss = losses.dice_loss(output[-1], batch_label_input) valid_count += batch_images_input.shape[0] if use_multiinput_architecture is False: score = losses.dice_score(output, batch_label_input) else: score = losses.dice_score(output[-1], batch_label_input) valid_dice_score += (score.sum().item() / score.size(0) ) * batch_images_input.shape[0] valid_dice_score_0 += score[0].item( ) * batch_images_input.shape[0] valid_dice_score_1 += score[1].item( ) * batch_images_input.shape[0] if valid_loader.dataloader_type != "test": valid_dice_score = valid_dice_score / valid_count valid_dice_score_0 = valid_dice_score_0 / valid_count valid_dice_score_1 = valid_dice_score_1 / valid_count if generate_mask is True: for i, files in enumerate(valid_loader.data_list): temp_mask = prediction_array[i].astype(int) temp_mask = ndimage.zoom( temp_mask, np.asarray(valid_loader.original_size_array[files]) / np.asarray(temp_mask.shape), order=0) io.imsave(save_path + '/Pred_mask_' + files.split('.')[0] + '.png', temp_mask) if valid_loader.dataloader_type != "test": print('Valid Dice Score: ', valid_dice_score, ' Valid Dice Score 0: ', valid_dice_score_0, ' Valid Dice Score 1: ', valid_dice_score_1)
def main(): if use_multiinput_architecture is False: if modeltype == 'unet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, up_mode='upsample', batch_norm=True, residual=False).double().to(device) elif modeltype == 'resunet': model = UNet(n_classes=n_classes, padding=True, depth=model_depth, up_mode='upsample', batch_norm=True, residual=True).double().to(device) elif use_multiinput_architecture is True: if modeltype == 'unet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=False, wf=wf, use_attention=use_attention).double().to(device) elif modeltype == 'resunet': model = Attention_UNet( n_classes=n_classes, padding=True, up_mode='upconv', batch_norm=True, residual=True, wf=wf, use_attention=use_attention).double().to(device) checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['model_state_dict']) valid_loader = dataloader_cxr.DataLoader(data_path, dataloader_type=dataloader_type, batchsize=batch_size, device=device, image_resolution=image_resolution, invert=invert, remove_wires=remove_wires) if len(valid_loader.data_list) % batch_size == 0: total_idx_valid = len(valid_loader.data_list) // batch_size else: total_idx_valid = len(valid_loader.data_list) // batch_size + 1 model.eval() prediction_array = np.zeros((len(valid_loader.data_list), image_resolution[0], image_resolution[1])) if valid_loader.dataloader_type != "test": input_mask_array = np.zeros((len(valid_loader.data_list), image_resolution[0], image_resolution[1])) valid_count = 0 valid_dice_score = 0.0 if 0 in classes: valid_dice_score_0 = 0.0 if 1 in classes: valid_dice_score_1 = 0.0 for idx in range(total_idx_valid): with torch.no_grad(): if valid_loader.dataloader_type != "test": batch_images_input, batch_label_input = valid_loader[idx] else: batch_images_input = valid_loader[idx] output = model(batch_images_input) if use_multiinput_architecture is False: if len(valid_loader.data_list) % batch_size == 0: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if idx == len(valid_loader.data_list) // batch_size: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[ idx * batch_size:] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: temp_image = torch.max( output, 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if len(valid_loader.data_list) % batch_size == 0: temp_image = torch.max(output[-1], 1)[1].detach().cpu().numpy().astype( np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: if idx == len(valid_loader.data_list) // batch_size: temp_image = torch.max( output[-1], 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[ idx * batch_size:] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) else: temp_image = torch.max( output[-1], 1)[1].detach().cpu().numpy().astype(np.bool) prediction_array[idx * batch_size:(idx + 1) * batch_size] = remove_small_regions( temp_image, 0.02 * np.prod(image_resolution)) if valid_loader.dataloader_type != "test": input_mask_array[idx * batch_size:( idx + 1) * batch_size] = batch_label_input.detach( ).cpu().numpy().astype(np.uint8) if valid_loader.dataloader_type != "test": if use_multiinput_architecture is False: loss = losses.dice_loss(output, batch_label_input, exclude_0) else: loss = losses.dice_loss(output[-1], batch_label_input, exclude_0) valid_count += batch_images_input.shape[0] if use_multiinput_architecture is False: score = losses.dice_score(output, batch_label_input, exclude_0) else: score = losses.dice_score(output[-1], batch_label_input, exclude_0) valid_dice_score += (score.sum().item() / score.size(0) ) * batch_images_input.shape[0] if 0 in classes and 1 in classes and len(classes) == 2 and len( clubbed) == 0: valid_dice_score_0 += score[0].item( ) * batch_images_input.shape[0] valid_dice_score_1 += score[1].item( ) * batch_images_input.shape[0] if valid_loader.dataloader_type != "test": valid_dice_score = valid_dice_score / valid_count if 0 in classes: valid_dice_score_0 = valid_dice_score_0 / valid_count if 1 in classes: valid_dice_score_1 = valid_dice_score_1 / valid_count if generate_mask is True: for i, files in enumerate(valid_loader.data_list): temp_mask = prediction_array[i].astype(int) temp_mask = ndimage.zoom( temp_mask, np.asarray(valid_loader.original_size_array[files]) / np.asarray(temp_mask.shape), order=0) temp_mask = sitk.GetImageFromArray(temp_mask) temp_mask = sitk.Cast(temp_mask, sitk.sitkUInt8) resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(temp_mask) resampler.SetOutputSpacing(valid_loader.spacing[files]) resampler.SetSize(valid_loader.size_[files]) resampler.SetInterpolator(sitk.sitkNearestNeighbor) temp_mask = resampler.Execute(temp_mask) temp_mask.SetOrigin(valid_loader.origin[files]) #temp_name = '' #for j in range(len(files.split('.'))-1): #if files.split('.')[j] != 'nii': #temp_name = temp_name + files.split('.')[j] + '.' temp_name = files[:-4] sitk.WriteImage(temp_mask, save_path + '/Pred_mask_' + temp_name + '.nii.gz') # sitk.WriteImage(temp_mask, save_path + '/Pred_mask_' + files.split('.')[0] + '.nii.gz') # io.imsave(save_path + '/Pred_mask_' + files.split('.')[0] + '.png', temp_mask) # if valid_loader.dataloader_type != "test": # temp_img_plus_mask = prediction_array[i].astype(int) + input_mask_array[i]*2 # temp_img_plus_mask = ndimage.zoom(temp_img_plus_mask, np.asarray(valid_loader.original_size_array[files]) / np.asarray(temp_img_plus_mask.shape), order=0) # temp_img_plus_mask = sitk.GetImageFromArray(temp_img_plus_mask) # resampler = sitk.ResampleImageFilter() # resampler.SetReferenceImage(temp_img_plus_mask) # resampler.SetOutputSpacing(valid_loader.spacing[files]) # resampler.SetSize(valid_loader.size_[files]) # resampler.SetInterpolator(sitk.sitkNearestNeighbor) # temp_img_plus_mask = resampler.Execute(temp_img_plus_mask) # temp_img_plus_mask.SetOrigin(valid_loader.origin[files]) # sitk.WriteImage(temp_img_plus_mask, save_path + '/Pred_merged_mask_' + files.split('.')[0] + '.nii.gz') if valid_loader.dataloader_type != "test": if 0 in classes and 1 in classes and len(clubbed) == 0: print('Valid Dice Score: ', valid_dice_score, ' Valid Dice Score 0: ', valid_dice_score_0, ' Valid Dice Score 1: ', valid_dice_score_1)