Exemplo n.º 1
0
def main(hdf_file: str, intensity_rescale_max: int, normalize_zscore: bool,
         center: bool, data_dir: str):
    transformers = []

    if center:
        print('Applying center centroid')
        transformers.append(data_preproc.CenterCentroidTransform())

    if intensity_rescale_max:
        print('Applying intensity rescale {}'.format(intensity_rescale_max))
        transformers.append(
            miapy_tfm.IntensityRescale(0, intensity_rescale_max))
        if intensity_rescale_max <= 255:
            transformers.append(DTypeTransform(np.uint8))
        else:
            transformers.append(DTypeTransform(np.float16))
    if normalize_zscore:
        transformers.append(miapy_tfm.IntensityNormalization())

    transformer = miapy_tfm.ComposeTransform(transformers)

    collector = data_collector.Brats18Collector(data_dir)

    if os.path.exists(hdf_file):
        print('Overriding existing {}'.format(hdf_file))
        os.remove(hdf_file)

    store = data_storage.DataStore(hdf_file)
    store.import_data(collector.subject_files, intensity_rescale_max,
                      normalize_zscore, transformer)

    print('{} subjects imported to {}'.format(len(collector.subject_files),
                                              hdf_file))
def main(hdf_file: str, intensity_rescale_max: int, center: bool,
         clip_intensity: bool, brain_pattern: str, csv_file: str,
         directories: list):
    transformers = []

    if clip_intensity:
        print('Applying clip negative intensities')
        transformers.append(data_preproc.ClipNegativeTransform())

    if center:
        print('Applying center centroid')
        transformers.append(data_preproc.CenterCentroidTransform())

    if intensity_rescale_max:
        print('Applying intensity rescale {}'.format(intensity_rescale_max))
        transformers.append(
            miapy_tfm.IntensityRescale(0, intensity_rescale_max))
        if intensity_rescale_max <= 255:
            transformers.append(DTypeTransform(np.uint8))
        else:
            transformers.append(DTypeTransform(np.float16))

    transformer = miapy_tfm.ComposeTransform(transformers)

    meta_dict = read_meta(csv_file) if csv_file else None
    subjects = []
    for directory in directories:
        assert_exists(directory)
        for subject_dir in glob.glob(os.path.join(directory, '*')):
            subject_id = os.path.basename(subject_dir)
            if meta_dict and subject_id not in meta_dict:
                continue

            brain_file = os.path.join(subject_dir, brain_pattern)
            stats_dir = os.path.join(subject_dir, 'stats')
            assert_exists(brain_file)
            assert_exists(stats_dir)

            subjects.append(
                data_storage.Subject(
                    subject_id, {
                        data_storage.FileTypes.BRAIN_MRI: brain_file,
                        data_storage.FileTypes.MORPHOMETRY_STATS: stats_dir
                    }))

    if os.path.exists(hdf_file):
        print('Overriding existing {}'.format(hdf_file))
        os.remove(hdf_file)

    store = data_storage.DataStore(hdf_file)
    store.import_data(subjects, intensity_rescale_max, transformer, meta_dict)

    print('{} subjects imported to {}'.format(len(subjects), hdf_file))
