Exemplo n.º 1
0
    def get_all_metrics(self) -> typing.Tuple[np.ndarray, list]:
        """
        Get the metrics from all subjects and the corresponding column names
        
        :return: (metrics, column_names)
        """

        metrics = [self.dataset.direct_extract(
            miapy_extr.SelectiveDataExtractor(category=STORE_MORPHOMETRICS), idx) for idx in range(len(self.dataset))]
        column_names = self.dataset.reader.read(STORE_META_MORPHOMETRY_COLUMNS)

        return np.stack(self.collate_batch(metrics)[STORE_MORPHOMETRICS], axis=0), column_names
Exemplo n.º 2
0
    def train(self):
        self.write_status('status_init')
        start_time = timeit.default_timer()
        self.set_seed(epoch=0)

        transform = self.get_transform()
        self._data_store = data_storage.DataStore(self._cfg.hdf_file,
                                                  transform)
        dataset = self._data_store.dataset

        self.assign_subjects()

        # prepare loaders and extractors
        training_loader = self._data_store.get_loader(self._cfg.batch_size,
                                                      self._subjects_train,
                                                      self._num_workers)
        validation_loader = self._data_store.get_loader(
            self._cfg.batch_size_eval, self._subjects_validate,
            self._num_workers)
        testing_loader = self._data_store.get_loader(self._cfg.batch_size_eval,
                                                     self._subjects_test,
                                                     self._num_workers)

        train_extractor = miapy_extr.ComposeExtractor([
            miapy_extr.DataExtractor(),
            miapy_extr.SelectiveDataExtractor(
                category=data_storage.STORE_MORPHOMETRICS),
            miapy_extr.SubjectExtractor(),
            data_storage.DemographicsExtractor()
        ])

        dataset.set_extractor(train_extractor)

        # read all labels to calculate multiplier
        column_values, column_names = self._data_store.get_all_metrics()
        self._regression_column_ids = np.array([
            column_names.index(name) for name in self._cfg.regression_columns
        ])
        self._regression_column_multipliers = np.max(np.abs(
            column_values[:, self._regression_column_ids]),
                                                     axis=0)

        model_net.SCALE = float(self._data_store.get_intensity_scale_max())

        n_batches = int(
            np.ceil(len(self._subjects_train) / self._cfg.batch_size))

        logger.info('Net: {}, scale: {}'.format(
            inspect.getsource(self.get_python_obj(self._cfg.model)),
            model_net.SCALE))
        logger.info('Train: {}, Validation: {}, Test: {}'.format(
            len(self._subjects_train), len(self._subjects_validate),
            len(self._subjects_test)))
        logger.info('Label multiplier: {}'.format(
            self._regression_column_multipliers))
        logger.info('n_batches: {}'.format(n_batches))
        logger.info(self._cfg)
        logger.info('checkpoints dir: {}'.format(self.checkpoint_dir))

        sample = dataset.direct_extract(train_extractor,
                                        0)  # extract a subject to obtain shape

        with tf.Graph().as_default() as graph:
            self.set_seed(epoch=0)  # set again as seed is per graph

            x = tf.placeholder(tf.float32, (None, ) +
                               sample[data_storage.STORE_IMAGES].shape[0:],
                               name='x')
            y = tf.placeholder(tf.float32,
                               (None, len(self._regression_column_ids)),
                               name='y')
            d = tf.placeholder(tf.float32, (None, 2), name='d')  # age, sex
            is_train = tf.placeholder(tf.bool, shape=(), name='is_train')

            global_step = tf.train.get_or_create_global_step()
            epoch_checkpoint = tf.Variable(0, name='epoch')
            best_r2_score_checkpoint = tf.Variable(0.0,
                                                   name='best_r2_score',
                                                   dtype=tf.float64)

            net = self.get_python_obj(self._cfg.model)({
                'x': x,
                'y': y,
                'd': d,
                'is_train': is_train
            })
            optimizer = None
            loss = None

            if self._cfg.loss_function == 'mse':
                loss = tf.losses.mean_squared_error(labels=y, predictions=net)
            elif self._cfg.loss_function == 'absdiff':
                loss = tf.losses.absolute_difference(labels=y, predictions=net)

            if self._cfg.learning_rate_decay_rate is not None and self._cfg.learning_rate_decay_rate > 0:
                learning_rate = tf.train.exponential_decay(
                    self._cfg.learning_rate, global_step,
                    self._cfg.learning_rate_decay_steps,
                    self._cfg.learning_rate_decay_rate)
            else:
                learning_rate = tf.Variable(self._cfg.learning_rate, name='lr')

            if self._cfg.optimizer == 'SGD':
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
            elif self._cfg.optimizer == 'Adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self._cfg.learning_rate)

            with tf.control_dependencies(
                    tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS)):  # required for batch_norm
                train_op = optimizer.minimize(loss=loss,
                                              global_step=global_step)

            sum_lr = tf.summary.scalar('learning_rate', learning_rate)

            # collect variables to add to summaries (histograms, kernel weight images)
            sum_histograms = []
            sum_kernels = []
            kernel_tensors = []
            kernel_tensor_names = []
            for v in tf.global_variables():
                m = re.match('NET/(layer\d+_\w+)/(kernel|bias):0', v.name)
                if m:
                    var = graph.get_tensor_by_name(v.name)
                    sum_histograms.append(
                        tf.summary.histogram(
                            '{}/{}'.format(m.group(1), m.group(2)), var))
                    #sum_histograms.append(tf.summary.histogram('{}/{}'.format(m.group(1), m.group(2)), tf.Variable(tf.zeros([1,1,1,1]))))

                    if m.group(2) == 'kernel' and m.group(1).endswith('conv'):
                        kernel_tensor_names.append(v.name)
                        h, w = visualize.get_grid_size(
                            var.get_shape().as_list())
                        img = tf.Variable(tf.zeros([1, h, w, 1]))
                        kernel_tensors.append(img)
                        sum_kernels.append(tf.summary.image(m.group(1), img))

            summary_writer = tf.summary.FileWriter(
                os.path.join(self.checkpoint_dir, 'tb_logs'),
                tf.get_default_graph())

            init = tf.global_variables_initializer()
            # Saver keeps only 5 per default - make sure best-r2-checkpoint remains!
            # TODO: maybe set max_to_keep=None?
            saver = tf.train.Saver(max_to_keep=self._cfg.checkpoint_keep + 2)

            with tf.Session() as sess:
                sess.run(init)

                checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)
                if checkpoint:
                    # existing checkpoint found, restoring...
                    if not self._cfg.subjects_train_val_test_file:
                        msg = 'Continue training, but no fixed subject assignments found. ' \
                              'Set subjects_train_val_test_file in config.'
                        logger.error(msg)
                        raise RuntimeError(msg)

                    logger.info('Restoring from ' + checkpoint)
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    if saver._max_to_keep < self._cfg.checkpoint_keep + 2:
                        msg = 'ERROR: Restored saver._max_to_keep={}, but self._cfg.checkpoint_keep={}'.format(
                            saver._max_to_keep, self._cfg.checkpoint_keep)
                        print(msg)
                        logger.error(msg)
                        exit(1)

                    sess.run(epoch_checkpoint.assign_add(tf.constant(1)))
                    self._best_r2_score = best_r2_score_checkpoint.eval()
                    logger.info('Continue with epoch %i (best r2: %f)',
                                epoch_checkpoint.eval(), self._best_r2_score)
                    self._checkpoint_idx = int(
                        re.match('.*/checkpoint-(\d+).ckpt',
                                 checkpoint).group(1))

                    # load column multipliers from file if available
                    cfg_file = os.path.join(self.checkpoint_dir, 'config.json')
                    if os.path.exists(cfg_file):
                        cfg_tmp = Configuration.load(cfg_file)
                        if len(cfg_tmp.z_column_multipliers) > 0:
                            logger.info('Loading column multipliers from %s',
                                        cfg_file)
                            self._regression_column_multipliers = np.array(
                                cfg_tmp.z_column_multipliers)
                            logger.info('Label multiplier: {}'.format(
                                self._regression_column_multipliers))
                else:
                    # new training, write config to checkpoint dir
                    self.write_settings()

                summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=global_step.eval())

                for epoch in range(epoch_checkpoint.eval(),
                                   self._cfg.epochs + 1):
                    self.set_seed(epoch=epoch)

                    self.write_status('status_epoch', str(epoch))
                    epoch_start_time = timeit.default_timer()
                    loss_sum = 0
                    r2_validation = None

                    # training (enable data augmentation)
                    self._data_store.set_transforms_enabled(True)
                    for batch in training_loader:
                        print('.', end='', flush=True)

                        feed_dict = self.batch_to_feed_dict(
                            x, y, d, is_train, batch, True)

                        # perform training step
                        _, loss_step = sess.run([train_op, loss],
                                                feed_dict=feed_dict)
                        loss_sum = loss_sum + loss_step * len(batch)

                    # disable transformations (data augmentation) for validation
                    self._data_store.set_transforms_enabled(False)

                    # loss on training data
                    cost_train = loss_sum / len(self._subjects_train)

                    if epoch % self._cfg.log_num_epoch == 0:
                        # loss on validation set
                        cost_validation = self.evaluate_loss(
                            sess, loss, validation_loader, x, y, d, is_train)
                        cost_validation_str = '{:.16f}'.format(cost_validation)
                    else:
                        print()
                        cost_validation = None
                        cost_validation_str = '-'

                    logger.info(
                        'Epoch:{:4.0f}, Loss train: {:.16f}, Loss validation: {}, lr: {:.16f}, dt={:.1f}s'
                        .format(epoch, cost_train, cost_validation_str,
                                learning_rate.eval(),
                                timeit.default_timer() - epoch_start_time))

                    # don't write loss for first epoch (usually very high) to avoid scaling issue in graph
                    if epoch > 0:
                        # write summary
                        summary = tf.Summary()
                        summary.value.add(tag='loss_train',
                                          simple_value=cost_train)
                        if cost_validation:
                            summary.value.add(tag='loss_validation',
                                              simple_value=cost_validation)
                        summary_writer.add_summary(summary, epoch)

                        summary_op = tf.summary.merge([sum_lr])
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    if epoch % self._cfg.log_eval_num_epoch == 0 and epoch > 0:
                        # calculate and log R2 score on training and validation set
                        eval_start_time = timeit.default_timer()
                        predictions_train, gt_train, subjects_train = self.predict(
                            sess, net, training_loader, x, y, d, is_train)
                        predictions_validation, gt_validation, subjects_validation = self.predict(
                            sess, net, validation_loader, x, y, d, is_train)
                        r2_train = metrics.r2_score(gt_train,
                                                    predictions_train)
                        r2_validation = metrics.r2_score(
                            gt_validation, predictions_validation)
                        logger.info(
                            'Epoch:{:4.0f}, R2 train: {:.3f}, R2 validation: {:.8f}, dt={:.1f}s'
                            .format(epoch, r2_train, r2_validation,
                                    timeit.default_timer() - eval_start_time))

                        # write csv with intermediate results
                        self.write_results_csv(
                            'results_train-{0:04.0f}.csv'.format(epoch),
                            predictions_train, gt_train, subjects_train)
                        self.write_results_csv(
                            'results_validate-{0:04.0f}.csv'.format(epoch),
                            predictions_validation, gt_validation,
                            subjects_validation)

                        summary = tf.Summary()
                        # average r2
                        summary.value.add(tag='r2_train',
                                          simple_value=r2_train)
                        summary.value.add(tag='r2_validation',
                                          simple_value=r2_validation)

                        # add r2 per metric
                        for idx, col_name in enumerate(
                                self._cfg.regression_columns):
                            summary.value.add(
                                tag='train/r2_{}'.format(col_name),
                                simple_value=metrics.r2_score(
                                    gt_train[:, idx],
                                    predictions_train[:, idx],
                                    multioutput='raw_values'))
                            summary.value.add(
                                tag='validation/r2_{}'.format(col_name),
                                simple_value=metrics.r2_score(
                                    gt_validation[:, idx],
                                    predictions_validation[:, idx],
                                    multioutput='raw_values'))

                        summary_writer.add_summary(summary, epoch)

                    if epoch % self._cfg.visualize_layer_num_epoch == 0 and len(
                            sum_histograms) > 0:
                        # write histogram summaries and kernel visualization
                        summary_op = tf.summary.merge(sum_histograms)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                        if len(kernel_tensor_names) > 0:
                            for idx, kernel_name in enumerate(
                                    kernel_tensor_names):
                                # visualize weights of kernel layer from a middle slice
                                kernel_weights = graph.get_tensor_by_name(
                                    kernel_name).eval()
                                # make last axis the first
                                kernel_weights = np.moveaxis(
                                    kernel_weights, -1, 0)

                                if len(kernel_weights.shape) > 4:
                                    # 3d convolution, remove last (single) channel
                                    kernel_weights = kernel_weights[:, :, :, :,
                                                                    0]

                                if kernel_weights.shape[3] > 1:
                                    # multiple channels, take example from middle slide
                                    slice_num = int(kernel_weights.shape[3] /
                                                    2)
                                    kernel_weights = kernel_weights[:, :, :,
                                                                    slice_num:
                                                                    slice_num +
                                                                    1]

                                grid = visualize.make_grid(kernel_weights)[
                                    np.newaxis, :, :, np.newaxis]
                                sess.run(kernel_tensors[idx].assign(grid))

                            summary_op = tf.summary.merge(sum_kernels)
                            summary_str = sess.run(summary_op)
                            summary_writer.add_summary(summary_str, epoch)

                    summary_writer.flush()

                    if self._cfg.max_timelimit > 0 and (
                            timeit.default_timer() - start_time >
                            self._cfg.max_timelimit):
                        logger.info(
                            'Timelimit {}s exceeded. Stopping training...'.
                            format(self._cfg.max_timelimit))
                        self.checkpoint_safer(sess, saver, epoch_checkpoint,
                                              epoch, best_r2_score_checkpoint,
                                              True, r2_validation)
                        self.write_status('status_timeout')
                        break
                    else:
                        # epoch done
                        self.checkpoint_safer(sess, saver, epoch_checkpoint,
                                              epoch, best_r2_score_checkpoint,
                                              False, r2_validation)

                summary_writer.close()
                logger.info('Training done.')
                self.write_status('status_done')

                if self._best_r2_score > 0:
                    # restore checkpoint of best R2 score
                    checkpoint = os.path.join(self.checkpoint_dir,
                                              'checkpoint-best-r-2.ckpt')
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    logger.info(
                        'RESTORED best-r-2 checkpoint. Epoch: {}, R2: {:.8f}'.
                        format(epoch_checkpoint.eval(),
                               best_r2_score_checkpoint.eval()))

                # disable transformations (data augmentation) for test
                self._data_store.set_transforms_enabled(False)

                predictions_train, gt_train, subjects_train = self.predict(
                    sess, net, training_loader, x, y, d, is_train)
                predictions_test, gt_test, subjects_test = self.predict(
                    sess, net, testing_loader, x, y, d, is_train)

                self.write_results_csv('results_train.csv', predictions_train,
                                       gt_train, subjects_train)
                self.write_results_csv('results_test.csv', predictions_test,
                                       gt_test, subjects_test)

                # Note: use scaled metrics for MSE and unscaled (original) for R^2
                if len(gt_train) > 0:
                    accuracy_train = metrics.mean_squared_error(
                        gt_train / self._regression_column_multipliers,
                        predictions_train /
                        self._regression_column_multipliers)
                    r2_train = metrics.r2_score(gt_train, predictions_train)
                else:
                    accuracy_train = 0
                    r2_train = 0

                if len(subjects_test) > 0:
                    accuracy_test = metrics.mean_squared_error(
                        gt_test / self._regression_column_multipliers,
                        predictions_test / self._regression_column_multipliers)
                    r2_test = metrics.r2_score(gt_test, predictions_test)

                    s, _ = analyze_score.print_summary(
                        subjects_test, self._cfg.regression_columns,
                        predictions_test, gt_test)
                    logger.info('Summary:\n%s-------', s)

                    logger.info(
                        'TRAIN accuracy(mse): {:.8f}, r2: {:.8f}'.format(
                            accuracy_train, r2_train))
                    logger.info(
                        'TEST  accuracy(mse): {:.8f}, r2: {:.8f}'.format(
                            accuracy_test, r2_test))

                visualize.make_kernel_gif(self.checkpoint_dir,
                                          kernel_tensor_names)
