def get_all_metrics(self) -> typing.Tuple[np.ndarray, list]: """ Get the metrics from all subjects and the corresponding column names :return: (metrics, column_names) """ metrics = [self.dataset.direct_extract( miapy_extr.SelectiveDataExtractor(category=STORE_MORPHOMETRICS), idx) for idx in range(len(self.dataset))] column_names = self.dataset.reader.read(STORE_META_MORPHOMETRY_COLUMNS) return np.stack(self.collate_batch(metrics)[STORE_MORPHOMETRICS], axis=0), column_names
def train(self): self.write_status('status_init') start_time = timeit.default_timer() self.set_seed(epoch=0) transform = self.get_transform() self._data_store = data_storage.DataStore(self._cfg.hdf_file, transform) dataset = self._data_store.dataset self.assign_subjects() # prepare loaders and extractors training_loader = self._data_store.get_loader(self._cfg.batch_size, self._subjects_train, self._num_workers) validation_loader = self._data_store.get_loader( self._cfg.batch_size_eval, self._subjects_validate, self._num_workers) testing_loader = self._data_store.get_loader(self._cfg.batch_size_eval, self._subjects_test, self._num_workers) train_extractor = miapy_extr.ComposeExtractor([ miapy_extr.DataExtractor(), miapy_extr.SelectiveDataExtractor( category=data_storage.STORE_MORPHOMETRICS), miapy_extr.SubjectExtractor(), data_storage.DemographicsExtractor() ]) dataset.set_extractor(train_extractor) # read all labels to calculate multiplier column_values, column_names = self._data_store.get_all_metrics() self._regression_column_ids = np.array([ column_names.index(name) for name in self._cfg.regression_columns ]) self._regression_column_multipliers = np.max(np.abs( column_values[:, self._regression_column_ids]), axis=0) model_net.SCALE = float(self._data_store.get_intensity_scale_max()) n_batches = int( np.ceil(len(self._subjects_train) / self._cfg.batch_size)) logger.info('Net: {}, scale: {}'.format( inspect.getsource(self.get_python_obj(self._cfg.model)), model_net.SCALE)) logger.info('Train: {}, Validation: {}, Test: {}'.format( len(self._subjects_train), len(self._subjects_validate), len(self._subjects_test))) logger.info('Label multiplier: {}'.format( self._regression_column_multipliers)) logger.info('n_batches: {}'.format(n_batches)) logger.info(self._cfg) logger.info('checkpoints dir: {}'.format(self.checkpoint_dir)) sample = dataset.direct_extract(train_extractor, 0) # extract a subject to obtain shape with tf.Graph().as_default() as graph: self.set_seed(epoch=0) # set again as seed is per graph x = tf.placeholder(tf.float32, (None, ) + sample[data_storage.STORE_IMAGES].shape[0:], name='x') y = tf.placeholder(tf.float32, (None, len(self._regression_column_ids)), name='y') d = tf.placeholder(tf.float32, (None, 2), name='d') # age, sex is_train = tf.placeholder(tf.bool, shape=(), name='is_train') global_step = tf.train.get_or_create_global_step() epoch_checkpoint = tf.Variable(0, name='epoch') best_r2_score_checkpoint = tf.Variable(0.0, name='best_r2_score', dtype=tf.float64) net = self.get_python_obj(self._cfg.model)({ 'x': x, 'y': y, 'd': d, 'is_train': is_train }) optimizer = None loss = None if self._cfg.loss_function == 'mse': loss = tf.losses.mean_squared_error(labels=y, predictions=net) elif self._cfg.loss_function == 'absdiff': loss = tf.losses.absolute_difference(labels=y, predictions=net) if self._cfg.learning_rate_decay_rate is not None and self._cfg.learning_rate_decay_rate > 0: learning_rate = tf.train.exponential_decay( self._cfg.learning_rate, global_step, self._cfg.learning_rate_decay_steps, self._cfg.learning_rate_decay_rate) else: learning_rate = tf.Variable(self._cfg.learning_rate, name='lr') if self._cfg.optimizer == 'SGD': optimizer = tf.train.GradientDescentOptimizer( learning_rate=learning_rate) elif self._cfg.optimizer == 'Adam': optimizer = tf.train.AdamOptimizer( learning_rate=self._cfg.learning_rate) with tf.control_dependencies( tf.get_collection( tf.GraphKeys.UPDATE_OPS)): # required for batch_norm train_op = optimizer.minimize(loss=loss, global_step=global_step) sum_lr = tf.summary.scalar('learning_rate', learning_rate) # collect variables to add to summaries (histograms, kernel weight images) sum_histograms = [] sum_kernels = [] kernel_tensors = [] kernel_tensor_names = [] for v in tf.global_variables(): m = re.match('NET/(layer\d+_\w+)/(kernel|bias):0', v.name) if m: var = graph.get_tensor_by_name(v.name) sum_histograms.append( tf.summary.histogram( '{}/{}'.format(m.group(1), m.group(2)), var)) #sum_histograms.append(tf.summary.histogram('{}/{}'.format(m.group(1), m.group(2)), tf.Variable(tf.zeros([1,1,1,1])))) if m.group(2) == 'kernel' and m.group(1).endswith('conv'): kernel_tensor_names.append(v.name) h, w = visualize.get_grid_size( var.get_shape().as_list()) img = tf.Variable(tf.zeros([1, h, w, 1])) kernel_tensors.append(img) sum_kernels.append(tf.summary.image(m.group(1), img)) summary_writer = tf.summary.FileWriter( os.path.join(self.checkpoint_dir, 'tb_logs'), tf.get_default_graph()) init = tf.global_variables_initializer() # Saver keeps only 5 per default - make sure best-r2-checkpoint remains! # TODO: maybe set max_to_keep=None? saver = tf.train.Saver(max_to_keep=self._cfg.checkpoint_keep + 2) with tf.Session() as sess: sess.run(init) checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir) if checkpoint: # existing checkpoint found, restoring... if not self._cfg.subjects_train_val_test_file: msg = 'Continue training, but no fixed subject assignments found. ' \ 'Set subjects_train_val_test_file in config.' logger.error(msg) raise RuntimeError(msg) logger.info('Restoring from ' + checkpoint) saver = tf.train.import_meta_graph(checkpoint + '.meta') saver.restore(sess, checkpoint) if saver._max_to_keep < self._cfg.checkpoint_keep + 2: msg = 'ERROR: Restored saver._max_to_keep={}, but self._cfg.checkpoint_keep={}'.format( saver._max_to_keep, self._cfg.checkpoint_keep) print(msg) logger.error(msg) exit(1) sess.run(epoch_checkpoint.assign_add(tf.constant(1))) self._best_r2_score = best_r2_score_checkpoint.eval() logger.info('Continue with epoch %i (best r2: %f)', epoch_checkpoint.eval(), self._best_r2_score) self._checkpoint_idx = int( re.match('.*/checkpoint-(\d+).ckpt', checkpoint).group(1)) # load column multipliers from file if available cfg_file = os.path.join(self.checkpoint_dir, 'config.json') if os.path.exists(cfg_file): cfg_tmp = Configuration.load(cfg_file) if len(cfg_tmp.z_column_multipliers) > 0: logger.info('Loading column multipliers from %s', cfg_file) self._regression_column_multipliers = np.array( cfg_tmp.z_column_multipliers) logger.info('Label multiplier: {}'.format( self._regression_column_multipliers)) else: # new training, write config to checkpoint dir self.write_settings() summary_writer.add_session_log( tf.SessionLog(status=tf.SessionLog.START), global_step=global_step.eval()) for epoch in range(epoch_checkpoint.eval(), self._cfg.epochs + 1): self.set_seed(epoch=epoch) self.write_status('status_epoch', str(epoch)) epoch_start_time = timeit.default_timer() loss_sum = 0 r2_validation = None # training (enable data augmentation) self._data_store.set_transforms_enabled(True) for batch in training_loader: print('.', end='', flush=True) feed_dict = self.batch_to_feed_dict( x, y, d, is_train, batch, True) # perform training step _, loss_step = sess.run([train_op, loss], feed_dict=feed_dict) loss_sum = loss_sum + loss_step * len(batch) # disable transformations (data augmentation) for validation self._data_store.set_transforms_enabled(False) # loss on training data cost_train = loss_sum / len(self._subjects_train) if epoch % self._cfg.log_num_epoch == 0: # loss on validation set cost_validation = self.evaluate_loss( sess, loss, validation_loader, x, y, d, is_train) cost_validation_str = '{:.16f}'.format(cost_validation) else: print() cost_validation = None cost_validation_str = '-' logger.info( 'Epoch:{:4.0f}, Loss train: {:.16f}, Loss validation: {}, lr: {:.16f}, dt={:.1f}s' .format(epoch, cost_train, cost_validation_str, learning_rate.eval(), timeit.default_timer() - epoch_start_time)) # don't write loss for first epoch (usually very high) to avoid scaling issue in graph if epoch > 0: # write summary summary = tf.Summary() summary.value.add(tag='loss_train', simple_value=cost_train) if cost_validation: summary.value.add(tag='loss_validation', simple_value=cost_validation) summary_writer.add_summary(summary, epoch) summary_op = tf.summary.merge([sum_lr]) summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, epoch) if epoch % self._cfg.log_eval_num_epoch == 0 and epoch > 0: # calculate and log R2 score on training and validation set eval_start_time = timeit.default_timer() predictions_train, gt_train, subjects_train = self.predict( sess, net, training_loader, x, y, d, is_train) predictions_validation, gt_validation, subjects_validation = self.predict( sess, net, validation_loader, x, y, d, is_train) r2_train = metrics.r2_score(gt_train, predictions_train) r2_validation = metrics.r2_score( gt_validation, predictions_validation) logger.info( 'Epoch:{:4.0f}, R2 train: {:.3f}, R2 validation: {:.8f}, dt={:.1f}s' .format(epoch, r2_train, r2_validation, timeit.default_timer() - eval_start_time)) # write csv with intermediate results self.write_results_csv( 'results_train-{0:04.0f}.csv'.format(epoch), predictions_train, gt_train, subjects_train) self.write_results_csv( 'results_validate-{0:04.0f}.csv'.format(epoch), predictions_validation, gt_validation, subjects_validation) summary = tf.Summary() # average r2 summary.value.add(tag='r2_train', simple_value=r2_train) summary.value.add(tag='r2_validation', simple_value=r2_validation) # add r2 per metric for idx, col_name in enumerate( self._cfg.regression_columns): summary.value.add( tag='train/r2_{}'.format(col_name), simple_value=metrics.r2_score( gt_train[:, idx], predictions_train[:, idx], multioutput='raw_values')) summary.value.add( tag='validation/r2_{}'.format(col_name), simple_value=metrics.r2_score( gt_validation[:, idx], predictions_validation[:, idx], multioutput='raw_values')) summary_writer.add_summary(summary, epoch) if epoch % self._cfg.visualize_layer_num_epoch == 0 and len( sum_histograms) > 0: # write histogram summaries and kernel visualization summary_op = tf.summary.merge(sum_histograms) summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, epoch) if len(kernel_tensor_names) > 0: for idx, kernel_name in enumerate( kernel_tensor_names): # visualize weights of kernel layer from a middle slice kernel_weights = graph.get_tensor_by_name( kernel_name).eval() # make last axis the first kernel_weights = np.moveaxis( kernel_weights, -1, 0) if len(kernel_weights.shape) > 4: # 3d convolution, remove last (single) channel kernel_weights = kernel_weights[:, :, :, :, 0] if kernel_weights.shape[3] > 1: # multiple channels, take example from middle slide slice_num = int(kernel_weights.shape[3] / 2) kernel_weights = kernel_weights[:, :, :, slice_num: slice_num + 1] grid = visualize.make_grid(kernel_weights)[ np.newaxis, :, :, np.newaxis] sess.run(kernel_tensors[idx].assign(grid)) summary_op = tf.summary.merge(sum_kernels) summary_str = sess.run(summary_op) summary_writer.add_summary(summary_str, epoch) summary_writer.flush() if self._cfg.max_timelimit > 0 and ( timeit.default_timer() - start_time > self._cfg.max_timelimit): logger.info( 'Timelimit {}s exceeded. Stopping training...'. format(self._cfg.max_timelimit)) self.checkpoint_safer(sess, saver, epoch_checkpoint, epoch, best_r2_score_checkpoint, True, r2_validation) self.write_status('status_timeout') break else: # epoch done self.checkpoint_safer(sess, saver, epoch_checkpoint, epoch, best_r2_score_checkpoint, False, r2_validation) summary_writer.close() logger.info('Training done.') self.write_status('status_done') if self._best_r2_score > 0: # restore checkpoint of best R2 score checkpoint = os.path.join(self.checkpoint_dir, 'checkpoint-best-r-2.ckpt') saver = tf.train.import_meta_graph(checkpoint + '.meta') saver.restore(sess, checkpoint) logger.info( 'RESTORED best-r-2 checkpoint. Epoch: {}, R2: {:.8f}'. format(epoch_checkpoint.eval(), best_r2_score_checkpoint.eval())) # disable transformations (data augmentation) for test self._data_store.set_transforms_enabled(False) predictions_train, gt_train, subjects_train = self.predict( sess, net, training_loader, x, y, d, is_train) predictions_test, gt_test, subjects_test = self.predict( sess, net, testing_loader, x, y, d, is_train) self.write_results_csv('results_train.csv', predictions_train, gt_train, subjects_train) self.write_results_csv('results_test.csv', predictions_test, gt_test, subjects_test) # Note: use scaled metrics for MSE and unscaled (original) for R^2 if len(gt_train) > 0: accuracy_train = metrics.mean_squared_error( gt_train / self._regression_column_multipliers, predictions_train / self._regression_column_multipliers) r2_train = metrics.r2_score(gt_train, predictions_train) else: accuracy_train = 0 r2_train = 0 if len(subjects_test) > 0: accuracy_test = metrics.mean_squared_error( gt_test / self._regression_column_multipliers, predictions_test / self._regression_column_multipliers) r2_test = metrics.r2_score(gt_test, predictions_test) s, _ = analyze_score.print_summary( subjects_test, self._cfg.regression_columns, predictions_test, gt_test) logger.info('Summary:\n%s-------', s) logger.info( 'TRAIN accuracy(mse): {:.8f}, r2: {:.8f}'.format( accuracy_train, r2_train)) logger.info( 'TEST accuracy(mse): {:.8f}, r2: {:.8f}'.format( accuracy_test, r2_test)) visualize.make_kernel_gif(self.checkpoint_dir, kernel_tensor_names)
def main(hdf_file: str): extractor = extr.ComposeExtractor([ extr.NamesExtractor(), extr.DataExtractor(), extr.SelectiveDataExtractor(), extr.DataExtractor(('numerical', ), ignore_indexing=True), extr.DataExtractor(('sex', ), ignore_indexing=True), extr.DataExtractor(('mask', ), ignore_indexing=False), extr.SubjectExtractor(), extr.FilesExtractor(categories=('images', 'labels', 'mask', 'numerical', 'sex')), extr.IndexingExtractor(), extr.ImagePropertiesExtractor() ]) dataset = extr.ParameterizableDataset(hdf_file, extr.SliceIndexing(), extractor) for i in range(len(dataset)): item = dataset[i] index_expr = item['index_expr'] # type: miapy_data.IndexExpression root = item['file_root'] image = None # type: sitk.Image for i, file in enumerate(item['images_files']): image = sitk.ReadImage(os.path.join(root, file)) np_img = sitk.GetArrayFromImage(image).astype(np.float32) np_img = (np_img - np_img.mean()) / np_img.std() np_slice = np_img[index_expr.expression] if (np_slice != item['images'][..., i]).any(): raise ValueError('slice not equal') # for any image image_properties = conv.ImageProperties(image) if image_properties != item['properties']: raise ValueError('image properties not equal') for file in item['labels_files']: image = sitk.ReadImage(os.path.join(root, file)) np_img = sitk.GetArrayFromImage(image) np_img = np.expand_dims( np_img, axis=-1 ) # due to the convention of having the last dim as number of channels np_slice = np_img[index_expr.expression] if (np_slice != item['labels']).any(): raise ValueError('slice not equal') for file in item['mask_files']: image = sitk.ReadImage(os.path.join(root, file)) np_img = sitk.GetArrayFromImage(image) np_img = np.expand_dims( np_img, axis=-1 ) # due to the convention of having the last dim as number of channels np_slice = np_img[index_expr.expression] if (np_slice != item['mask']).any(): raise ValueError('slice not equal') for file in item['numerical_files']: with open(os.path.join(root, file), 'r') as f: lines = f.readlines() age = float(lines[0].split(':')[1].strip()) gpa = float(lines[1].split(':')[1].strip()) if age != item['numerical'][0][0] or gpa != item['numerical'][0][1]: raise ValueError('value not equal') for file in item['sex_files']: with open(os.path.join(root, file), 'r') as f: sex = f.readlines()[2].split(':')[1].strip() if sex != str(item['sex'][0]): raise ValueError('value not equal') print('All test passed!')
def main(config_file: str): config = cfg.load(config_file, cfg.Configuration) print(config) indexing_strategy = miapy_extr.SliceIndexing() # slice-wise extraction extraction_transform = None # we do not want to apply any transformation on the slices after extraction # define an extractor for training, i.e. what information we would like to extract per sample train_extractor = miapy_extr.ComposeExtractor([ miapy_extr.NamesExtractor(), miapy_extr.DataExtractor(), miapy_extr.SelectiveDataExtractor() ]) # define an extractor for testing, i.e. what information we would like to extract per sample # not that usually we don't use labels for testing, i.e. the SelectiveDataExtractor is only used for this example test_extractor = miapy_extr.ComposeExtractor([ miapy_extr.NamesExtractor(), miapy_extr.IndexingExtractor(), miapy_extr.DataExtractor(), miapy_extr.SelectiveDataExtractor(), miapy_extr.ImageShapeExtractor() ]) # define an extractor for evaluation, i.e. what information we would like to extract per sample eval_extractor = miapy_extr.ComposeExtractor([ miapy_extr.NamesExtractor(), miapy_extr.SubjectExtractor(), miapy_extr.SelectiveDataExtractor(), miapy_extr.ImagePropertiesExtractor() ]) # define the data set dataset = miapy_extr.ParameterizableDataset( config.database_file, indexing_strategy, miapy_extr.SubjectExtractor(), # for select_indices() below extraction_transform) # generate train / test split for data set # we use Subject_0, Subject_1 and Subject_2 for training and Subject_3 for testing sampler_ids_train = miapy_extr.select_indices( dataset, miapy_extr.SubjectSelection(('Subject_0', 'Subject_1', 'Subject_2'))) sampler_ids_test = miapy_extr.select_indices( dataset, miapy_extr.SubjectSelection(('Subject_3'))) # set up training data loader training_sampler = miapy_extr.SubsetRandomSampler(sampler_ids_train) training_loader = miapy_extr.DataLoader(dataset, config.batch_size_training, sampler=training_sampler, collate_fn=collate_batch, num_workers=1) # set up testing data loader testing_sampler = miapy_extr.SubsetSequentialSampler(sampler_ids_test) testing_loader = miapy_extr.DataLoader(dataset, config.batch_size_testing, sampler=testing_sampler, collate_fn=collate_batch, num_workers=1) sample = dataset.direct_extract(train_extractor, 0) # extract a subject evaluator = init_evaluator() # initialize evaluator for epoch in range(config.epochs): # epochs loop dataset.set_extractor(train_extractor) for batch in training_loader: # batches for training # feed_dict = batch_to_feed_dict(x, y, batch, True) # e.g. for TensorFlow # train model, e.g.: # sess.run([train_op, loss], feed_dict=feed_dict) pass # subject assembler for testing subject_assembler = miapy_asmbl.SubjectAssembler() dataset.set_extractor(test_extractor) for batch in testing_loader: # batches for testing # feed_dict = batch_to_feed_dict(x, y, batch, False) # e.g. for TensorFlow # test model, e.g.: # prediction = sess.run(y_model, feed_dict=feed_dict) prediction = np.stack( batch['labels'], axis=0 ) # we use the labels as predictions such that we can validate the assembler subject_assembler.add_batch(prediction, batch) # evaluate all test images for subject_idx in list(subject_assembler.predictions.keys()): # convert prediction and labels back to SimpleITK images sample = dataset.direct_extract(eval_extractor, subject_idx) label_image = miapy_conv.NumpySimpleITKImageBridge.convert( sample['labels'], sample['properties']) assembled = subject_assembler.get_assembled_subject( sample['subject_index']) prediction_image = miapy_conv.NumpySimpleITKImageBridge.convert( assembled, sample['properties']) evaluator.evaluate(prediction_image, label_image, sample['subject']) # evaluate prediction
def main_predict(cfg_file, checkpoint_file, hdf_file, subjects_file, out_csv): cfg = Configuration.load(cfg_file) trainer = training.Trainer(cfg, 0) data_store = data_storage.DataStore(hdf_file if hdf_file else cfg.hdf_file) if subjects_file: with open(subjects_file, 'r') as file: subjects = [s.rstrip() for s in file.readlines()] else: subjects = [s['subject'] for s in data_store.dataset] validation_loader = data_store.get_loader(cfg.batch_size_eval, subjects, 0) validation_extractor = miapy_extr.ComposeExtractor([ miapy_extr.DataExtractor(), miapy_extr.SelectiveDataExtractor( category=data_storage.STORE_MORPHOMETRICS), miapy_extr.SubjectExtractor(), data_storage.DemographicsExtractor() ]) data_store.dataset.set_extractor(validation_extractor) column_values, column_names = data_store.get_all_metrics() trainer._regression_column_ids = np.array( [column_names.index(name) for name in cfg.regression_columns]) trainer._regression_column_multipliers = np.array(cfg.z_column_multipliers) with tf.Graph().as_default() as graph: init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) print('Using checkpoint {}'.format(checkpoint_file)) saver = tf.train.import_meta_graph(checkpoint_file + '.meta') saver.restore(sess, checkpoint_file) net = graph.get_tensor_by_name('NET/model:0') x_placeholder = graph.get_tensor_by_name('x:0') y_placeholder = graph.get_tensor_by_name('y:0') d_placeholder = graph.get_tensor_by_name('d:0') is_train_placeholder = graph.get_tensor_by_name('is_train:0') epoch_checkpoint = graph.get_tensor_by_name('epoch:0') print('Epoch from checkpoint: {}'.format(epoch_checkpoint.eval())) predictions, gt, pred_subjects = trainer.predict( sess, net, validation_loader, x_placeholder, y_placeholder, d_placeholder, is_train_placeholder) trainer.write_results_csv(out_csv, predictions, gt, pred_subjects) if len(pred_subjects) > 1: s, _ = analyze_score.print_summary(subjects, cfg.regression_columns, predictions, gt) print(s) if len(pred_subjects) != len(subjects): print( "WARN: Number of subjects in predictions ({}) != given ({})" .format(len(pred_subjects), len(subjects)))