Exemplo n.º 3
0
    def train(self):
        self.write_status('status_init')
        start_time = timeit.default_timer()
        self.set_seed(epoch=0)

        transform = self.get_transform()
        self._data_store = data_storage.DataStore(self._cfg.hdf_file,
                                                  transform)
        dataset = self._data_store.dataset

        self.assign_subjects()

        # prepare loaders and extractors
        training_loader = self._data_store.get_loader(self._cfg.batch_size,
                                                      self._subjects_train,
                                                      self._num_workers)
        validation_loader = self._data_store.get_loader(
            self._cfg.batch_size_eval, self._subjects_validate,
            self._num_workers)
        testing_loader = self._data_store.get_loader(self._cfg.batch_size_eval,
                                                     self._subjects_test,
                                                     self._num_workers)

        train_extractor = miapy_extr.ComposeExtractor([
            miapy_extr.DataExtractor(),
            miapy_extr.SelectiveDataExtractor(
                category=data_storage.STORE_MORPHOMETRICS),
            miapy_extr.SubjectExtractor(),
            data_storage.DemographicsExtractor()
        ])

        dataset.set_extractor(train_extractor)

        # read all labels to calculate multiplier
        column_values, column_names = self._data_store.get_all_metrics()
        self._regression_column_ids = np.array([
            column_names.index(name) for name in self._cfg.regression_columns
        ])
        self._regression_column_multipliers = np.max(np.abs(
            column_values[:, self._regression_column_ids]),
                                                     axis=0)

        model_net.SCALE = float(self._data_store.get_intensity_scale_max())

        n_batches = int(
            np.ceil(len(self._subjects_train) / self._cfg.batch_size))

        logger.info('Net: {}, scale: {}'.format(
            inspect.getsource(self.get_python_obj(self._cfg.model)),
            model_net.SCALE))
        logger.info('Train: {}, Validation: {}, Test: {}'.format(
            len(self._subjects_train), len(self._subjects_validate),
            len(self._subjects_test)))
        logger.info('Label multiplier: {}'.format(
            self._regression_column_multipliers))
        logger.info('n_batches: {}'.format(n_batches))
        logger.info(self._cfg)
        logger.info('checkpoints dir: {}'.format(self.checkpoint_dir))

        sample = dataset.direct_extract(train_extractor,
                                        0)  # extract a subject to obtain shape

        with tf.Graph().as_default() as graph:
            self.set_seed(epoch=0)  # set again as seed is per graph

            x = tf.placeholder(tf.float32, (None, ) +
                               sample[data_storage.STORE_IMAGES].shape[0:],
                               name='x')
            y = tf.placeholder(tf.float32,
                               (None, len(self._regression_column_ids)),
                               name='y')
            d = tf.placeholder(tf.float32, (None, 2), name='d')  # age, sex
            is_train = tf.placeholder(tf.bool, shape=(), name='is_train')

            global_step = tf.train.get_or_create_global_step()
            epoch_checkpoint = tf.Variable(0, name='epoch')
            best_r2_score_checkpoint = tf.Variable(0.0,
                                                   name='best_r2_score',
                                                   dtype=tf.float64)

            net = self.get_python_obj(self._cfg.model)({
                'x': x,
                'y': y,
                'd': d,
                'is_train': is_train
            })
            optimizer = None
            loss = None

            if self._cfg.loss_function == 'mse':
                loss = tf.losses.mean_squared_error(labels=y, predictions=net)
            elif self._cfg.loss_function == 'absdiff':
                loss = tf.losses.absolute_difference(labels=y, predictions=net)

            if self._cfg.learning_rate_decay_rate is not None and self._cfg.learning_rate_decay_rate > 0:
                learning_rate = tf.train.exponential_decay(
                    self._cfg.learning_rate, global_step,
                    self._cfg.learning_rate_decay_steps,
                    self._cfg.learning_rate_decay_rate)
            else:
                learning_rate = tf.Variable(self._cfg.learning_rate, name='lr')

            if self._cfg.optimizer == 'SGD':
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
            elif self._cfg.optimizer == 'Adam':
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self._cfg.learning_rate)

            with tf.control_dependencies(
                    tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS)):  # required for batch_norm
                train_op = optimizer.minimize(loss=loss,
                                              global_step=global_step)

            sum_lr = tf.summary.scalar('learning_rate', learning_rate)

            # collect variables to add to summaries (histograms, kernel weight images)
            sum_histograms = []
            sum_kernels = []
            kernel_tensors = []
            kernel_tensor_names = []
            for v in tf.global_variables():
                m = re.match('NET/(layer\d+_\w+)/(kernel|bias):0', v.name)
                if m:
                    var = graph.get_tensor_by_name(v.name)
                    sum_histograms.append(
                        tf.summary.histogram(
                            '{}/{}'.format(m.group(1), m.group(2)), var))
                    #sum_histograms.append(tf.summary.histogram('{}/{}'.format(m.group(1), m.group(2)), tf.Variable(tf.zeros([1,1,1,1]))))

                    if m.group(2) == 'kernel' and m.group(1).endswith('conv'):
                        kernel_tensor_names.append(v.name)
                        h, w = visualize.get_grid_size(
                            var.get_shape().as_list())
                        img = tf.Variable(tf.zeros([1, h, w, 1]))
                        kernel_tensors.append(img)
                        sum_kernels.append(tf.summary.image(m.group(1), img))

            summary_writer = tf.summary.FileWriter(
                os.path.join(self.checkpoint_dir, 'tb_logs'),
                tf.get_default_graph())

            init = tf.global_variables_initializer()
            # Saver keeps only 5 per default - make sure best-r2-checkpoint remains!
            # TODO: maybe set max_to_keep=None?
            saver = tf.train.Saver(max_to_keep=self._cfg.checkpoint_keep + 2)

            with tf.Session() as sess:
                sess.run(init)

                checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)
                if checkpoint:
                    # existing checkpoint found, restoring...
                    if not self._cfg.subjects_train_val_test_file:
                        msg = 'Continue training, but no fixed subject assignments found. ' \
                              'Set subjects_train_val_test_file in config.'
                        logger.error(msg)
                        raise RuntimeError(msg)

                    logger.info('Restoring from ' + checkpoint)
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    if saver._max_to_keep < self._cfg.checkpoint_keep + 2:
                        msg = 'ERROR: Restored saver._max_to_keep={}, but self._cfg.checkpoint_keep={}'.format(
                            saver._max_to_keep, self._cfg.checkpoint_keep)
                        print(msg)
                        logger.error(msg)
                        exit(1)

                    sess.run(epoch_checkpoint.assign_add(tf.constant(1)))
                    self._best_r2_score = best_r2_score_checkpoint.eval()
                    logger.info('Continue with epoch %i (best r2: %f)',
                                epoch_checkpoint.eval(), self._best_r2_score)
                    self._checkpoint_idx = int(
                        re.match('.*/checkpoint-(\d+).ckpt',
                                 checkpoint).group(1))

                    # load column multipliers from file if available
                    cfg_file = os.path.join(self.checkpoint_dir, 'config.json')
                    if os.path.exists(cfg_file):
                        cfg_tmp = Configuration.load(cfg_file)
                        if len(cfg_tmp.z_column_multipliers) > 0:
                            logger.info('Loading column multipliers from %s',
                                        cfg_file)
                            self._regression_column_multipliers = np.array(
                                cfg_tmp.z_column_multipliers)
                            logger.info('Label multiplier: {}'.format(
                                self._regression_column_multipliers))
                else:
                    # new training, write config to checkpoint dir
                    self.write_settings()

                summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=global_step.eval())

                for epoch in range(epoch_checkpoint.eval(),
                                   self._cfg.epochs + 1):
                    self.set_seed(epoch=epoch)

                    self.write_status('status_epoch', str(epoch))
                    epoch_start_time = timeit.default_timer()
                    loss_sum = 0
                    r2_validation = None

                    # training (enable data augmentation)
                    self._data_store.set_transforms_enabled(True)
                    for batch in training_loader:
                        print('.', end='', flush=True)

                        feed_dict = self.batch_to_feed_dict(
                            x, y, d, is_train, batch, True)

                        # perform training step
                        _, loss_step = sess.run([train_op, loss],
                                                feed_dict=feed_dict)
                        loss_sum = loss_sum + loss_step * len(batch)

                    # disable transformations (data augmentation) for validation
                    self._data_store.set_transforms_enabled(False)

                    # loss on training data
                    cost_train = loss_sum / len(self._subjects_train)

                    if epoch % self._cfg.log_num_epoch == 0:
                        # loss on validation set
                        cost_validation = self.evaluate_loss(
                            sess, loss, validation_loader, x, y, d, is_train)
                        cost_validation_str = '{:.16f}'.format(cost_validation)
                    else:
                        print()
                        cost_validation = None
                        cost_validation_str = '-'

                    logger.info(
                        'Epoch:{:4.0f}, Loss train: {:.16f}, Loss validation: {}, lr: {:.16f}, dt={:.1f}s'
                        .format(epoch, cost_train, cost_validation_str,
                                learning_rate.eval(),
                                timeit.default_timer() - epoch_start_time))

                    # don't write loss for first epoch (usually very high) to avoid scaling issue in graph
                    if epoch > 0:
                        # write summary
                        summary = tf.Summary()
                        summary.value.add(tag='loss_train',
                                          simple_value=cost_train)
                        if cost_validation:
                            summary.value.add(tag='loss_validation',
                                              simple_value=cost_validation)
                        summary_writer.add_summary(summary, epoch)

                        summary_op = tf.summary.merge([sum_lr])
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    if epoch % self._cfg.log_eval_num_epoch == 0 and epoch > 0:
                        # calculate and log R2 score on training and validation set
                        eval_start_time = timeit.default_timer()
                        predictions_train, gt_train, subjects_train = self.predict(
                            sess, net, training_loader, x, y, d, is_train)
                        predictions_validation, gt_validation, subjects_validation = self.predict(
                            sess, net, validation_loader, x, y, d, is_train)
                        r2_train = metrics.r2_score(gt_train,
                                                    predictions_train)
                        r2_validation = metrics.r2_score(
                            gt_validation, predictions_validation)
                        logger.info(
                            'Epoch:{:4.0f}, R2 train: {:.3f}, R2 validation: {:.8f}, dt={:.1f}s'
                            .format(epoch, r2_train, r2_validation,
                                    timeit.default_timer() - eval_start_time))

                        # write csv with intermediate results
                        self.write_results_csv(
                            'results_train-{0:04.0f}.csv'.format(epoch),
                            predictions_train, gt_train, subjects_train)
                        self.write_results_csv(
                            'results_validate-{0:04.0f}.csv'.format(epoch),
                            predictions_validation, gt_validation,
                            subjects_validation)

                        summary = tf.Summary()
                        # average r2
                        summary.value.add(tag='r2_train',
                                          simple_value=r2_train)
                        summary.value.add(tag='r2_validation',
                                          simple_value=r2_validation)

                        # add r2 per metric
                        for idx, col_name in enumerate(
                                self._cfg.regression_columns):
                            summary.value.add(
                                tag='train/r2_{}'.format(col_name),
                                simple_value=metrics.r2_score(
                                    gt_train[:, idx],
                                    predictions_train[:, idx],
                                    multioutput='raw_values'))
                            summary.value.add(
                                tag='validation/r2_{}'.format(col_name),
                                simple_value=metrics.r2_score(
                                    gt_validation[:, idx],
                                    predictions_validation[:, idx],
                                    multioutput='raw_values'))

                        summary_writer.add_summary(summary, epoch)

                    if epoch % self._cfg.visualize_layer_num_epoch == 0 and len(
                            sum_histograms) > 0:
                        # write histogram summaries and kernel visualization
                        summary_op = tf.summary.merge(sum_histograms)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                        if len(kernel_tensor_names) > 0:
                            for idx, kernel_name in enumerate(
                                    kernel_tensor_names):
                                # visualize weights of kernel layer from a middle slice
                                kernel_weights = graph.get_tensor_by_name(
                                    kernel_name).eval()
                                # make last axis the first
                                kernel_weights = np.moveaxis(
                                    kernel_weights, -1, 0)

                                if len(kernel_weights.shape) > 4:
                                    # 3d convolution, remove last (single) channel
                                    kernel_weights = kernel_weights[:, :, :, :,
                                                                    0]

                                if kernel_weights.shape[3] > 1:
                                    # multiple channels, take example from middle slide
                                    slice_num = int(kernel_weights.shape[3] /
                                                    2)
                                    kernel_weights = kernel_weights[:, :, :,
                                                                    slice_num:
                                                                    slice_num +
                                                                    1]

                                grid = visualize.make_grid(kernel_weights)[
                                    np.newaxis, :, :, np.newaxis]
                                sess.run(kernel_tensors[idx].assign(grid))

                            summary_op = tf.summary.merge(sum_kernels)
                            summary_str = sess.run(summary_op)
                            summary_writer.add_summary(summary_str, epoch)

                    summary_writer.flush()

                    if self._cfg.max_timelimit > 0 and (
                            timeit.default_timer() - start_time >
                            self._cfg.max_timelimit):
                        logger.info(
                            'Timelimit {}s exceeded. Stopping training...'.
                            format(self._cfg.max_timelimit))
                        self.checkpoint_safer(sess, saver, epoch_checkpoint,
                                              epoch, best_r2_score_checkpoint,
                                              True, r2_validation)
                        self.write_status('status_timeout')
                        break
                    else:
                        # epoch done
                        self.checkpoint_safer(sess, saver, epoch_checkpoint,
                                              epoch, best_r2_score_checkpoint,
                                              False, r2_validation)

                summary_writer.close()
                logger.info('Training done.')
                self.write_status('status_done')

                if self._best_r2_score > 0:
                    # restore checkpoint of best R2 score
                    checkpoint = os.path.join(self.checkpoint_dir,
                                              'checkpoint-best-r-2.ckpt')
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    logger.info(
                        'RESTORED best-r-2 checkpoint. Epoch: {}, R2: {:.8f}'.
                        format(epoch_checkpoint.eval(),
                               best_r2_score_checkpoint.eval()))

                # disable transformations (data augmentation) for test
                self._data_store.set_transforms_enabled(False)

                predictions_train, gt_train, subjects_train = self.predict(
                    sess, net, training_loader, x, y, d, is_train)
                predictions_test, gt_test, subjects_test = self.predict(
                    sess, net, testing_loader, x, y, d, is_train)

                self.write_results_csv('results_train.csv', predictions_train,
                                       gt_train, subjects_train)
                self.write_results_csv('results_test.csv', predictions_test,
                                       gt_test, subjects_test)

                # Note: use scaled metrics for MSE and unscaled (original) for R^2
                if len(gt_train) > 0:
                    accuracy_train = metrics.mean_squared_error(
                        gt_train / self._regression_column_multipliers,
                        predictions_train /
                        self._regression_column_multipliers)
                    r2_train = metrics.r2_score(gt_train, predictions_train)
                else:
                    accuracy_train = 0
                    r2_train = 0

                if len(subjects_test) > 0:
                    accuracy_test = metrics.mean_squared_error(
                        gt_test / self._regression_column_multipliers,
                        predictions_test / self._regression_column_multipliers)
                    r2_test = metrics.r2_score(gt_test, predictions_test)

                    s, _ = analyze_score.print_summary(
                        subjects_test, self._cfg.regression_columns,
                        predictions_test, gt_test)
                    logger.info('Summary:\n%s-------', s)

                    logger.info(
                        'TRAIN accuracy(mse): {:.8f}, r2: {:.8f}'.format(
                            accuracy_train, r2_train))
                    logger.info(
                        'TEST  accuracy(mse): {:.8f}, r2: {:.8f}'.format(
                            accuracy_test, r2_test))

                visualize.make_kernel_gif(self.checkpoint_dir,
                                          kernel_tensor_names)