Exemplo n.º 3
0
def main(hdf_file: str):
    extractor = extr.ComposeExtractor([
        extr.NamesExtractor(),
        extr.DataExtractor(),
        extr.SelectiveDataExtractor(),
        extr.DataExtractor(('numerical', ), ignore_indexing=True),
        extr.DataExtractor(('sex', ), ignore_indexing=True),
        extr.DataExtractor(('mask', ), ignore_indexing=False),
        extr.SubjectExtractor(),
        extr.FilesExtractor(categories=('images', 'labels', 'mask',
                                        'numerical', 'sex')),
        extr.IndexingExtractor(),
        extr.ImagePropertiesExtractor()
    ])
    dataset = extr.ParameterizableDataset(hdf_file, extr.SliceIndexing(),
                                          extractor)

    for i in range(len(dataset)):
        item = dataset[i]

        index_expr = item['index_expr']  # type: miapy_data.IndexExpression
        root = item['file_root']

        image = None  # type: sitk.Image
        for i, file in enumerate(item['images_files']):
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image).astype(np.float32)
            np_img = (np_img - np_img.mean()) / np_img.std()
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['images'][..., i]).any():
                raise ValueError('slice not equal')

        # for any image
        image_properties = conv.ImageProperties(image)

        if image_properties != item['properties']:
            raise ValueError('image properties not equal')

        for file in item['labels_files']:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['labels']).any():
                raise ValueError('slice not equal')

        for file in item['mask_files']:
            image = sitk.ReadImage(os.path.join(root, file))
            np_img = sitk.GetArrayFromImage(image)
            np_img = np.expand_dims(
                np_img, axis=-1
            )  # due to the convention of having the last dim as number of channels
            np_slice = np_img[index_expr.expression]
            if (np_slice != item['mask']).any():
                raise ValueError('slice not equal')

        for file in item['numerical_files']:
            with open(os.path.join(root, file), 'r') as f:
                lines = f.readlines()
            age = float(lines[0].split(':')[1].strip())
            gpa = float(lines[1].split(':')[1].strip())
            if age != item['numerical'][0][0] or gpa != item['numerical'][0][1]:
                raise ValueError('value not equal')

        for file in item['sex_files']:
            with open(os.path.join(root, file), 'r') as f:
                sex = f.readlines()[2].split(':')[1].strip()
            if sex != str(item['sex'][0]):
                raise ValueError('value not equal')

    print('All test passed!')
