def validate(state_dict_path, use_gpu, device): model = UNet(n_channels=1, n_classes=2) model.load_state_dict(torch.load(state_dict_path, map_location='cpu' if not use_gpu else device)) model.to(device) val_transforms = transforms.Compose([ ToTensor(), NormalizeBRATS()]) BraTS_val_ds = BRATS2018('./BRATS2018',\ data_set='val',\ seg_type='et',\ scan_type='t1ce',\ transform=val_transforms) data_loader = DataLoader(BraTS_val_ds, batch_size=2, shuffle=False, num_workers=0) running_dice_score = 0. for batch_ind, batch in enumerate(data_loader): imgs, targets = batch imgs = imgs.to(device) targets = targets.to(device) model.eval() with torch.no_grad(): outputs = model(imgs) preds = torch.argmax(F.softmax(outputs, dim=1), dim=1) running_dice_score += dice_score(preds, targets) * targets.size(0) print('running dice score: {:.6f}'.format(running_dice_score)) dice = running_dice_score / len(BraTS_val_ds) print('mean dice score of the validating set: {:.6f}'.format(dice))
def train(): args = setup_run_arguments() # args = parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[INFO] Initializing UNet-model using: {device}") net = UNet(n_channels=args.n_channels, n_classes=args.n_classes, bilinear=True) if args.from_pretrained: net.load_state_dict(torch.load(args.from_pretrained, map_location=device)) net.to(device=device) training_loop.run(network=net, epochs=args.epochs, batch_size=args.batch_size, lr=args.learning_rate, device=device, n_classes=args.n_classes, val_percent=args.val_percent, image_dir=args.image_dir, mask_dir=args.mask_dir, checkpoint_path=args.checkpoint_path, loss=args.loss, num_workers=args.num_workers )
class EventGANBase(object): def __init__(self, options): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.generator = UNet(num_input_channels=2*options.n_image_channels, num_output_channels=options.n_time_bins * 2, skip_type='concat', activation='relu', num_encoders=4, base_num_channels=32, num_residual_blocks=2, norm='BN', use_upsample_conv=True, with_activation=True, sn=options.sn, multi=False) latest_checkpoint = get_latest_checkpoint(options.checkpoint_dir) checkpoint = torch.load(latest_checkpoint) self.generator.load_state_dict(checkpoint["gen"]) self.generator.to(self.device) def forward(self, images, is_train=False): if len(images.shape) == 3: images = images[None, ...] assert len(images.shape) == 4 and images.shape[1] == 2, \ "Input images must be either 2xHxW or Bx2xHxW." if not is_train: with torch.no_grad(): self.generator.eval() event_volume = self.generator(images) self.generator.train() else: event_volume = self.generator(images) return event_volume
def prediction_to_json(image_path, chkp_path, net=None) -> dict: """ Convert mask prediction to json. The format matches the format in the training annotation data: {'filename':file_name, 'labels': [{'name': label_name, 'annotations': [{'id':some_unique_integer_id, 'segmentation':[x,y,x,y,x,y....]} ....] } ....] } """ file_name = os.path.basename(image_path) annotation = {'filename': file_name, 'labels': []} device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if not net: net = UNet(n_channels=3, n_classes=4) net.to(device=device) net.load_state_dict(torch.load(chkp_path, map_location=device)) img = Image.open(image_path) msk = predict_on_image(net=net, device=device, src_img=img) msk = msk.transpose((1, 2, 0)) h, w, n_labels = msk.shape rgb_mask = np.ones((h, w, 3), dtype=np.uint8) annotation['height'] = h annotation['width'] = w for label in range(1, n_labels): color = hex_labels[str(label)] category = category_labels[str(label)] c_label = {'color': color, 'name': category, 'annotations': []} label_mask = msk[:, :, label].astype(int).astype(np.uint8) contours, hierarchy = cv2.findContours(label_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) for contour in contours: vector_points = [] for x, y in contour.reshape((len(contour), 2)): vector_points += [float(x), float(y)] c_label['annotations'].append({'segmentation': vector_points}) idx = np.where(msk[:, :, label].astype(int) == 1) rgb_mask[idx] = colors_from_hex[str(label)] annotation['labels'].append(c_label) return annotation
def load_finetuned_model(self, baseline_model): """ Loads the augmentation net, sample reweighting net, and baseline model Note: sets all these models to train mode """ # augment_net = Net(0, 0.0, 32, 3, 0.0, num_classes=32**2 * 3, do_res=True) if self.args.dataset == DATASET_MNIST: imsize, in_channel, num_classes = 28, 1, 10 else: imsize, in_channel, num_classes = 32, 3, 10 augment_net = UNet( in_channels=in_channel, n_classes=in_channel, depth=2, wf=3, padding=True, batch_norm=False, do_noise_channel=True, up_mode='upconv', use_identity_residual=True) # TODO(PV): Initialize UNet properly # TODO (JON): DEPTH 1 WORKED WELL. Changed upconv to upsample. Use a wf of 2. # This ResNet outputs scalar weights to be applied element-wise to the per-example losses reweighting_net = Net(1, 0.0, imsize, in_channel, 0.0, num_classes=1) # resnet_cifar.resnet20(num_classes=1) if self.args.load_finetune_checkpoint: checkpoint = torch.load(self.args.load_finetune_checkpoint) # temp_baseline_model = baseline_model # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) if 'weight_decay' in checkpoint: baseline_model.weight_decay = checkpoint['weight_decay'] # baseline_model.weight_decay = temp_baseline_model.weight_decay # baseline_model.load_state_dict(checkpoint['elementary_model_state_dict']) augment_net.load_state_dict(checkpoint['augment_model_state_dict']) try: reweighting_net.load_state_dict( checkpoint['reweighting_model_state_dict']) except KeyError: pass augment_net, reweighting_net, baseline_model = augment_net.to( self.device), reweighting_net.to(self.device), baseline_model.to( self.device) augment_net.train(), reweighting_net.train(), baseline_model.train() return augment_net, reweighting_net, baseline_model
Learning rate: {args.lr} Weight decay: {args.weight_decay} Device: GPU{args.gpu} Log name: {args.save} ''') torch.cuda.set_device(args.gpu) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # choose a model if args.model == 'unet': net = UNet() elif args.model == 'nestedunet': net = NestedUNet() net.to(device=device) # choose a dataset if args.dataset == 'promise12': dir_data = '../data/promise12' trainset = Promise12(dir_data, mode='train') valset = Promise12(dir_data, mode='val') elif args.dataset == 'chaos': dir_data = '../data/chaos' trainset = Chaos(dir_data, mode='train') valset = Chaos(dir_data, mode='val') try: train_net(net=net, trainset=trainset,
def train(input_data_type, grade, seg_type, num_classes, batch_size, epochs, use_gpu, learning_rate, w_decay, pre_trained=False): logger.info('Start training using {} modal.'.format(input_data_type)) model = UNet(4, 4, residual=True, expansion=2) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=w_decay) if pre_trained: checkpoint = torch.load(pre_trained_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) if use_gpu: ts = time.time() model.to(device) print("Finish cuda loading, time elapsed {}".format(time.time() - ts)) scheduler = lr_scheduler.StepLR( optimizer, step_size=step_size, gamma=gamma) # decay LR by a factor of 0.5 every 5 epochs data_set, data_loader = get_dataset_dataloader(input_data_type, seg_type, batch_size, grade=grade) since = time.time() best_model_wts = copy.deepcopy(model.state_dict()) best_iou = 0.0 epoch_loss = np.zeros((2, epochs)) epoch_acc = np.zeros((2, epochs)) epoch_class_acc = np.zeros((2, epochs)) epoch_mean_iou = np.zeros((2, epochs)) evaluator = Evaluator(num_classes) def term_int_handler(signal_num, frame): np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc) np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou) np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss) model.load_state_dict(best_model_wts) logger.info('Got terminated and saved model.state_dict') torch.save(model.state_dict(), os.path.join(score_dir, 'terminated_model.pt')) torch.save( { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict() }, os.path.join(score_dir, 'terminated_model.tar')) quit() signal.signal(signal.SIGINT, term_int_handler) signal.signal(signal.SIGTERM, term_int_handler) for epoch in range(epochs): logger.info('Epoch {}/{}'.format(epoch + 1, epochs)) logger.info('-' * 28) for phase_ind, phase in enumerate(['train', 'val']): if phase == 'train': model.train() logger.info(phase) else: model.eval() logger.info(phase) evaluator.reset() running_loss = 0.0 running_dice = 0.0 for batch_ind, batch in enumerate(data_loader[phase]): imgs, targets = batch imgs = imgs.to(device) targets = targets.to(device) # zero the learnable parameters gradients optimizer.zero_grad() with torch.set_grad_enabled(phase == 'train'): outputs = model(imgs) loss = criterion(outputs, targets) if phase == 'train': loss.backward() optimizer.step() preds = torch.argmax(F.softmax(outputs, dim=1), dim=1, keepdim=True) running_loss += loss * imgs.size(0) logger.debug('Batch {} running loss: {:.4f}'.format(batch_ind,\ running_loss)) # test the iou and pixelwise accuracy using evaluator preds = torch.squeeze(preds, dim=1) preds = preds.cpu().numpy() targets = targets.cpu().numpy() evaluator.add_batch(targets, preds) epoch_loss[phase_ind, epoch] = running_loss / len(data_set[phase]) epoch_acc[phase_ind, epoch] = evaluator.Pixel_Accuracy() epoch_class_acc[phase_ind, epoch] = evaluator.Pixel_Accuracy_Class() epoch_mean_iou[phase_ind, epoch] = evaluator.Mean_Intersection_over_Union() logger.info('{} loss: {:.4f}, acc: {:.4f}, class acc: {:.4f}, mean iou: {:.6f}'.format(phase,\ epoch_loss[phase_ind, epoch],\ epoch_acc[phase_ind, epoch],\ epoch_class_acc[phase_ind, epoch],\ epoch_mean_iou[phase_ind, epoch])) if phase == 'val' and epoch_mean_iou[phase_ind, epoch] > best_iou: best_iou = epoch_mean_iou[phase_ind, epoch] best_model_wts = copy.deepcopy(model.state_dict()) if phase == 'val' and (epoch + 1) % 10 == 0: logger.info('Saved model.state_dict in epoch {}'.format(epoch + 1)) torch.save( model.state_dict(), os.path.join(score_dir, 'epoch{}_model.pt'.format(epoch + 1))) print() time_elapsed = time.time() - since logger.info('Training completed in {}m {}s'.format(int(time_elapsed / 60),\ int(time_elapsed) % 60)) # load best model weights model.load_state_dict(best_model_wts) # save numpy results np.save(os.path.join(score_dir, 'epoch_accuracy'), epoch_acc) np.save(os.path.join(score_dir, 'epoch_mean_iou'), epoch_mean_iou) np.save(os.path.join(score_dir, 'epoch_loss'), epoch_loss) return model, optimizer
class NNUnet(pl.LightningModule): def __init__(self, args): super(NNUnet, self).__init__() self.args = args if not hasattr(self.args, "drop_block"): # For backward compability self.args.drop_block = False self.save_hyperparameters() self.build_nnunet() self.loss = Loss(self.args.focal) self.dice = Dice(self.n_class) self.best_sum = 0 self.best_sum_epoch = 0 self.best_dice = self.n_class * [0] self.best_epoch = self.n_class * [0] self.best_sum_dice = self.n_class * [0] self.learning_rate = args.learning_rate self.tta_flips = get_tta_flips(args.dim) self.test_idx = 0 self.test_imgs = [] if self.args.exec_mode in ["train", "evaluate"]: self.dllogger = get_dllogger(args.results) def forward(self, img): if self.args.benchmark: if self.args.dim == 2 and self.args.data2d_dim == 3: img = layout_2d(img, None) return self.model(img) return self.tta_inference(img) if self.args.tta else self.do_inference( img) def training_step(self, batch, batch_idx): img, lbl = self.get_train_data(batch) pred = self.model(img) loss = self.compute_loss(pred, lbl) mark_step(self.args.run_lazy_mode) return loss def on_before_zero_grad(self, optimizer): mark_step(self.args.run_lazy_mode) def on_after_backward(self): mark_step(self.args.run_lazy_mode) def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs): optimizer.step(closure=optimizer_closure) mark_step(self.args.run_lazy_mode) def validation_step(self, batch, batch_idx): if self.current_epoch < self.args.skip_first_n_eval: return None img, lbl = batch["image"], batch["label"] if self.args.hpus: img, lbl = img.to(torch.device("hpu"), non_blocking=False), lbl.to(torch.device("hpu"), non_blocking=False) pred = self.forward(img) loss = self.loss(pred, lbl) self.dice.update(pred, lbl[:, 0]) mark_step(self.args.run_lazy_mode) return {"val_loss": loss} def test_step(self, batch, batch_idx): print("Start test") if self.args.exec_mode == "evaluate": return self.validation_step(batch, batch_idx) img = batch["image"] if self.args.hpus: img = img.to(torch.device("hpu"), non_blocking=False) if self.args.channels_last: if img.ndim == 4 or self.args.dim == 2: img = img.contiguous(memory_format=torch.channels_last) elif img.ndim == 5 and self.args.dim == 3: img = img.contiguous(memory_format=torch.channels_last_3d) mark_step(self.args.run_lazy_mode) pred = self.forward(img) mark_step(self.args.run_lazy_mode) if self.args.save_preds: meta = batch["meta"][0].cpu().detach().numpy() original_shape = meta[2] min_d, max_d = meta[0, 0], meta[1, 0] min_h, max_h = meta[0, 1], meta[1, 1] min_w, max_w = meta[0, 2], meta[1, 2] final_pred = torch.zeros((1, pred.shape[1], *original_shape), device=img.device) final_pred[:, :, min_d:max_d, min_h:max_h, min_w:max_w] = pred final_pred = nn.functional.softmax(final_pred, dim=1) final_pred = final_pred.squeeze(0).cpu().detach().numpy() if not all(original_shape == final_pred.shape[1:]): class_ = final_pred.shape[0] resized_pred = np.zeros((class_, *original_shape)) for i in range(class_): resized_pred[i] = resize(final_pred[i], original_shape, order=3, mode="edge", cval=0, clip=True, anti_aliasing=False) final_pred = resized_pred self.save_mask(final_pred) def on_save_checkpoint(self, checkpoint): if not self.args.hpus: return state_dict = checkpoint['state_dict'] optimizer_states = checkpoint['optimizer_states'] optimizer_state_dict = optimizer_states[0]['state'] for k, v in checkpoint["callbacks"].items(): if isinstance(v, dict): for k1, v1 in v.items(): if isinstance(v1, torch.Tensor): v[k1] = v1.to("cpu") adjust_tensors_for_save(state_dict, optimizer_state_dict, to_device="cpu", to_filters_last=False, lazy_mode=self.args.run_lazy_mode, permute=True) def build_nnunet(self): in_channels, n_class, kernels, strides, self.patch_size = get_unet_params( self.args) self.n_class = n_class - 1 self.model = UNet( in_channels=in_channels, n_class=n_class, kernels=kernels, strides=strides, dimension=self.args.dim, residual=self.args.residual, attention=self.args.attention, drop_block=self.args.drop_block, normalization_layer=self.args.norm, negative_slope=self.args.negative_slope, deep_supervision=self.args.deep_supervision, ) if is_main_process(): print( f"Filters: {self.model.filters},\nKernels: {kernels}\nStrides: {strides}" ) def compute_loss(self, preds, label): if self.args.deep_supervision: loss = self.loss(preds[0], label) for i, pred in enumerate(preds[1:]): downsampled_label = nn.functional.interpolate( label, pred.shape[2:]) loss += 0.5**(i + 1) * self.loss(pred, downsampled_label) c_norm = 1 / (2 - 2**(-len(preds))) return c_norm * loss return self.loss(preds, label) def do_inference(self, image): if self.args.dim == 3: return self.sliding_window_inference(image) if self.args.data2d_dim == 2: return self.model(image) if self.args.exec_mode == "predict": return self.inference2d_test(image) return self.inference2d(image) def tta_inference(self, img): pred = self.do_inference(img) for flip_idx in self.tta_flips: pred += flip(self.do_inference(flip(img, flip_idx)), flip_idx) pred /= len(self.tta_flips) + 1 return pred def inference2d(self, image): batch_modulo = image.shape[2] % self.args.val_batch_size if batch_modulo != 0: batch_pad = self.args.val_batch_size - batch_modulo image = nn.ConstantPad3d((0, 0, 0, 0, batch_pad, 0), 0)(image) mark_step(self.args.run_lazy_mode) image = torch.transpose(image.squeeze(0), 0, 1) preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:]) if self.args.hpus: preds = None for start in range(0, image.shape[0] - self.args.val_batch_size + 1, self.args.val_batch_size): end = start + self.args.val_batch_size pred = self.model(image[start:end]) preds = pred if preds == None else torch.cat( (preds, pred), dim=0) mark_step(self.args.run_lazy_mode) if batch_modulo != 0: preds = preds[batch_pad:] mark_step(self.args.run_lazy_mode) else: preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for start in range(0, image.shape[0] - self.args.val_batch_size + 1, self.args.val_batch_size): end = start + self.args.val_batch_size pred = self.model(image[start:end]) preds[start:end] = pred.data if batch_modulo != 0: preds = preds[batch_pad:] return torch.transpose(preds, 0, 1).unsqueeze(0) def inference2d_test(self, image): preds_shape = (image.shape[0], self.n_class + 1, *image.shape[2:]) preds = torch.zeros(preds_shape, dtype=image.dtype, device=image.device) for depth in range(image.shape[2]): preds[:, :, depth] = self.sliding_window_inference(image[:, :, depth]) return preds def sliding_window_inference(self, image): if self.args.hpus: from models.monai_sliding_window_inference import sliding_window_inference else: from monai.inferers import sliding_window_inference return sliding_window_inference( inputs=image, roi_size=self.patch_size, sw_batch_size=self.args.val_batch_size, predictor=self.model, overlap=self.args.overlap, mode=self.args.blend, ) @staticmethod def metric_mean(name, outputs): return torch.stack([out[name] for out in outputs]).mean(dim=0) def validation_epoch_end(self, outputs): if self.current_epoch < self.args.skip_first_n_eval: self.log("dice_sum", 0.001 * self.current_epoch) self.dice.reset() return None loss = self.metric_mean("val_loss", outputs) dice = self.dice.compute() dice_sum = torch.sum(dice) if dice_sum >= self.best_sum: self.best_sum = dice_sum self.best_sum_dice = dice[:] self.best_sum_epoch = self.current_epoch for i, dice_i in enumerate(dice): if dice_i > self.best_dice[i]: self.best_dice[i], self.best_epoch[ i] = dice_i, self.current_epoch if is_main_process(): metrics = {} metrics.update({"mean dice": round(torch.mean(dice).item(), 2)}) metrics.update( {"TOP_mean": round(torch.mean(self.best_sum_dice).item(), 2)}) if self.n_class > 1: metrics.update({ f"L{i+1}": round(m.item(), 2) for i, m in enumerate(dice) }) metrics.update({ f"TOP_L{i+1}": round(m.item(), 2) for i, m in enumerate(self.best_sum_dice) }) metrics.update({"val_loss": round(loss.item(), 4)}) self.dllogger.log(step=self.current_epoch, data=metrics) self.dllogger.flush() self.log("val_loss", loss) self.log("dice_sum", dice_sum) def test_epoch_end(self, outputs): if self.args.exec_mode == "evaluate": self.eval_dice = self.dice.compute() def configure_optimizers(self): if self.args.hpus: self.model = self.model.to(get_device(self.args)) permute_params(self.model, True, self.args.run_lazy_mode) # Avoid instantiate optimizers if not have to # since might not be supported if self.args.optimizer.lower() == 'sgd': optimizer = SGD(self.parameters(), lr=self.learning_rate, momentum=self.args.momentum) elif self.args.optimizer.lower() == 'adam': optimizer = Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay) elif self.args.optimizer.lower() == 'radam': optimizer = RAdam(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay) elif self.args.optimizer.lower() == 'adamw': optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.args.weight_decay) elif self.args.optimizer.lower() == 'fusedadamw': from habana_frameworks.torch.hpex.optimizers import FusedAdamW optimizer = FusedAdamW(self.parameters(), lr=self.learning_rate, eps=1e-08, weight_decay=self.args.weight_decay) else: assert False, "optimizer {} not suppoerted".format( self.args.optimizer.lower()) scheduler = { "none": None, "multistep": torch.optim.lr_scheduler.MultiStepLR(optimizer, self.args.steps, gamma=self.args.factor), "cosine": torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_epochs), "plateau": torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, factor=self.args.factor, patience=self.args.lr_patience), }[self.args.scheduler.lower()] opt_dict = {"optimizer": optimizer, "monitor": "val_loss"} if scheduler is not None: opt_dict.update({"lr_scheduler": scheduler}) return opt_dict def save_mask(self, pred): if self.test_idx == 0: data_path = get_path(self.args) self.test_imgs, _ = get_test_fnames(self.args, data_path) fname = os.path.basename(self.test_imgs[self.test_idx]).replace( "_x", "") np.save(os.path.join(self.save_dir, fname), pred, allow_pickle=False) self.test_idx += 1 def get_train_data(self, batch): img, lbl = batch["image"], batch["label"] if self.args.dim == 2 and self.args.data2d_dim == 3: img, lbl = layout_2d(img, lbl) if self.args.hpus: img, lbl = img.to(torch.device("hpu"), non_blocking=False), lbl.to(torch.device("hpu"), non_blocking=False) if self.args.channels_last: if img.ndim == 4: img = img.contiguous(memory_format=torch.channels_last) lbl = lbl.contiguous(memory_format=torch.channels_last) elif img.ndim == 5: img = img.contiguous(memory_format=torch.channels_last_3d) lbl = lbl.contiguous(memory_format=torch.channels_last_3d) mark_step(self.args.run_lazy_mode) return img, lbl
'cmap': 'jet', 'vmin': 0, 'vmax': eval_label.max() }, { 'cmap': 'jet', 'vmin': 0, 'vmax': eval_label.max() }) net_is_3d = False if torch.cuda.device_count() > 1: print("Using", torch.cuda.device_count(), "GPUs.") device_ids = [i for i in range(torch.cuda.device_count())] model = nn.DataParallel(model, device_ids=device_ids) model = model.to(device) if experiment == "Unet": model.load_state_dict(torch.load("best_weights.pth")) elif experiment == "DeepLab": model.load_state_dict(torch.load(f"best_weights_{backbone}_deeplab.pth")) model.eval() eval_images, eval_labels, eval_label_corners = batch_generator( eval_image, eval_label, **windowing_params, return_corners=True) eval_dataset = PlateletDataset(eval_images, eval_labels, train=False) prob_maps = stitch(model, eval_images, eval_labels, eval_label.shape,
def inference(): """Support two mode: evaluation (on valid set) or inference mode (on test-set for submission) """ parser = argparse.ArgumentParser(description="Inference mode") parser.add_argument('-testf', "--test-filepath", type=str, default=None, required=True, help="testing dataset filepath.") parser.add_argument("-eval", "--evaluate", action="store_true", default=False, help="Evaluation mode") parser.add_argument("--load-weights", type=str, default=None, help="Load pretrained weights, torch state_dict() (filepath, default: None)") parser.add_argument("--load-model", type=str, default=None, help="Load pretrained model, entire model (filepath, default: None)") parser.add_argument("--save2dir", type=str, default=None, help="save the prediction labels to the directory (default: None)") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument("--num-cpu", type=int, default=10, help="Number of CPUs to use in parallel for dataloader.") parser.add_argument('--cuda', type=int, default=0, help='CUDA visible device (use CPU if -1, default: 0)') args = parser.parse_args() printYellow("="*10 + " Inference mode. "+"="*10) if args.save2dir: os.makedirs(args.save2dir, exist_ok=True) device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() and (args.cuda >= 0) else "cpu") transform_normalize = transforms.Normalize(mean=[0.5], std=[0.5]) data_transform = transforms.Compose([ transforms.ToTensor(), transform_normalize ]) data_loader_params = {'batch_size': args.batch_size, 'shuffle': False, 'num_workers': args.num_cpu, 'drop_last': False, 'pin_memory': False } test_set = LiTSDataset(args.test_filepath, dtype=np.float32, pixelwise_transform=data_transform, inference_mode=(not args.evaluate), ) dataloader_test = torch.utils.data.DataLoader(test_set, **data_loader_params) # =================== Build model =================== if args.load_weights: model = UNet(in_ch=1, out_ch=3, # there are 3 classes: 0: background, 1: liver, 2: tumor depth=4, start_ch=64, inc_rate=2, kernel_size=3, padding=True, batch_norm=True, spec_norm=False, dropout=0.5, up_mode='upconv', include_top=True, include_last_act=False, ) model.load_state_dict(torch.load(args.load_weights)) printYellow("Successfully loaded pretrained weights.") elif args.load_model: # load entire model model = torch.load(args.load_model) printYellow("Successfully loaded pretrained model.") model.eval() model.to(device) # n_batch_per_epoch = len(dataloader_test) sigmoid_act = torch.nn.Sigmoid() st = time.time() volume_start_index = test_set.volume_start_index spacing = test_set.spacing direction = test_set.direction # use it for the submission offset = test_set.offset msk_pred_buffer = [] if args.evaluate: msk_gt_buffer = [] for data_batch in tqdm(dataloader_test): # import ipdb # ipdb.set_trace() if args.evaluate: img, msk_gt = data_batch msk_gt_buffer.append(msk_gt.cpu().detach().numpy()) else: img = data_batch img = img.to(device) with torch.no_grad(): msk_pred = model(img) # shape (N, 3, H, W) msk_pred = sigmoid_act(msk_pred) msk_pred_buffer.append(msk_pred.cpu().detach().numpy()) msk_pred_buffer = np.vstack(msk_pred_buffer) # shape (N, 3, H, W) if args.evaluate: msk_gt_buffer = np.vstack(msk_gt_buffer) results = [] for vol_ind, vol_start_ind in enumerate(volume_start_index): if vol_ind == len(volume_start_index) - 1: volume_msk = msk_pred_buffer[vol_start_ind:] # shape (N, 3, H, W) if args.evaluate: volume_msk_gt = msk_gt_buffer[vol_start_ind:] else: vol_end_ind = volume_start_index[vol_ind+1] volume_msk = msk_pred_buffer[vol_start_ind:vol_end_ind] # shape (N, 3, H, W) if args.evaluate: volume_msk_gt = msk_gt_buffer[vol_start_ind:vol_end_ind] if args.evaluate: # liver liver_scores = get_scores(volume_msk[:, 1] >= 0.5, volume_msk_gt >= 1, spacing[vol_ind]) # tumor lesion_scores = get_scores(volume_msk[:, 2] >= 0.5, volume_msk_gt == 2, spacing[vol_ind]) print("Liver dice", liver_scores['dice'], "Lesion dice", lesion_scores['dice']) results.append([vol_ind, liver_scores, lesion_scores]) # =========================== else: # import ipdb; ipdb.set_trace() if args.save2dir: # reverse the order, because we prioritize tumor, liver then background. msk_pred = (volume_msk >= 0.5)[:, ::-1, ...] # shape (N, 3, H, W) msk_pred = np.argmax(msk_pred, axis=1) # shape (N, H, W) = (z, x, y) msk_pred = np.transpose(msk_pred, axes=(1, 2, 0)) # shape (x, y, z) # remember to correct 'direction' and np.transpose before the submission !!! if direction[vol_ind][0] == -1: # x-axis msk_pred = msk_pred[::-1, ...] if direction[vol_ind][1] == -1: # y-axis msk_pred = msk_pred[:, ::-1, :] if direction[vol_ind][2] == -1: # z-axis msk_pred = msk_pred[..., ::-1] # save medical image header as well # see: http://loli.github.io/medpy/generated/medpy.io.header.Header.html file_header = med_header(spacing=tuple(spacing[vol_ind]), offset=tuple(offset[vol_ind]), direction=np.diag(direction[vol_ind])) # submission guide: # see: https://github.com/PatrickChrist/LITS-CHALLENGE/blob/master/submission-guide.md # test-segmentation-X.nii filepath = os.path.join(args.save2dir, f"test-segmentation-{vol_ind}.nii") med_save(msk_pred, filepath, hdr=file_header) if args.save2dir: # outpath = os.path.join(args.save2dir, "results.csv") outpath = os.path.join(args.save2dir, "results.pkl") with open(outpath, "wb") as file: final_result = {} final_result['liver'] = defaultdict(list) final_result['tumor'] = defaultdict(list) for vol_ind, liver_scores, lesion_scores in results: # [OTC] assuming vol_ind is continuous for key in liver_scores: final_result['liver'][key].append(liver_scores[key]) for key in lesion_scores: final_result['tumor'][key].append(lesion_scores[key]) pickle.dump(final_result, file, protocol=3) # ======== code from official metric ======== # create line for csv file # outstr = str(vol_ind) + ',' # for l in [liver_scores, lesion_scores]: # for k, v in l.items(): # outstr += str(v) + ',' # outstr += '\n' # # create header for csv file if necessary # if not os.path.isfile(outpath): # headerstr = 'Volume,' # for k, v in liver_scores.items(): # headerstr += 'Liver_' + k + ',' # for k, v in liver_scores.items(): # headerstr += 'Lesion_' + k + ',' # headerstr += '\n' # outstr = headerstr + outstr # # write to file # f = open(outpath, 'a+') # f.write(outstr) # f.close() # =========================== printGreen(f"Total elapsed time: {time.time()-st}") return results
def train(args): ''' -------------------------Hyperparameters-------------------------- ''' EPOCHS = args.epochs START = 0 # could enter a checkpoint start epoch ITER = args.iterations # per epoch LR = args.lr MOM = args.momentum # LOGInterval = args.log_interval BATCHSIZE = args.batch_size TEST_BATCHSIZE = args.test_batch_size NUMBER_OF_WORKERS = args.workers DATA_FOLDER = args.data TESTSET_FOLDER = args.testset ROOT = args.run WEIGHT_DIR = os.path.join(ROOT, "weights") CUSTOM_LOG_DIR = os.path.join(ROOT, "additionalLOGS") CHECKPOINT = os.path.join(WEIGHT_DIR, str(args.model) + str(args.name) + ".pt") useTensorboard = args.tb # check existance of data if not os.path.isdir(DATA_FOLDER): print("data folder not existant or in wrong layout.\n\t", DATA_FOLDER) exit(0) # check existance of testset if TESTSET_FOLDER is not None and not os.path.isdir(TESTSET_FOLDER): print("testset folder not existant or in wrong layout.\n\t", DATA_FOLDER) exit(0) ''' ---------------------------preparations--------------------------- ''' # CUDA for PyTorch use_cuda = torch.cuda.is_available() device = torch.device("cuda:0" if use_cuda else "cpu") print("using device: ", str(device)) # loading the validation samples to make online evaluations path_to_valX = args.valX path_to_valY = args.valY valX = None valY = None if path_to_valX is not None and path_to_valY is not None \ and os.path.exists(path_to_valX) and os.path.exists(path_to_valY) \ and os.path.isfile(path_to_valX) and os.path.isfile(path_to_valY): with torch.no_grad(): valX, valY = torch.load(path_to_valX, map_location='cpu'), \ torch.load(path_to_valY, map_location='cpu') ''' ---------------------------loading dataset and normalizing--------------------------- ''' # Dataloader Parameters train_params = { 'batch_size': BATCHSIZE, 'shuffle': True, 'num_workers': NUMBER_OF_WORKERS } test_params = { 'batch_size': TEST_BATCHSIZE, 'shuffle': False, 'num_workers': NUMBER_OF_WORKERS } # create a folder for the weights and custom logs if not os.path.isdir(WEIGHT_DIR): os.makedirs(WEIGHT_DIR) if not os.path.isdir(CUSTOM_LOG_DIR): os.makedirs(CUSTOM_LOG_DIR) labelsNorm = None # NORMLABEL # normalizing on a trainingset wide mean and std mean = None std = None if args.norm: print('computing mean and std over trainingset') # computes mean and std over all ground truths in dataset to tackle the problem of numerical insignificance mean, std = computeMeanStdOverDataset('CONRADataset', DATA_FOLDER, train_params, device) print('\niodine (mean/std): {}\t{}'.format(mean[0], std[0])) print('water (mean/std): {}\t{}\n'.format(mean[1], std[1])) labelsNorm = transforms.Normalize(mean=[0, 0], std=std) m2, s2 = computeMeanStdOverDataset('CONRADataset', DATA_FOLDER, train_params, device, transform=labelsNorm) print("new mean and std are:") print('\nnew iodine (mean/std): {}\t{}'.format(m2[0], s2[0])) print('new water (mean/std): {}\t{}\n'.format(m2[1], s2[1])) traindata = CONRADataset(DATA_FOLDER, True, device=device, precompute=True, transform=labelsNorm) testdata = None if TESTSET_FOLDER is not None: testdata = CONRADataset(TESTSET_FOLDER, False, device=device, precompute=True, transform=labelsNorm) else: testdata = CONRADataset(DATA_FOLDER, False, device=device, precompute=True, transform=labelsNorm) trainingset = DataLoader(traindata, **train_params) testset = DataLoader(testdata, **test_params) ''' ----------------loading model and checkpoints--------------------- ''' if args.model == "unet": m = UNet(2, 2).to(device) print( "using the U-Net architecture with {} trainable params; Good Luck!" .format(count_trainables(m))) else: m = simpleConvNet(2, 2).to(device) o = optim.SGD(m.parameters(), lr=LR, momentum=MOM) loss_fn = nn.MSELoss() test_loss = None train_loss = None if len(os.listdir(WEIGHT_DIR)) != 0: checkpoints = os.listdir(WEIGHT_DIR) checkDir = {} latestCheckpoint = 0 for i, checkpoint in enumerate(checkpoints): stepOfCheckpoint = int( checkpoint.split(str(args.model) + str(args.name))[-1].split('.pt')[0]) checkDir[stepOfCheckpoint] = checkpoint latestCheckpoint = max(latestCheckpoint, stepOfCheckpoint) print("[{}] {}".format(stepOfCheckpoint, checkpoint)) # if on development machine, prompt for input, else just take the most recent one if 'faui' in os.uname()[1]: toUse = int(input("select checkpoint to use: ")) else: toUse = latestCheckpoint checkpoint = torch.load(os.path.join(WEIGHT_DIR, checkDir[toUse])) m.load_state_dict(checkpoint['model_state_dict']) m.to(device) # pushing weights to gpu o.load_state_dict(checkpoint['optimizer_state_dict']) train_loss = checkpoint['train_loss'] test_loss = checkpoint['test_loss'] START = checkpoint['epoch'] print("using checkpoint {}:\n\tloss(train/test): {}/{}".format( toUse, train_loss, test_loss)) else: print("starting from scratch") ''' -----------------------------training----------------------------- ''' global_step = 0 # calculating initial loss if test_loss is None or train_loss is None: print("calculating initial loss") m.eval() print("testset...") test_loss = calculate_loss(set=testset, loss_fn=loss_fn, length_set=len(testdata), dev=device, model=m) print("trainset...") train_loss = calculate_loss(set=trainingset, loss_fn=loss_fn, length_set=len(traindata), dev=device, model=m) ## SSIM and R value R = [] SSIM = [] performanceFLE = os.path.join(CUSTOM_LOG_DIR, "performance.csv") with open(performanceFLE, 'w+') as f: f.write( "step, SSIMiodine, SSIMwater, Riodine, Rwater, train_loss, test_loss\n" ) print("computing ssim and r coefficents to: {}".format(performanceFLE)) # printing runtime information print( "starting training at {} for {} epochs {} iterations each\n\t{} total". format(START, EPOCHS, ITER, EPOCHS * ITER)) print("\tbatchsize: {}\n\tloss: {}\n\twill save results to \"{}\"".format( BATCHSIZE, train_loss, CHECKPOINT)) print( "\tmodel: {}\n\tlearningrate: {}\n\tmomentum: {}\n\tnorming output space: {}" .format(args.model, LR, MOM, args.norm)) #start actual training loops for e in range(START, START + EPOCHS): # iterations will not be interupted with validation and metrics for i in range(ITER): global_step = (e * ITER) + i # training m.train() iteration_loss = 0 for x, y in tqdm(trainingset): x, y = x.to(device=device, dtype=torch.float), y.to(device=device, dtype=torch.float) pred = m(x) loss = loss_fn(pred, y) iteration_loss += loss.item() o.zero_grad() loss.backward() o.step() print("\niteration {}: --accumulated loss {}".format( global_step, iteration_loss)) # validation, saving and logging print("\nvalidating") m.eval() # disable dropout batchnorm etc print("testset...") test_loss = calculate_loss(set=testset, loss_fn=loss_fn, length_set=len(testdata), dev=device, model=m) print("trainset...") train_loss = calculate_loss(set=trainingset, loss_fn=loss_fn, length_set=len(traindata), dev=device, model=m) print("calculating SSIM and R coefficients") currSSIM, currR = performance(set=testset, dev=device, model=m, bs=TEST_BATCHSIZE) print("SSIM (iod/water): {}/{}\nR (iod/water): {}/{}".format( currSSIM[0], currSSIM[1], currR[0], currR[1])) with open(performanceFLE, 'a') as f: newCSVline = "{}, {}, {}, {}, {}, {}, {}\n".format( global_step, currSSIM[0], currSSIM[1], currR[0], currR[1], train_loss, test_loss) f.write(newCSVline) print("wrote new line to csv:\n\t{}".format(newCSVline)) ''' if valX and valY were set in preparations, use them to perform analytics. if not, use the first sample from the testset to perform analytics ''' with torch.no_grad(): truth, pred = None, None IMAGE_LOG_DIR = os.path.join(CUSTOM_LOG_DIR, str(global_step)) if not os.path.isdir(IMAGE_LOG_DIR): os.makedirs(IMAGE_LOG_DIR) if valX is not None and valY is not None: batched = np.zeros((BATCHSIZE, *valX.numpy().shape)) batched[0] = valX.numpy() batched = torch.from_numpy(batched).to(device=device, dtype=torch.float) pred = m(batched) pred = pred.cpu().numpy()[0] truth = valY.numpy() # still on cpu assert pred.shape == truth.shape else: for x, y in testset: # x, y in shape[2,2,480,620] [b,c,h,w] x, y = x.to(device=device, dtype=torch.float), y.to(device=device, dtype=torch.float) pred = m(x) pred = pred.cpu().numpy()[ 0] # taking only the first sample of batch truth = y.cpu().numpy()[ 0] # first projection for evaluation advanvedMetrics(truth, pred, mean, std, global_step, args.norm, IMAGE_LOG_DIR) print("logging") CHECKPOINT = os.path.join( WEIGHT_DIR, str(args.model) + str(args.name) + str(global_step) + ".pt") torch.save( { 'epoch': e + 1, # end of this epoch; so resume at next. 'model_state_dict': m.state_dict(), 'optimizer_state_dict': o.state_dict(), 'train_loss': train_loss, 'test_loss': test_loss }, CHECKPOINT) print('\tsaved weigths to: ', CHECKPOINT) if logger is not None and train_loss is not None: logger.add_scalar('test_loss', test_loss, global_step=global_step) logger.add_scalar('train_loss', train_loss, global_step=global_step) logger.add_image("iodine-prediction", pred[0].reshape(1, 480, 620), global_step=global_step) logger.add_image("water-prediction", pred[1].reshape(1, 480, 620), global_step=global_step) # logger.add_image("water-prediction", wat) print( "\ttensorboard updated with test/train loss and a sample image" ) elif train_loss is not None: print("\tloss of global-step {}: {}".format( global_step, train_loss)) elif not useTensorboard: print("\t(tb-logging disabled) test/train loss: {}/{} ".format( test_loss, train_loss)) else: print("\tno loss accumulated yet") # saving final results print("saving upon exit") torch.save( { 'epoch': EPOCHS, 'model_state_dict': m.state_dict(), 'optimizer_state_dict': o.state_dict(), 'train_loss': train_loss, 'test_loss': test_loss }, CHECKPOINT) print('\tsaved progress to: ', CHECKPOINT) if logger is not None and train_loss is not None: logger.add_scalar('test_loss', test_loss, global_step=global_step) logger.add_scalar('train_loss', train_loss, global_step=global_step)
def main(args): def log_string(str): # logger.info(str) print(str) '''HYPER PARAMETER''' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu '''CREATE DIR''' timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) experiment_dir = Path('./log/') experiment_dir.mkdir(exist_ok=True) experiment_dir = experiment_dir.joinpath('part_seg') experiment_dir.mkdir(exist_ok=True) if args.log_dir is None: experiment_dir = experiment_dir.joinpath(timestr) else: experiment_dir = experiment_dir.joinpath(args.log_dir) experiment_dir.mkdir(exist_ok=True) checkpoints_dir = experiment_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = experiment_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' args = parse_args() logger = logging.getLogger("Model") logger.setLevel(logging.INFO) formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s') file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) log_string('PARAMETER ...') log_string(args) root = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/' # file_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/train2.list' val_list = '/media/feihu/Storage/kitti_point_cloud/semantic_kitti/val2.list' # TRAIN_DATASET = KittiDataset(root = root, file_list=file_list, npoints=args.npoint, training=True, augment=True) # trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=2) TEST_DATASET = KittiDataset(root=root, file_list=val_list, npoints=args.npoint, training=False, augment=False) testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=2) # log_string("The number of training data is: %d" % len(TRAIN_DATASET)) log_string("The number of test data is: %d" % len(TEST_DATASET)) # num_classes = 16 num_devices = args.num_gpus #torch.cuda.device_count() # assert num_devices > 1, "Cannot detect more than 1 GPU." # print(num_devices) devices = list(range(num_devices)) target_device = devices[0] # MODEL = importlib.import_module(args.model) net = UNet(4, 20, nPlanes) # net = MODEL.get_model(num_classes, normal_channel=args.normal) net = net.to(target_device) try: checkpoint = torch.load( str(experiment_dir) + '/checkpoints/best_model.pth') start_epoch = checkpoint['epoch'] net.load_state_dict(checkpoint['model_state_dict']) log_string('Use pretrain model') except: log_string('No existing model, starting training from scratch...') quit() if 1: with torch.no_grad(): net.eval() evaluator = iouEval(num_classes, ignore) evaluator.reset() # for iteration, (points, target, ins, mask) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): for iteration, (points, target, ins, mask) in enumerate(testDataLoader): evaone = iouEval(num_classes, ignore) evaone.reset() cur_batch_size, NUM_POINT, _ = points.size() if iteration > 128: break inputs, targets, masks = [], [], [] coords = [] for i in range(num_devices): start = int(i * (cur_batch_size / num_devices)) end = int((i + 1) * (cur_batch_size / num_devices)) with torch.cuda.device(devices[i]): pc = points[start:end, :, :].to(devices[i]) #feas = points[start:end,:,3:].to(devices[i]) targeti = target[start:end, :].to(devices[i]) maski = mask[start:end, :].to(devices[i]) locs, feas, label, maski, offsets = input_layer( pc, targeti, maski, scale.to(devices[i]), spatialSize.to(devices[i]), True) # print(locs.size(), feas.size(), label.size(), maski.size(), offsets.size()) org_coords = locs[1] label = Variable(label, requires_grad=False) inputi = ME.SparseTensor(feas.cpu(), locs[0].cpu()) inputs.append([inputi.to(devices[i]), org_coords]) targets.append(label) masks.append(maski) replicas = parallel.replicate(net, devices) outputs = parallel.parallel_apply(replicas, inputs, devices=devices) seg_pred = outputs[0].cpu() mask = masks[0].cpu() target = targets[0].cpu() loc = locs[0].cpu() for i in range(1, num_devices): seg_pred = torch.cat((seg_pred, outputs[i].cpu()), 0) mask = torch.cat((mask, masks[i].cpu()), 0) target = torch.cat((target, targets[i].cpu()), 0) seg_pred = seg_pred[target > 0, :] target = target[target > 0] _, seg_pred = seg_pred.data.max(1) #[1] target = target.data.numpy() evaluator.addBatch(seg_pred, target) evaone.addBatch(seg_pred, target) cur_accuracy = evaone.getacc() cur_jaccard, class_jaccard = evaone.getIoU() print('%.4f %.4f' % (cur_accuracy, cur_jaccard)) m_accuracy = evaluator.getacc() m_jaccard, class_jaccard = evaluator.getIoU() log_string('Validation set:\n' 'Acc avg {m_accuracy:.3f}\n' 'IoU avg {m_jaccard:.3f}'.format(m_accuracy=m_accuracy, m_jaccard=m_jaccard)) # print also classwise for i, jacc in enumerate(class_jaccard): if i not in ignore: log_string( 'IoU class {i:} [{class_str:}] = {jacc:.3f}'.format( i=i, class_str=class_strings[class_inv_remap[i]], jacc=jacc))
def train(): startTime = time.time() args = parameters.parse_arguments() logging.basicConfig(filename=args.logfile, level=logging.INFO) logging.critical("\n\n" + args.log_header) logging.info(args) device = ("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"TIME: {time.time() - startTime}s Using device {device}") logging.info(f"TIME: {time.time()-startTime}s Loading dataset") try: with open(os.path.join(args.datadir, "data.pkl"), "rb") as f: data = pickle.load(f) except: data = DataLoader(args.datadir, int(args.batchsize), shuffle=int(args.shuffle)) with open(os.path.join(args.datadir, "data.pkl"), "wb") as f: pickle.dump(data, f) data.batchSize = int(args.batchsize) logging.info(f"TIME: {time.time()-startTime}s Dataset Loaded") random.seed(args.seed) indices = list(range(len(data))) random.shuffle( indices ) # 0:floor((1-validationFrac)*len(data)) will be training data, rest will be validation data trainEndIndex = math.floor((1 - args.validation_frac) * (len(data))) model = UNet(in_channels=1, num_classes=2, start_filts=int(args.conv_filters), up_mode=args.mode, depth=int(args.depth), batchnorm=args.batchnorm) model.reset_params() model = model.to(device) optimizer = None if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lrstart) logging.info(f"TIME: {time.time()-startTime}s Optimizer: adam") elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lrstart, momentum=args.momentum) logging.info(f"TIME: {time.time()-startTime}s Optimizer: SGD") elif args.optimizer == 'rmsprop': optimizer = optim.RMSprop(model.parameters(), lr=args.lrstart) logging.info(f"TIME: {time.time()-startTime}s Optimizer: RMSProp") else: logging.error( f"TIME: {time.time()-startTime}s Incorrect optimizer given") scheduler = [] if args.lrscheduler == "steplr": scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.decay) logging.info(f"TIME: {time.time()-startTime}s LRScheduler: StepLR") elif args.lrscheduler == "exponentiallr": scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay) logging.info( f"TIME: {time.time()-startTime}s LRScheduler: exponentialLR") else: scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=int(args.epochs)) logging.info( f"TIME: {time.time()-startTime}s LRScheduler: lr shouldn't change with epochs" ) criteria = CombinedLoss(args.lambda_loss, args.loss_type) diceCoeff = DiceLoss() TL = [] VL = [] if not os.path.exists(os.path.join(os.getcwd(), "loss_files")): os.makedirs(os.path.join(os.getcwd(), "loss_files")) lossFile = open(os.path.join("loss_files", args.log_header + ".csv"), "w+") lossFile.write("Epoch,TrainLoss,ValidationLoss,Dice Coefficient\n") for epoch in tqdm(range(1, int(args.epochs) + 1), desc="Training model"): trainLoss = 0 valLoss = 0 trainingSample = 0 testSample = 0 netCoeff = 0 for i in range(len(data)): images, masks = data[i] images = torch.tensor(images.astype(np.float32)) masks = torch.tensor(masks.astype(np.float32)) images = images.to(device) masks = masks.to(device) images = torch.transpose(images, 1, 3) masks = torch.transpose(masks, 1, 3) if i in indices[:trainEndIndex]: trainingSample += images.shape[0] networkPred = model(images) if args.regularization == 'l1': reg = L1_regularization(model, args.reg_lamda1) loss = criteria(masks, networkPred) + reg elif args.regularization == 'l1l2': reg = L1L2_regularization(model, args.reg_lamda1, args.reg_lamda2) loss = criteria(masks, networkPred) + reg else: loss = criteria(masks, networkPred) loss.backward() trainLoss += loss.item() optimizer.step() model.zero_grad() else: with torch.no_grad(): testSample += images.shape[0] prediction = model(images) if (epoch % args.save_epochs == 0) or (epoch == 1) or (epoch == args.epochs): imgPath = os.path.join("validation_sample", args.log_header, f"epoch {epoch}") if not os.path.exists(imgPath): os.makedirs(imgPath) hrt = images[0, 0, :, :].to("cpu") plt.imshow(np.array(hrt), cmap='gray') plt.title("Heart Image") plt.savefig(os.path.join(imgPath, "heart.png")) plt.clf() # ax = figure.add_subplot(232, title="Mask 1 Predicted") msk1 = prediction[0, 0, :, :].to("cpu") plt.imshow(np.array(msk1), cmap='gray') plt.title("Predicted Mask 1") plt.savefig(os.path.join(imgPath, "pred-mask1.png")) plt.clf() # ax = figure.add_subplot(231, title="Mask 2 Predicted") msk2 = prediction[0, 1, :, :].to("cpu") plt.imshow(np.array(msk2), cmap='gray') plt.title("Predicted Mask 2") plt.savefig(os.path.join(imgPath, "pred-mask2.png")) plt.clf() msk = np.zeros((192, 192, 3)) msk[:, :, 0] = np.array(msk1) msk[:, :, 1] = np.array(msk2) plt.imshow(np.array(hrt), cmap='gray') plt.imshow(msk, cmap='jet', alpha=0.4) plt.title("predicted-RV") plt.savefig(os.path.join(imgPath, "pred-RV.png")) plt.clf() # ax = figure.add_subplot(231, title="Mask 2 Real") msk1 = masks[0, 0, :, :].to("cpu") plt.imshow(np.array(msk1), cmap='gray') plt.title("Actual Mask 1") plt.savefig(os.path.join(imgPath, "actual-mask1.png")) plt.clf() # ax = figure.add_subplot(231, title="Mask 2 Real") msk2 = masks[0, 1, :, :].to("cpu") plt.imshow(np.array(msk2), cmap='gray') plt.title("Actual Mask 2") plt.savefig(os.path.join(imgPath, "actual-mask2.png")) plt.clf() # plt.savefig(os.path.join("validation_sample", f"{args.log_header}-epoch {epoch}.png")) msk = np.zeros((192, 192, 3)) msk[:, :, 0] = np.array(msk1) msk[:, :, 1] = np.array(msk2) plt.imshow(np.array(hrt), cmap='gray') plt.imshow(msk, cmap='jet', alpha=0.4) plt.title("actual-RV") plt.savefig(os.path.join(imgPath, "actual-RV.png")) plt.clf() if args.regularization == 'l1': reg = L1_regularization(model, args.reg_lamda1) loss = criteria(masks, prediction) + reg elif args.regularization == 'l1l2': reg = L1L2_regularization(model, args.reg_lamda1, args.reg_lamda2) loss = criteria(masks, prediction) + reg else: loss = criteria(masks, prediction) valLoss += loss.item() coeff = diceCoeff(masks, prediction) netCoeff += torch.sum(1 - coeff).item() if (epoch % int(args.save_epochs) == 0) or (epoch == int(args.epochs)): if not os.path.exists(args.model_save_dir): os.makedirs(args.model_save_dir) # save model torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, os.path.join(args.model_save_dir, f"model-epoch({epoch}).hdf5")) logging.info( f"TIME: {time.time()-startTime}s Model state saved for epoch: {epoch}" ) logging.info( f"TIME: {time.time()-startTime}s TRAINING: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {trainLoss/(2*trainingSample)}" ) logging.info( f"TIME: {time.time()-startTime}s VALIDATION: Epoch: {epoch}, lr: {scheduler.get_last_lr()}, loss: {valLoss/(2*testSample)}" ) TL.append(trainLoss / (2 * trainingSample)) VL.append(valLoss / (2 * testSample)) lossFile.write( f"{epoch},{trainLoss/(2*trainingSample)},{valLoss/(2*testSample)},{netCoeff/(2*testSample)}\n" ) scheduler.step( ) # https://www.deeplearningwizard.com/deep_learning/boosting_models_pytorch/lr_scheduling/ plt.plot(list(range(1, int(args.epochs) + 1)), TL, label="Training loss") plt.plot(list(range(1, int(args.epochs) + 1)), VL, label="Validation loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend(loc="best") if not os.path.exists(os.path.join(os.getcwd(), "plots")): os.makedirs(os.path.join(os.getcwd(), "plots")) plt.savefig(os.path.join("plots", args.log_header + ".png"))
batch_size=1, shuffle=False) partition = 'train' unet_train = HistologyData(ROOT_DIR, partition, True) unet_loader = torch.utils.data.DataLoader( unet_train, batch_size=1, shuffle=True, ) # Create model model = ShapeUNet((15, 512, 512)) unet = UNet((3, 512, 512)) model.to(device) unet.to(device) mask_values = (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14) # here not RGB but BGR because of OPENCV. real_colors = ((0, 0, 0), (255, 0, 0), (0, 255, 0), (0, 0, 255), (85, 0, 0), (0, 170, 0), (255, 0, 127), (0, 255, 255), (0, 85, 0), (255, 0, 255), (255, 85, 0), (255, 165, 0), (255, 255, 0), (128, 130, 128), (128, 190, 190)) lr = 1e-4 optimizer = Adam(model.parameters(), lr=lr) NUM_OF_EPOCHS = 40 lr1 = 1e-4 unet_optim = Adam(unet.parameters(), lr=lr1) train_network_on_top_of_other(model, train_loader, val_loader, optimizer, unet,
def train_ei_adv(self, dataloader, physics, transform, epochs, lr, alpha, ckp_interval, schedule, residual=True, pretrained=None, task='', loss_type='l2', cat=True, report_psnr=False, lr_cos=False): save_path = './ckp/{}_ei_adv_{}'.format(get_timestamp(), task) os.makedirs(save_path, exist_ok=True) generator = UNet(in_channels=self.in_channels, out_channels=self.out_channels, compact=4, residual=residual, circular_padding=True, cat=cat) if pretrained: checkpoint = torch.load(pretrained) generator.load_state_dict(checkpoint['state_dict']) discriminator = Discriminator( (self.in_channels, self.img_width, self.img_height)) generator = generator.to(self.device) discriminator = discriminator.to(self.device) if loss_type == 'l2': criterion_mc = torch.nn.MSELoss().to(self.device) criterion_ei = torch.nn.MSELoss().to(self.device) if loss_type == 'l1': criterion_mc = torch.nn.L1Loss().to(self.device) criterion_ei = torch.nn.L1Loss().to(self.device) criterion_gan = torch.nn.MSELoss().to(self.device) optimizer_G = Adam(generator.parameters(), lr=lr['G'], weight_decay=lr['WD']) optimizer_D = Adam(discriminator.parameters(), lr=lr['D'], weight_decay=0) if report_psnr: log = LOG(save_path, filename='training_loss', field_name=[ 'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G', 'loss_D', 'psnr', 'mse' ]) else: log = LOG(save_path, filename='training_loss', field_name=[ 'epoch', 'loss_mc', 'loss_ei', 'loss_g', 'loss_G', 'loss_D' ]) for epoch in range(epochs): adjust_learning_rate(optimizer_G, epoch, lr['G'], lr_cos, epochs, schedule) adjust_learning_rate(optimizer_D, epoch, lr['D'], lr_cos, epochs, schedule) loss = closure_ei_adv(generator, discriminator, dataloader, physics, transform, optimizer_G, optimizer_D, criterion_mc, criterion_ei, criterion_gan, alpha, self.dtype, self.device, report_psnr) log.record(epoch + 1, *loss) if report_psnr: print( '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}\tpsnr={:.4f}\tmse={:.4e}' .format(get_timestamp(), epoch, epochs, *loss)) else: print( '{}\tEpoch[{}/{}]\tfc={:.4e}\tti={:.4e}\tg={:.4e}\tG={:.4e}\tD={:.4e}' .format(get_timestamp(), epoch, epochs, *loss)) if epoch % ckp_interval == 0 or epoch + 1 == epochs: state = { 'epoch': epoch, 'state_dict_G': generator.state_dict(), 'state_dict_D': discriminator.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict() } torch.save( state, os.path.join(save_path, 'ckp_{}.pth.tar'.format(epoch))) log.close()
def main(): parser = argparse.ArgumentParser(description="Train the model") parser.add_argument('-trainf', "--train-filepath", type=str, default=None, required=True, help="training dataset filepath.") parser.add_argument('-validf', "--val-filepath", type=str, default=None, help="validation dataset filepath.") parser.add_argument("--shuffle", action="store_true", default=False, help="Shuffle the dataset") parser.add_argument("--load-weights", type=str, default=None, help="load pretrained weights") parser.add_argument("--load-model", type=str, default=None, help="load pretrained model, entire model (filepath, default: None)") parser.add_argument("--debug", action="store_true", default=False) parser.add_argument('--epochs', type=int, default=30, help='number of epochs to train (default: 30)') parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument('--img-shape', type=str, default="(1,512,512)", help='Image shape (default "(1,512,512)"') parser.add_argument("--num-cpu", type=int, default=10, help="Number of CPUs to use in parallel for dataloader.") parser.add_argument('--cuda', type=int, default=0, help='CUDA visible device (use CPU if -1, default: 0)') parser.add_argument('--cuda-non-deterministic', action='store_true', default=False, help="sets flags for non-determinism when using CUDA (potentially fast)") parser.add_argument('-lr', type=float, default=0.0005, help='Learning rate') parser.add_argument('--seed', type=int, default=0, help='Seed (numpy and cuda if GPU is used.).') parser.add_argument('--log-dir', type=str, default=None, help='Save the results/model weights/logs under the directory.') args = parser.parse_args() # TODO: support image reshape img_shape = tuple(map(int, args.img_shape.strip()[1:-1].split(","))) if args.log_dir: os.makedirs(args.log_dir, exist_ok=True) best_model_path = os.path.join(args.log_dir, "model_weights.pth") else: best_model_path = None if args.seed is not None: np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda >= 0: if args.cuda_non_deterministic: printBlue("Warning: using CUDA non-deterministc. Could be faster but results might not be reproducible.") else: printBlue("Using CUDA deterministc. Use --cuda-non-deterministic might accelerate the training a bit.") # Make CuDNN Determinist torch.backends.cudnn.deterministic = not args.cuda_non_deterministic # torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # TODO [OPT] enable multi-GPUs ? # https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html device = torch.device("cuda:{}".format(args.cuda) if torch.cuda.is_available() and (args.cuda >= 0) else "cpu") # ================= Build dataloader ================= # DataLoader # transform_normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], # std=[0.5, 0.5, 0.5]) transform_normalize = transforms.Normalize(mean=[0.5], std=[0.5]) # Warning: DO NOT use geometry transform (do it in the dataloader instead) data_transform = transforms.Compose([ # transforms.ToPILImage(mode='F'), # mode='F' for one-channel image # transforms.Resize((256, 256)) # NO # transforms.RandomResizedCrop(256), # NO # transforms.RandomHorizontalFlip(p=0.5), # NO # WARNING, ISSUE: transforms.ColorJitter doesn't work with ToPILImage(mode='F'). # Need custom data augmentation functions: TODO: DONE. # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # Use OpenCVRotation, OpenCVXXX, ... (our implementation) # OpenCVRotation((-10, 10)), # angles (in degree) transforms.ToTensor(), # already done in the dataloader transform_normalize ]) geo_transform = GeoCompose([ OpenCVRotation(angles=(-10, 10), scales=(0.9, 1.1), centers=(-0.05, 0.05)), # TODO add more data augmentation here ]) def worker_init_fn(worker_id): # WARNING spawn start method is used, # worker_init_fn cannot be an unpicklable object, e.g., a lambda function. # A work-around for issue #5059: https://github.com/pytorch/pytorch/issues/5059 np.random.seed() data_loader_train = {'batch_size': args.batch_size, 'shuffle': args.shuffle, 'num_workers': args.num_cpu, # 'sampler': balanced_sampler, 'drop_last': True, # for GAN-like 'pin_memory': False, 'worker_init_fn': worker_init_fn, } data_loader_valid = {'batch_size': args.batch_size, 'shuffle': False, 'num_workers': args.num_cpu, 'drop_last': False, 'pin_memory': False, } train_set = LiTSDataset(args.train_filepath, dtype=np.float32, geometry_transform=geo_transform, # TODO enable data augmentation pixelwise_transform=data_transform, ) valid_set = LiTSDataset(args.val_filepath, dtype=np.float32, pixelwise_transform=data_transform, ) dataloader_train = torch.utils.data.DataLoader(train_set, **data_loader_train) dataloader_valid = torch.utils.data.DataLoader(valid_set, **data_loader_valid) # =================== Build model =================== # TODO: control the model by bash command if args.load_weights: model = UNet(in_ch=1, out_ch=3, # there are 3 classes: 0: background, 1: liver, 2: tumor depth=4, start_ch=32, # 64 inc_rate=2, kernel_size=5, # 3 padding=True, batch_norm=True, spec_norm=False, dropout=0.5, up_mode='upconv', include_top=True, include_last_act=False, ) printYellow(f"Loading pretrained weights from: {args.load_weights}...") model.load_state_dict(torch.load(args.load_weights)) printYellow("+ Done.") elif args.load_model: # load entire model model = torch.load(args.load_model) printYellow("Successfully loaded pretrained model.") model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.95)) # TODO best_valid_loss = float('inf') # TODO TODO: add learning decay for epoch in range(args.epochs): for valid_mode, dataloader in enumerate([dataloader_train, dataloader_valid]): n_batch_per_epoch = len(dataloader) if args.debug: n_batch_per_epoch = 1 # infinite dataloader allows several update per iteration (for special models e.g. GAN) dataloader = infinite_dataloader(dataloader) if valid_mode: printYellow("Switch to validation mode.") model.eval() prev_grad_mode = torch.is_grad_enabled() torch.set_grad_enabled(False) else: model.train() st = time.time() cum_loss = 0 for iter_ind in range(n_batch_per_epoch): supplement_logs = "" # reset cumulated losses at the begining of each batch # loss_manager.reset_losses() # TODO: use torch.utils.tensorboard !! optimizer.zero_grad() img, msk = next(dataloader) img, msk = img.to(device), msk.to(device) # TODO this is ugly: convert dtype and convert the shape from (N, 1, 512, 512) to (N, 512, 512) msk = msk.to(torch.long).squeeze(1) msk_pred = model(img) # shape (N, 3, 512, 512) # label_weights is determined according the liver_ratio & tumor_ratio # loss = CrossEntropyLoss(msk_pred, msk, label_weights=[1., 10., 100.], device=device) loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 50.], device=device) # loss = DiceLoss(msk_pred, msk, label_weights=[1., 20., 500.], device=device) if valid_mode: pass else: loss.backward() optimizer.step() loss = loss.item() # release cum_loss += loss if valid_mode: print("\r--------(valid) {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format( (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="") else: print("\rEpoch: {:3}/{} {:.2%} Loss: {:.3f} (time: {:.1f}s) |supp: {}".format( (epoch+1), args.epochs, (iter_ind+1)/n_batch_per_epoch, cum_loss/(iter_ind+1), time.time()-st, supplement_logs), end="") print() if valid_mode: torch.set_grad_enabled(prev_grad_mode) valid_mean_loss = cum_loss/(iter_ind+1) # validation (mean) loss of the current epoch if best_model_path and (valid_mean_loss < best_valid_loss): printGreen("Valid loss decreases from {:.5f} to {:.5f}, saving best model.".format( best_valid_loss, valid_mean_loss)) best_valid_loss = valid_mean_loss # Only need to save the weights # torch.save(model.state_dict(), best_model_path) # save the entire model torch.save(model, best_model_path) return best_valid_loss
if not (os.path.exists(CHECKPOINT) and os.path.isfile(CHECKPOINT)): print("weights in wrong format or non-existant: \n\t{}".format( CHECKPOINT)) exit() # loading the model m = None if args.model == "unet": m = UNet(2, 2).to(device) else: m = simpleConvNet(2, 2).to(device) print("loading model weights from \"{}\"".format(CHECKPOINT)) checkpoint = torch.load(CHECKPOINT) m.load_state_dict(checkpoint['model_state_dict']) m.to(device) # pushing weights to gpu train_loss = checkpoint['train_loss'] test_loss = checkpoint['test_loss'] START = checkpoint['epoch'] scans = [ os.path.join(root_dir, i) for i in os.listdir(os.path.abspath(root_dir)) if os.path.isdir(os.path.join(os.path.abspath(root_dir), i)) and "_" in i ] if len(scans) == 0: print( "no scan data found (folder name must be in format mmddhhmmss_x with x beeing the serialnumber" ) exit() else:
pin_memory=True, shuffle=False, drop_last=False) # learning-rate LEARNING_RATE = 1e-3 # Число эпох N_EPOCHS = 10 # tensorboard writer = SummaryWriter(log_dir='./{}'.format(MODEL_NAME), comment=MODEL_NAME) # Задаем модель model = UNet(3, NUM_PTS) model.to(device) with torch.no_grad(): # writer.add_graph(model, next(iter(val_dataloader))['image'].to(device)) summary(model, next(iter(train_dataloader))['image'].shape[1:]) # Задаем параметры оптимизации optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, amsgrad=True) # criterion = F.mse_loss criterion = AdaptiveWingLoss() # Временные параметры для выбора наилучшего результата best_val_loss, best_model_state_dict = np.inf, {} CURRENT_EPOCH = 0 for epoch in range(CURRENT_EPOCH, N_EPOCHS):