def _load_daae(state_dict, device=0): """ Load a pretrained pytorch DAAE2D state dict :param state_dict: state dict of a daae :param device: index of the device (if there are no GPU devices, it will be moved to the CPU) :return: a module that corresponds to the trained network """ from networks.daae import DAAE2D # extract the hyperparameters of the network feature_maps = state_dict[ 'encoder.features.convblock1.conv1.unit.0.weight'].size(0) levels = int(list(state_dict.keys())[-15][len('decoder.features.upconv')]) bottleneck_in_features = state_dict['encoder.bottleneck.0.weight'].size(1) bottleneck_dim = state_dict['encoder.bottleneck.0.weight'].size(0) x = int( np.sqrt(bottleneck_in_features * 2**(3 * levels - 1) / feature_maps)) norm = 'batch' if 'norm' in list(state_dict.keys())[2] else 'instance' lambda_reg = 0.0 activation = 'relu' dropout_enc = 0.0 n_hidden = state_dict['domain_classifier.linear2.unit.0.weight'].size(1) n_domains = state_dict['domain_classifier.linear2.unit.0.weight'].size(0) # initialize the network net = DAAE2D(lambda_reg=lambda_reg, input_size=[x, x], bottleneck_dim=bottleneck_dim, feature_maps=feature_maps, levels=levels, dropout_enc=dropout_enc, norm=norm, activation=activation, fc_channels=(n_hidden, n_domains)) # load the parameters in the model net.load_state_dict(state_dict) # map to the correct device module_to_device(net, device=device) return net
def train_epoch(self, loader_src, loader_tar_ul, loader_tar_l, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src) else: dl = zip(loader_src, loader_tar_l) # start epoch time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device data_src = tensor_to_device(data[0], device) if loader_tar_l is not None: data_tar_l = tensor_to_device(data[1], device) # augment if necessary if loader_tar_l is None: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) else: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (data_tar_l[0], data_tar_l[1]) x_tar_l, y_tar_l = augment_samples(data_aug, augmenter=augmenter) y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int) # zero the gradient buffers self.zero_grad() # forward prop and compute loss loss_seg_tar = torch.Tensor([0]) y_src_pred = self(x_src) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) total_loss = loss_seg_src if loader_tar_l is not None: y_tar_l_pred = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 # backward prop total_loss.backward() # apply one step in the optimization optimizer.step() # print statistics of necessary if i % print_stats == 0: print( '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, i, len(loader_src.dataset) / loader_src.batch_size, loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt, total_loss_cum / cnt)) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg], [ 'train/' + s for s in ['loss-seg-src', 'loss-seg-tar', 'total-loss'] ], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [x_src.data, y_src.data, y_src_pred], ['train/' + s for s in ['src/x', 'src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
def test_epoch(self, loader_src, loader_tar_ul, loader_tar_l, epoch, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src) else: dl = zip(loader_src, loader_tar_l) # start epoch y_preds = [] ys = [] time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device x_src, y_src = tensor_to_device(data[0], device) x_tar_l, y_tar_l = tensor_to_device(data[1], device) x_src = x_src.float() x_tar_l = x_tar_l.float() y_src = y_src.long() y_tar_l = y_tar_l.long() # forward prop and compute loss y_src_pred = self(x_src) y_tar_l_pred = self(x_tar_l) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = loss_seg_src + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 for b in range(y_tar_l_pred.size(0)): y_preds.append( F.softmax(y_tar_l_pred, dim=1)[b, ...].view(y_tar_l_pred.size(1), -1).data.cpu().numpy()) ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy()) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # prep for metric computation y_preds = np.concatenate(y_preds, axis=1) ys = np.concatenate(ys) js = np.asarray([ jaccard((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) ams = np.asarray([ accuracy_metrics((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Testing Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, total_loss_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'total-loss', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src.data, y_src.data, y_src_pred, x_tar_l.data, y_tar_l, y_tar_l_pred ], [ 'test/' + s for s in [ 'src/x', 'src/y', 'src/y-pred', 'tar/x-l', 'tar/y-l', 'tar/y-l-pred' ] ], writer, epoch=epoch) return total_loss_avg
def test_epoch(self, loader_src, loader_tar_ul, loader_tar_l, epoch, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 loss_rec_src_cum = 0.0 loss_rec_tar_cum = 0.0 loss_dc_x_cum = 0.0 loss_dc_y_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders dl = zip(loader_src, loader_tar_ul, loader_tar_l) # start epoch y_preds = [] ys = [] time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device x_src, y_src = tensor_to_device(data[0], device) x_tar_ul = tensor_to_device(data[1], device) x_tar_l, y_tar_l = tensor_to_device(data[2], device) x_src = x_src.float() x_tar_ul = x_tar_ul.float() x_tar_l = x_tar_l.float() y_src = y_src.long() y_tar_l = y_tar_l.long() # get domain labels for domain confusion dom_labels_x = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_x[x_src.size(0):] = 1 dom_labels_y = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_y[x_src.size(0):] = 1 # check train mode and compute loss loss_seg_src = torch.Tensor([0]) loss_seg_tar = torch.Tensor([0]) loss_rec_src = torch.Tensor([0]) loss_rec_tar = torch.Tensor([0]) loss_dc_x = torch.Tensor([0]) loss_dc_y = torch.Tensor([0]) if self.train_mode == RECONSTRUCTION: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x elif self.train_mode == SEGMENTATION: # switch between reconstructed and original inputs if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: x_src_rec, _ = self.forward_rec(x_src) y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: x_tar_ul_rec, _ = self.forward_rec(x_tar_ul) y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_dc * loss_dc_y y_tar_l_pred, _ = self.forward_seg(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar else: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \ self.lambda_dc * (loss_dc_x + loss_dc_y) _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() loss_rec_src_cum += loss_rec_src.data.cpu().numpy() loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy() loss_dc_x_cum += loss_dc_x.data.cpu().numpy() loss_dc_y_cum += loss_dc_y.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 if self.train_mode == SEGMENTATION or self.train_mode == JOINT: for b in range(y_tar_l_pred.size(0)): y_preds.append( F.softmax(y_tar_l_pred, dim=1)[b, ...].view(y_tar_l_pred.size(1), -1).data.cpu().numpy()) ys.append(y_tar_l[b, 0, ...].flatten().cpu().numpy()) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for testing: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # prep for metric computation if self.train_mode == SEGMENTATION or self.train_mode == JOINT: y_preds = np.concatenate(y_preds, axis=1) ys = np.concatenate(ys) js = np.asarray([ jaccard((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) ams = np.asarray([ accuracy_metrics((ys == i).astype(int), y_preds[i, :]) for i in range(len(self.coi)) ]) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt loss_rec_src_avg = loss_rec_src_cum / cnt loss_rec_tar_avg = loss_rec_tar_cum / cnt loss_dc_x_avg = loss_dc_x_cum / cnt loss_dc_y_avg = loss_dc_y_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Testing Epoch %5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars if self.train_mode == RECONSTRUCTION: log_scalars( [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [ 'test/' + s for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-dc-y', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) else: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, np.mean(js, axis=0), *(np.mean(ams, axis=0)) ], [ 'test/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-rec-src', 'loss-rec-tar', 'loss-dc-x', 'loss-dc-y', 'jaccard', 'accuracy', 'balanced-accuracy', 'precision', 'recall', 'f-score' ] ], writer, epoch=epoch) log_scalars([total_loss_avg], ['test/' + s for s in ['total-loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x_src.data], ['test/' + s for s in ['src/x']], writer, epoch=epoch) if self.train_mode == RECONSTRUCTION: log_images_2d( [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [ 'test/' + s for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [y_src.data, y_src_pred], ['test/' + s for s in ['src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'test/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) else: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data, x_tar_ul_rec.data ], [ 'test/' + s for s in [ 'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul', 'tar/x-ul-rec' ] ], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'test/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
def train_epoch(self, loader_src, loader_tar_ul, loader_tar_l, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_seg_src_cum = 0.0 loss_seg_tar_cum = 0.0 loss_rec_src_cum = 0.0 loss_rec_tar_cum = 0.0 loss_dc_x_cum = 0.0 loss_dc_y_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # zip dataloaders if loader_tar_l is None: dl = zip(loader_src, loader_tar_ul) else: dl = zip(loader_src, loader_tar_ul, loader_tar_l) # start epoch time_start = datetime.datetime.now() for i, data in enumerate(dl): # transfer to suitable device data_src = tensor_to_device(data[0], device) x_tar_ul = tensor_to_device(data[1], device) if loader_tar_l is not None: data_tar_l = tensor_to_device(data[2], device) # augment if necessary if loader_tar_l is None: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (x_tar_ul, x_tar_ul) x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter) else: data_aug = (data_src[0], data_src[1]) x_src, y_src = augment_samples(data_aug, augmenter=augmenter) data_aug = (x_tar_ul, x_tar_ul) x_tar_ul, _ = augment_samples(data_aug, augmenter=augmenter) data_aug = (data_tar_l[0], data_tar_l[1]) x_tar_l, y_tar_l = augment_samples(data_aug, augmenter=augmenter) y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int) y_src = get_labels(y_src, coi=self.coi, dtype=int) x_tar_ul = x_tar_ul.float() # zero the gradient buffers self.zero_grad() # get domain labels for domain confusion dom_labels_x = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_x[x_src.size(0):] = 1 dom_labels_y = tensor_to_device( torch.zeros((x_src.size(0) + x_tar_ul.size(0))), device).long() dom_labels_y[x_src.size(0):] = 1 # check train mode and compute loss loss_seg_src = torch.Tensor([0]) loss_seg_tar = torch.Tensor([0]) loss_rec_src = torch.Tensor([0]) loss_rec_tar = torch.Tensor([0]) loss_dc_x = torch.Tensor([0]) loss_dc_y = torch.Tensor([0]) if self.train_mode == RECONSTRUCTION: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) total_loss = loss_rec_src + loss_rec_tar + self.lambda_dc * loss_dc_x elif self.train_mode == SEGMENTATION: # switch between reconstructed and original inputs if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: x_src_rec, _ = self.forward_rec(x_src) y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: x_tar_ul_rec, _ = self.forward_rec(x_tar_ul) y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_dc * loss_dc_y if loader_tar_l is not None: y_tar_l_pred, _ = self.forward_seg(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar else: x_src_rec, x_src_rec_dom = self.forward_rec(x_src) if np.random.rand() < self.p: y_src_pred, y_src_pred_dom = self.forward_seg(x_src) else: y_src_pred, y_src_pred_dom = self.forward_seg(x_src_rec) dom_labels_y[:x_src.size(0)] = 1 x_tar_ul_rec, x_tar_ul_rec_dom = self.forward_rec(x_tar_ul) if np.random.rand() < self.p: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul) else: y_tar_ul_pred, y_tar_ul_pred_dom = self.forward_seg( x_tar_ul_rec) dom_labels_y[x_src.size(0):] = 1 loss_rec_src = self.rec_loss(x_src_rec, x_src) loss_rec_tar = self.rec_loss(x_tar_ul, x_tar_ul_rec) loss_seg_src = self.seg_loss(y_src_pred, y_src[:, 0, ...]) loss_dc_x = self.dc_loss( torch.cat((x_src_rec_dom, x_tar_ul_rec_dom), dim=0), dom_labels_x) loss_dc_y = self.dc_loss( torch.cat((y_src_pred_dom, y_tar_ul_pred_dom), dim=0), dom_labels_y) total_loss = loss_seg_src + self.lambda_rec * (loss_rec_src + loss_rec_tar) + \ self.lambda_dc * (loss_dc_x + loss_dc_y) if loader_tar_l is not None: _, y_tar_l_pred, _, y_tar_l_pred_dom = self(x_tar_l) loss_seg_tar = self.seg_loss(y_tar_l_pred, y_tar_l[:, 0, ...]) total_loss = total_loss + loss_seg_tar loss_seg_src_cum += loss_seg_src.data.cpu().numpy() loss_seg_tar_cum += loss_seg_tar.data.cpu().numpy() loss_rec_src_cum += loss_rec_src.data.cpu().numpy() loss_rec_tar_cum += loss_rec_tar.data.cpu().numpy() loss_dc_x_cum += loss_dc_x.data.cpu().numpy() loss_dc_y_cum += loss_dc_y.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 # backward prop total_loss.backward() # apply one step in the optimization optimizer.step() # print statistics of necessary if i % print_stats == 0: print( '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, i, len(loader_src.dataset) / loader_src.batch_size, loss_seg_src_cum / cnt, loss_seg_tar_cum / cnt, loss_rec_src_cum / cnt, loss_rec_tar_cum / cnt, loss_dc_x_cum / cnt, loss_dc_y_cum / cnt, total_loss_cum / cnt)) # keep track of time runtime = datetime.datetime.now() - time_start seconds = runtime.total_seconds() hours = seconds // 3600 minutes = (seconds - hours * 3600) // 60 seconds = seconds - hours * 3600 - minutes * 60 print_frm( 'Epoch %5d - Runtime for training: %d hours, %d minutes, %f seconds' % (epoch, hours, minutes, seconds)) # don't forget to compute the average and print it loss_seg_src_avg = loss_seg_src_cum / cnt loss_seg_tar_avg = loss_seg_tar_cum / cnt loss_rec_src_avg = loss_rec_src_cum / cnt loss_rec_tar_avg = loss_rec_tar_cum / cnt loss_dc_x_avg = loss_dc_x_cum / cnt loss_dc_y_avg = loss_dc_y_cum / cnt total_loss_avg = total_loss_cum / cnt print( '[%s] Training Epoch %4d - Loss seg src: %.6f - Loss seg tar: %.6f - Loss rec src: %.6f - Loss rec tar: %.6f - Loss DCX: %.6f - Loss DCY: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars if self.train_mode == RECONSTRUCTION: log_scalars( [loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg], [ 'train/' + s for s in ['loss-rec-src', 'loss-rec-tar', 'loss-dc-x'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: log_scalars( [loss_seg_src_avg, loss_seg_tar_avg, loss_dc_y_avg], [ 'train/' + s for s in ['loss-seg-src', 'loss-seg-tar', 'loss-dc-y'] ], writer, epoch=epoch) else: log_scalars([ loss_seg_src_avg, loss_seg_tar_avg, loss_rec_src_avg, loss_rec_tar_avg, loss_dc_x_avg, loss_dc_y_avg ], [ 'train/' + s for s in [ 'loss-seg-src', 'loss-seg-tar', 'loss-rec-src', 'loss-rec-tar', 'loss-dc-x', 'loss-dc-y' ] ], writer, epoch=epoch) log_scalars([total_loss_avg], ['train/' + s for s in ['total-loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x_src.data], ['train/' + s for s in ['src/x']], writer, epoch=epoch) if self.train_mode == RECONSTRUCTION: log_images_2d( [x_src_rec.data, x_tar_ul.data, x_tar_ul_rec.data], [ 'train/' + s for s in ['src/x-rec', 'tar/x-ul', 'tar/x-ul-rec'] ], writer, epoch=epoch) elif self.train_mode == SEGMENTATION: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d( [y_src.data, y_src_pred], ['train/' + s for s in ['src/y', 'src/y-pred']], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) else: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, :, :].data log_images_2d([ x_src_rec.data, y_src.data, y_src_pred, x_tar_ul.data, x_tar_ul_rec.data ], [ 'train/' + s for s in [ 'src/x-rec', 'src/y', 'src/y-pred', 'tar/x-ul', 'tar/x-ul-rec' ] ], writer, epoch=epoch) if loader_tar_l is not None: y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, :, :].data log_images_2d([x_tar_l.data, y_tar_l, y_tar_l_pred], [ 'train/' + s for s in ['tar/x-l', 'tar/y-l', 'tar/y-l-pred'] ], writer, epoch=epoch) return total_loss_avg
def test_epoch(self, loader_src, loader_tar, loss_seg_fn, loss_rec_fn, epoch, writer=None, write_images=False, device=0): """ Tests the network for one epoch :param loader_src: source dataloader (should be labeled) :param loader_tar: target dataloader (should be labeled) :param loss_seg_fn: segmentation loss function :param loss_rec_fn: reconstruction loss function :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.eval() # keep track of the average loss and metrics during the epoch loss_seg_cum = 0.0 loss_rec_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # start epoch y_src_preds = [] ys_src = [] y_tar_preds = [] ys_tar = [] for i, data in enumerate(zip(loader_src, loader_tar)): # get inputs and transfer to suitable device x_src, y_src = tensor_to_device(data[0], device) x_tar, y_tar = tensor_to_device(data[1], device) y_src = get_labels(y_src, coi=self.coi, dtype=int) y_tar = get_labels(y_tar, coi=self.coi, dtype=int) x_src = x_src.float() x_tar = x_tar.float() # zero the gradient buffers self.zero_grad() # forward prop y_src_pred = self(x_src) x_src_pred = self.reconstruction_outputs y_tar_pred = self(x_tar) x_tar_pred = self.reconstruction_outputs # compute loss loss_seg = loss_seg_fn(y_src_pred, y_src) loss_rec = 0.5 * (loss_rec_fn(x_src_pred, x_src) + loss_rec_fn(x_tar_pred, x_tar)) total_loss = loss_seg + self.lambda_rec * loss_rec loss_seg_cum += loss_seg.data.cpu().numpy() loss_rec_cum += loss_rec.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 for b in range(y_src_pred.size(0)): y_src_preds.append( F.softmax(y_src_pred, dim=1).data.cpu().numpy()[b, 1, ...]) y_tar_preds.append( F.softmax(y_tar_pred, dim=1).data.cpu().numpy()[b, 1, ...]) ys_src.append(y_src[b, 0, ...].cpu().numpy()) ys_tar.append(y_tar[b, 0, ...].cpu().numpy()) # compute interesting metrics y_src_preds = np.asarray(y_src_preds) y_tar_preds = np.asarray(y_tar_preds) ys_src = np.asarray(ys_src) ys_tar = np.asarray(ys_tar) j_src = jaccard(ys_src, y_src_preds) j_tar = jaccard(ys_src, y_tar_preds) a_src, ba_src, p_src, r_src, f_src = accuracy_metrics( ys_src, y_src_preds) a_tar, ba_tar, p_tar, r_tar, f_tar = accuracy_metrics( ys_tar, y_tar_preds) # don't forget to compute the average and print it loss_seg_avg = loss_seg_cum / cnt loss_rec_avg = loss_rec_cum / cnt total_loss_avg = total_loss_cum / cnt print('[%s] Epoch %5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_avg, loss_rec_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([ loss_seg_avg, loss_rec_avg, total_loss_avg, j_src, a_src, ba_src, p_src, r_src, f_src, j_tar, a_tar, ba_tar, p_tar, r_tar, f_tar ], [ 'test/' + s for s in [ 'loss-rec', 'loss-seg', 'total-loss', 'src/jaccard', 'src/accuracy', 'src/balanced-accuracy', 'src/precision', 'src/recall', 'src/f-score', 'tar/jaccard', 'tar/accuracy', 'tar/balanced-accuracy', 'tar/precision', 'tar/recall', 'tar/f-score' ] ], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, ...].data y_tar_pred = F.softmax(y_tar_pred, dim=1)[:, 1:2, ...].data log_images_3d([ x_src, x_src_pred.data, y_src, y_src_pred, x_tar, x_tar_pred.data, y_tar, y_tar_pred ], [ 'test/' + s for s in [ 'src/x', 'src/x-pred', 'src/y', 'src/y-pred', 'tar/x', 'tar/x-pred', 'tar/y', 'tar/y-pred' ] ], writer, epoch=epoch) return total_loss_avg
def train_epoch_semi_supervised(self, loader_src, loader_tar_ul, loader_tar_l, loss_seg_fn, loss_rec_fn, optimizer, epoch, augmenter_src=None, augmenter_tar=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader_src: source dataloader (labeled) :param loader_tar_ul: target dataloader (unlabeled) :param loader_tar_l: target dataloader (labeled) :param loss_seg_fn: segmentation loss function :param loss_rec_fn: reconstruction loss function :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter_src: source data augmenter :param augmenter_tar: target data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # perform training on GPU/CPU module_to_device(self, device) self.train() # keep track of the average loss during the epoch loss_seg_cum = 0.0 loss_rec_cum = 0.0 total_loss_cum = 0.0 cnt = 0 # start epoch for i, data in enumerate(zip(loader_src, loader_tar_ul, loader_tar_l)): # transfer to suitable device data_src = tensor_to_device(data[0], device) x_tar_ul = tensor_to_device(data[1], device) data_tar_l = tensor_to_device(data[2], device) # augment if necessary x_src, y_src = augment_samples(data_src, augmenter=augmenter_src) x_tar_l, y_tar_l = augment_samples(data_tar_l, augmenter=augmenter_tar) y_src = get_labels(y_src, coi=self.coi, dtype=int) y_tar_l = get_labels(y_tar_l, coi=self.coi, dtype=int) x_tar_ul = x_tar_ul.float() # zero the gradient buffers self.zero_grad() # forward prop y_src_pred = self(x_src) x_src_pred = self.reconstruction_outputs y_tar_ul_pred = self(x_tar_ul) x_tar_ul_pred = self.reconstruction_outputs y_tar_l_pred = self(x_tar_l) x_tar_l_pred = self.reconstruction_outputs # compute loss loss_seg = 0.5 * (loss_seg_fn(y_src_pred, y_src) + loss_seg_fn(y_tar_l_pred, y_tar_l)) loss_rec = 0.5 * (loss_rec_fn(x_src_pred, x_src) + loss_rec_fn(x_tar_ul_pred, x_tar_ul)) total_loss = loss_seg + self.lambda_rec * loss_rec loss_seg_cum += loss_seg.data.cpu().numpy() loss_rec_cum += loss_rec.data.cpu().numpy() total_loss_cum += total_loss.data.cpu().numpy() cnt += 1 # backward prop total_loss.backward() # apply one step in the optimization optimizer.step() # print statistics of necessary if i % print_stats == 0: print( '[%s] Epoch %5d - Iteration %5d/%5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, i, len(loader_src.dataset) / loader_src.batch_size, loss_seg, loss_rec, total_loss)) # don't forget to compute the average and print it loss_seg_avg = loss_seg_cum / cnt loss_rec_avg = loss_rec_cum / cnt total_loss_avg = total_loss_cum / cnt print('[%s] Epoch %5d - Loss seg: %.6f - Loss rec: %.6f - Loss: %.6f' % (datetime.datetime.now(), epoch, loss_seg_avg, loss_rec_avg, total_loss_avg)) # log everything if writer is not None: # always log scalars log_scalars( [loss_seg_avg, loss_rec_avg, total_loss_avg], ['train/' + s for s in ['loss-rec', 'loss-seg', 'total-loss']], writer, epoch=epoch) # log images if necessary if write_images: y_src_pred = F.softmax(y_src_pred, dim=1)[:, 1:2, ...].data y_tar_l_pred = F.softmax(y_tar_l_pred, dim=1)[:, 1:2, ...].data log_images_3d([ x_src, x_src_pred.data, y_src, y_src_pred, x_tar_l, x_tar_l_pred.data, y_tar_l, y_tar_l_pred ], [ 'train/' + s for s in [ 'src/x', 'src/x-pred', 'src/y', 'src/y-pred', 'tar/x', 'tar/x-pred', 'tar/y', 'tar/y-pred' ] ], writer, epoch=epoch) return total_loss_avg
def test_epoch(self, loader, loss_rec_fn, loss_kl_fn, epoch, writer=None, write_images=False, device=0): """ Tests the network for one epoch :param loader: dataloader :param loss_rec_fn: reconstruction loss function :param loss_kl_fn: kullback leibler loss function :param epoch: current epoch :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average testing loss over the epoch """ # make sure network is on the gpu and in training mode module_to_device(self, device) self.eval() # keep track of the average losses during the epoch loss_rec_cum = 0.0 loss_kl_cum = 0.0 loss_cum = 0.0 cnt = 0 # start epoch z = [] li = [] for i, data in enumerate(loader): # transfer to suitable device x = tensor_to_device(data.float(), device) # forward prop x_pred = torch.sigmoid(self(x)) z.append(_reparametrise(self.mu, self.logvar).cpu().data.numpy()) li.append(x.cpu().data.numpy()) # compute loss loss_rec = loss_rec_fn(x_pred, x) loss_kl = loss_kl_fn(self.mu, self.logvar) loss = loss_rec + self.beta * loss_kl loss_rec_cum += loss_rec.data.cpu().numpy() loss_kl_cum += loss_kl.data.cpu().numpy() loss_cum += loss.data.cpu().numpy() cnt += 1 # don't forget to compute the average and print it loss_rec_avg = loss_rec_cum / cnt loss_kl_avg = loss_kl_cum / cnt loss_avg = loss_cum / cnt print_frm( 'Epoch %5d - Average test loss rec: %.6f - Average test loss KL: %.6f - Average test loss: %.6f' % (epoch, loss_rec_avg, loss_kl_avg, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars([loss_rec_avg, loss_kl_avg, loss_avg], ['test/' + s for s in ['loss-rec', 'loss-kl', 'loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x, x_pred], ['test/' + s for s in ['x', 'x_pred']], writer, epoch=epoch) return loss_avg
def train_epoch(self, loader, loss_rec_fn, loss_kl_fn, optimizer, epoch, augmenter=None, print_stats=1, writer=None, write_images=False, device=0): """ Trains the network for one epoch :param loader: dataloader :param loss_rec_fn: reconstruction loss function :param loss_kl_fn: kullback leibler loss function :param optimizer: optimizer for the loss function :param epoch: current epoch :param augmenter: data augmenter :param print_stats: frequency of printing statistics :param writer: summary writer :param write_images: frequency of writing images :param device: GPU device where the computations should occur :return: average training loss over the epoch """ # make sure network is on the gpu and in training mode module_to_device(self, device) self.train() # keep track of the average losses during the epoch loss_rec_cum = 0.0 loss_kl_cum = 0.0 loss_cum = 0.0 cnt = 0 # start epoch for i, data in enumerate(loader): # transfer to suitable device x = tensor_to_device(data.float(), device) # get the inputs and augment if necessary if augmenter is not None: x = augmenter(x) # zero the gradient buffers self.zero_grad() # forward prop x_pred = torch.sigmoid(self(x)) # compute loss loss_rec = loss_rec_fn(x_pred, x) loss_kl = loss_kl_fn(self.mu, self.logvar) loss = loss_rec + self.beta * loss_kl loss_rec_cum += loss_rec.data.cpu().numpy() loss_kl_cum += loss_kl.data.cpu().numpy() loss_cum += loss.data.cpu().numpy() cnt += 1 # backward prop loss.backward() # apply one step in the optimization optimizer.step() # print statistics if necessary if i % print_stats == 0: print_frm( 'Epoch %5d - Iteration %5d/%5d - Loss Rec: %.6f - Loss KL: %.6f - Loss: %.6f' % (epoch, i, len(loader.dataset) / loader.batch_size, loss_rec, loss_kl, loss)) # don't forget to compute the average and print it loss_rec_avg = loss_rec_cum / cnt loss_kl_avg = loss_kl_cum / cnt loss_avg = loss_cum / cnt print_frm( 'Epoch %5d - Average train loss rec: %.6f - Average train loss KL: %.6f - Average train loss: %.6f' % (epoch, loss_rec_avg, loss_kl_avg, loss_avg)) # log everything if writer is not None: # always log scalars log_scalars( [loss_rec_avg, loss_kl_avg, loss_avg], ['train/' + s for s in ['loss-rec', 'loss-kl', 'loss']], writer, epoch=epoch) # log images if necessary if write_images: log_images_2d([x, x_pred], ['train/' + s for s in ['x', 'x_pred']], writer, epoch=epoch) return loss_avg
def segment_multichannel_3d(data, net, input_shape, in_channels=1, batch_size=1, step_size=None, train=False, track_progress=False, device=0, orientation=0, normalization='unit'): """ Segment a multichannel 3D image using a specific network :param data: 4D array (C, Z, Y, X) representing the multichannel 3D image :param net: image-to-image segmentation network :param input_shape: size of the inputs (either 2 or 3-tuple) :param in_channels: amount of subsequent slices that serve as input for the network (should be odd) :param batch_size: batch size for processing :param step_size: step size of the sliding window :param train: evaluate the network in training mode :param track_progress: optionally, for tracking progress with progress bar :param device: GPU device where the computations should occur :param orientation: orientation to perform segmentation: 0-Z, 1-Y, 2-X (only for 2D based segmentation) :param normalization: type of data normalization (unit, z or minmax) :return: the segmented image """ # make sure we compute everything on the correct device module_to_device(net, device) # set the network in the correct mode if train: net.train() else: net.eval() # orient data if necessary data = _orient(data, orientation) # pad data if necessary data, pad_width = _pad(data, input_shape, in_channels) # 2D or 3D is2d = len(input_shape) == 2 # get the amount of channels channels = data.shape[0] if is2d: channels = in_channels # initialize the step size step_size = _init_step_size(step_size, input_shape, is2d) # gaussian window for smooth block merging g_window = _init_gaussian_window(input_shape, is2d) # allocate space seg_cum = np.zeros((net.out_channels, *data.shape[1:])) counts_cum = np.zeros(data.shape[1:]) # define sliding window sw = _init_sliding_window(data, step_size, input_shape, in_channels, is2d, track_progress, normalization) # start prediction batch_counter = 0 batch = np.zeros((batch_size, channels, *input_shape)) positions = np.zeros((batch_size, 3), dtype=int) for (z, y, x, inputs) in sw: # fill batch batch[batch_counter, ...] = inputs positions[batch_counter, :] = [z, y, x] # increment batch counter batch_counter += 1 # perform segmentation when a full batch is filled if batch_counter == batch_size: # process a single batch _process_batch(net, batch, device, seg_cum, counts_cum, g_window, positions, batch_size, input_shape, in_channels, is2d) # reset batch counter batch_counter = 0 # don't forget to process the last batch _process_batch(net, batch, device, seg_cum, counts_cum, g_window, positions, batch_size, input_shape, in_channels, is2d) # crop out the symmetric extension and compute segmentation data, seg_cum, counts_cum = _crop(data, seg_cum, counts_cum, pad_width) for c in range(net.out_channels): seg_cum[c, ...] = np.divide(seg_cum[c, ...], counts_cum) # reorient data to its original orientation data = _orient(data, orientation) seg_cum = _orient(seg_cum, orientation) return seg_cum
def segment_multichannel_2d(data, net, input_shape, batch_size=1, step_size=None, train=False, track_progress=False, device=0, normalization='unit'): """ Segment a multichannel 2D image using a specific network :param data: 3D array (C, Y, X) representing the multichannel 2D image :param net: image-to-image segmentation network :param input_shape: size of the inputs (2-tuple) :param batch_size: batch size for processing :param step_size: step size of the sliding window :param train: evaluate the network in training mode :param track_progress: optionally, for tracking progress with progress bar :param device: GPU device where the computations should occur :param normalization: type of data normalization (unit, z or minmax) :return: the segmented image """ # make sure we compute everything on the correct device module_to_device(net, device) # set the network in the correct mode if train: net.train() else: net.eval() # pad data if necessary data, pad_width = _pad(data[:, np.newaxis, ...], input_shape, 1) data = data[:, 0, ...] # get the amount of channels channels = data.shape[0] # initialize the step size step_size = _init_step_size(step_size, input_shape, True) # gaussian window for smooth block merging g_window = _init_gaussian_window(input_shape, True) # allocate space seg_cum = np.zeros((net.out_channels, 1, *data.shape[1:])) counts_cum = np.zeros((1, *data.shape[1:])) # define sliding window sw = _init_sliding_window(data[np.newaxis, ...], [channels, *step_size[1:]], input_shape, channels, True, track_progress, normalization) # start prediction batch_counter = 0 batch = np.zeros((batch_size, channels, *input_shape)) positions = np.zeros((batch_size, 3), dtype=int) for (z, y, x, inputs) in sw: # fill batch batch[batch_counter, ...] = inputs positions[batch_counter, :] = [z, y, x] # increment batch counter batch_counter += 1 # perform segmentation when a full batch is filled if batch_counter == batch_size: # process a single batch _process_batch(net, batch, device, seg_cum, counts_cum, g_window, positions, batch_size, input_shape, 1, True) # reset batch counter batch_counter = 0 # don't forget to process the last batch _process_batch(net, batch, device, seg_cum, counts_cum, g_window, positions, batch_size, input_shape, 1, True) # crop out the symmetric extension and compute segmentation data, seg_cum, counts_cum = _crop(data[:, np.newaxis, ...], seg_cum, counts_cum, pad_width) for c in range(net.out_channels): seg_cum[c, ...] = np.divide(seg_cum[c, ...], counts_cum) seg_cum = seg_cum[:, 0, ...] return seg_cum