def train(net, train_loader, val_loader, callbacks, params, reduce_epochs=False): # trains a model with a specific training and validation loader, and manually specified callbacks t_start = time.perf_counter() epochs = params['epochs'] // 2 if reduce_epochs else params['epochs'] trainer = pl.Trainer(max_epochs=epochs, gpus=params['gpus'], accelerator=params['accelerator'], default_root_dir=params['log_dir'], flush_logs_every_n_steps=params['log_freq'], log_every_n_steps=params['log_freq'], callbacks=callbacks, progress_bar_refresh_rate=params['log_refresh_rate']) trainer.fit(net, train_loader, val_loader) t_stop = time.perf_counter() print_frm('Elapsed training time: %d hours, %d minutes, %.2f seconds' % process_seconds(t_stop - t_start)) # load the best checkpoint net.load_state_dict( torch.load(trainer.checkpoint_callback.best_model_path)['state_dict']) return trainer
def train_net(self, train_loader, test_loader, loss_fn, optimizer, epochs, scheduler=None, test_freq=1, augmenter=None, print_stats=1, log_dir=None, write_images_freq=1, device=0): """ Trains the network :param train_loader: data loader with training data :param test_loader: data loader with testing data :param loss_fn: loss function :param optimizer: optimizer for the loss function :param epochs: number of training epochs :param scheduler: optional scheduler for learning rate tuning :param test_freq: frequency of testing :param augmenter: data augmenter :param print_stats: frequency of logging statistics :param log_dir: logging directory :param write_images_freq: frequency of writing images :param device: GPU device where the computations should occur """ # log everything if necessary if log_dir is not None: writer = SummaryWriter(log_dir=log_dir) else: writer = None j_max = 0 for epoch in range(epochs): print_frm('Epoch %5d/%5d' % (epoch, epochs)) # train the model for one epoch self.train_epoch(loader=train_loader, loss_fn=loss_fn, optimizer=optimizer, epoch=epoch, augmenter=augmenter, print_stats=print_stats, writer=writer, write_images=epoch % write_images_freq == 0, device=device) # adjust learning rate if necessary if scheduler is not None: scheduler.step() # and keep track of the learning rate writer.add_scalar('learning_rate', float(scheduler.get_last_lr()[0]), epoch) # test the model for one epoch is necessary if epoch % test_freq == 0: j = self.test_epoch(loader=test_loader, loss_fn=loss_fn, epoch=epoch, writer=writer, write_images=True, device=device) # and save model if higher segmentation performance was obtained if j > j_max: j_max = j torch.save(self, os.path.join(log_dir, 'best_checkpoint.pytorch')) # save model every epoch torch.save(self, os.path.join(log_dir, 'checkpoint.pytorch')) writer.close()
def validate(net, trainer, loader, params): # validates a network that was trained using a specific trainer on a dataset t_start = time.perf_counter() test_data, test_labels = loader.dataset.data[0], loader.dataset.labels[0] validate_base(net, test_data, test_labels, params['input_size'], in_channels=params['in_channels'], classes_of_interest=params['coi'], batch_size=params['test_batch_size'], write_dir=os.path.join(trainer.log_dir, 'best_predictions'), val_file=os.path.join(trainer.log_dir, 'metrics.npy'), device=params['gpus'][0]) t_stop = time.perf_counter() print_frm('Elapsed testing time: %d hours, %d minutes, %.2f seconds' % process_seconds(t_stop - t_start))
def _select_subset(data, labels, n=1, sz_size=(512, 512), min_pos=0.01, coi=(0, 1)): # data dimensions z = int(n) y = int(min(sz_size[0], data.shape[1])) x = int(min(sz_size[1], data.shape[2])) data_ = np.zeros((z, y, x), dtype=data.dtype) labels_ = np.zeros((z, y, x), dtype=labels.dtype) # constant max_iters = 100 # select samples for z_ in range(z): found = False iters = 0 while not found: iters += 1 # select sample data_[z_:z_ + 1], labels_[z_:z_ + 1] = sample_labeled_input( data, labels, (1, y, x)) # check f sample is valid nnz = 0 for c in coi: if c > 0: nnz += np.sum(labels_[z_:z_ + 1] == c) if nnz / (y * x) > min_pos: print_frm('Sample %d successfully found!' % z_) found = True if iters > max_iters: print_frm( 'Maximum number of iterations reached.. selecting random sample' ) found = True # select the data and return return data_, labels_
print('[%s] Arguments: ' % (datetime.datetime.now())) print('[%s] %s' % (datetime.datetime.now(), args)) args.input_size = [int(item) for item in args.input_size.split(',')] """ Fix seed (for reproducibility) """ set_seed(args.seed) # parameters device = args.device # computing device n = args.n # amount of samples to be extracted per domain b = args.batch_size # batch size for processing input_size = args.input_size # load the network print_frm('Loading network') model_file = args.net net = _load_net(model_file, device) # load reference patch print_frm('Loading data') data_file = args.data_file df = json.load(open(data_file)) n_domains = len(df['raw']) input_shape = (1, input_size[0], input_size[1]) # datasets dss = [] for d in range(n_domains): print_frm('Loading %s' % df['raw'][d]) dss.append(
def test_epoch(self, loader_src, loader_tar_ul, loader_tar_l, epoch, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src) else: dl = zip(loader_src, loader_tar_l) # start epoch y_preds = [] ys = [] time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device x_src, y_src = tensor_to_device(data[0], device) x_tar_l, y_tar_l = tensor_to_device(data[1], device) x_src = x_src.float() x_tar_l = x_tar_l.float() y_src = y_src.long() y_tar_l = y_tar_l.long() # forward prop and compute loss y_src_pred = self(x_src) y_tar_l_pred = self(x_tar_l) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = loss_seg_src + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 for b in range(y_tar_l_pred.size(0)): y_preds.append( F.softmax(y_tar_l_pred, dim=1)[b, ...].view(y_tar_l_pred.size(1), -1).data.cpu().numpy()) ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy()) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # prep for metric computation y_preds = np.concatenate(y_preds, axis=1) ys = np.concatenate(ys) js = np.asarray([ jaccard((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) ams = np.asarray([ accuracy_metrics((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Testing Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'total-loss', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src.data, y_src.data, y_src_pred, x_tar_l.data, y_tar_l, y_tar_l_pred ], [ 'test/' + s for s in [ 'src/x', 'src/y', 'src/y-pred', 'tar/x-l', 'tar/y-l', 'tar/y-l-pred' ] ], writer, epoch=epoch) return total_loss_avg
def train_epoch(self, loader_src, loader_tar_ul, loader_tar_l, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 loss_rec_src_cum = 0.0 loss_rec_tar_cum = 0.0 loss_dc_x_cum = 0.0 loss_dc_y_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src, loader_tar_ul) else: dl = zip(loader_src, loader_tar_ul, loader_tar_l) # start epoch time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device data_src = tensor_to_device(data[0], device) x_tar_ul = tensor_to_device(data[1], device) if loader_tar_l is not None: data_tar_l = tensor_to_device(data[2], device) # augment if necessary if loader_tar_l is None: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (x_tar_ul, x_tar_ul) x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter) else: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (x_tar_ul, x_tar_ul) x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter) data_aug = (data_tar_l[0], data_tar_l[1]) x_tar_l, y_tar_l = augment_samples(data_aug, augmenter=augmenter) y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int) y_src = get_labels(y_src, coi=self.coi, dtype=int) x_tar_ul = x_tar_ul.float() # zero the gradient buffers self.zero_grad() # get domain labels for domain confusion dom_labels_x = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_x[x_src.size(0):] = 1 dom_labels_y = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_y[x_src.size(0):] = 1 # check train mode and compute loss loss_seg_src = torch.Tensor([0]) loss_seg_tar = torch.Tensor([0]) loss_rec_src = torch.Tensor([0]) loss_rec_tar = torch.Tensor([0]) loss_dc_x = torch.Tensor([0]) loss_dc_y = torch.Tensor([0]) if self.train_mode == RECONSTRUCTION: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x elif self.train_mode == SEGMENTATION: # switch between reconstructed and original inputs if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: x_src_rec, _ = self.forward_rec(x_src) y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: x_tar_ul_rec, _ = self.forward_rec(x_tar_ul) y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_dc * loss_dc_y if loader_tar_l is not None: y_tar_l_pred, _ = self.forward_seg(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar else: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \ self.lambda_dc * (loss_dc_x + loss_dc_y) if loader_tar_l is not None: _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() loss_rec_src_cum += loss_rec_src.data.cpu().numpy() loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy() loss_dc_x_cum += loss_dc_x.data.cpu().numpy() loss_dc_y_cum += loss_dc_y.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 # backward prop total_loss.backward() # apply one step in the optimization optimizer.step() # print statistics of necessary if i % print_stats == 0: print( '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, i, len(loader_src.dataset) / loader_src.batch_size, loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt, loss_rec_src_cum / cnt, loss_rec_tar_cum / cnt, loss_dc_x_cum / cnt, loss_dc_y_cum / cnt, total_loss_cum / cnt)) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt loss_rec_src_avg = loss_rec_src_cum / cnt loss_rec_tar_avg = loss_rec_tar_cum / cnt loss_dc_x_avg = loss_dc_x_cum / cnt loss_dc_y_avg = loss_dc_y_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars if self.train_mode == RECONSTRUCTION: log_scalars( [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [ 'train/' + s for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: log_scalars( [loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg], [ 'train/' + s for s in ['loss-seg-src', 'loss-seg-tar', 'loss-dc-y'] ], writer, epoch=epoch) else: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg ], [ 'train/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-rec-src', 'loss-rec-tar', 'loss-dc-x', 'loss-dc-y' ] ], writer, epoch=epoch) log_scalars([total_loss_avg], ['train/' + s for s in ['total-loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x_src.data], ['train/' + s for s in ['src/x']], writer, epoch=epoch) if self.train_mode == RECONSTRUCTION: log_images_2d( [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [ 'train/' + s for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [y_src.data, y_src_pred], ['train/' + s for s in ['src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) else: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data, x_tar_ul_rec.data ], [ 'train/' + s for s in [ 'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul', 'tar/x-ul-rec' ] ], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
def test_epoch(self, loader, epoch, writer=None, write_images=False, device=0): """ Tests the network for one epoch :param loader: dataloader :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average testing loss over the epoch """ # make sure network is on the gpu and in training mode module_to_device(self, device) self.eval() # keep track of the average losses during the epoch loss_rec_cum = 0.0 loss_dc_cum = 0.0 loss_cum = 0.0 cnt = 0 # start epoch for i, data in enumerate(loader): # transfer to suitable device x, dom = data x = tensor_to_device(x.float(), device) dom = tensor_to_device(dom.long(), device) # forward prop x_pred, dom_pred = self(x) x_pred = torch.sigmoid(x_pred) # compute loss loss_rec = self.loss_rec_fn(x_pred, x) loss_dc = self.loss_dc_fn(dom_pred, dom) loss = loss_rec + self.lambda_reg * loss_dc loss_rec_cum += loss_rec.data.cpu().numpy() loss_dc_cum += loss_dc.data.cpu().numpy() loss_cum += loss.data.cpu().numpy() cnt += 1 # don't forget to compute the average and print it loss_rec_avg = loss_rec_cum / cnt loss_dc_avg = loss_dc_cum / cnt loss_avg = loss_cum / cnt print_frm( 'Epoch %5d - Average test loss rec: %.6f - Average test loss DC: %.6f - Average test loss: %.6f' % (epoch, loss_rec_avg, loss_dc_avg, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_rec_avg, loss_dc_avg, loss_avg], ['test/' + s for s in ['loss-rec', 'loss-dc', 'loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x, x_pred], ['test/' + s for s in ['x', 'x_pred']], writer, epoch=epoch) return loss_avg
def test_epoch(self, loader, loss_fn, epoch, writer=None, write_images=False, device=0): """ Tests the network for one epoch :param loader: dataloader :param loss_fn: loss function :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average testing loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss and metrics during the epoch loss_cum = 0.0 cnt = 0 # test loss y_preds = [] ys = [] ys_ = [] time_start = datetime.datetime.now() for i, data in enumerate(loader): # get the inputs and transfer to suitable device x, y = tensor_to_device(data, device) y_ = get_unlabeled(y) x = x.float() y = get_labels(y, coi=self.coi, dtype=int) y_ = get_labels(y_, coi=[0, 255], dtype=bool) # forward prop y_pred = self(x) # compute loss loss = loss_fn(y_pred, y[:, 0, ...], mask=~y_) loss_cum += loss.data.cpu().numpy() cnt += 1 for b in range(y_pred.size(0)): y_preds.append(F.softmax(y_pred, dim=1)[b, ...].view(y_pred.size(1), -1).data.cpu().numpy()) ys.append(y[b, 0, ...].flatten().cpu().numpy()) ys_.append(y_[b, 0, ...].flatten().cpu().numpy()) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # prep for metric computation y_preds = np.concatenate(y_preds, axis=1) ys = np.concatenate(ys) ys_ = np.concatenate(ys_) w = (1 - ys_).astype(bool) js = np.asarray([jaccard((ys == i).astype(int), y_preds[i, :], w=w) for i in range(len(self.coi))]) ams = np.asarray([accuracy_metrics((ys == i).astype(int), y_preds[i, :], w=w) for i in range(len(self.coi))]) # don't forget to compute the average and print it loss_avg = loss_cum / cnt print_frm('Epoch %5d - Average test loss: %.6f' % (epoch, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0))], ['test/' + s for s in ['loss-seg', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score']], writer, epoch=epoch) # log images if necessary if write_images: log_images_3d([x], ['test/' + s for s in ['x']], writer, epoch=epoch) y_pred = F.softmax(y_pred, dim=1) for i, c in enumerate(self.coi): if not i == 0: # skip background class y_p = y_pred[:, i:i + 1, ...].data y_t = (y == i).long() log_images_3d([y_t, y_p], ['test/' + s for s in ['y_class_%d)' % (c), 'y_pred_class_%d)' % (c)]], writer, epoch=epoch) return np.mean(js)
from neuralnets.util.io import print_frm from neuralnets.util.tools import set_seed from util.tools import parse_params, get_dataloaders from networks.factory import generate_model from train.base import train, validate from multiprocessing import freeze_support if __name__ == '__main__': freeze_support() """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", help="Path to the configuration file", type=str, default='train_supervised.yaml') parser.add_argument( "--clean-up", help="Boolean flag that specifies cleaning of the checkpoints", action='store_true', default=False) args = parser.parse_args() with open(args.config) as file: params = parse_params(yaml.load(file, Loader=yaml.FullLoader)) """
import torch.optim as optim from torch.utils.data import DataLoader from torchvision.transforms import Compose from neuralnets.data.datasets import StronglyLabeledVolumeDataset from neuralnets.networks.unet import UNet2D from neuralnets.util.augmentation import * from neuralnets.util.io import print_frm from neuralnets.util.losses import get_loss_function from neuralnets.util.tools import set_seed from neuralnets.util.validation import validate """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() # logging parameters parser.add_argument("--seed", help="Seed for randomization", type=int, default=0) parser.add_argument("--device", help="GPU device for computations", type=int, default=0) parser.add_argument("--log_dir", help="Logging directory", type=str, default="unet_2d")
def get_dataloaders(params, domain=None, domain_labels_available=1.0, supervised=False): input_shape = (1, *(params['input_size'])) transform = get_transforms(params['augmentation'], coi=params['coi']) print_frm('Applying data augmentation! Specifically %s' % str(params['augmentation'])) if domain is None: split_src = params['src']['train_val_test_split'] split_tar = params['tar']['train_val_test_split'] print_frm('Train data... ') train = LabeledVolumeDataset( (params['src']['data'], params['tar']['data']), (params['src']['labels'], params['tar']['labels']), len_epoch=params['len_epoch'], input_shape=input_shape, in_channels=params['in_channels'], type=params['type'], batch_size=params['train_batch_size'], transform=transform, range_split=((0, split_src[0]), (0, split_tar[0])), coi=params['coi'], range_dir=(params['src']['split_orientation'], params['tar']['split_orientation']), partial_labels=(1, params['tar_labels_available']), seed=params['seed']) print_frm('Validation data...') val = LabeledVolumeDataset( (params['src']['data'], params['tar']['data']), (params['src']['labels'], params['tar']['labels']), len_epoch=params['len_epoch'], input_shape=input_shape, in_channels=params['in_channels'], type=params['type'], batch_size=params['test_batch_size'], coi=params['coi'], range_split=((split_src[0], split_src[1]), (split_tar[0], split_tar[1])), range_dir=(params['src']['split_orientation'], params['tar']['split_orientation']), partial_labels=(1, params['tar_labels_available']), seed=params['seed']) print_frm('Test data...') test = LabeledSlidingWindowDataset( params['tar']['data'], params['tar']['labels'], in_channels=params['in_channels'], type=params['type'], batch_size=params['test_batch_size'], range_split=(split_tar[1], 1), range_dir=params['tar']['split_orientation'], coi=params['coi']) print_frm('Train volume shape: %s (source) - %s (target)' % (str(train.data[0].shape), str(train.data[1].shape))) print_frm( 'Available target labels for training: %.1f (i.e. %.2f MV)' % (params['tar_labels_available'] * 100, np.prod(train.data[1].shape) * params['tar_labels_available'] / 1000 / 1000)) print_frm('Validation volume shape: %s (source) - %s (target)' % (str(val.data[0].shape), str(val.data[1].shape))) print_frm('Test volume shape: %s (target)' % str(test.data[0].shape)) else: split = params['train_val_test_split'] if supervised else params[ domain]['train_val_test_split'] data = params['data'] if supervised else params[domain]['data'] labels = params['labels'] if supervised else params[domain]['labels'] range_dir = params['split_orientation'] if supervised else params[ domain]['split_orientation'] print_frm('Train data...') train = LabeledVolumeDataset(data, labels, len_epoch=params['len_epoch'], input_shape=input_shape, in_channels=params['in_channels'], type=params['type'], batch_size=params['train_batch_size'], transform=transform, range_split=(0, split[0]), range_dir=range_dir, partial_labels=domain_labels_available, seed=params['seed'], coi=params['coi']) print_frm('Validation data...') val = LabeledVolumeDataset(data, labels, len_epoch=params['len_epoch'], input_shape=input_shape, in_channels=params['in_channels'], type=params['type'], batch_size=params['test_batch_size'], transform=transform, range_split=(split[0], split[1]), range_dir=range_dir, coi=params['coi'], partial_labels=domain_labels_available, seed=params['seed']) print_frm('Test data...') test = LabeledSlidingWindowDataset( data, labels, in_channels=params['in_channels'], type=params['type'], batch_size=params['test_batch_size'], transform=transform, range_split=(split[1], 1), range_dir=range_dir, coi=params['coi']) print_frm('Train volume shape: %s' % str(train.data[0].shape)) print_frm( 'Available %s labels for training: %d%% (i.e. %.2f MV)' % (domain, domain_labels_available * 100, np.prod( train.data[0].shape) * domain_labels_available / 1000 / 1000)) print_frm('Validation volume shape: %s' % str(val.data[0].shape)) print_frm('Test volume shape: %s' % str(test.data[0].shape)) train_loader = DataLoader(train, batch_size=params['train_batch_size'], num_workers=params['num_workers'], pin_memory=True) val_loader = DataLoader(val, batch_size=params['test_batch_size'], num_workers=params['num_workers'], pin_memory=True) test_loader = DataLoader(test, batch_size=params['test_batch_size'], num_workers=params['num_workers'], pin_memory=True) return train_loader, val_loader, test_loader
def mv(source, target): print_frm(' Moving %s -> %s' % (source, target)) shutil.move(source, target)
def cp(source, target): print_frm(' Copying %s -> %s' % (source, target)) shutil.copyfile(source, target)
def rmdir(dir): print_frm(' Removing %s' % dir) shutil.rmtree(dir, ignore_errors=True)
def __init__(self, data_path, input_shape, split_orientation='z', split_location=0.50, scaling=None, len_epoch=1000, types=['tif3d'], sampling_mode='uniform', in_channels=1, orientations=(0, ), batch_size=1, dtype='uint8', norm_type='unit', train=True, available=-1): self.data_path = data_path self.input_shape = input_shape self.scaling = scaling self.len_epoch = len_epoch self.sampling_mode = sampling_mode self.in_channels = in_channels self.orientations = orientations self.orientation = 0 self.k = 0 self.batch_size = batch_size self.norm_type = norm_type # load the data self.data = [] self.data_sizes = [] for k, path in enumerate(data_path): print_frm('Loading dataset %d/%d: %s' % (k, len(data_path), path)) d = 0 if split_orientation[ k] == 'z' else 1 if split_orientation[k] == 'y' else 2 if split_orientation[k] == 'z': split = int(len(os.listdir(path)) * split_location[k]) start = 0 if train else split stop = split if train else -1 data = read_volume(path, type=types[k], dtype=dtype, start=start, stop=stop) else: data = read_volume(path, type=types[k], dtype=dtype) split = int(data.shape[d] * split_location[k]) if split_orientation[k] == 'y': data = data[:, :split, :] if train else data[:, split:, :] else: data = data[:, :, :split] if train else data[:, :, split:] # rescale the dataset if necessary if scaling is not None: target_size = np.asarray(np.multiply(data.shape, scaling), dtype=int) data = F.interpolate(torch.Tensor(data[np.newaxis, np.newaxis, ...]), size=tuple(target_size), mode='area')[0, 0, ...].numpy() self.data.append(data) self.data_sizes.append(data.size) self.data_sizes = np.array(self.data_sizes) self.data_sizes = self.data_sizes / np.sum(self.data_sizes)
from neuralnets.util.io import print_frm, read_pngseq from neuralnets.util.tools import set_seed from neuralnets.util.validation import segment_read, segment_ram from util.tools import parse_params, process_seconds from networks.factory import generate_model from multiprocessing import freeze_support if __name__ == '__main__': freeze_support() """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", help="Path to the network configuration file", type=str, default='../train_supervised.yaml') parser.add_argument("--model", "-m", help="Path to the network parameters", type=str, required=True) parser.add_argument("--dataset", "-d", help="Path to the dataset that needs to be segmented", type=str, required=True) parser.add_argument("--block_wise", "-bw", help="Flag that specifies to compute block wise or not", action='store_true', default=False) parser.add_argument("--output", "-o", help="Path to store the output segmentation", type=str, required=True) parser.add_argument("--gpu", "-g", help="GPU device for computations", type=int, default=0) args = parser.parse_args() with open(args.config) as file: params = parse_params(yaml.load(file, Loader=yaml.FullLoader)) """
args.input_size = [int(item) for item in args.input_size.split(',')] """ Fix seed (for reproducibility) """ set_seed(args.seed) # parameters domain_id = args.domain_id # id of the domain where a reference patch should be selected device = args.device # computing device n = args.n # amount of samples to be extracted per domain b = args.batch_size # batch size for processing input_size = args.input_size k = args.k # amount of closest samples to be extracted # load the network print_frm('Loading network') model_file = args.net net = _load_net(model_file, device) # load reference patch print_frm('Loading data') data_file = args.data_file df = json.load(open(data_file)) n_domains = len(df['raw']) input_shape = (1, input_size[0], input_size[1]) dataset_ref = UnlabeledVolumeDataset( df['raw'][domain_id], split_orientation=df['split-orientation'][domain_id], split_location=df['split-location'][domain_id], input_shape=input_shape, type=df['types'][domain_id],
from neuralnets.util.augmentation import * from neuralnets.util.io import print_frm, mkdir from neuralnets.util.tools import set_seed from util.tools import parse_params from networks.factory import generate_model from train.base import train, validate from multiprocessing import freeze_support if __name__ == '__main__': freeze_support() """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", help="Path to the configuration file", type=str, default='clem1.yaml') parser.add_argument( "--clean-up", help="Boolean flag that specifies cleaning of the checkpoints", action='store_true', default=False) args = parser.parse_args() with open(args.config) as file: params = parse_params(yaml.load(file, Loader=yaml.FullLoader)) """
from util.tools import parse_params, get_dataloaders, rmdir, mv, cp from networks.factory import generate_model from train.base import train, validate from multiprocessing import freeze_support if __name__ == '__main__': freeze_support() """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", help="Path to the configuration file", type=str, default='train_semi_supervised.yaml') parser.add_argument("--clean-up", help="Boolean flag that specifies cleaning of the checkpoints", action='store_true', default=False) args = parser.parse_args() with open(args.config) as file: params = parse_params(yaml.load(file, Loader=yaml.FullLoader)) """ Fix seed (for reproducibility) """ set_seed(params['seed']) """
def train_epoch(self, loader, loss_fn, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader: dataloader :param loss_fn: loss function :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_cum = 0.0 cnt = 0 # start epoch time_start = datetime.datetime.now() for i, data in enumerate(loader): # transfer to suitable device and get labels data = tensor_to_device(data, device) y = get_labels(data[1], coi=self.coi, dtype=float) # filter out unlabeled pixels and include them in augmentation y_ = get_unlabeled(data[1], dtype=float) data.append(y) data.append(y_) # perform augmentation and transform to appropriate type x, _, y, y_ = augment_samples(data, augmenter=augmenter) rep = 0 rep_max = 10 while ( 1 - y_).sum() == 0 and rep < rep_max: # make sure labels are not lost in augmentation, otherwise augment new sample x, _, y, y_ = augment_samples(data, augmenter=augmenter) rep += 1 if rep == rep_max: x, _, y, y_ = data x = x.float() y = y.round().long() # clean labels if necessary (due to augmentations) if len(self.coi) > 2: y = clean_labels(y, len(self.coi)) y_ = y_.bool() # zero the gradient buffers self.zero_grad() # forward prop y_pred = self(x) # compute loss loss = loss_fn(y_pred, y[:, 0, ...], mask=~y_) loss_cum += loss.data.cpu().numpy() cnt += 1 # backward prop loss.backward() # apply one step in the optimization optimizer.step() # print statistics if necessary if i % print_stats == 0: print_frm('Epoch %5d - Iteration %5d/%5d - Loss: %.6f' % ( epoch, i, len(loader.dataset) / loader.batch_size, loss)) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # don't forget to compute the average and print it loss_avg = loss_cum / cnt print_frm('Epoch %5d - Average train loss: %.6f' % (epoch, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_avg], ['train/' + s for s in ['loss-seg']], writer, epoch=epoch) # log images if necessary if write_images: log_images_3d([x], ['train/' + s for s in ['x']], writer, epoch=epoch) y_pred = F.softmax(y_pred, dim=1) for i, c in enumerate(self.coi): if not i == 0: # skip background class y_p = y_pred[:, i:i + 1, ...].data y_t = (y == i).long() log_images_3d([y_t, y_p], ['train/' + s for s in ['y_class_%d)' % (c), 'y_pred_class_%d)' % (c)]], writer, epoch=epoch) return loss_avg
def validate(net, data, labels, input_size, in_channels=1, classes_of_interest=(0, 1), batch_size=1, write_dir=None, val_file=None, track_progress=False, device=0, orientations=(0, ), normalization='unit'): """ Validate a network on a dataset and its labels :param net: image-to-image segmentation network :param data: 3D array (Z, Y, X) representing the 3D image :param labels: 3D array (Z, Y, X) representing the 3D labels :param input_size: size of the inputs (either 2 or 3-tuple) for processing :param in_channels: Amount of subsequent slices that serve as input for the network (should be odd) :param classes_of_interest: index of the label of interest :param batch_size: batch size for processing :param write_dir: optionally, specify a directory to write the output :param val_file: optionally, specify a file to write the validation results :param track_progress: optionally, for tracking progress with progress bar :param device: GPU device where the computations should occur :param orientations: list of orientations to perform segmentation: 0-Z, 1-Y, 2-X (only for 2D based segmentation) :param normalization: type of data normalization (unit, z or minmax) :return: validation results, i.e. accuracy, precision, recall, f-score, jaccard and dice score """ print_frm('Validating the trained network...') # compute segmentation for each orientation and average results segmentation = np.zeros((net.out_channels, *data.shape)) for orientation in orientations: segmentation += segment(data, net, input_size, in_channels=in_channels, batch_size=batch_size, track_progress=track_progress, device=device, orientation=orientation, normalization=normalization) segmentation = segmentation / len(orientations) # compute metrics w = labels != 255 comp_hausdorff = np.sum(labels == 255) == 0 js = np.asarray([ jaccard(segmentation[i], (labels == c).astype('float'), w=w) for i, c in enumerate(classes_of_interest) ]) ams = np.asarray([ accuracy_metrics(segmentation[i], (labels == c).astype('float'), w=w) for i, c in enumerate(classes_of_interest) ]) for i, c in enumerate(classes_of_interest): if comp_hausdorff: h = hausdorff_distance(segmentation[i], labels)[0] else: h = -1 # report results print_frm('Validation performance for class %d: ' % c) print_frm(' - Accuracy: %f' % ams[i, 0]) print_frm(' - Balanced accuracy: %f' % ams[i, 1]) print_frm(' - Precision: %f' % ams[i, 2]) print_frm(' - Recall: %f' % ams[i, 3]) print_frm(' - F1 score: %f' % ams[i, 4]) print_frm(' - IoU: %f' % js[i]) print_frm(' - Hausdorff distance: %f' % h) # report results print_frm('Validation performance mean: ') print_frm(' - Accuracy: %f' % np.mean(ams[:, 0])) print_frm(' - Balanced accuracy: %f' % np.mean(ams[:, 1])) print_frm(' - Precision: %f' % np.mean(ams[:, 2])) print_frm(' - Recall: %f' % np.mean(ams[:, 3])) print_frm(' - F1 score: %f' % np.mean(ams[:, 4])) print_frm(' - mIoU: %f' % np.mean(js)) # write stuff if necessary if write_dir is not None: print_frm('Writing out the segmentation...') mkdir(write_dir) segmentation_volume = np.zeros(segmentation.shape[1:]) for i, c in enumerate(classes_of_interest): segmentation_volume[segmentation[i] > 0.5] = c write_volume(segmentation_volume, write_dir, type='pngseq') if val_file is not None: np.save(val_file, np.concatenate((js[:, np.newaxis], ams), axis=1)) return js, ams
def train_epoch(self, loader, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader: dataloader :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # make sure network is on the gpu and in training mode module_to_device(self, device) self.train() # keep track of the average losses during the epoch loss_rec_cum = 0.0 loss_dc_cum = 0.0 loss_cum = 0.0 cnt = 0 # start epoch for i, data in enumerate(loader): # transfer to suitable device x, dom = data x = tensor_to_device(x.float(), device) dom = tensor_to_device(dom.long(), device) # get the inputs and augment if necessary if augmenter is not None: x = augmenter(x) # zero the gradient buffers self.zero_grad() # forward prop x_pred, dom_pred = self(x) x_pred = torch.sigmoid(x_pred) # compute loss loss_rec = self.loss_rec_fn(x_pred, x) loss_dc = self.loss_dc_fn(dom_pred, dom) loss = loss_rec + self.lambda_reg * loss_dc loss_rec_cum += loss_rec.data.cpu().numpy() loss_dc_cum += loss_dc.data.cpu().numpy() loss_cum += loss.data.cpu().numpy() cnt += 1 # backward prop loss.backward() # apply one step in the optimization optimizer.step() # print statistics if necessary if i % print_stats == 0: print_frm( 'Epoch %5d - Iteration %5d/%5d - Loss Rec: %.6f - Loss DC: %.6f - Loss: %.6f' % (epoch, i, len(loader.dataset) / loader.batch_size, loss_rec, loss_dc, loss)) # don't forget to compute the average and print it loss_rec_avg = loss_rec_cum / cnt loss_dc_avg = loss_dc_cum / cnt loss_avg = loss_cum / cnt print_frm( 'Epoch %5d - Average train loss rec: %.6f - Average train loss DC: %.6f - Average train loss: %.6f' % (epoch, loss_rec_avg, loss_dc_avg, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars( [loss_rec_avg, loss_dc_avg, loss_avg], ['train/' + s for s in ['loss-rec', 'loss-dc', 'loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x, x_pred], ['train/' + s for s in ['x', 'x_pred']], writer, epoch=epoch) return loss_avg
def test_epoch(self, loader, loss_rec_fn, loss_kl_fn, epoch, writer=None, write_images=False, device=0): """ Tests the network for one epoch :param loader: dataloader :param loss_rec_fn: reconstruction loss function :param loss_kl_fn: kullback leibler loss function :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average testing loss over the epoch """ # make sure network is on the gpu and in training mode module_to_device(self, device) self.eval() # keep track of the average losses during the epoch loss_rec_cum = 0.0 loss_kl_cum = 0.0 loss_cum = 0.0 cnt = 0 # start epoch z = [] li = [] for i, data in enumerate(loader): # transfer to suitable device x = tensor_to_device(data.float(), device) # forward prop x_pred = torch.sigmoid(self(x)) z.append(_reparametrise(self.mu, self.logvar).cpu().data.numpy()) li.append(x.cpu().data.numpy()) # compute loss loss_rec = loss_rec_fn(x_pred, x) loss_kl = loss_kl_fn(self.mu, self.logvar) loss = loss_rec + self.beta * loss_kl loss_rec_cum += loss_rec.data.cpu().numpy() loss_kl_cum += loss_kl.data.cpu().numpy() loss_cum += loss.data.cpu().numpy() cnt += 1 # don't forget to compute the average and print it loss_rec_avg = loss_rec_cum / cnt loss_kl_avg = loss_kl_cum / cnt loss_avg = loss_cum / cnt print_frm( 'Epoch %5d - Average test loss rec: %.6f - Average test loss KL: %.6f - Average test loss: %.6f' % (epoch, loss_rec_avg, loss_kl_avg, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_rec_avg, loss_kl_avg, loss_avg], ['test/' + s for s in ['loss-rec', 'loss-kl', 'loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x, x_pred], ['test/' + s for s in ['x', 'x_pred']], writer, epoch=epoch) return loss_avg
results_final = np.asarray(results_final) # D x N x C x M results_best = np.asarray(results_best) # D x N x C x M # average over classes results_final = 100 * np.mean(results_final, axis=2) results_best = 100 * np.mean(results_best, axis=2) # compute average and standard deviation over experiments results_final_mean = np.mean(results_final, axis=1) results_final_std = np.std(results_final, axis=1) results_best_mean = np.mean(results_best, axis=1) results_best_std = np.std(results_best, axis=1) # report mean performance for i, dom in enumerate(domains): print_frm('Domain: %s' % dom) print_frm('') print_frm('Validation performance final: ') for j, metric in enumerate(metrics): print_frm(' - %s: %.2f (+/- %.2f)' % (metric, results_final_mean[i, j], results_final_std[i, j])) print_frm('') print_frm('Validation performance best: ') for j, metric in enumerate(metrics): print_frm(' - %s: %.2f (+/- %.2f)' % (metric, results_best_mean[i, j], results_best_std[i, j])) print_frm('') print_frm('=================================================')
def generate_model(name, params): if name == 'u-net' or name == 'no-da': net = UNetDA2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr']) elif name == 'mmd': net = UNetMMD2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr'], lambda_mmd=params['lambda_mmd']) elif name == 'dat': net = UNetDAT2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr'], lambda_dat=params['lambda_dat'], input_shape=params['input_size']) elif name == 'ynet': net = YNet2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr'], lambda_rec=params['lambda_rec']) elif name == 'wnet': net = WNet2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr'], lambda_rec=params['lambda_rec'], lambda_dat=params['lambda_dat'], input_shape=params['input_size']) elif name == 'unet-ts': net = UNetTS2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr'], lambda_w=params['lambda_w'], lambda_o=params['lambda_o']) else: net = UNetDA2D(in_channels=params['in_channels'], feature_maps=params['fm'], levels=params['levels'], dropout_enc=params['dropout'], dropout_dec=params['dropout'], norm=params['norm'], activation=params['activation'], coi=params['coi'], loss_fn=params['loss'], lr=params['lr']) print_frm('Employed network: %s' % str(net.__class__.__name__)) print_frm(' - Input channels: %d' % params['in_channels']) print_frm(' - Initial feature maps: %d' % params['fm']) print_frm(' - Levels: %d' % params['levels']) print_frm(' - Dropout: %.2f' % params['dropout']) print_frm(' - Normalization: %s' % params['norm']) print_frm(' - Activation: %s' % params['activation']) print_frm(' - Classes of interest: %s' % str(params['coi'])) print_frm(' - Initial learning rate: %f' % params['lr']) return net
def test_epoch(self, loader_src, loader_tar_ul, loader_tar_l, epoch, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 loss_rec_src_cum = 0.0 loss_rec_tar_cum = 0.0 loss_dc_x_cum = 0.0 loss_dc_y_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders dl = zip(loader_src, loader_tar_ul, loader_tar_l) # start epoch y_preds = [] ys = [] time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device x_src, y_src = tensor_to_device(data[0], device) x_tar_ul = tensor_to_device(data[1], device) x_tar_l, y_tar_l = tensor_to_device(data[2], device) x_src = x_src.float() x_tar_ul = x_tar_ul.float() x_tar_l = x_tar_l.float() y_src = y_src.long() y_tar_l = y_tar_l.long() # get domain labels for domain confusion dom_labels_x = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_x[x_src.size(0):] = 1 dom_labels_y = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_y[x_src.size(0):] = 1 # check train mode and compute loss loss_seg_src = torch.Tensor([0]) loss_seg_tar = torch.Tensor([0]) loss_rec_src = torch.Tensor([0]) loss_rec_tar = torch.Tensor([0]) loss_dc_x = torch.Tensor([0]) loss_dc_y = torch.Tensor([0]) if self.train_mode == RECONSTRUCTION: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x elif self.train_mode == SEGMENTATION: # switch between reconstructed and original inputs if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: x_src_rec, _ = self.forward_rec(x_src) y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: x_tar_ul_rec, _ = self.forward_rec(x_tar_ul) y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_dc * loss_dc_y y_tar_l_pred, _ = self.forward_seg(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar else: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \ self.lambda_dc * (loss_dc_x + loss_dc_y) _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() loss_rec_src_cum += loss_rec_src.data.cpu().numpy() loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy() loss_dc_x_cum += loss_dc_x.data.cpu().numpy() loss_dc_y_cum += loss_dc_y.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 if self.train_mode == SEGMENTATION or self.train_mode == JOINT: for b in range(y_tar_l_pred.size(0)): y_preds.append( F.softmax(y_tar_l_pred, dim=1)[b, ...].view(y_tar_l_pred.size(1), -1).data.cpu().numpy()) ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy()) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # prep for metric computation if self.train_mode == SEGMENTATION or self.train_mode == JOINT: y_preds = np.concatenate(y_preds, axis=1) ys = np.concatenate(ys) js = np.asarray([ jaccard((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) ams = np.asarray([ accuracy_metrics((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt loss_rec_src_avg = loss_rec_src_cum / cnt loss_rec_tar_avg = loss_rec_tar_cum / cnt loss_dc_x_avg = loss_dc_x_cum / cnt loss_dc_y_avg = loss_dc_y_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Testing Epoch %5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars if self.train_mode == RECONSTRUCTION: log_scalars( [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [ 'test/' + s for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-dc-y', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) else: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-rec-src', 'loss-rec-tar', 'loss-dc-x', 'loss-dc-y', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) log_scalars([total_loss_avg], ['test/' + s for s in ['total-loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x_src.data], ['test/' + s for s in ['src/x']], writer, epoch=epoch) if self.train_mode == RECONSTRUCTION: log_images_2d( [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [ 'test/' + s for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [y_src.data, y_src_pred], ['test/' + s for s in ['src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'test/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) else: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data, x_tar_ul_rec.data ], [ 'test/' + s for s in [ 'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul', 'tar/x-ul-rec' ] ], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'test/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
def __init__(self, data_path, label_path, input_shape=None, split_orientation='z', split_location=0.50, scaling=None, len_epoch=1000, type='tif3d', coi=(0, 1), in_channels=1, orientations=(0, ), batch_size=1, data_dtype='uint8', label_dtype='uint8', norm_type='unit', train=True, available=-1): super().__init__(data_path, input_shape, split_orientation=split_orientation, split_location=split_location, scaling=scaling, len_epoch=len_epoch, type=type, in_channels=in_channels, orientations=orientations, batch_size=batch_size, dtype=data_dtype, norm_type=norm_type, train=train) self.label_path = label_path self.coi = coi self.available = available # load labels d = 0 if split_orientation == 'z' else 1 if split_orientation == 'y' else 2 if split_orientation == 'z': split = int(len(os.listdir(label_path)) * split_location) start = 0 if train else split stop = split if train else -1 self.labels = read_volume(label_path, type=type, dtype=label_dtype, start=start, stop=stop) else: data = read_volume(label_path, type=type, dtype=label_dtype) split = int(data.shape[d] * split_location) if split_orientation == 'y': self.labels = data[:, :split, :] if train else data[:, split:, :] else: self.labels = data[:, :, :split] if train else data[:, :, split:] # rescale the dataset if necessary if scaling is not None: target_size = np.asarray(np.multiply(self.labels.shape, scaling), dtype=int) self.labels = F.interpolate(torch.Tensor(self.labels[np.newaxis, np.newaxis, ...]), size=tuple(target_size), mode='area')[0, 0, ...].numpy() # select a crop of the data if necessary print_frm('Original dataset size: %d x %d x %d (total: %d)' % (self.data.shape[0], self.data.shape[1], self.data.shape[2], self.data.size)) if available > 0: self.data, self.labels = _select_subset(self.data, self.labels, n=available, coi=coi) t_str = 'training' if train else 'testing' print_frm('Used for %s: %d x %d x %d (total: %d)' % (t_str, self.data.shape[0], self.data.shape[1], self.data.shape[2], self.data.size))
def train_epoch(self, loader_src, loader_tar_ul, loader_tar_l, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src) else: dl = zip(loader_src, loader_tar_l) # start epoch time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device data_src = tensor_to_device(data[0], device) if loader_tar_l is not None: data_tar_l = tensor_to_device(data[1], device) # augment if necessary if loader_tar_l is None: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) else: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (data_tar_l[0], data_tar_l[1]) x_tar_l, y_tar_l = augment_samples(data_aug, augmenter=augmenter) y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int) # zero the gradient buffers self.zero_grad() # forward prop and compute loss loss_seg_tar = torch.Tensor([0]) y_src_pred = self(x_src) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) total_loss = loss_seg_src if loader_tar_l is not None: y_tar_l_pred = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 # backward prop total_loss.backward() # apply one step in the optimization optimizer.step() # print statistics of necessary if i % print_stats == 0: print( '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, i, len(loader_src.dataset) / loader_src.batch_size, loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt, total_loss_cum / cnt)) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg], [ 'train/' + s for s in ['loss-seg-src', 'loss-seg-tar', 'total-loss'] ], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [x_src.data, y_src.data, y_src_pred], ['train/' + s for s in ['src/x', 'src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
from neuralnets.util.io import print_frm, save from neuralnets.util.tools import set_seed, log_hparams from multiprocessing import freeze_support from sklearn.model_selection import GridSearchCV from torch.utils.data import DataLoader from util.tools import parse_params, parse_search_grid, get_transforms from networks.factory import generate_classifier if __name__ == '__main__': freeze_support() """ Parse all the arguments """ print_frm('Parsing arguments') parser = argparse.ArgumentParser() parser.add_argument("--config", "-c", help="Path to the configuration file", type=str, default='cross_validate.yaml') args = parser.parse_args() with open(args.config) as file: params = parse_params(yaml.load(file, Loader=yaml.FullLoader)) """ Fix seed (for reproducibility) """ set_seed(params['seed']) """ Load the data