Exemplo n.º 4
0
def main(config_file: str):
    config = cfg.load(config_file, cfg.Configuration)
    print(config)

    indexing_strategy = miapy_extr.SliceIndexing()  # slice-wise extraction
    extraction_transform = None  # we do not want to apply any transformation on the slices after extraction
    # define an extractor for training, i.e. what information we would like to extract per sample
    train_extractor = miapy_extr.ComposeExtractor([
        miapy_extr.NamesExtractor(),
        miapy_extr.DataExtractor(),
        miapy_extr.SelectiveDataExtractor()
    ])

    # define an extractor for testing, i.e. what information we would like to extract per sample
    # not that usually we don't use labels for testing, i.e. the SelectiveDataExtractor is only used for this example
    test_extractor = miapy_extr.ComposeExtractor([
        miapy_extr.NamesExtractor(),
        miapy_extr.IndexingExtractor(),
        miapy_extr.DataExtractor(),
        miapy_extr.SelectiveDataExtractor(),
        miapy_extr.ImageShapeExtractor()
    ])

    # define an extractor for evaluation, i.e. what information we would like to extract per sample
    eval_extractor = miapy_extr.ComposeExtractor([
        miapy_extr.NamesExtractor(),
        miapy_extr.SubjectExtractor(),
        miapy_extr.SelectiveDataExtractor(),
        miapy_extr.ImagePropertiesExtractor()
    ])

    # define the data set
    dataset = miapy_extr.ParameterizableDataset(
        config.database_file,
        indexing_strategy,
        miapy_extr.SubjectExtractor(),  # for select_indices() below
        extraction_transform)

    # generate train / test split for data set
    # we use Subject_0, Subject_1 and Subject_2 for training and Subject_3 for testing
    sampler_ids_train = miapy_extr.select_indices(
        dataset,
        miapy_extr.SubjectSelection(('Subject_0', 'Subject_1', 'Subject_2')))
    sampler_ids_test = miapy_extr.select_indices(
        dataset, miapy_extr.SubjectSelection(('Subject_3')))

    # set up training data loader
    training_sampler = miapy_extr.SubsetRandomSampler(sampler_ids_train)
    training_loader = miapy_extr.DataLoader(dataset,
                                            config.batch_size_training,
                                            sampler=training_sampler,
                                            collate_fn=collate_batch,
                                            num_workers=1)

    # set up testing data loader
    testing_sampler = miapy_extr.SubsetSequentialSampler(sampler_ids_test)
    testing_loader = miapy_extr.DataLoader(dataset,
                                           config.batch_size_testing,
                                           sampler=testing_sampler,
                                           collate_fn=collate_batch,
                                           num_workers=1)

    sample = dataset.direct_extract(train_extractor, 0)  # extract a subject

    evaluator = init_evaluator()  # initialize evaluator

    for epoch in range(config.epochs):  # epochs loop
        dataset.set_extractor(train_extractor)
        for batch in training_loader:  # batches for training
            # feed_dict = batch_to_feed_dict(x, y, batch, True)  # e.g. for TensorFlow
            # train model, e.g.:
            # sess.run([train_op, loss], feed_dict=feed_dict)
            pass

        # subject assembler for testing
        subject_assembler = miapy_asmbl.SubjectAssembler()

        dataset.set_extractor(test_extractor)
        for batch in testing_loader:  # batches for testing
            # feed_dict = batch_to_feed_dict(x, y, batch, False)  # e.g. for TensorFlow
            # test model, e.g.:
            # prediction = sess.run(y_model, feed_dict=feed_dict)
            prediction = np.stack(
                batch['labels'], axis=0
            )  # we use the labels as predictions such that we can validate the assembler
            subject_assembler.add_batch(prediction, batch)

        # evaluate all test images
        for subject_idx in list(subject_assembler.predictions.keys()):
            # convert prediction and labels back to SimpleITK images
            sample = dataset.direct_extract(eval_extractor, subject_idx)
            label_image = miapy_conv.NumpySimpleITKImageBridge.convert(
                sample['labels'], sample['properties'])

            assembled = subject_assembler.get_assembled_subject(
                sample['subject_index'])
            prediction_image = miapy_conv.NumpySimpleITKImageBridge.convert(
                assembled, sample['properties'])
            evaluator.evaluate(prediction_image, label_image,
                               sample['subject'])  # evaluate prediction
