Example #1
0
    def train(self):
        X_train, y_train, _ = self.load_results_from_result_paths(self.train_paths)
        X_val, y_val, _ = self.load_results_from_result_paths(self.val_paths)
        self.model.fit(X_train, y_train)

        train_pred, var_train = self.model.predict(X_train), None
        val_pred, var_val = self.model.predict(X_val), None

        #self.save()

        fig_train = utils.scatter_plot(np.array(train_pred), np.array(y_train), xlabel='Predicted', ylabel='True', title='')
        fig_train.savefig(os.path.join(self.log_dir, 'pred_vs_true_train.jpg'))
        plt.close()

        fig_val = utils.scatter_plot(np.array(val_pred), np.array(y_val), xlabel='Predicted', ylabel='True', title='')
        fig_val.savefig(os.path.join(self.log_dir, 'pred_vs_true_val.jpg'))
        plt.close()

        train_metrics = utils.evaluate_metrics(y_train, train_pred, prediction_is_first_arg=False)
        valid_metrics = utils.evaluate_metrics(y_val, val_pred, prediction_is_first_arg=False)

        logging.info('train metrics: %s', train_metrics)
        logging.info('valid metrics: %s', valid_metrics)

        return valid_metrics
Example #2
0
    def train(self):
        X_train, y_train, _ = self.load_results_from_result_paths(
            self.train_paths)
        X_val, y_val, _ = self.load_results_from_result_paths(self.val_paths)

        base_learner_config = self.parse_config("base:")
        param_config = self.parse_config("param:")

        # train
        base_learner = DecisionTreeRegressor(criterion='friedman_mse',
                                             random_state=None,
                                             splitter='best',
                                             **base_learner_config)
        self.model = NGBRegressor(Dist=Normal,
                                  Base=base_learner,
                                  Score=LogScore,
                                  verbose=True,
                                  **param_config)
        self.model = self.model.fit(
            X_train,
            y_train,
            X_val=X_val,
            Y_val=y_val,
            early_stopping_rounds=self.model_config["early_stopping_rounds"])

        train_pred, var_train = self.model.predict(X_train), None
        val_pred, var_val = self.model.predict(X_val), None

        # self.save()

        fig_train = utils.scatter_plot(np.array(train_pred),
                                       np.array(y_train),
                                       xlabel='Predicted',
                                       ylabel='True',
                                       title='')
        fig_train.savefig(os.path.join(self.log_dir, 'pred_vs_true_train.jpg'))
        plt.close()

        fig_val = utils.scatter_plot(np.array(val_pred),
                                     np.array(y_val),
                                     xlabel='Predicted',
                                     ylabel='True',
                                     title='')
        fig_val.savefig(os.path.join(self.log_dir, 'pred_vs_true_val.jpg'))
        plt.close()

        train_metrics = utils.evaluate_metrics(y_train,
                                               train_pred,
                                               prediction_is_first_arg=False)
        valid_metrics = utils.evaluate_metrics(y_val,
                                               val_pred,
                                               prediction_is_first_arg=False)

        logging.info('train metrics: %s', train_metrics)
        logging.info('valid metrics: %s', valid_metrics)

        return valid_metrics
Example #3
0
    def train(self):
        X_train, y_train, _ = self.load_results_from_result_paths(
            self.train_paths)
        X_val, y_val, _ = self.load_results_from_result_paths(self.val_paths)

        logging.info(
            "LGBOOST TRAIN: Careful categoricals not specified in dataset conversion"
        )

        dtrain = lgb.Dataset(X_train, label=y_train)
        dval = lgb.Dataset(X_val, label=y_val)

        param_config = self.parse_param_config()
        param_config["seed"] = self.seed

        self.model = lgb.train(
            param_config,
            dtrain,
            early_stopping_rounds=self.model_config["early_stopping_rounds"],
            verbose_eval=1,
            valid_sets=[dval])

        train_pred, var_train = self.model.predict(X_train), None
        val_pred, var_val = self.model.predict(X_val), None

        # self.save()

        fig_train = utils.scatter_plot(np.array(train_pred),
                                       np.array(y_train),
                                       xlabel='Predicted',
                                       ylabel='True',
                                       title='')
        fig_train.savefig(os.path.join(self.log_dir, 'pred_vs_true_train.jpg'))
        plt.close()

        fig_val = utils.scatter_plot(np.array(val_pred),
                                     np.array(y_val),
                                     xlabel='Predicted',
                                     ylabel='True',
                                     title='')
        fig_val.savefig(os.path.join(self.log_dir, 'pred_vs_true_val.jpg'))
        plt.close()

        train_metrics = utils.evaluate_metrics(y_train,
                                               train_pred,
                                               prediction_is_first_arg=False)
        valid_metrics = utils.evaluate_metrics(y_val,
                                               val_pred,
                                               prediction_is_first_arg=False)

        logging.info('train metrics: %s', train_metrics)
        logging.info('valid metrics: %s', valid_metrics)

        return valid_metrics
