def main(hdf_file, plot_dir): os.makedirs(plot_dir, exist_ok=True) # setup the datasource extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS)) indexing_strategy = extr.SliceIndexing() dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor) seed = 1 np.random.seed(seed) sample_idx = 55 # set up transformations without augmentation transforms_augmentation = [] transforms_before_augmentation = [tfm.Permute(permutation=(2, 0, 1)), ] # to have the channel-dimension first transforms_after_augmentation = [tfm.Squeeze(entries=(defs.KEY_LABELS,)), ] # get rid of the channel-dimension for the labels train_transforms = tfm.ComposeTransform(transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation) dataset.set_transform(train_transforms) sample = dataset[sample_idx] plot_sample(plot_dir, 'none', sample) # augmentation with pymia transforms_augmentation = [augm.RandomRotation90(axes=(-2, -1)), augm.RandomMirror()] train_transforms = tfm.ComposeTransform( transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation) dataset.set_transform(train_transforms) sample = dataset[sample_idx] plot_sample(plot_dir, 'pymia', sample) # augmentation with batchgenerators transforms_augmentation = [BatchgeneratorsTransform([ bg_tfm.spatial_transforms.MirrorTransform(axes=(0, 1), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS), bg_tfm.noise_transforms.GaussianBlurTransform(blur_sigma=(0.2, 1.0), data_key=defs.KEY_IMAGES, label_key=defs.KEY_LABELS), ])] train_transforms = tfm.ComposeTransform( transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation) dataset.set_transform(train_transforms) sample = dataset[sample_idx] plot_sample(plot_dir, 'batchgenerators', sample) # augmentation with TorchIO transforms_augmentation = [TorchIOTransform( [tio.RandomFlip(axes=('LR'), flip_probability=1.0, keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed), tio.RandomAffine(scales=(0.9, 1.2), degrees=(10), isotropic=False, default_pad_value='otsu', image_interpolation='NEAREST', keys=(defs.KEY_IMAGES, defs.KEY_LABELS), seed=seed), ])] train_transforms = tfm.ComposeTransform( transforms_before_augmentation + transforms_augmentation + transforms_after_augmentation) dataset.set_transform(train_transforms) sample = dataset[sample_idx] plot_sample(plot_dir, 'torchio', sample)
def occlude(tester, result_dir: str, temporal_dim: int): for temporal_idx in range(temporal_dim): print('Occlude temporal dimension {}...'.format(temporal_idx)) # modify extraction transform for occlusion experiment tester.data_handler.extraction_transform_valid = pymia_tfm.ComposeTransform( [ pymia_tfm.Squeeze( entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS, data.ID_MASK_FG, data.ID_MASK_T1H2O), squeeze_axis=0), tfm.MRFTemporalOcclusion(temporal_idx=temporal_idx, entries=(pymia_def.KEY_IMAGES, )) ]) tester.result_dir = os.path.join(result_dir, 'occlusion{}'.format(temporal_idx)) tester.predict()
def main(hdf_file, log_dir): # initialize the evaluator with the metrics and the labels to evaluate metrics = [metric.DiceCoefficient()] labels = {1: 'WHITEMATTER', 2: 'GREYMATTER', 3: 'HIPPOCAMPUS', 4: 'AMYGDALA', 5: 'THALAMUS'} evaluator = eval_.SegmentationEvaluator(metrics, labels) # we want to log the mean and standard deviation of the metrics among all subjects of the dataset functions = {'MEAN': np.mean, 'STD': np.std} statistics_aggregator = writer.StatisticsAggregator(functions=functions) console_writer = writer.ConsoleStatisticsWriter(functions=functions) # initialize TensorBoard writer tb = tensorboard.SummaryWriter(os.path.join(log_dir, 'logging-example-torch')) # setup the training datasource train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4'] extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS)) indexing_strategy = extr.SliceIndexing() augmentation_transforms = [augm.RandomElasticDeformation(), augm.RandomMirror()] transforms = [tfm.Permute(permutation=(2, 0, 1)), tfm.Squeeze(entries=(defs.KEY_LABELS,))] train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms) train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms, subject_subset=train_subjects) # setup the validation datasource valid_transforms = tfm.ComposeTransform([tfm.Permute(permutation=(2, 0, 1))]) valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms, subject_subset=valid_subjects) direct_extractor = extr.ComposeExtractor( [extr.SubjectExtractor(), extr.ImagePropertiesExtractor(), extr.DataExtractor(categories=(defs.KEY_LABELS,))] ) assembler = assm.SubjectAssembler(valid_dataset) # torch specific handling pytorch_train_dataset = pymia_torch.PytorchDatasetAdapter(train_dataset) train_loader = torch_data.dataloader.DataLoader(pytorch_train_dataset, batch_size=16, shuffle=True) pytorch_valid_dataset = pymia_torch.PytorchDatasetAdapter(valid_dataset) valid_loader = torch_data.dataloader.DataLoader(pytorch_valid_dataset, batch_size=16, shuffle=False) u_net = unet.UNetModel(ch_in=2, ch_out=6, n_channels=16, n_pooling=3).to(device) print(u_net) optimizer = optim.Adam(u_net.parameters(), lr=1e-3) train_batches = len(train_loader) # looping over the data in the dataset epochs = 100 for epoch in range(epochs): u_net.train() print(f'Epoch {epoch + 1}/{epochs}') # training print('training') for i, batch in enumerate(train_loader): x, y = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_LABELS].to(device).long() logits = u_net(x) optimizer.zero_grad() loss = F.cross_entropy(logits, y) loss.backward() optimizer.step() tb.add_scalar('train/loss', loss.item(), epoch*train_batches + i) print(f'[{i + 1}/{train_batches}]\tloss: {loss.item()}') # validation print('validation') with torch.no_grad(): u_net.eval() valid_batches = len(valid_loader) for i, batch in enumerate(valid_loader): x, sample_indices = batch[defs.KEY_IMAGES].to(device), batch[defs.KEY_SAMPLE_INDEX] logits = u_net(x) prediction = logits.argmax(dim=1, keepdim=True) numpy_prediction = prediction.cpu().numpy().transpose((0, 2, 3, 1)) is_last = i == valid_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject(subject_index) direct_sample = train_dataset.direct_extract(direct_extractor, subject_index) target, image_properties = direct_sample[defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES] # evaluate the prediction against the reference evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT]) # calculate mean and standard deviation of each metric results = statistics_aggregator.calculate(evaluator.results) # log to TensorBoard into category train for result in results: tb.add_scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch) console_writer.write(evaluator.results) # clear results such that the evaluator is ready for the next evaluation evaluator.clear()
def __init__(self, config: cfg.Configuration, subjects_train, subjects_valid, subjects_test, collate_fn=pymia_cnv.TensorFlowCollate()): super().__init__() indexing_strategy = PointCloudIndexing(config.no_points) self.dataset = pymia_extr.ParameterizableDataset( config.database_file, indexing_strategy, pymia_extr.SubjectExtractor(), # for the usual select_indices None) self.no_subjects_train = len(subjects_train) self.no_subjects_valid = len(subjects_valid) self.no_subjects_test = len( subjects_valid ) # same as validation for this kind of cross validation # get sampler ids by subjects sampler_ids_train = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_train)) sampler_ids_valid = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_valid)) categories = ('images', 'labels') categories_tfm = ('images', 'labels') if config.use_image_information: image_information_categories = (data.KEY_IMAGE_INFORMATION, ) categories += image_information_categories categories_tfm += image_information_categories collate_fn.entries += image_information_categories # define point cloud shuffler for augmentation sizes = {} for idx in range(len(self.dataset.get_subjects())): sample = self.dataset.direct_extract(PointCloudSizeExtractor(), idx) sizes[idx] = sample['size'] self.point_cloud_shuffler = aug.PointCloudShuffler(sizes) self.point_cloud_shuffler_valid = aug.PointCloudShuffler( sizes ) # will only shuffle once at instantiation because shuffle() is not called during training (see set_seed) data_extractor_train = aug.ShuffledDataExtractor( self.point_cloud_shuffler, categories) data_extractor_valid = aug.ShuffledDataExtractor( self.point_cloud_shuffler_valid, categories) data_extractor_test = aug.ShuffledDataExtractor( self.point_cloud_shuffler_valid, categories=('indices', 'labels')) # define extractors self.extractor_train = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.SubjectExtractor(), # required for plotting PointCloudSizeExtractor(), # to init_shape in SubjectAssembler data_extractor_train, pymia_extr.IndexingExtractor( ), # for SubjectAssembler (assembling) pymia_extr.ImageShapeExtractor() # for SubjectAssembler (shape) ]) self.extractor_valid = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.SubjectExtractor(), # required for plotting PointCloudSizeExtractor(), # to init_shape in SubjectAssembler data_extractor_valid, pymia_extr.IndexingExtractor(), pymia_extr.ImageShapeExtractor() ]) self.extractor_test = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.SubjectExtractor(), data_extractor_test, # we need the indices, i.e. the point's coordinates, # to convert the point cloud back to an image pymia_extr.DataExtractor( categories=('gt', ), ignore_indexing=True), # the ground truth is used for the # validation at config.save_validation_nth_epoch pymia_extr.ImagePropertiesExtractor(), pymia_extr.ImageShapeExtractor() ]) # define transforms for extraction self.extraction_transform_train = pymia_tfm.ComposeTransform([ pymia_tfm.Squeeze(entries=('labels', ), squeeze_axis=-1), # for PyTorch loss functions pymia_tfm.LambdaTransform( lambda_fn=lambda np_data: np_data.astype(np.int64), entries=('labels', )), # for PyTorch loss functions ]) self.extraction_transform_valid = pymia_tfm.ComposeTransform([ pymia_tfm.Squeeze(entries=('labels', ), squeeze_axis=-1), # for PyTorch loss functions pymia_tfm.LambdaTransform( lambda_fn=lambda np_data: np_data.astype(np.int64), entries=('labels', )) # for PyTorch loss functions ]) if config.use_jitter: self.extraction_transform_train.transforms.append( aug.PointCloudJitter()) self.extraction_transform_valid.transforms.append( aug.PointCloudJitter()) if config.use_rotation: self.extraction_transform_train.transforms.append( aug.PointCloudRotate()) self.extraction_transform_valid.transforms.append( aug.PointCloudRotate()) # need to add probability concatenation after augmentation! if config.use_point_feature: self.extraction_transform_train.transforms.append( ConcatenateCoordinatesAndPointFeatures()) self.extraction_transform_valid.transforms.append( ConcatenateCoordinatesAndPointFeatures()) if config.use_image_information: spatial_size = config.image_information_config.spatial_size def slice_patches(np_data): z = (np_data.shape[1] - spatial_size) // 2 y = (np_data.shape[2] - spatial_size) // 2 x = (np_data.shape[3] - spatial_size) // 2 np_data = np_data[:, z:(z + spatial_size), y:(y + spatial_size), x:(x + spatial_size), :] return np_data self.extraction_transform_train.transforms.append( pymia_tfm.LambdaTransform( lambda_fn=slice_patches, entries=image_information_categories)) self.extraction_transform_valid.transforms.append( pymia_tfm.LambdaTransform( lambda_fn=slice_patches, entries=image_information_categories)) # define loaders training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train) self.loader_train = pymia_extr.DataLoader(self.dataset, config.batch_size_training, sampler=training_sampler, collate_fn=collate_fn, num_workers=1) validation_sampler = pymia_extr.SubsetSequentialSampler( sampler_ids_valid) self.loader_valid = pymia_extr.DataLoader(self.dataset, config.batch_size_testing, sampler=validation_sampler, collate_fn=collate_fn, num_workers=1) self.loader_test = pymia_extr.DataLoader(self.dataset, config.batch_size_testing, sampler=validation_sampler, collate_fn=collate_fn, num_workers=1) self.extraction_transform_test = None
def __init__(self, config: cfg.Configuration, subjects_train, subjects_valid, subjects_test, collate_fn=pymia_cnv.TorchCollate( ('images', 'labels', 'mask_fg', 'mask_t1h2o'))): super().__init__() indexing_strategy = pymia_extr.SliceIndexing() self.dataset = pymia_extr.ParameterizableDataset( config.database_file, indexing_strategy, pymia_extr.SubjectExtractor(), # for the usual select_indices None) self.no_subjects_train = len(subjects_train) self.no_subjects_valid = len(subjects_valid) self.no_subjects_test = 0 # get sampler ids by subjects sampler_ids_train = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_train)) sampler_ids_valid = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_valid)) # define extractors self.extractor_train = pymia_extr.ComposeExtractor([ pymia_extr.DataExtractor(categories=('images', 'labels')), pymia_extr.IndexingExtractor( ), # for SubjectAssembler (assembling) pymia_extr.ImageShapeExtractor() # for SubjectAssembler (shape) ]) self.extractor_valid = pymia_extr.ComposeExtractor([ pymia_extr.DataExtractor(categories=('images', 'labels')), pymia_extr.IndexingExtractor( ), # for SubjectAssembler (assembling) pymia_extr.ImageShapeExtractor() # for SubjectAssembler (shape) ]) self.extractor_test = pymia_extr.ComposeExtractor([ pymia_extr.SubjectExtractor(), pymia_extr.DataExtractor(categories=('labels', )), pymia_extr.ImagePropertiesExtractor(), pymia_extr.ImageShapeExtractor() ]) # define transforms for extraction self.extraction_transform_train = pymia_tfm.ComposeTransform([ pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)), pymia_tfm.Permute((2, 0, 1)), pymia_tfm.Squeeze(entries=('labels', ), squeeze_axis=0), # for PyTorch loss functions pymia_tfm.LambdaTransform( lambda_fn=lambda np_data: np_data.astype(np.int64), entries=('labels', )), # for PyTorch loss functions pymia_tfm.ToTorchTensor() ]) self.extraction_transform_valid = pymia_tfm.ComposeTransform([ pymia_tfm.SizeCorrection((cfg.TENSOR_WIDTH, cfg.TENSOR_HEIGHT)), pymia_tfm.Permute((2, 0, 1)), pymia_tfm.Squeeze(entries=('labels', ), squeeze_axis=0), # for PyTorch loss functions pymia_tfm.LambdaTransform( lambda_fn=lambda np_data: np_data.astype(np.int64), entries=('labels', )), # for PyTorch loss functions pymia_tfm.ToTorchTensor() ]) self.extraction_transform_test = None # define loaders training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train) self.loader_train = pymia_extr.DataLoader(self.dataset, config.batch_size_training, sampler=training_sampler, collate_fn=collate_fn, num_workers=1) validation_sampler = pymia_extr.SubsetSequentialSampler( sampler_ids_valid) self.loader_valid = pymia_extr.DataLoader(self.dataset, config.batch_size_testing, sampler=validation_sampler, collate_fn=collate_fn, num_workers=1) self.loader_test = None
def __init__(self, config: cfg.Configuration, subjects_train, subjects_valid, subjects_test, is_subject_selection: bool = True, collate_fn=lib_cnv.TensorFlowCollate(), padding_size: tuple = (0, 0, 0)): super().__init__() indexing_strategy = pymia_extr.PatchWiseIndexing( patch_shape=config.patch_size, ignore_incomplete=False) self.dataset = pymia_extr.ParameterizableDataset( config.database_file, indexing_strategy, pymia_extr.SubjectExtractor(), # for the select_indices None) self.no_subjects_train = len(subjects_train) self.no_subjects_valid = len(subjects_valid) self.no_subjects_test = len(subjects_test) if is_subject_selection: # get sampler ids by subjects sampler_ids_train = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_train)) sampler_ids_valid = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_valid)) sampler_ids_test = pymia_extr.select_indices( self.dataset, pymia_extr.SubjectSelection(subjects_test)) else: # get sampler ids from indices files sampler_ids_train, sampler_ids_valid, sampler_ids_test = pkl.load_sampler_ids( config.indices_dir, pkl.PATCH_WISE_FILE_NAME, subjects_train, subjects_valid, subjects_test) # define extractors self.extractor_train = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.PadDataExtractor( padding=padding_size, extractor=pymia_extr.DataExtractor( categories=(pymia_def.KEY_IMAGES, ))), pymia_extr.PadDataExtractor( padding=(0, 0, 0), extractor=pymia_extr.SelectiveDataExtractor( selection=config.maps, category=pymia_def.KEY_LABELS)), pymia_extr.PadDataExtractor(padding=(0, 0, 0), extractor=pymia_extr.DataExtractor( categories=(defs.ID_MASK_FG, defs.ID_MASK_T1H2O))), pymia_extr.IndexingExtractor(), pymia_extr.ImageShapeExtractor() ]) # to calculate validation loss, we require the labels and mask during validation self.extractor_valid = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.PadDataExtractor( padding=padding_size, extractor=pymia_extr.DataExtractor( categories=(pymia_def.KEY_IMAGES, ))), pymia_extr.PadDataExtractor( padding=(0, 0, 0), extractor=pymia_extr.SelectiveDataExtractor( selection=config.maps, category=pymia_def.KEY_LABELS)), pymia_extr.PadDataExtractor( padding=(0, 0, 0), extractor=pymia_extr.DataExtractor( categories=(defs.ID_MASK_FG, defs.ID_MASK_T1H2O, defs.ID_MASK_ROI, defs.ID_MASK_ROI_T1H2O))), pymia_extr.IndexingExtractor(), pymia_extr.ImageShapeExtractor() ]) self.extractor_test = pymia_extr.ComposeExtractor([ pymia_extr.NamesExtractor(), # required for SelectiveDataExtractor pymia_extr.SubjectExtractor(), pymia_extr.SelectiveDataExtractor(selection=config.maps, category=pymia_def.KEY_LABELS), pymia_extr.DataExtractor(categories=(defs.ID_MASK_FG, defs.ID_MASK_T1H2O, defs.ID_MASK_ROI, defs.ID_MASK_ROI_T1H2O)), pymia_extr.ImagePropertiesExtractor(), pymia_extr.ImageShapeExtractor(), ext.NormalizationExtractor() ]) # define transforms for extraction # after extraction, the first dimension is the batch dimension. # E.g., shape = (1, 16, 16, 4) instead of (16, 16, 4) --> therefore squeeze the data self.extraction_transform_train = pymia_tfm.Squeeze( entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS, defs.ID_MASK_FG, defs.ID_MASK_T1H2O), squeeze_axis=0) self.extraction_transform_valid = pymia_tfm.Squeeze( entries=(pymia_def.KEY_IMAGES, pymia_def.KEY_LABELS, defs.ID_MASK_FG, defs.ID_MASK_T1H2O), squeeze_axis=0) self.extraction_transform_test = None # define loaders training_sampler = pymia_extr.SubsetRandomSampler(sampler_ids_train) self.loader_train = pymia_extr.DataLoader(self.dataset, config.batch_size_training, sampler=training_sampler, collate_fn=collate_fn, num_workers=1) validation_sampler = pymia_extr.SubsetSequentialSampler( sampler_ids_valid) self.loader_valid = pymia_extr.DataLoader(self.dataset, config.batch_size_testing, sampler=validation_sampler, collate_fn=collate_fn, num_workers=1) testing_sampler = pymia_extr.SubsetSequentialSampler(sampler_ids_test) self.loader_test = pymia_extr.DataLoader(self.dataset, config.batch_size_testing, sampler=testing_sampler, collate_fn=collate_fn, num_workers=1)
def main(hdf_file, log_dir): # initialize the evaluator with the metrics and the labels to evaluate metrics = [metric.DiceCoefficient()] labels = { 1: 'WHITEMATTER', 2: 'GREYMATTER', 3: 'HIPPOCAMPUS', 4: 'AMYGDALA', 5: 'THALAMUS' } evaluator = eval_.SegmentationEvaluator(metrics, labels) # we want to log the mean and standard deviation of the metrics among all subjects of the dataset functions = {'MEAN': np.mean, 'STD': np.std} statistics_aggregator = writer.StatisticsAggregator(functions=functions) console_writer = writer.ConsoleStatisticsWriter(functions=functions) # initialize TensorBoard writer summary_writer = tf.summary.create_file_writer( os.path.join(log_dir, 'logging-example-tensorflow')) # setup the training datasource train_subjects, valid_subjects = ['Subject_1', 'Subject_2', 'Subject_3'], ['Subject_4'] extractor = extr.DataExtractor(categories=(defs.KEY_IMAGES, defs.KEY_LABELS)) indexing_strategy = extr.SliceIndexing() augmentation_transforms = [ augm.RandomElasticDeformation(), augm.RandomMirror() ] transforms = [tfm.Squeeze(entries=(defs.KEY_LABELS, ))] train_transforms = tfm.ComposeTransform(augmentation_transforms + transforms) train_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, train_transforms, subject_subset=train_subjects) # setup the validation datasource batch_size = 16 valid_transforms = tfm.ComposeTransform([]) valid_dataset = extr.PymiaDatasource(hdf_file, indexing_strategy, extractor, valid_transforms, subject_subset=valid_subjects) direct_extractor = extr.ComposeExtractor([ extr.SubjectExtractor(), extr.ImagePropertiesExtractor(), extr.DataExtractor(categories=(defs.KEY_LABELS, )) ]) assembler = assm.SubjectAssembler(valid_dataset) # tensorflow specific handling train_gen_fn = pymia_tf.get_tf_generator(train_dataset) tf_train_dataset = tf.data.Dataset.from_generator( generator=train_gen_fn, output_types={ defs.KEY_IMAGES: tf.float32, defs.KEY_LABELS: tf.int64, defs.KEY_SAMPLE_INDEX: tf.int64 }) tf_train_dataset = tf_train_dataset.batch(batch_size).shuffle( len(train_dataset)) valid_gen_fn = pymia_tf.get_tf_generator(valid_dataset) tf_valid_dataset = tf.data.Dataset.from_generator( generator=valid_gen_fn, output_types={ defs.KEY_IMAGES: tf.float32, defs.KEY_LABELS: tf.int64, defs.KEY_SAMPLE_INDEX: tf.int64 }) tf_valid_dataset = tf_valid_dataset.batch(batch_size) u_net = unet.build_model(channels=2, num_classes=6, layer_depth=3, filters_root=16) optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3) train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32) train_batches = len(train_dataset) // batch_size # looping over the data in the dataset epochs = 100 for epoch in range(epochs): print(f'Epoch {epoch + 1}/{epochs}') # training print('training') for i, batch in enumerate(tf_train_dataset): x, y = batch[defs.KEY_IMAGES], batch[defs.KEY_LABELS] with tf.GradientTape() as tape: logits = u_net(x, training=True) loss = tf.keras.losses.sparse_categorical_crossentropy( y, logits, from_logits=True) grads = tape.gradient(loss, u_net.trainable_variables) optimizer.apply_gradients(zip(grads, u_net.trainable_variables)) train_loss(loss) with summary_writer.as_default(): tf.summary.scalar('train/loss', train_loss.result(), step=epoch * train_batches + i) print( f'[{i + 1}/{train_batches}]\tloss: {train_loss.result().numpy()}' ) # validation print('validation') valid_batches = len(valid_dataset) // batch_size for i, batch in enumerate(tf_valid_dataset): x, sample_indices = batch[defs.KEY_IMAGES], batch[ defs.KEY_SAMPLE_INDEX] logits = u_net(x) prediction = tf.expand_dims(tf.math.argmax(logits, -1), -1) numpy_prediction = prediction.numpy() is_last = i == valid_batches - 1 assembler.add_batch(numpy_prediction, sample_indices.numpy(), is_last) for subject_index in assembler.subjects_ready: subject_prediction = assembler.get_assembled_subject( subject_index) direct_sample = train_dataset.direct_extract( direct_extractor, subject_index) target, image_properties = direct_sample[ defs.KEY_LABELS], direct_sample[defs.KEY_PROPERTIES] # evaluate the prediction against the reference evaluator.evaluate(subject_prediction[..., 0], target[..., 0], direct_sample[defs.KEY_SUBJECT]) # calculate mean and standard deviation of each metric results = statistics_aggregator.calculate(evaluator.results) # log to TensorBoard into category train with summary_writer.as_default(): for result in results: tf.summary.scalar(f'valid/{result.metric}-{result.id_}', result.value, epoch) console_writer.write(evaluator.results) # clear results such that the evaluator is ready for the next evaluation evaluator.clear()