Exemple #1
0
    def output_mean(self, z_vol, folder_split, vol_num):
        err_avg_emb = dict()
        for yi, y in enumerate(self.output_modalities):
            z_idx = [outi for outi, out in enumerate(self.mm.model.outputs) if self.output_modalities[yi] in out.name]
            y_synth = [z_vol[zi][:, 0] for zi in z_idx]

            y_truth_mod = self.data.select_for_ids(y, [vol_num])
            if y_truth_mod.shape[1] > 1:
                sh = y_truth_mod.shape
                y_truth_mod = y_truth_mod[:, sh[1] / 2:sh[1] / 2 + 1]

            y_synth_mod = np.mean(y_synth, axis=0)
            scipy.misc.imsave(folder_split + '/avg_emb_ims/im_avg_emb' + str(vol_num) + '_' + str(90) + '.png',
                              np.concatenate([y_synth_mod[90], y_truth_mod[90, 0]], axis=1))
            err = ErrorMetrics(np.expand_dims(y_synth_mod, axis=1), y_truth_mod)
            err_avg_emb[y] = err
        return err_avg_emb
Exemple #2
0
    def test_at_split(self, split_dict, folder_split):
        ids_train = split_dict['train']
        ids_valid = split_dict['validation']
        ids_test = split_dict['test']
        all_ids = sorted(ids_train + ids_valid + ids_test)
        num_vols = len(all_ids)

        metrics = [
            'MSE_NBG', 'MSE', 'SSIM_NBG', 'PSNR_NBG', 'SSIM', 'PSNR',
            'MSE_NBG_AVG_EMB'
        ]

        print('testing model on all volumes...')

        # create files
        files_embs = {}
        for emb in range(self.mm.num_emb):
            files = {}
            for mod in self.output_modalities:
                csv_header = '#,' + ','.join(
                    metrics[:-1]) + ', volume_type, MSE_NBG_AVG_EMB\n'
                csv_file = folder_split + '/individual_results_emb_' + str(
                    emb) + '_mod_' + mod + '.csv'

                #python 'open' can't make the directory, so invoke 'os' to create the folder
                if not os.path.exists(folder_split):
                    os.makedirs(folder_split)

                fd = open(csv_file, "w")
                fd.write(csv_header)
                files[mod] = fd
            files_embs[emb] = files
        print 'Created ' + str(len(files_embs)) + ' test files'

        if not os.path.exists(folder_split + '/avg_emb_ims'):
            os.makedirs(folder_split + '/avg_emb_ims')

        for vol_num in range(num_vols):
            if vol_num not in ids_test:
                continue

            print('testing model on volume ' + str(vol_num) + '...')

            X = [
                self.data.select_for_ids(mod, [vol_num])
                for mod in self.input_modalities
            ]
            Z = self.mm.model.predict(X)

            # compute emb average
            err_avg_emb = self.output_mean(Z, folder_split, vol_num)

            for emb in range(self.mm.num_emb):
                files = files_embs[emb]
                for yi, y in enumerate(self.output_modalities):
                    z_idx = [
                        outi for outi, out in enumerate(self.mm.model.outputs)
                        if self.output_modalities[yi] in out.name
                    ]
                    y_synth = [Z[zi][:, 0] for zi in z_idx]

                    y_truth = self.data.select_for_ids(y, [vol_num])

                    if y_truth.shape[1] > 1:
                        sh = y_truth.shape
                        y_truth = y_truth[:, sh[1] / 2:sh[1] / 2 + 1]

                    err = ErrorMetrics(y_synth[emb], y_truth)

                    vol_type = ''
                    if vol_num in ids_test:
                        vol_type = 'test'
                    if vol_num in ids_valid:
                        vol_type = 'validation'
                    if vol_num in ids_train:
                        vol_type = 'training'

                    pattern = "%d" + ", %.3f" * (len(metrics) -
                                                 1) + ', %s, %.3f\n'
                    new_row = pattern % tuple(
                        [vol_num] + list([err[em] for em in metrics[:-1]]) +
                        [vol_type] + [err_avg_emb[y]['MSE_NBG']])
                    files[y].write(new_row)

        for files in files_embs.values():
            for fd in files.values():
                fd.close()

        cb_X = [
            self.data.select_for_ids(mod, all_ids, as_array=False)
            for mod in self.input_modalities
        ]
        cb_Y = [
            self.data.select_for_ids(mod, all_ids, as_array=False)
            for mod in self.output_modalities
        ]
        cb = ImageSaveCallback(cb_X, cb_Y, None, None, folder_split,
                               self.output_modalities)
        cb.model = self.mm.model
        for vol in ids_test + [1, 8, 12, 13]:
            cb.saveImage(vol, [70, 80, 90, 100],
                         folder_split + '/test_im' + str(vol), cb_X, cb_Y)