Example #4
0
    def validate(self):
        preds = []
        targets = []
        self.model.eval()

        valid_queue = self.load_results_from_result_paths(self.val_paths)
        for step, (arch_path_enc, y_true) in enumerate(valid_queue):
            arch_path_enc = arch_path_enc.to(self.device).float()
            y_true = y_true.to(self.device).float()

            pred = self.model(arch_path_enc)
            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(y_true.detach().cpu().numpy())

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(os.path.join(self.log_dir, 'pred_vs_true_valid.jpg'))
        plt.close()

        val_results = utils.evaluate_metrics(np.array(targets),
                                             np.array(preds),
                                             prediction_is_first_arg=False)
        logging.info('validation metrics %s', val_results)

        return val_results
Example #5
0
    def test(self):
        preds = []
        targets = []
        self.model.eval()

        test_queue = self.load_results_from_result_paths(self.test_paths)
        for step, graph_batch in enumerate(test_queue):
            graph_batch = graph_batch.to(self.device)

            if self.model_config['model'] == 'gnn_vs_gae_classifier':
                pred_bins, pred = self.model(graph_batch=graph_batch)

            else:
                pred = self.model(graph_batch=graph_batch)

            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(graph_batch.y.detach().cpu().numpy())

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(os.path.join(self.log_dir, 'pred_vs_true_test.jpg'))
        plt.close()

        test_results = utils.evaluate_metrics(np.array(targets),
                                              np.array(preds),
                                              prediction_is_first_arg=False)
        logging.info('test metrics %s', test_results)

        return test_results
Example #6
0
    def test(self):
        X_test, y_test, _ = self.load_results_from_result_paths(self.test_paths)
        test_pred, var_test = self.model.predict(X_test), None

        fig = utils.scatter_plot(np.array(test_pred), np.array(y_test), xlabel='Predicted', ylabel='True', title='')
        fig.savefig(os.path.join(self.log_dir, 'pred_vs_true_test.jpg'))
        plt.close()

        test_metrics = utils.evaluate_metrics(y_test, test_pred, prediction_is_first_arg=False)

        logging.info('test metrics %s', test_metrics)

        return test_metrics
Example #7
0
    def train(self):
        X_train, y_train, _ = self.load_results_from_result_paths(self.train_paths)
        X_val, y_val, _ = self.load_results_from_result_paths(self.val_paths)

        dtrain = xgb.DMatrix(X_train, label=y_train)
        dval = xgb.DMatrix(X_val, label=y_val)

        param_config = self.parse_param_config()
        param_config["seed"] = self.seed

        self.model = xgb.train(param_config, dtrain, num_boost_round=self.model_config["param:num_rounds"],
                               early_stopping_rounds=self.model_config["early_stopping_rounds"],
                               verbose_eval=1,
                               evals=[(dval, 'val')])

        train_pred, var_train = self.model.predict(dtrain), None
        val_pred, var_val = self.model.predict(dval), None

        # self.save()

        fig_train = utils.scatter_plot(np.array(train_pred), np.array(y_train), xlabel='Predicted', ylabel='True',
                                       title='')
        fig_train.savefig(os.path.join(self.log_dir, 'pred_vs_true_train.jpg'))
        plt.close()

        fig_val = utils.scatter_plot(np.array(val_pred), np.array(y_val), xlabel='Predicted', ylabel='True', title='')
        fig_val.savefig(os.path.join(self.log_dir, 'pred_vs_true_val.jpg'))
        plt.close()

        train_metrics = utils.evaluate_metrics(y_train, train_pred, prediction_is_first_arg=False)
        valid_metrics = utils.evaluate_metrics(y_val, val_pred, prediction_is_first_arg=False)

        logging.info('train metrics: %s', train_metrics)
        logging.info('valid metrics: %s', valid_metrics)

        return valid_metrics