Exemplo n.º 5
0
def main_predict(cfg_file, checkpoint_file, hdf_file, subjects_file, out_csv):
    cfg = Configuration.load(cfg_file)

    trainer = training.Trainer(cfg, 0)
    data_store = data_storage.DataStore(hdf_file if hdf_file else cfg.hdf_file)

    if subjects_file:
        with open(subjects_file, 'r') as file:
            subjects = [s.rstrip() for s in file.readlines()]
    else:
        subjects = [s['subject'] for s in data_store.dataset]

    validation_loader = data_store.get_loader(cfg.batch_size_eval, subjects, 0)
    validation_extractor = miapy_extr.ComposeExtractor([
        miapy_extr.DataExtractor(),
        miapy_extr.SelectiveDataExtractor(
            category=data_storage.STORE_MORPHOMETRICS),
        miapy_extr.SubjectExtractor(),
        data_storage.DemographicsExtractor()
    ])

    data_store.dataset.set_extractor(validation_extractor)

    column_values, column_names = data_store.get_all_metrics()
    trainer._regression_column_ids = np.array(
        [column_names.index(name) for name in cfg.regression_columns])
    trainer._regression_column_multipliers = np.array(cfg.z_column_multipliers)

    with tf.Graph().as_default() as graph:
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            sess.run(init)

            print('Using checkpoint {}'.format(checkpoint_file))
            saver = tf.train.import_meta_graph(checkpoint_file + '.meta')
            saver.restore(sess, checkpoint_file)

            net = graph.get_tensor_by_name('NET/model:0')
            x_placeholder = graph.get_tensor_by_name('x:0')
            y_placeholder = graph.get_tensor_by_name('y:0')
            d_placeholder = graph.get_tensor_by_name('d:0')
            is_train_placeholder = graph.get_tensor_by_name('is_train:0')
            epoch_checkpoint = graph.get_tensor_by_name('epoch:0')

            print('Epoch from checkpoint: {}'.format(epoch_checkpoint.eval()))

            predictions, gt, pred_subjects = trainer.predict(
                sess, net, validation_loader, x_placeholder, y_placeholder,
                d_placeholder, is_train_placeholder)

            trainer.write_results_csv(out_csv, predictions, gt, pred_subjects)
            if len(pred_subjects) > 1:
                s, _ = analyze_score.print_summary(subjects,
                                                   cfg.regression_columns,
                                                   predictions, gt)
                print(s)

            if len(pred_subjects) != len(subjects):
                print(
                    "WARN: Number of subjects in predictions ({}) != given ({})"
                    .format(len(pred_subjects), len(subjects)))