Exemplo n.º 4
0
def main_predict(cfg_file, checkpoint_file, hdf_file, subjects_file, out_csv):
    cfg = Configuration.load(cfg_file)

    trainer = training.Trainer(cfg, 0)
    data_store = data_storage.DataStore(hdf_file if hdf_file else cfg.hdf_file)

    if subjects_file:
        with open(subjects_file, 'r') as file:
            subjects = [s.rstrip() for s in file.readlines()]
    else:
        subjects = [s['subject'] for s in data_store.dataset]

    validation_loader = data_store.get_loader(cfg.batch_size_eval, subjects, 0)
    validation_extractor = miapy_extr.ComposeExtractor([
        miapy_extr.DataExtractor(),
        miapy_extr.SelectiveDataExtractor(
            category=data_storage.STORE_MORPHOMETRICS),
        miapy_extr.SubjectExtractor(),
        data_storage.DemographicsExtractor()
    ])

    data_store.dataset.set_extractor(validation_extractor)

    column_values, column_names = data_store.get_all_metrics()
    trainer._regression_column_ids = np.array(
        [column_names.index(name) for name in cfg.regression_columns])
    trainer._regression_column_multipliers = np.array(cfg.z_column_multipliers)

    with tf.Graph().as_default() as graph:
        init = tf.global_variables_initializer()

        with tf.Session() as sess:
            sess.run(init)

            print('Using checkpoint {}'.format(checkpoint_file))
            saver = tf.train.import_meta_graph(checkpoint_file + '.meta')
            saver.restore(sess, checkpoint_file)

            net = graph.get_tensor_by_name('NET/model:0')
            x_placeholder = graph.get_tensor_by_name('x:0')
            y_placeholder = graph.get_tensor_by_name('y:0')
            d_placeholder = graph.get_tensor_by_name('d:0')
            is_train_placeholder = graph.get_tensor_by_name('is_train:0')
            epoch_checkpoint = graph.get_tensor_by_name('epoch:0')

            print('Epoch from checkpoint: {}'.format(epoch_checkpoint.eval()))

            predictions, gt, pred_subjects = trainer.predict(
                sess, net, validation_loader, x_placeholder, y_placeholder,
                d_placeholder, is_train_placeholder)

            trainer.write_results_csv(out_csv, predictions, gt, pred_subjects)
            if len(pred_subjects) > 1:
                s, _ = analyze_score.print_summary(subjects,
                                                   cfg.regression_columns,
                                                   predictions, gt)
                print(s)

            if len(pred_subjects) != len(subjects):
                print(
                    "WARN: Number of subjects in predictions ({}) != given ({})"
                    .format(len(pred_subjects), len(subjects)))