Example #8
0
    def infer(self, train_queue, valid_queue, model, criterion, optimizer, lr,
              epoch):
        objs = utils.AvgrageMeter()

        # VALIDATION
        preds = []
        targets = []

        for step, (arch_path_enc, y_true) in enumerate(valid_queue):
            arch_path_enc = arch_path_enc.to(self.device).float()
            y_true = y_true.to(self.device).float()
            pred = self.model(arch_path_enc)
            loss = torch.mean(
                torch.abs((self.normalize_data(pred) /
                           self.normalize_data(y_true / 100)) - 1))
            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(y_true.detach().cpu().numpy())
            objs.update(loss.data.item(), len(arch_path_enc))

            if step % self.data_config['report_freq'] == 0:
                logging.info('valid %03d %e ', step, objs.avg)

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(
            os.path.join(self.log_dir,
                         'pred_vs_true_valid_{}.jpg'.format(epoch)))
        plt.close()

        val_results = utils.evaluate_metrics(np.array(targets),
                                             np.array(preds),
                                             prediction_is_first_arg=False)

        return objs.avg, val_results
Example #9
0
    def train_epoch(self, train_queue, valid_queue, model, criterion,
                    optimizer, lr, epoch):
        objs = utils.AvgrageMeter()

        # TRAINING
        preds = []
        targets = []

        model.train()

        for step, (arch_path_enc, y_true) in enumerate(train_queue):
            arch_path_enc = arch_path_enc.to(self.device).float()
            y_true = y_true.to(self.device).float()

            pred = self.model(arch_path_enc)
            if self.model_config['loss:loss_log_transform']:
                loss = torch.mean(
                    torch.abs((self.normalize_data(pred) /
                               self.normalize_data(y_true / 100)) - 1))
            else:
                loss = criterion(1 - pred, 1 - y_true / 100)
            if self.model_config['loss:pairwise_ranking_loss']:
                m = 0.1
                pairwise_ranking_loss = []
                sort_idx = torch.argsort(y_true, descending=True)
                for idx, idx_y_i in enumerate(sort_idx):
                    for idx_y_i_p1 in sort_idx[idx + 1:]:
                        pairwise_ranking_loss.append(
                            torch.max(torch.tensor(0.0, dtype=torch.float),
                                      m - (pred[idx_y_i] - pred[idx_y_i_p1])))
                pairwise_ranking_loss = torch.mean(
                    torch.stack(pairwise_ranking_loss))

                loss += pairwise_ranking_loss
                if step % self.data_config['report_freq'] == 0:
                    logging.info('Pairwise ranking loss {}'.format(
                        pairwise_ranking_loss))

            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(y_true.detach().cpu().numpy())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            objs.update(loss.data.item(), len(arch_path_enc))

            if step % self.data_config['report_freq'] == 0:
                logging.info('train %03d %e', step, objs.avg)

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(
            os.path.join(self.log_dir,
                         'pred_vs_true_train_{}.jpg'.format(epoch)))
        plt.close()
        train_results = utils.evaluate_metrics(np.array(targets),
                                               np.array(preds),
                                               prediction_is_first_arg=False)

        return objs.avg, train_results