Exemplo n.º 5
0
    def train(self):
        self.set_seed(epoch=0)

        transform = self.get_transform()
        self._data_store = data_storage.DataStore(self._cfg.hdf_file,
                                                  transform)
        dataset = self._data_store.dataset

        self.assign_subjects()

        # prepare loaders and extractors
        training_loader = self._data_store.get_loader(self._cfg.batch_size,
                                                      self._subjects_train,
                                                      self._num_workers)
        validation_loader = self._data_store.get_loader(
            self._cfg.batch_size_eval, self._subjects_validate,
            self._num_workers)
        testing_loader = self._data_store.get_loader(self._cfg.batch_size_eval,
                                                     self._subjects_test,
                                                     self._num_workers)

        # train_extractor = miapy_extr.ComposeExtractor(
        #     [miapy_extr.DataExtractor(categories=('images',)),
        #      miapy_extr.DataExtractor(entire_subject=True, categories=('age', 'resectionstatus', 'volncr', 'voled', 'volet', 'etrimwidth', 'etgeomhet',
        #                                                                'rim_q1_clipped', 'rim_q2_clipped', 'rim_q3_clipped')),
        #      miapy_extr.NamesExtractor(categories=('images',)),
        #      miapy_extr.SubjectExtractor()])

        train_extractor = miapy_extr.ComposeExtractor([
            miapy_extr.DataExtractor(categories=('images', )),
            #miapy_extr.DataExtractor(entire_subject=True, categories=('age', 'resectionstatus', 'survival', 'volncr', 'voled', 'volet', 'etrimwidth', 'etgeomhet',
            #                                                          'rim_q1_clipped', 'rim_q2_clipped', 'rim_q3_clipped')),
            #miapy_extr.DataExtractor(entire_subject=True, categories=(
            #'age', 'resectionstatus', 'survival', 'volncr', 'voled', 'volet')),
            miapy_extr.DataExtractor(ignore_indexing=True,
                                     categories=('age', 'resectionstatus',
                                                 'survival')),
            miapy_extr.NamesExtractor(categories=('images', )),
            miapy_extr.SubjectExtractor()
        ])

        dataset.set_extractor(train_extractor)

        # # read all labels to calculate multiplier
        # column_values, column_names = self._data_store.get_all_metrics()
        # self._regression_column_ids = np.array([column_names.index(name) for name in self._cfg.regression_columns])
        # self._regression_column_multipliers = np.max(column_values[:, self._regression_column_ids], axis=0)

        # alexnet.SCALE = float(self._data_store.get_intensity_scale_max())

        n_batches = int(
            np.ceil(len(self._subjects_train) / self._cfg.batch_size))

        logger.info('Net: {}, scale: {}'.format(
            inspect.getsource(self.get_python_obj(self._cfg.model)),
            convnet.SCALE))
        logger.info('Train: {}, Validation: {}, Test: {}'.format(
            len(self._subjects_train), len(self._subjects_validate),
            len(self._subjects_test)))
        logger.info('n_batches: {}'.format(n_batches))
        logger.info(self._cfg)
        logger.info('checkpoints dir: {}'.format(self.checkpoint_dir))

        shape = dataset.direct_extract(
            train_extractor,
            0)['images'].shape  # extract a subject to obtain shape
        print("Shape: " + str(shape))

        with tf.Graph().as_default() as graph:
            self.set_seed(epoch=0)  # set again as seed is per graph

            x = tf.placeholder(tf.float32, (None, ) + shape, name='x')
            y = tf.placeholder(tf.float32, (None, 1), name='y')
            # y = tf.placeholder(tf.float32, (None,) + shape_y, name='y')
            d = tf.placeholder(tf.float32, (None, 2),
                               name='d')  # age, resectionstate
            is_train = tf.placeholder(tf.bool, shape=(), name='is_train')

            global_step = tf.train.get_or_create_global_step()
            epoch_checkpoint = tf.Variable(0, name='epoch')
            best_r2_score_checkpoint = tf.Variable(0.0,
                                                   name='best_r2_score',
                                                   dtype=tf.float64)

            # net_full = self.get_python_obj(self._cfg.model)({'x': x, 'y': y, 'd': d, 'is_train': is_train})
            net = self.get_python_obj(self._cfg.model)({
                'x': x,
                'y': y,
                'd': d,
                'is_train': is_train
            })  #["reg"]
            print(net)
            print("%%%%%%%%%%% blabla %%%%%%%%%%%%%%%%%")
            print(y.shape)
            print(y)
            loss = tf.losses.mean_squared_error(labels=net["reg"],
                                                predictions=y)
            learning_rate = None
            optimizer = None
            if self._cfg.optimizer == 'SGD':
                learning_rate = tf.train.exponential_decay(
                    self._cfg.learning_rate, global_step,
                    self._cfg.learning_rate_decay_steps,
                    self._cfg.learning_rate_decay_rate)
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
                lr_get = lambda: optimizer._learning_rate.eval()
            elif self._cfg.optimizer == 'Adam':
                learning_rate = tf.Variable(self._cfg.learning_rate, name='lr')
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self._cfg.learning_rate)
                lr_get = lambda: optimizer._lr

            with tf.control_dependencies(
                    tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS)):  # required for batch_norm
                train_op = optimizer.minimize(loss=loss,
                                              global_step=global_step)

            sum_lr = tf.summary.scalar('learning_rate', learning_rate)

            # collect variables to add to summaries (histograms, kernel weight images)
            sum_histograms = []
            sum_kernels = []
            kernel_tensors = []
            kernel_tensor_names = []
            for v in tf.global_variables():
                m = re.match('NET/(layer\d+_\w+)/(kernel|bias):0', v.name)
                if m:
                    var = graph.get_tensor_by_name(v.name)
                    sum_histograms.append(
                        tf.summary.histogram(
                            '{}/{}'.format(m.group(1), m.group(2)), var))

                    if m.group(2) == 'kernel' and m.group(1).endswith('conv'):
                        # if m.group(2) == 'kernel' and (m.group(1).find('conv') != -1):
                        # if (m.group(1).find('conv') != -1):
                        kernel_tensor_names.append(v.name)
                        print(kernel_tensor_names)
                        h, w = visualize.get_grid_size(
                            var.get_shape().as_list())
                        img = tf.Variable(tf.zeros([1, h, w, 1]))
                        kernel_tensors.append(img)
                        sum_kernels.append(tf.summary.image(m.group(1), img))

            summary_writer = tf.summary.FileWriter(self.checkpoint_dir,
                                                   tf.get_default_graph())

            init = tf.global_variables_initializer()
            saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(init)

                checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)
                if checkpoint:
                    # existing checkpoint found, restoring...
                    if not self._cfg.subjects_train_val_test_file:
                        msg = 'Continue training, but no fixed subject assignments found. ' \
                              'Set subjects_train_val_test_file in config.'
                        logger.error(msg)
                        raise RuntimeError(msg)

                    logger.info('Restoring from ' + checkpoint)
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    sess.run(epoch_checkpoint.assign_add(tf.constant(1)))
                    self._best_r2_score = best_r2_score_checkpoint.eval()
                    logger.info('Continue with epoch %i (best r2: %f)',
                                epoch_checkpoint.eval(), self._best_r2_score)
                    self._checkpoint_idx = int(
                        re.match('.*/checkpoint-.*(\d+).ckpt',
                                 checkpoint).group(1))

                    # # load column multipliers from file if available
                    # cfg_file = os.path.join(self.checkpoint_dir, 'config.json')
                    # if os.path.exists(cfg_file):
                    #     cfg_tmp = Configuration.load(cfg_file)
                    #     if len(cfg_tmp.z_column_multipliers) > 0:
                    #         self._regression_column_multipliers = np.array(cfg_tmp.z_column_multipliers)
                else:
                    # new training, write config to checkpoint dir
                    self.write_settings()

                summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=global_step.eval())

                for epoch in range(epoch_checkpoint.eval(),
                                   self._cfg.epochs + 1):
                    self.set_seed(epoch=epoch)

                    epoch_start_time = timeit.default_timer()
                    loss_sum = 0
                    r2_validation = None
                    spearmanR_validation = -10

                    # training (enable data augmentation)
                    self._data_store.set_transforms_enabled(True)
                    for batch in training_loader:
                        print('.', end='', flush=True)

                        feed_dict = self.batch_to_feed_dict(
                            x, y, d, is_train, batch, True)

                        # perform training step
                        _, loss_step = sess.run([train_op, loss],
                                                feed_dict=feed_dict)
                        loss_sum = loss_sum + loss_step * len(batch)

                    # disable transformations (data augmentation) for validation
                    self._data_store.set_transforms_enabled(False)

                    # loss on training data
                    cost_train = loss_sum / len(self._subjects_train)

                    if epoch % self._cfg.log_num_epoch == 0:
                        # loss on validation set
                        cost_validation = self.evaluate_loss(
                            sess, loss, validation_loader, x, y, d, is_train)
                        cost_validation_str = '{:.16f}'.format(cost_validation)

                    else:
                        print()
                        cost_validation = None
                        cost_validation_str = '-'

                    logger.info(
                        'Epoch:{:4.0f}, Loss train: {:.16f}, Loss validation: {}, lr: {:.16f}, dt={:.1f}s'
                        .format(epoch, cost_train, cost_validation_str,
                                lr_get(),
                                timeit.default_timer() - epoch_start_time))

                    # don't write loss for first epoch (usually very high) to avoid scaling issue in graph
                    if epoch > 0:
                        # write summary
                        summary = tf.Summary()
                        summary.value.add(tag='loss_train',
                                          simple_value=cost_train)
                        if cost_validation:
                            summary.value.add(tag='loss_validation',
                                              simple_value=cost_validation)
                        summary_writer.add_summary(summary, epoch)

                        summary_op = tf.summary.merge([sum_lr])
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    if epoch % self._cfg.log_eval_num_epoch == 0 and epoch > 0:
                        # calculate and log R2 score on training and validation set
                        eval_start_time = timeit.default_timer()
                        predictions_train, gt_train, subjects_train, _, _ = self.predict(
                            sess, net, training_loader, x, y, d, is_train)
                        predictions_validation, gt_validation, subjects_validation, _, _ = self.predict(
                            sess, net, validation_loader, x, y, d, is_train)
                        r2_train = metrics.r2_score(gt_train,
                                                    predictions_train)
                        r2_validation = metrics.r2_score(
                            gt_validation, predictions_validation)
                        spearmanR_train, _ = scipy.stats.spearmanr(
                            gt_train, predictions_train)
                        spearmanR_validation, _ = scipy.stats.spearmanr(
                            gt_validation, predictions_validation)

                        gt_train_class = assign_class(gt_train)
                        predictions_train_class = assign_class(
                            np.squeeze(predictions_train))
                        gt_validation_class = assign_class(gt_validation)
                        predictions_validation_class = assign_class(
                            np.squeeze(predictions_validation))

                        trainCorrectClass = np.equal(gt_train_class,
                                                     predictions_train_class)
                        valCorrectClass = np.equal(
                            gt_validation_class, predictions_validation_class)

                        traincorrect = np.sum(trainCorrectClass) / len(
                            trainCorrectClass)
                        valcorrect = np.sum(valCorrectClass) / len(
                            valCorrectClass)

                        logger.info(
                            'Epoch:{:4.0f}, R2 train: {:.3f}, R2 validation: {:.8f}, sRho train: {:.3f}, sRho validation: {:.3f}, cl.acc training: {:.1%}, cl.acc validation: {:.1%},  dt={:.1f}s'
                            .format(epoch, r2_train, r2_validation,
                                    spearmanR_train, spearmanR_validation,
                                    traincorrect, valcorrect,
                                    timeit.default_timer() - eval_start_time))

                        # write csv with intermediate results
                        self.write_results_csv(
                            'results_train-{0:04.0f}.csv'.format(epoch),
                            predictions_train, gt_train, subjects_train)
                        self.write_results_csv(
                            'results_validate-{0:04.0f}.csv'.format(epoch),
                            predictions_validation, gt_validation,
                            subjects_validation)

                        summary = tf.Summary()
                        summary.value.add(tag='r2_train',
                                          simple_value=r2_train)
                        summary.value.add(tag='r2_validation',
                                          simple_value=r2_validation)
                        summary.value.add(tag='SpearmanRho_train',
                                          simple_value=spearmanR_train)
                        summary.value.add(tag='SpearmanRho_validation',
                                          simple_value=spearmanR_validation)
                        summary.value.add(tag='Classification_Accuracy_train',
                                          simple_value=traincorrect)
                        summary.value.add(
                            tag='Classification_Accuracy_validation',
                            simple_value=valcorrect)

                        summary_writer.add_summary(summary, epoch)

                    #if epoch % self._cfg.log_num_epoch == 0 and epoch > 0:
                    if epoch % 1 == 0 and epoch > 0:
                        # plot prediction vs. ground truth on training and validation set
                        plt.ioff()

                        gt_train_class = assign_class(gt_train)
                        predictions_train_class = assign_class(
                            np.squeeze(predictions_train))
                        gt_validation_class = assign_class(gt_validation)
                        predictions_validation_class = assign_class(
                            np.squeeze(predictions_validation))
                        spearmanR_validation, _ = scipy.stats.spearmanr(
                            gt_validation, predictions_validation)
                        trainCorrectClass = np.equal(gt_train_class,
                                                     predictions_train_class)
                        valCorrectClass = np.equal(
                            gt_validation_class, predictions_validation_class)

                        traincorrect = np.sum(trainCorrectClass) / len(
                            trainCorrectClass)
                        valcorrect = np.sum(valCorrectClass) / len(
                            valCorrectClass)

                        p0 = (sns.jointplot(
                            gt_train,
                            np.squeeze(predictions_train),
                            xlim=(0, np.max(np.append(gt_train,
                                                      gt_validation))),
                            ylim=(0, np.max(np.append(gt_train,
                                                      gt_validation))),
                            kind="reg",
                            stat_func=metrics.r2_score).set_axis_labels(
                                "GT", "Prediction"))
                        p1 = (sns.jointplot(
                            gt_validation,
                            np.squeeze(predictions_validation),
                            xlim=(0, np.max(np.append(gt_train,
                                                      gt_validation))),
                            ylim=(0, np.max(np.append(gt_train,
                                                      gt_validation))),
                            kind="reg",
                            stat_func=metrics.r2_score).set_axis_labels(
                                "GT", "Prediction"))

                        p0.ax_joint.set_title(
                            'Training \n Accuracy: {:.1%}'.format(
                                traincorrect),
                            pad=-18)
                        p1.ax_joint.set_title(
                            'Validation \n Accuracy: {:.1%}'.format(
                                valcorrect),
                            pad=-18)

                        fig = plt.figure(figsize=(16, 8))
                        gs = gridspec.GridSpec(1, 2)

                        mg0 = SeabornFig2Grid(p0, fig, gs[0])
                        mg1 = SeabornFig2Grid(p1, fig, gs[1])

                        gs.tight_layout(fig)

                        # gs.update(top=0.7)
                        plt.suptitle("Epoch " + str(epoch))
                        plt.savefig(
                            os.path.join(
                                self._plotdir, self._cfg.model + "epoch_" +
                                str(epoch).zfill(4) + ".png"))
                        plt.close(fig)
                        print('Regression plot saved.')

                        # predictions_train, gt_train, subjects_train, fc100_train, fc20_train = self.predict(sess, net,
                        #                                                                                     training_loader,
                        #                                                                                     x, y, d,
                        #                                                                                     is_train)
                        # predictions_test, gt_test, subjects_test, fc100_test, fc20_test = self.predict(sess, net,
                        #                                                                                testing_loader,
                        #                                                                                x, y, d,
                        #                                                                                is_train)
                        #
                        # print(fc100_test[0])
                        # print(fc100_train[0])
                        # print(fc20_test[0])
                        # print(fc20_train[0])
                        #
                        # print("?&&&&&&&&&&&&&&")
                        #
                        # print(len(fc100_test))
                        # print(len(fc100_train))
                        # print(len(fc20_test))
                        # print(len(fc20_train))
                        #
                        # print("?&&&&&&&&&&&&&&")
                        # print(fc100_test[0].shape)
                        # print(fc100_train[0].shape)
                        # print(fc20_test[0].shape)
                        # print(fc20_train[0].shape)
                        #
                        # self.write_results_csv('results_train.csv', predictions_train, gt_train, subjects_train)
                        # self.write_TESTresults_csv('results_test.csv', predictions_test, subjects_test)
                        #
                        # self.write_deepfeat_csv('fc100_test.csv', fc100_test[0], subjects_test)
                        # self.write_deepfeat_csv('fc20_test.csv', fc20_test[0], subjects_test)
                        #
                        # self.write_deepfeat_csv('fc100_train.csv', fc100_train[0], subjects_train)
                        # self.write_deepfeat_csv('fc20_train.csv', fc20_train[0], subjects_train)

                    if epoch % self._cfg.visualize_layer_num_epoch == 0 and len(
                            sum_histograms) > 0:
                        # write histogram summaries and kernel visualization
                        summary_op = tf.summary.merge(sum_histograms)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                        for idx, kernel_name in enumerate(kernel_tensor_names):
                            # visualize weights of kernel layer from a middle slice
                            kernel_weights = graph.get_tensor_by_name(
                                kernel_name).eval()
                            # make last axis the first
                            kernel_weights = np.moveaxis(kernel_weights, -1, 0)

                            if len(kernel_weights.shape) > 4:
                                # 3d convolution, remove last (single) channel
                                kernel_weights = kernel_weights[:, :, :, :, 0]

                            if kernel_weights.shape[3] > 1:
                                # multiple channels, take example from middle slide
                                slice_num = int(kernel_weights.shape[3] / 2)
                                kernel_weights = kernel_weights[:, :, :,
                                                                slice_num:
                                                                slice_num + 1]

                            grid = visualize.make_grid(kernel_weights)[
                                np.newaxis, :, :, np.newaxis]
                            sess.run(kernel_tensors[idx].assign(grid))

                        # summary_op = tf.summary.merge(sum_kernels)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    summary_writer.flush()

                    # epoch done
                    self.checkpoint_safer(sess, saver, epoch_checkpoint, epoch,
                                          best_r2_score_checkpoint,
                                          spearmanR_validation)

                summary_writer.close()
                logger.info('Training done.')

                if self._best_r2_score > 0:
                    # restore checkpoint of best R2 score
                    checkpoint = os.path.join(self.checkpoint_dir,
                                              'checkpoint-best-r-2.ckpt')
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    logger.info(
                        'RESTORED best-r-2 checkpoint. Epoch: {}, R2: {:.8f}'.
                        format(epoch_checkpoint.eval(),
                               best_r2_score_checkpoint.eval()))

                # disable transformations (data augmentation) for test
                self._data_store.set_transforms_enabled(False)

                predictions_train, gt_train, subjects_train, fc100_train, fc20_train = self.predict(
                    sess, net, training_loader, x, y, d, is_train)
                predictions_val, gt_val, subjects_val, fc100_val, fc20_val = self.predict(
                    sess, net, validation_loader, x, y, d, is_train)
                predictions_test, gt_test, subjects_test, fc100_test, fc20_test = self.predict(
                    sess, net, testing_loader, x, y, d, is_train)

                print(subjects_test)
                print(len(subjects_test))
                print(predictions_test)

                self.write_results_csv(
                    'results_train_epoch' + str(epoch_checkpoint.eval()) +
                    '.csv', predictions_train, gt_train, subjects_train)
                self.write_TESTresults_csv(
                    'results_test_epoch' + str(epoch_checkpoint.eval()) +
                    '.csv', predictions_test, subjects_test)

                self.write_deepfeat_csv(
                    'fc100_test_epoch' + str(epoch_checkpoint.eval()) + '.csv',
                    fc100_test, subjects_test)
                self.write_deepfeat_csv(
                    'fc20_test_epoch' + str(epoch_checkpoint.eval()) + '.csv',
                    fc20_test, subjects_test)

                self.write_deepfeat_csv(
                    'fc100_val_epoch' + str(epoch_checkpoint.eval()) + '.csv',
                    fc100_val, subjects_val)
                self.write_deepfeat_csv(
                    'fc20_val_epoch' + str(epoch_checkpoint.eval()) + '.csv',
                    fc20_val, subjects_val)

                self.write_deepfeat_csv(
                    'fc100_train_epoch' + str(epoch_checkpoint.eval()) +
                    '.csv', fc100_train, subjects_train)
                self.write_deepfeat_csv(
                    'fc20_train_epoch' + str(epoch_checkpoint.eval()) + '.csv',
                    fc20_train, subjects_train)

                # # save last conv-layer output as image for inspection
                # convoutimg_test = sitk.GetImageFromArray(np.swapaxes(conv3out_test[0],0,2))
                # convoutimg_train = sitk.GetImageFromArray(np.swapaxes(conv3out_train[0],0,2))
                #
                # sitk.WriteImage(convoutimg_test, os.path.join(self.checkpoint_dir, 'convout3_test.nii.gz'))
                # sitk.WriteImage(convoutimg_train, os.path.join(self.checkpoint_dir, 'convout3_train.nii.gz'))

                gt_train_class = assign_class(gt_train)
                predictions_train_class = assign_class(
                    np.squeeze(predictions_train))
                #gt_test_class = assign_class(gt_test)
                predictions_test_class = assign_class(
                    np.squeeze(predictions_test))

                # Note: use scaled metrics for MSE and unscaled (original) for R^2
                if len(gt_train) > 0:
                    accuracy_train = metrics.mean_squared_error(
                        gt_train, predictions_train)
                    r2_train = metrics.r2_score(gt_train, predictions_train)

                    spearmanR_train, _ = scipy.stats.spearmanr(
                        gt_train, predictions_train)
                    trainCorrectClass = np.equal(gt_train_class,
                                                 predictions_train_class)
                    traincorrect = np.sum(trainCorrectClass) / len(
                        trainCorrectClass)

                if len(gt_train) == 0:
                    accuracy_train = 0
                    r2_train = 0
                    traincorrect = 0

                    #if len(gt_test) > 0:
                    accuracy_test = metrics.mean_squared_error(
                        gt_test, predictions_test)
                    r2_test = metrics.r2_score(gt_test, predictions_test)
                    #spearmanR_test, _ = scipy.stats.spearmanr(gt_test, predictions_test)
                    #testCorrectClass = np.equal(gt_test_class, predictions_test_class)
                    #testcorrect = np.sum(testCorrectClass) / len(testCorrectClass)

                else:
                    accuracy_test = 0
                    r2_test = 0
                    spearmanR_test = 0
                    testcorrect = 0

                #s = analyze_score.print_summary(subjects_test, ['survival'], predictions_test, gt_test)
                #logger.info('Summary:\n%s-------', s)

                logger.info(
                    'TRAIN accuracy(mse): {:.8f}, r2: {:.8f}, Spearman Rho: {:.8f}, Classification Accuracy: {:.1%}'
                    .format(accuracy_train, r2_train, spearmanR_train,
                            traincorrect))
    def train(self):
        self.set_seed(epoch=0)

        transform = self.get_transform()
        self._data_store = data_storage.DataStore(self._cfg.hdf_file,
                                                  transform)
        dataset = self._data_store.dataset

        self.assign_subjects()

        # prepare loaders and extractors
        training_loader = self._data_store.get_loader(self._cfg.batch_size,
                                                      self._subjects_train,
                                                      self._num_workers)
        validation_loader = self._data_store.get_loader(
            self._cfg.batch_size_eval, self._subjects_validate,
            self._num_workers)
        testing_loader = self._data_store.get_loader(self._cfg.batch_size_eval,
                                                     self._subjects_test,
                                                     self._num_workers)

        # train_extractor = miapy_extr.ComposeExtractor(
        #     [miapy_extr.DataExtractor(categories=('images',)),
        #      miapy_extr.DataExtractor(entire_subject=True, categories=('age', 'resectionstatus', 'volncr', 'voled', 'volet', 'etrimwidth', 'etgeomhet',
        #                                                                'rim_q1_clipped', 'rim_q2_clipped', 'rim_q3_clipped')),
        #      miapy_extr.NamesExtractor(categories=('images',)),
        #      miapy_extr.SubjectExtractor()])

        train_extractor = miapy_extr.ComposeExtractor([
            miapy_extr.DataExtractor(categories=('images', )),
            #miapy_extr.DataExtractor(entire_subject=True, categories=('age', 'resectionstatus', 'survival', 'volncr', 'voled', 'volet', 'etrimwidth', 'etgeomhet',
            #                                                          'rim_q1_clipped', 'rim_q2_clipped', 'rim_q3_clipped')),
            #miapy_extr.DataExtractor(entire_subject=True, categories=(
            #'age', 'resectionstatus', 'survival', 'volncr', 'voled', 'volet')),
            # miapy_extr.DataExtractor(entire_subject=True, categories=(
            #     'age', 'resectionstatus', 'survclass', 'volncr', 'voled', 'volet', 'etrimwidth', 'etgeomhet',
            #     'rim_q1_clipped', 'rim_q2_clipped', 'rim_q3_clipped')),
            miapy_extr.DataExtractor(entire_subject=True,
                                     categories=('age', 'resectionstatus',
                                                 'survclass')),
            miapy_extr.NamesExtractor(categories=('images', )),
            miapy_extr.SubjectExtractor()
        ])

        dataset.set_extractor(train_extractor)

        # # read all labels to calculate multiplier
        # column_values, column_names = self._data_store.get_all_metrics()
        # self._regression_column_ids = np.array([column_names.index(name) for name in self._cfg.regression_columns])
        # self._regression_column_multipliers = np.max(column_values[:, self._regression_column_ids], axis=0)

        # alexnet.SCALE = float(self._data_store.get_intensity_scale_max())

        n_batches = int(
            np.ceil(len(self._subjects_train) / self._cfg.batch_size))

        logger.info('Net: {}, scale: {}'.format(
            inspect.getsource(self.get_python_obj(self._cfg.model)),
            alexnet.SCALE))
        logger.info('Train: {}, Validation: {}, Test: {}'.format(
            len(self._subjects_train), len(self._subjects_validate),
            len(self._subjects_test)))
        logger.info('n_batches: {}'.format(n_batches))
        logger.info(self._cfg)
        logger.info('checkpoints dir: {}'.format(self.checkpoint_dir))

        shape = dataset.direct_extract(
            train_extractor,
            0)['images'].shape  # extract a subject to obtain shape
        print("Shape: " + str(shape))
        shape_y = dataset.direct_extract(
            train_extractor,
            0)['survclass'].shape  # extract a subject to obtain shape
        #print("Shape_y: " + str(shape_y))

        with tf.Graph().as_default() as graph:
            self.set_seed(epoch=0)  # set again as seed is per graph

            x = tf.placeholder(tf.float32, (None, ) + shape, name='x')
            y = tf.placeholder(tf.float32, (None, ) + shape_y, name='y')
            d = tf.placeholder(tf.float32, (None, 2),
                               name='d')  # age, resectionstate
            is_train = tf.placeholder(tf.bool, shape=(), name='is_train')

            global_step = tf.train.get_or_create_global_step()
            epoch_checkpoint = tf.Variable(0, name='epoch')
            # best_r2_score_checkpoint = tf.Variable(0.0, name='best_r2_score', dtype=tf.float64)
            best_xent_checkpoint = tf.Variable(0.0,
                                               name='best_xent',
                                               dtype=tf.float64)

            net = self.get_python_obj(self._cfg.model)({
                'x': x,
                'y': y,
                'd': d,
                'is_train': is_train
            })
            #print(net["classes"])
            #print(net["probabilities"])
            #print(net["logits"])
            #logits_test = tf.Session().run(net["logits"])
            #print(logits_test)
            #loss = tf.losses.softmax_cross_entropy(onehot_labels=tf.one_hot(tf.cast(y, tf.uint8), 3, on_value=1, off_value=0), logits=net["logits"])
            # loss = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=net)
            loss = tf.losses.softmax_cross_entropy(
                onehot_labels=y,
                logits=net["logits"],
                weights=1.0,
                label_smoothing=0,
                scope=None,
                loss_collection=tf.GraphKeys.LOSSES,
                reduction=tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
            #tf.losses.softmax_cross_entropy(y, logits=net["logits"])
            learning_rate = None
            optimizer = None
            if self._cfg.optimizer == 'SGD':
                learning_rate = tf.train.exponential_decay(
                    self._cfg.learning_rate, global_step,
                    self._cfg.learning_rate_decay_steps,
                    self._cfg.learning_rate_decay_rate)
                optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
                lr_get = lambda: optimizer._learning_rate.eval()
            elif self._cfg.optimizer == 'Adam':
                learning_rate = tf.Variable(self._cfg.learning_rate, name='lr')
                optimizer = tf.train.AdamOptimizer(
                    learning_rate=self._cfg.learning_rate)
                lr_get = lambda: optimizer._lr

            with tf.control_dependencies(
                    tf.get_collection(
                        tf.GraphKeys.UPDATE_OPS)):  # required for batch_norm
                train_op = optimizer.minimize(loss=loss,
                                              global_step=global_step)

            sum_lr = tf.summary.scalar('learning_rate', learning_rate)

            # collect variables to add to summaries (histograms, kernel weight images)
            sum_histograms = []
            sum_kernels = []
            kernel_tensors = []
            kernel_tensor_names = []
            for v in tf.global_variables():
                m = re.match('NET/(layer\d+_\w+)/(kernel|bias):0', v.name)
                if m:
                    var = graph.get_tensor_by_name(v.name)
                    sum_histograms.append(
                        tf.summary.histogram(
                            '{}/{}'.format(m.group(1), m.group(2)), var))

                    if m.group(2) == 'kernel' and m.group(1).endswith('conv'):
                        kernel_tensor_names.append(v.name)
                        h, w = visualize.get_grid_size(
                            var.get_shape().as_list())
                        img = tf.Variable(tf.zeros([1, h, w, 1]))
                        kernel_tensors.append(img)
                        sum_kernels.append(tf.summary.image(m.group(1), img))

            summary_writer = tf.summary.FileWriter(self.checkpoint_dir,
                                                   tf.get_default_graph())

            init = tf.global_variables_initializer()
            saver = tf.train.Saver()

            with tf.Session() as sess:
                sess.run(init)

                checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)
                if checkpoint:
                    # existing checkpoint found, restoring...
                    if not self._cfg.subjects_train_val_test_file:
                        msg = 'Continue training, but no fixed subject assignments found. ' \
                              'Set subjects_train_val_test_file in config.'
                        logger.error(msg)
                        raise RuntimeError(msg)

                    logger.info('Restoring from ' + checkpoint)
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    sess.run(epoch_checkpoint.assign_add(tf.constant(1)))
                    self._xent = best_xent_checkpoint.eval()
                    #logger.info('Continue with epoch %i (best r2: %f)', epoch_checkpoint.eval(), self._best_r2_score)
                    logger.info('Continue with epoch %i (best xent: %f)',
                                epoch_checkpoint.eval(), self._xent)
                    self._checkpoint_idx = int(
                        re.match('.*/checkpoint-.*(\d+).ckpt',
                                 checkpoint).group(1))

                    # load column multipliers from file if available
                    cfg_file = os.path.join(self.checkpoint_dir, 'config.json')
                    # if os.path.exists(cfg_file):
                    #     cfg_tmp = Configuration.load(cfg_file)
                    #     if len(cfg_tmp.z_column_multipliers) > 0:
                    #         self._regression_column_multipliers = np.array(cfg_tmp.z_column_multipliers)
                else:
                    # new training, write config to checkpoint dir
                    self.write_settings()

                summary_writer.add_session_log(
                    tf.SessionLog(status=tf.SessionLog.START),
                    global_step=global_step.eval())

                for epoch in range(epoch_checkpoint.eval(),
                                   self._cfg.epochs + 1):
                    self.set_seed(epoch=epoch)

                    epoch_start_time = timeit.default_timer()
                    loss_sum = 0
                    r2_validation = None

                    # training (enable data augmentation)
                    self._data_store.set_transforms_enabled(True)
                    for batch in training_loader:
                        print('.', end='', flush=True)

                        #print("y type: " + str(y))
                        feed_dict = self.batch_to_feed_dict(
                            x, y, d, is_train, batch, True)

                        # perform training step
                        _, loss_step = sess.run([train_op, loss],
                                                feed_dict=feed_dict)
                        #print("loss_sum: " +str(loss_sum))
                        # print("loss_step: " +str(np.mean(loss_step)))
                        loss_sum = loss_sum + (loss_step) * len(batch)

                    # disable transformations (data augmentation) for validation
                    self._data_store.set_transforms_enabled(False)

                    # loss on training data
                    cost_train = loss_sum / len(self._subjects_train)

                    if epoch % self._cfg.log_num_epoch == 0:
                        # loss on validation set
                        #cost_validation = np.mean(self.evaluate_loss(sess, loss, validation_loader, x, y, d, is_train))
                        cost_validation = self.evaluate_loss(
                            sess, loss, validation_loader, x, y, d, is_train)
                        cost_validation_str = '{:.16f}'.format(cost_validation)

                    else:
                        print()
                        cost_validation = None
                        cost_validation_str = '-'

                    logger.info(
                        'Epoch:{:4.0f}, Loss train: {:.16f}, Loss validation: {}, lr: {:.16f}, dt={:.1f}'
                        .format(epoch, cost_train, cost_validation_str,
                                lr_get(),
                                timeit.default_timer() - epoch_start_time))

                    # don't write loss for first epoch (usually very high) to avoid scaling issue in graph
                    if epoch > 0:
                        # write summary
                        summary = tf.Summary()
                        summary.value.add(tag='loss_train',
                                          simple_value=cost_train)
                        if cost_validation:
                            summary.value.add(tag='loss_validation',
                                              simple_value=cost_validation)
                        summary_writer.add_summary(summary, epoch)

                        summary_op = tf.summary.merge([sum_lr])
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    if epoch % self._cfg.log_eval_num_epoch == 0 and epoch > 0:
                        # calculate and log R2 score on training and validation set
                        eval_start_time = timeit.default_timer()
                        predictions_train, gt_train, subjects_train = self.predict(
                            sess, net, training_loader, x, y, d, is_train)
                        predictions_validation, gt_validation, subjects_validation = self.predict(
                            sess, net, validation_loader, x, y, d, is_train)

                        #predictions_train_class = predictions_train['classes']
                        #predictions_validation_class = predictions_validation['classes']

                        trainCorrectClass = np.equal(gt_train,
                                                     predictions_train)
                        valCorrectClass = np.equal(gt_validation,
                                                   predictions_validation)

                        traincorrect = np.sum(trainCorrectClass) / len(
                            trainCorrectClass)
                        valcorrect = np.sum(valCorrectClass) / len(
                            valCorrectClass)

                        # confusion matrix
                        confmat_train = metrics.confusion_matrix(
                            gt_train, predictions_train)
                        # precision_train = np.empty(3)
                        # precision_train[:] = np.nan
                        # recall_train = np.empty(3)
                        # recall_train[:] = np.nan

                        confmat_val = metrics.confusion_matrix(
                            gt_validation, predictions_validation)
                        # precision_val = np.empty(3)
                        # precision_val[:] = np.nan
                        # recall_val = np.empty(3)
                        # recall_val[:] = np.nan
                        # for cl in range(0,3):
                        #     precision_train[cl] = confmat_train[cl][cl] / (confmat_train[0][cl] + confmat_train[1][cl] + confmat_train[2][cl])
                        #     recall_train[cl] = confmat_train[cl][cl] / (confmat_train[cl][0] + confmat_train[cl][1] + confmat_train[cl][2])
                        #     precision_val[cl] = confmat_val[cl][cl] / (confmat_val[0][cl] + confmat_val[1][cl] + confmat_val[2][cl])
                        #     recall_val[cl] = confmat_val[cl][cl] / (confmat_val[cl][0] + confmat_val[cl][1] + confmat_val[cl][2])

                        precision_train = metrics.precision_score(
                            gt_train, predictions_train, average=None)
                        recall_train = metrics.recall_score(gt_train,
                                                            predictions_train,
                                                            average=None)

                        precision_val = metrics.precision_score(
                            gt_validation,
                            predictions_validation,
                            average=None)
                        recall_val = metrics.recall_score(
                            gt_validation,
                            predictions_validation,
                            average=None)

                        tp_train = np.zeros(3)
                        fp_train = np.zeros(3)
                        fn_train = np.zeros(3)
                        tn_train = np.zeros(3)
                        tp_val = np.zeros(3)
                        fp_val = np.zeros(3)
                        fn_val = np.zeros(3)
                        tn_val = np.zeros(3)
                        for cl in range(0, 3):
                            tp_train[cl], fp_train[cl], fn_train[cl], tn_train[
                                cl] = process_cm(confmat_train, cl)
                            tp_val[cl], fp_val[cl], fn_val[cl], tn_val[
                                cl] = process_cm(confmat_val, cl)

                        specificity_train = tn_train / (tn_train + fp_train)
                        specificity_val = tn_val / (tn_val + fp_val)

                        #logger.info('Epoch:{:4.0f}, R2 train: {:.3f}, R2 validation: {:.8f}, sRho train: {:.3f}, sRho validation: {:.3f}, cl.acc training: {:.1%}, cl.acc validation: {:.1%},  dt={:.1f}s'.format(epoch, r2_train, r2_validation, spearmanR_train, spearmanR_validation, traincorrect, valcorrect, timeit.default_timer() - eval_start_time))
                        logger.info(
                            'Epoch:{:4.0f}, cl.acc training: {:.1%}, cl.acc validation: {:.1%},  dt={:.1f}s'
                            .format(epoch, traincorrect, valcorrect,
                                    timeit.default_timer() - eval_start_time))

                        # write csv with intermediate results
                        self.write_results_csv(
                            'results_train-{0:04.0f}.csv'.format(epoch),
                            predictions_train, gt_train, subjects_train)
                        self.write_results_csv(
                            'results_validate-{0:04.0f}.csv'.format(epoch),
                            predictions_validation, gt_validation,
                            subjects_validation)

                        summary = tf.Summary()
                        #summary.value.add(tag='r2_train', simple_value=r2_train)
                        #summary.value.add(tag='r2_validation', simple_value=r2_validation)
                        #summary.value.add(tag='SpearmanRho_train', simple_value=spearmanR_train)
                        #summary.value.add(tag='SpearmanRho_validation', simple_value=spearmanR_validation)
                        summary.value.add(tag='XEnt_train',
                                          simple_value=cost_train)
                        summary.value.add(tag='XEnt_validation',
                                          simple_value=cost_validation)
                        summary.value.add(tag='Classification_Accuracy_train',
                                          simple_value=traincorrect)
                        summary.value.add(
                            tag='Classification_Accuracy_validation',
                            simple_value=valcorrect)
                        summary.value.add(tag='Precision_STS_train',
                                          simple_value=precision_train[0])
                        summary.value.add(tag='Precision_STS_val',
                                          simple_value=precision_val[0])
                        summary.value.add(tag='Precision_MTS_train',
                                          simple_value=precision_train[1])
                        summary.value.add(tag='Precision_MTS_val',
                                          simple_value=precision_val[1])
                        summary.value.add(tag='Precision_LTS_train',
                                          simple_value=precision_train[2])
                        summary.value.add(tag='Precision_LTS_val',
                                          simple_value=precision_val[2])
                        summary.value.add(tag='Recall_STS_train',
                                          simple_value=recall_train[0])
                        summary.value.add(tag='Recall_STS_val',
                                          simple_value=recall_val[0])
                        summary.value.add(tag='Recall_MTS_train',
                                          simple_value=recall_train[1])
                        summary.value.add(tag='Recall_MTS_val',
                                          simple_value=recall_val[1])
                        summary.value.add(tag='Recall_LTS_train',
                                          simple_value=recall_train[2])
                        summary.value.add(tag='Recall_LTS_val',
                                          simple_value=recall_val[2])

                        summary.value.add(tag='Specificity_STS_train',
                                          simple_value=specificity_train[0])
                        summary.value.add(tag='Specificity_STS_val',
                                          simple_value=specificity_val[0])
                        summary.value.add(tag='Specificity_MTS_train',
                                          simple_value=specificity_train[1])
                        summary.value.add(tag='Specificity_MTS_val',
                                          simple_value=specificity_val[1])
                        summary.value.add(tag='Specificity_LTS_train',
                                          simple_value=specificity_train[2])
                        summary.value.add(tag='Specificity_LTS_val',
                                          simple_value=specificity_val[2])

                        summary_writer.add_summary(summary, epoch)

                    if epoch % self._cfg.visualize_layer_num_epoch == 0 and len(
                            sum_histograms) > 0:
                        # write histogram summaries and kernel visualization
                        summary_op = tf.summary.merge(sum_histograms)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                        for idx, kernel_name in enumerate(kernel_tensor_names):
                            # visualize weights of kernel layer from a middle slice
                            kernel_weights = graph.get_tensor_by_name(
                                kernel_name).eval()
                            # make last axis the first
                            kernel_weights = np.moveaxis(kernel_weights, -1, 0)

                            if len(kernel_weights.shape) > 4:
                                # 3d convolution, remove last (single) channel
                                kernel_weights = kernel_weights[:, :, :, :, 0]

                            if kernel_weights.shape[3] > 1:
                                # multiple channels, take example from middle slide
                                slice_num = int(kernel_weights.shape[3] / 2)
                                kernel_weights = kernel_weights[:, :, :,
                                                                slice_num:
                                                                slice_num + 1]

                            grid = visualize.make_grid(kernel_weights)[
                                np.newaxis, :, :, np.newaxis]
                            sess.run(kernel_tensors[idx].assign(grid))

                        summary_op = tf.summary.merge(sum_kernels)
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, epoch)

                    summary_writer.flush()

                    # epoch done
                    self.checkpoint_safer(sess, saver, epoch_checkpoint, epoch,
                                          best_xent_checkpoint,
                                          cost_validation)

                summary_writer.close()
                logger.info('Training done.')

                if self._xent > 0:
                    # restore checkpoint of best R2 score
                    checkpoint = os.path.join(self.checkpoint_dir,
                                              'checkpoint-best-r-2.ckpt')
                    saver = tf.train.import_meta_graph(checkpoint + '.meta')
                    saver.restore(sess, checkpoint)
                    logger.info(
                        'RESTORED best-r-2 checkpoint. Epoch: {}, R2: {:.8f}'.
                        format(epoch_checkpoint.eval(),
                               best_xent_checkpoint.eval()))

                # disable transformations (data augmentation) for test
                self._data_store.set_transforms_enabled(False)

                predictions_train, gt_train, subjects_train = self.predict(
                    sess, net, training_loader, x, y, d, is_train)
                predictions_test, gt_test, subjects_test = self.predict(
                    sess, net, testing_loader, x, y, d, is_train)

                self.write_results_csv('results_train.csv', predictions_train,
                                       gt_train, subjects_train)
                self.write_results_csv('results_test.csv', predictions_test,
                                       gt_test, subjects_test)

                #predictions_train_class = predictions_train['classes']
                #predictions_test_class = predictions_test['classes']

                trainCorrectClass = np.equal(gt_train, predictions_train)
                testCorrectClass = np.equal(gt_test, predictions_test)

                traincorrect = np.sum(trainCorrectClass) / len(
                    trainCorrectClass)
                testcorrect = np.sum(testCorrectClass) / len(testCorrectClass)

                # confusion matrix
                confmat_train = metrics.confusion_matrix(
                    gt_train, predictions_train)
                confmat_test = metrics.confusion_matrix(
                    gt_test, predictions_test)

                precision_train = metrics.precision_score(gt_train,
                                                          predictions_train,
                                                          average=None)
                recall_train = metrics.recall_score(gt_train,
                                                    predictions_train,
                                                    average=None)

                precision_test = metrics.precision_score(gt_test,
                                                         predictions_test,
                                                         average=None)
                recall_test = metrics.recall_score(gt_test,
                                                   predictions_test,
                                                   average=None)

                tp_train = np.zeros(3)
                fp_train = np.zeros(3)
                fn_train = np.zeros(3)
                tn_train = np.zeros(3)
                tp_test = np.zeros(3)
                fp_test = np.zeros(3)
                fn_test = np.zeros(3)
                tn_test = np.zeros(3)
                for cl in range(0, 3):
                    tp_train[cl], fp_train[cl], fn_train[cl], tn_train[
                        cl] = process_cm(confmat_train, cl)
                    tp_test[cl], fp_test[cl], fn_test[cl], tn_test[
                        cl] = process_cm(confmat_test, cl)

                specificity_train = tn_train / (tn_train + fp_train)
                specificity_test = tn_test / (tn_test + fp_test)

                # Note: use scaled metrics for MSE and unscaled (original) for R^2
                if len(gt_train) > 0:
                    #accuracy_train = metrics.mean_squared_error(gt_train, predictions_train)
                    #r2_train = metrics.r2_score(gt_train, predictions_train)

                    #spearmanR_train, _ = scipy.stats.spearmanr(gt_train, predictions_train)
                    trainCorrectClass = np.equal(gt_train, predictions_train)
                    traincorrect = np.sum(trainCorrectClass) / len(
                        trainCorrectClass)

                if len(gt_train) == 0:
                    #accuracy_train = 0
                    #r2_train = 0
                    traincorrect = 0

                if len(gt_test) > 0:
                    #accuracy_test = metrics.mean_squared_error(gt_test, predictions_test)
                    #r2_test = metrics.r2_score(gt_test, predictions_test)
                    #spearmanR_test, _ = scipy.stats.spearmanr(gt_test, predictions_test)
                    testCorrectClass = np.equal(gt_test, predictions_test)
                    testcorrect = np.sum(testCorrectClass) / len(
                        testCorrectClass)

                else:
                    #accuracy_test = 0
                    #r2_test = 0
                    #spearmanR_test = 0
                    testcorrect = 0

                #s = analyze_score.print_summary(subjects_test, ['survival'], predictions_test, gt_test)
                #logger.info('Summary:\n%s-------', s)

                # logger.info('TRAIN: cl.acc: {:.1%}, Precision: {:.4f}:, Recall: {:.4f}, Specificity: {:.4f}'.format(traincorrect, precision_train, recall_train, specificity_train))
                # logger.info('TEST: cl.acc: {:.1%}, Precision: {:.4f}:, Recall: {:.4f}, Specificity: {:.4f}'.format(testcorrect, precision_test, recall_test, specificity_test))

                logger.info('TRAIN: cl.acc: ' + str(traincorrect) +
                            ', Precision: ' + str(precision_train) +
                            ', Recall: ' + str(recall_train) +
                            ', Specificity: ' + str(specificity_train))
                logger.info('TEST: cl.acc: ' + str(testcorrect) +
                            ', Precision: ' + str(precision_test) +
                            ', Recall: ' + str(recall_test) +
                            ', Specificity: ' + str(specificity_test))