Example #10
0
    def infer(self, train_queue, valid_queue, model, criterion, optimizer, lr,
              epoch):
        objs = utils.AvgrageMeter()

        # VALIDATION
        preds = []
        targets = []

        model.eval()
        for step, graph_batch in enumerate(valid_queue):
            graph_batch = graph_batch.to(self.device)

            if self.model_config['model'] == 'gnn_vs_gae_classifier':
                pred_bins, pred = self.model(graph_batch=graph_batch)
                criterion = torch.nn.BCELoss()
                criterion_2 = torch.nn.MSELoss()

                bins = self.create_bins(lower_bound=0, width=10, quantity=9)
                binned_weights = []
                for value in graph_batch.y.cpu().numpy():
                    bin_index = self.find_bin(value, bins)
                    binned_weights.append(bin_index)
                bins = torch.FloatTensor(binned_weights)
                make_one_hot = lambda index: torch.eye(self.model_config[
                    'no_bins'])[index.view(-1).long()]
                binns_one_hot = make_one_hot(bins).to(self.device)

                loss_1 = criterion(pred_bins, binns_one_hot)
                loss_2 = criterion_2(pred, self.normalize_data(graph_batch.y))
                alpha = self.model_config['classification_loss']
                beta = self.model_config['regression_loss']

                loss = alpha * loss_1 + beta * loss_2
            else:
                pred = self.model(graph_batch=graph_batch)
                loss = criterion(self.normalize_data(pred),
                                 self.normalize_data(graph_batch.y / 100))

            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(graph_batch.y.detach().cpu().numpy())
            n = graph_batch.num_graphs
            objs.update(loss.data.item(), n)

            if step % self.data_config['report_freq'] == 0:
                logging.info('valid %03d %e ', step, objs.avg)

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(
            os.path.join(self.log_dir,
                         'pred_vs_true_valid_{}.jpg'.format(epoch)))
        plt.close()

        val_results = utils.evaluate_metrics(np.array(targets),
                                             np.array(preds),
                                             prediction_is_first_arg=False)

        return objs.avg, val_results
Example #11
0
    def train_epoch(self, train_queue, valid_queue, model, criterion,
                    optimizer, lr, epoch):
        objs = utils.AvgrageMeter()

        # TRAINING
        preds = []
        targets = []

        model.train()

        for step, graph_batch in enumerate(train_queue):
            graph_batch = graph_batch.to(self.device)
            #             print(step)

            if self.model_config['model'] == 'gnn_vs_gae_classifier':
                pred_bins, pred = self.model(graph_batch=graph_batch)
                criterion = torch.nn.BCELoss()
                criterion_2 = torch.nn.MSELoss()

                bins = self.create_bins(lower_bound=0, width=10, quantity=9)
                binned_weights = []
                for value in graph_batch.y.cpu().numpy():
                    bin_index = self.find_bin(value, bins)
                    binned_weights.append(bin_index)
                bins = torch.FloatTensor(binned_weights)
                make_one_hot = lambda index: torch.eye(self.model_config[
                    'no_bins'])[index.view(-1).long()]
                binns_one_hot = make_one_hot(bins).to(self.device)
                loss_1 = criterion(pred_bins, binns_one_hot)
                loss_2 = criterion_2(pred, self.normalize_data(graph_batch.y))
                alpha = self.model_config['classification_loss']
                beta = self.model_config['regression_loss']

                loss = alpha * loss_1 + beta * loss_2

            else:
                pred = self.model(graph_batch=graph_batch)
                if self.model_config['loss:loss_log_transform']:
                    loss = criterion(self.normalize_data(pred),
                                     self.normalize_data(graph_batch.y / 100))
                else:
                    loss = criterion(pred, graph_batch.y / 100)
                if self.model_config['loss:pairwise_ranking_loss']:
                    m = 0.1
                    '''
                    y = list(map(lambda y_i: 1 if y_i == True else -1, graph_batch.y[0: -1] > graph_batch.y[1:]))
                    pairwise_ranking_loss = torch.nn.HingeEmbeddingLoss(margin=m)(pred[0:-1] - pred[1:],
                                                                                  target=torch.from_numpy(np.array(y)))
                    '''
                    pairwise_ranking_loss = []
                    sort_idx = torch.argsort(graph_batch.y, descending=True)
                    for idx, idx_y_i in enumerate(sort_idx):
                        for idx_y_i_p1 in sort_idx[idx + 1:]:
                            pairwise_ranking_loss.append(
                                torch.max(
                                    torch.tensor(0.0, dtype=torch.float),
                                    m - (pred[idx_y_i] - pred[idx_y_i_p1])))
                    pairwise_ranking_loss = torch.mean(
                        torch.stack(pairwise_ranking_loss))

                    loss += pairwise_ranking_loss
                    if step % self.data_config['report_freq'] == 0:
                        logging.info('Pairwise ranking loss {}'.format(
                            pairwise_ranking_loss))

            preds.extend(pred.detach().cpu().numpy() * 100)
            targets.extend(graph_batch.y.detach().cpu().numpy())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n = graph_batch.num_graphs
            objs.update(loss.data.item(), n)

            if step % self.data_config['report_freq'] == 0:
                logging.info('train %03d %e', step, objs.avg)

        fig = utils.scatter_plot(np.array(preds),
                                 np.array(targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
        fig.savefig(
            os.path.join(self.log_dir,
                         'pred_vs_true_train_{}.jpg'.format(epoch)))
        plt.close()
        train_results = utils.evaluate_metrics(np.array(targets),
                                               np.array(preds),
                                               prediction_is_first_arg=False)

        return objs.avg, train_results
Example #12
0
def ensemble_analysis(nasbench_data, ensemble_parent_dir, data_splits_root):
    # Get model directories
    member_dirs = get_ensemble_member_dirs(ensemble_parent_dir)
    member_dirs_dict = group_dirs_by_surrogate_type(member_dirs)
    print("==> Found ensemble member directories:", member_dirs_dict)

    # Load an ensemble for each surrogate type
    surrogate_ensemble = None
    for surrogate_type, member_dirs in member_dirs_dict.items():

        # Load config
        print("==> Loading %s configs..." % surrogate_type)
        model_log_dir = member_dirs[0]
        data_config = json.load(
            open(os.path.join(model_log_dir, 'data_config.json'), 'r'))
        model_config = json.load(
            open(os.path.join(model_log_dir, 'model_config.json'), 'r'))
        train_paths = json.load(
            open(os.path.join(data_splits_root, 'train_paths.json'), 'r'))
        val_paths = json.load(
            open(os.path.join(data_splits_root, 'val_paths.json'), 'r'))
        test_paths = json.load(
            open(os.path.join(data_splits_root, 'test_paths.json'), 'r'))

        # Load ensemble
        print("==> Loading %s ensemble..." % surrogate_type)
        surrogate_ensemble_single = Ensemble(
            member_model_name=model_config['model'],
            data_root='None',
            log_dir=ensemble_parent_dir,
            starting_seed=data_config['seed'],
            model_config=model_config,
            data_config=data_config,
            ensemble_size=len(member_dirs),
            init_ensemble=False)

        surrogate_ensemble_single.load(model_paths=member_dirs,
                                       train_paths=train_paths,
                                       val_paths=val_paths,
                                       test_paths=test_paths)

        # Combine different model types
        if surrogate_ensemble is None:
            surrogate_ensemble = surrogate_ensemble_single
        else:
            for member_model in surrogate_ensemble_single.ensemble_members:
                surrogate_ensemble.add_member(member_model)

    print("==> Ensemble creation completed.")

    # Set the same seed as used during the training.
    np.random.seed(data_config['seed'])

    # Evaluate ensemble
    print("==> Evaluating ensemble performance...")
    train_metrics, train_preds, train_stddevs, train_targets = surrogate_ensemble.evaluate_ensemble(
        train_paths, apply_noise=False)
    val_metrics, val_preds, val_stddevs, val_targets = surrogate_ensemble.validate_ensemble(
        apply_noise=False)
    test_metrics, test_preds, test_stddevs, test_targets = surrogate_ensemble.test_ensemble(
        apply_noise=False)
    print('==> Ensemble train metrics', train_metrics)
    print('==> Ensemble val metrics', val_metrics)
    print('==> Ensemble test metrics', test_metrics)

    train_metrics_with_noise, train_preds_with_noise, _, _ = surrogate_ensemble.evaluate_ensemble(
        train_paths, apply_noise=True)
    val_metrics_with_noise, val_preds_with_noise, _, _ = surrogate_ensemble.validate_ensemble(
        apply_noise=True)
    test_metrics_with_noise, test_preds_with_noise, _, _ = surrogate_ensemble.test_ensemble(
        apply_noise=True)
    print('==> Ensemble train metrics (noisy)', train_metrics_with_noise)
    print('==> Ensemble val metrics (noisy)', val_metrics_with_noise)
    print('==> Ensemble test metrics (noisy)', test_metrics_with_noise)

    train_mean_stddev = np.mean(train_stddevs)
    val_mean_stddev = np.mean(val_stddevs)
    test_mean_stddev = np.mean(test_stddevs)
    print("==> Mean ensemble stddev on train set %f" % train_mean_stddev)
    print("==> Mean ensemble stddev on validation set %f" % val_mean_stddev)
    print("==> Mean ensemble stddev on test set %f" % test_mean_stddev)

    # Plots
    fig_val = utils.scatter_plot(np.array(train_preds),
                                 np.array(train_targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
    fig_val.savefig(os.path.join(ensemble_parent_dir,
                                 'pred_vs_true_train.jpg'))
    plt.close()

    fig_val = utils.scatter_plot(np.array(val_preds),
                                 np.array(val_targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
    fig_val.savefig(os.path.join(ensemble_parent_dir, 'pred_vs_true_val.jpg'))
    plt.close()

    fig_test = utils.scatter_plot(np.array(test_preds),
                                  np.array(test_targets),
                                  xlabel='Predicted',
                                  ylabel='True',
                                  title='')
    fig_test.savefig(os.path.join(ensemble_parent_dir,
                                  'pred_vs_true_test.jpg'))
    plt.close()

    fig_val = utils.scatter_plot(np.array(train_preds_with_noise),
                                 np.array(train_targets),
                                 xlabel='Predicted',
                                 ylabel='True',
                                 title='')
    fig_val.savefig(
        os.path.join(ensemble_parent_dir, 'pred_vs_true_train_noisy.jpg'))
    plt.close()

    fig_val_noise = utils.scatter_plot(np.array(val_preds_with_noise),
                                       np.array(val_targets),
                                       xlabel='Predicted',
                                       ylabel='True',
                                       title='')
    fig_val_noise.savefig(
        os.path.join(ensemble_parent_dir, 'pred_vs_true_val_noisy.jpg'))
    plt.close()

    fig_test_noise = utils.scatter_plot(np.array(test_preds_with_noise),
                                        np.array(test_targets),
                                        xlabel='Predicted',
                                        ylabel='True',
                                        title='')
    fig_test_noise.savefig(
        os.path.join(ensemble_parent_dir, 'pred_vs_true_test_noisy.jpg'))
    plt.close()

    # Test query_mean method
    # print("==> Testing ensemble predictions with query mean...")
    # test_results, configs, val_scores, test_scores = get_test_configs(member_dirs[0])
    # preds = [surrogate_ensemble.query_mean(config) for config in configs]
    # metrics = utils.evaluate_metrics(val_scores, preds, prediction_is_first_arg=False)
    # print("==> Metrics on test data (query mean):", metrics)

    # Perform cell topology performance analysis
    print("==> Checking configs from cell topology analysis...")
    test_paths = [
        filename for filename in Path(
            "/home/user/projects/nasbench_201_2/analysis/nb_301_cell_topology/cell_topology_analysis"
        ).rglob('*.json')
    ]
    test_metrics_topology, ensemble_predictions_topology, stddevs_topology, targets_topology = surrogate_ensemble.evaluate_ensemble(
        test_paths, apply_noise=False)
    print("==> Metrics on cell topology analysis:", test_metrics_topology)

    # Log
    results = {
        "train_metrics": train_metrics,
        "val_metrics": val_metrics,
        "test_metrics": test_metrics,
        "train_metrics_with_noise": train_metrics_with_noise,
        "val_metrics_with_noise": val_metrics_with_noise,
        "test_metrics_with_noise": test_metrics_with_noise,
        "train_mean_stddev": train_mean_stddev,
        "val_mean_stddev": val_mean_stddev,
        "test_mean_stddev": test_mean_stddev,
        "test_results_cell_topology": test_metrics_topology
    }

    for key, val in results.items():
        if isinstance(val, dict):
            for subkey, subval in val.items():
                results[key][subkey] = np.float64(subval)
        else:
            try:
                results[key] = np.float64(val)
            except:
                pass

    json.dump(
        results,
        open(os.path.join(ensemble_parent_dir, "ensemble_performance.json"),
             "w"))