def plot_station_grid_sequences(station_grids):
    for S, station_grid in enumerate(station_grids):
        lon_points = [grid_point for grid_point in station_grid[1]]
        lat_points = [grid_point for grid_point in station_grid[0]]

        rlon, rlat = PlotUtils.get_rlon_rlat(lon_points, lat_points)

        # needed to truncate the background map
        # rlat_min, rlat_max, rlon_min, rlon_max = -1.3, 1, -3, 0.5
        rlat_min, rlat_max, rlon_min, rlon_max = np.min(rlat), np.max(
            rlat), np.min(rlon), np.max(rlon)

        for i in range(len(lon_points)):
            fig = plt.figure(figsize=(16, 12))
            ax = plt.subplot(111)

            PlotUtils.plot_map_rlon_rlat(ax=ax,
                                         rlat_min=rlat_min - 0.2,
                                         rlat_max=rlat_max + 0.2,
                                         rlon_min=rlon_min - 0.4,
                                         rlon_max=rlon_max + 0.4,
                                         alpha_background=0.5)
            ax.scatter(rlon[:i + 1], rlat[:i + 1], s=20, color='red')
            plt.title('Station %s - Grid Point %s' % (S, i))
            plt.axis('scaled')
            fig.savefig(
                '/home/n1no/Documents/ethz/master_thesis/code/project/preprocessing/station_grid_plots/station_%s_grid_point_%s.png'
                % (S, i))
            plt.close()

        if S >= 3:
            break
def plot_station_grids(station_grids):
    fig = plt.figure(figsize=(30, 20))
    ax = plt.subplot(111)

    lon_points = [
        grid_point for station_grid in station_grids
        for grid_point in station_grid[1]
    ]
    lat_points = [
        grid_point for station_grid in station_grids
        for grid_point in station_grid[0]
    ]

    rlon, rlat = PlotUtils.get_rlon_rlat(lon_points, lat_points)

    # needed to truncate the background map
    # rlat_min, rlat_max, rlon_min, rlon_max = -1.3, 1, -3, 0.5
    rlat_min, rlat_max, rlon_min, rlon_max = np.min(rlat), np.max(
        rlat), np.min(rlon), np.max(rlon)

    PlotUtils.plot_map_rlon_rlat(ax=ax,
                                 rlat_min=rlat_min,
                                 rlat_max=rlat_max,
                                 rlon_min=rlon_min,
                                 rlon_max=rlon_max,
                                 alpha_background=0.5)

    ax.scatter(rlon, rlat, s=1, color='red')
    plt.title('Grid Points around Stations')
    plt.axis('scaled')
    fig.savefig(
        '/home/n1no/Documents/ethz/master_thesis/code/project/preprocessing/station_grid_plots/station_grid_plot.png'
    )
    plt.close()
def plotTrainTestDataDistribution(source_path, output_path):
    for train_test_folds_file in glob.glob(source_path +
                                           '/train_test_folds_*.pkl',
                                           recursive=True):

        with open(train_test_folds_file, 'rb') as f:
            train_test_folds = pkl.load(file=f)
            print('Loaded train/test folds.')
            sys.stdout.flush()

        train_date_times = []
        test_date_times = []

        seed = train_test_folds_file.split('_')[-1][:-4]
        file_name = train_test_folds_file.split('/')[-1][:-4]

        # cross validation
        for idx, train_test_fold in enumerate(train_test_folds):
            print('Cross-validation test fold %s' % str(idx + 1))
            train_fold, test_fold = train_test_fold

            # keep test and train datetimes to calculate distributions
            train_date_times += [list(map(operator.itemgetter(1), train_fold))]
            test_date_times += [list(map(operator.itemgetter(1), test_fold))]

        # create folder if necessary
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        # # create and dump descriptive json file
        # experiment_info_json = json.dumps(config)
        # f = open(output_path + '/experiment_info.json','w')
        # f.write(experiment_info_json)
        # f.close()

        # generate datetime plots
        PlotUtils.plot_datetime_distribution(train_date_times, test_date_times,
                                             output_path + '/' + file_name,
                                             seed)
Beispiel #4
0
def pipeline(subgraph_max_nodes):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes
    data = dataset[0]
    node_indices = torch.where(data.test_mask * data.y != 0)[0]

    gnnNets = GnnNets_NC(input_dim, output_dim, model_args)
    checkpoint = torch.load(mcts_args.explain_model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()
    save_dir = os.path.join('./results', f"{mcts_args.dataset_name}"
                                         f"_{model_args.model_name}"
                                         f"_{reward_args.reward_method}")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    fidelity_score_list = []
    sparsity_score_list = []
    for node_idx in tqdm(node_indices):
        # find the paths and build the graph
        result_path = os.path.join(save_dir, f"node_{node_idx}_score.pt")

        # get data and prediction
        logits, prob,  _ = gnnNets(data.clone())
        _, prediction = torch.max(prob, -1)
        prediction = prediction[node_idx].item()

        # build the graph for visualization
        graph = to_networkx(data, to_undirected=True)
        node_labels = {k: int(v) for k, v in enumerate(data.y)}
        nx.set_node_attributes(graph, node_labels, 'label')

        #  searching using gnn score
        mcts_state_map = MCTS(node_idx=node_idx, ori_graph=graph,
                              X=data.x, edge_index=data.edge_index,
                              num_hops=len(model_args.latent_dim),
                              n_rollout=mcts_args.rollout,
                              min_atoms=mcts_args.min_atoms,
                              c_puct=mcts_args.c_puct,
                              expand_atoms=mcts_args.expand_atoms)
        value_func = GnnNets_NC2value_func(gnnNets,
                                           node_idx=mcts_state_map.node_idx,
                                           target_class=prediction)
        score_func = reward_func(reward_args, value_func)
        mcts_state_map.set_score_func(score_func)

        # get searching result
        if os.path.isfile(result_path):
            gnn_results = torch.load(result_path)
        else:
            gnn_results = mcts_state_map.mcts(verbose=True)
            torch.save(gnn_results, result_path)
        tree_node_x = find_closest_node_result(gnn_results, subgraph_max_nodes)

        # calculate the metrics
        original_node_list = [i for i in tree_node_x.ori_graph.nodes]
        masked_node_list = [i for i in tree_node_x.ori_graph.nodes
                            if i not in tree_node_x.coalition or i == mcts_state_map.node_idx]
        original_score = gnn_score(original_node_list, tree_node_x.data,
                                   value_func=value_func, subgraph_building_method='zero_filling')
        masked_score = gnn_score(masked_node_list, tree_node_x.data,
                                 value_func=value_func, subgraph_building_method='zero_filling')
        sparsity_score = 1 - len(tree_node_x.coalition)/tree_node_x.ori_graph.number_of_nodes()

        fidelity_score_list.append(original_score - masked_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        subgraph_node_labels = nx.get_node_attributes(tree_node_x.ori_graph, name='label')
        subgraph_node_labels = torch.tensor([v for k, v in subgraph_node_labels.items()])
        plotutils.plot(tree_node_x.ori_graph, tree_node_x.coalition, y=subgraph_node_labels,
                       node_idx=mcts_state_map.node_idx,
                       figname=os.path.join(save_dir, f"node_{node_idx}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores
Beispiel #5
0
def plot_test():
    temp_out = np.loadtxt("./metrics/lightGBM_performance_word2vec_0.0",
                          delimiter=',')
    print(temp_out)
    PlotUtils.plot_graph(temp_out, 20)
Beispiel #6
0
def pipeline(max_nodes):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes

    if data_args.dataset_name == 'mutag':
        data_indices = list(range(len(dataset)))
    else:
        loader = get_dataloader(dataset,
                                batch_size=train_args.batch_size,
                                random_split_flag=data_args.random_split,
                                data_split_ratio=data_args.data_split_ratio,
                                seed=data_args.seed)
        data_indices = loader['test'].dataset.indices

    gnnNets = GnnNets(input_dim, output_dim, model_args)
    checkpoint = torch.load(mcts_args.explain_model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()

    save_dir = os.path.join(
        './results', f"{mcts_args.dataset_name}_"
        f"{model_args.model_name}_"
        f"{reward_args.reward_method}")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    fidelity_score_list = []
    sparsity_score_list = []
    for i in tqdm(data_indices):
        # get data and prediction
        data = dataset[i]
        _, probs, _ = gnnNets(Batch.from_data_list([data.clone()]))
        prediction = probs.squeeze().argmax(-1).item()
        original_score = probs.squeeze()[prediction]

        # get the reward func
        value_func = GnnNets_GC2value_func(gnnNets, target_class=prediction)
        payoff_func = reward_func(reward_args, value_func)

        # find the paths and build the graph
        result_path = os.path.join(save_dir, f"example_{i}.pt")

        # mcts for l_shapely
        mcts_state_map = MCTS(data.x,
                              data.edge_index,
                              score_func=payoff_func,
                              n_rollout=mcts_args.rollout,
                              min_atoms=mcts_args.min_atoms,
                              c_puct=mcts_args.c_puct,
                              expand_atoms=mcts_args.expand_atoms)

        if os.path.isfile(result_path):
            results = torch.load(result_path)
        else:
            results = mcts_state_map.mcts(verbose=True)
            torch.save(results, result_path)

        # l sharply score
        graph_node_x = find_closest_node_result(results, max_nodes=max_nodes)
        masked_node_list = [
            node for node in list(range(graph_node_x.data.x.shape[0]))
            if node not in graph_node_x.coalition
        ]
        fidelity_score = original_score - gnn_score(
            masked_node_list,
            data,
            value_func,
            subgraph_building_method='zero_filling')
        sparsity_score = 1 - len(
            graph_node_x.coalition) / graph_node_x.ori_graph.number_of_nodes()
        fidelity_score_list.append(fidelity_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        if hasattr(dataset, 'supplement'):
            words = dataset.supplement['sentence_tokens'][str(i)]
            plotutils.plot(graph_node_x.ori_graph,
                           graph_node_x.coalition,
                           words=words,
                           figname=os.path.join(save_dir, f"example_{i}.png"))
        else:
            plotutils.plot(graph_node_x.ori_graph,
                           graph_node_x.coalition,
                           x=graph_node_x.data.x,
                           figname=os.path.join(save_dir, f"example_{i}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores
Beispiel #7
0
def pipeline_GC(top_k):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    if data_args.dataset_name == 'mutag':
        data_indices = list(range(len(dataset)))
        pgexplainer_trainset = dataset
    else:
        loader = get_dataloader(dataset,
                                batch_size=train_args.batch_size,
                                random_split_flag=data_args.random_split,
                                data_split_ratio=data_args.data_split_ratio,
                                seed=data_args.seed)
        data_indices = loader['test'].dataset.indices
        pgexplainer_trainset = loader['train'].dataset

    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes
    gnnNets = GnnNets(input_dim, output_dim, model_args)
    checkpoint = torch.load(model_args.model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()

    save_dir = os.path.join(
        './results', f"{data_args.dataset_name}_"
        f"{model_args.model_name}_"
        f"pgexplainer")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    pgexplainer = PGExplainer(gnnNets)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    tic = time.perf_counter()

    pgexplainer.get_explanation_network(pgexplainer_trainset)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    toc = time.perf_counter()
    training_duration = toc - tic
    print(f"training time is {training_duration: .4}s ")

    explain_duration = 0.0
    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    fidelity_score_list = []
    sparsity_score_list = []
    for data_idx in tqdm(data_indices):
        data = dataset[data_idx]
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        tic = time.perf_counter()

        prob = pgexplainer.eval_probs(data.x, data.edge_index)
        pred_label = prob.argmax(-1).item()

        if glob.glob(os.path.join(save_dir, f"example_{data_idx}.pt")):
            file = glob.glob(os.path.join(save_dir,
                                          f"example_{data_idx}.pt"))[0]
            edge_mask = torch.from_numpy(torch.load(file))
        else:
            edge_mask = pgexplainer.explain_edge_mask(data.x, data.edge_index)
            save_path = os.path.join(save_dir, f"example_{data_idx}.pt")
            edge_mask = edge_mask.cpu()
            torch.save(edge_mask.detach().numpy(), save_path)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        toc = time.perf_counter()
        explain_duration += (toc - tic)

        graph = to_networkx(data)

        fidelity_score = top_k_fidelity(data, edge_mask, top_k, gnnNets,
                                        pred_label)
        sparsity_score = top_k_sparsity(data, edge_mask, top_k)

        fidelity_score_list.append(fidelity_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        if hasattr(dataset, 'supplement'):
            words = dataset.supplement['sentence_tokens'][str(data_idx)]
            plotutils.plot_soft_edge_mask(graph,
                                          edge_mask,
                                          top_k,
                                          x=data.x,
                                          words=words,
                                          un_directed=True,
                                          figname=os.path.join(
                                              save_dir,
                                              f"example_{data_idx}.png"))
        else:
            plotutils.plot_soft_edge_mask(graph,
                                          edge_mask,
                                          top_k,
                                          x=data.x,
                                          un_directed=True,
                                          figname=os.path.join(
                                              save_dir,
                                              f"example_{data_idx}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores
Beispiel #8
0
def pipeline_NC(top_k):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes
    data = dataset[0]
    node_indices = torch.where(data.test_mask * data.y != 0)[0].tolist()

    gnnNets = GnnNets_NC(input_dim, output_dim, model_args)
    checkpoint = torch.load(model_args.model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()

    save_dir = os.path.join(
        './results', f"{data_args.dataset_name}_"
        f"{model_args.model_name}_"
        f"pgexplainer")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    pgexplainer = PGExplainer(gnnNets)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    tic = time.perf_counter()

    pgexplainer.get_explanation_network(dataset, is_graph_classification=False)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    toc = time.perf_counter()
    training_duration = toc - tic
    print(f"training time is {training_duration}s ")

    duration = 0.0
    data = dataset[0]
    fidelity_score_list = []
    sparsity_score_list = []
    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    for ori_node_idx in tqdm(node_indices):
        tic = time.perf_counter()
        if glob.glob(os.path.join(save_dir, f"node_{ori_node_idx}.pt")):
            file = glob.glob(os.path.join(save_dir,
                                          f"node_{ori_node_idx}.pt"))[0]
            edge_mask, x, edge_index, y, subset = torch.load(file)
            edge_mask = torch.from_numpy(edge_mask)
            node_idx = int(torch.where(subset == ori_node_idx)[0])
            pred_label = pgexplainer.get_node_prediction(
                node_idx, x, edge_index)
        else:
            x, edge_index, y, subset, kwargs = \
                pgexplainer.get_subgraph(node_idx=ori_node_idx, x=data.x, edge_index=data.edge_index, y=data.y)
            node_idx = int(torch.where(subset == ori_node_idx)[0])

            edge_mask = pgexplainer.explain_edge_mask(x, edge_index)
            pred_label = pgexplainer.get_node_prediction(
                node_idx, x, edge_index)
            save_path = os.path.join(save_dir, f"node_{ori_node_idx}.pt")
            edge_mask = edge_mask.cpu()
            cache_list = [edge_mask.numpy(), x, edge_index, y, subset]
            torch.save(cache_list, save_path)

        duration += time.perf_counter() - tic
        sub_data = Data(x=x, edge_index=edge_index, y=y)

        graph = to_networkx(sub_data)

        fidelity_score = top_k_fidelity(sub_data, edge_mask, top_k, gnnNets,
                                        pred_label)
        sparsity_score = top_k_sparsity(sub_data, edge_mask, top_k)

        fidelity_score_list.append(fidelity_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        plotutils.plot_soft_edge_mask(graph,
                                      edge_mask,
                                      top_k,
                                      y=sub_data.y,
                                      node_idx=node_idx,
                                      un_directed=True,
                                      figname=os.path.join(
                                          save_dir,
                                          f"example_{ori_node_idx}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores
Beispiel #9
0
    if isLocal:
        source_path = '/home/n1no/Documents/ethz/master_thesis/code/project/data/preprocessed_data'
        output_path = '/home/n1no/Documents/ethz/master_thesis/code/project/preprocessing/experiments/train_test_distribution'
    else:
        source_path = '/mnt/data/ninow/preprocessed_data'
        output_path = '/home/ninow/master_thesis/code/project/preprocessing/experiments/train_test_distribution'
    # generate data splits and plot training/test data disribution
    PreprocessingDataDistribution.plotTrainTestDataDistribution(
        source_path=source_path, output_path=output_path)

# plot the results of a prediction run over all initializations as an average over all stations and per station seperately
# requires the folder of the model on which the prediction run was made as a run parameter: ---input-source "path_to_model_folder"
elif options.script == 'plotPredictionRun':
    print('Starting to run %s' % options.script)
    PlotUtils.plotPredictionRun(source_path=options.input_source,
                                observation_path=observation_path,
                                n_parallel=n_parallel)

# generates the network training and test results of previously trained network
# requires the folder of the model on which the prediction run was made as a run parameter: ---input-source "path_to_model_folder"
elif options.script == 'plotNetworkTrainingResults':
    print('Starting to run %s' % options.script)
    PlotUtils.plotExperimentErrorEvaluation(options.input_source)

# main preprocessing methods, generating preprocessed data based on the raw COSMO-1 data for neural network training
# ATTENTION: The paths of the data files (i.e. OBS, TOPO, COSMO) net to be specified in the preprocessing classes (CreateDataByStation/CreateDataByStationAndInit)
# run parameter "--preprocessing" can either be "station" or "station_init"
# for "station": a single processed file is generate per station => large memory footprint, since all data in RAM, fast data loading
# for "station_init": a preprocessed file is generated per (station, init) => small memory footprint, slow data loading, because each file is loaded seperately
# for environments with large memory and data on a mounted partition, "station" should be preferred
# if data is on the same machine as the model is run, also "station_init" can be fast
Beispiel #10
0
def runModel(config, data_dictionary, data_statistics, train_test_folds):
    # load time invariant data
    source_path = config['input_source']
    experiment_path = config['experiment_path']

    # assign all program arguments to local variables
    config['batch_size'] = 1
    config['runs'] = 3
    config['grid_size'] = 9

    # if needed, load time invariant features
    with open(
            "%s/%s/grid_size_%s/time_invariant_data_per_station.pkl" %
        (config['input_source'], config['preprocessing'],
         config['original_grid_size']), "rb") as input_file:
        time_invarian_data = pkl.load(input_file)

    # initialize feature scaling function for each feature
    featureScaleFunctions = DataUtils.getFeatureScaleFunctions(
        ParamNormalizationDict, data_statistics)

    plot_config = {
        'features': config['input_parameters'],
        'time_invariant_features': config['grid_time_invariant_parameters'],
        'station_features': config['station_parameters']
    }

    # cross validation
    for run in range(config['runs']):
        print('[Run %s] Cross-validation test fold %s' %
              (str(run + 1), str(run + 1)))

        # take the right preprocessed train/test data set for the current run
        train_fold, test_fold = train_test_folds[run]

        # initialize train and test dataloaders
        trainset = DataLoaders.SinglePredictionCosmoData(
            config=config,
            station_data_dict=data_dictionary,
            files=train_fold,
            featureScaling=featureScaleFunctions,
            time_invariant_data=time_invarian_data)
        trainloader = DataLoader(trainset,
                                 batch_size=config['batch_size'],
                                 shuffle=True,
                                 num_workers=config['n_loaders'],
                                 collate_fn=DataLoaders.collate_fn)

        testset = DataLoaders.SinglePredictionCosmoData(
            config=config,
            station_data_dict=data_dictionary,
            files=test_fold,
            featureScaling=featureScaleFunctions,
            time_invariant_data=time_invarian_data)
        testloader = DataLoader(testset,
                                batch_size=config['batch_size'],
                                shuffle=True,
                                num_workers=config['n_loaders'],
                                collate_fn=DataLoaders.collate_fn)

        train_features = [[] for _ in trainset.parameters]
        train_time_invariant_grid_features = [
            [] for _ in trainset.grid_time_invariant_parameters
        ]
        train_station_features = [[] for _ in trainset.station_parameters]
        train_labels = []
        # loop over complete train set
        for i, data in enumerate(trainloader, 0):
            try:
                # get training batch, e.g. label, cosmo-1 output and external features
                Blabel, Bip2d, StationTimeInv = data
            except ValueError:
                # when the batch size is small, it could happen, that all labels have been corrupted and therefore
                # collate_fn would return an empty list
                print('Value error...')
                continue

            train_labels += list(Blabel.numpy().flatten())
            for feature_idx, _ in enumerate(trainset.parameters):
                train_features[feature_idx] += list(
                    Bip2d[:, feature_idx, :, :].numpy().flatten())
            for ti_feature_idx, _ in enumerate(
                    trainset.grid_time_invariant_parameters):
                train_time_invariant_grid_features[ti_feature_idx] += list(
                    Bip2d[:, trainset.n_parameters +
                          ti_feature_idx, :, :].numpy().flatten())
            for station_feature_idx, _ in enumerate(
                    trainset.station_parameters):
                train_station_features[station_feature_idx] += list(
                    StationTimeInv[:, station_feature_idx].numpy().flatten())

        test_features = [[] for _ in testset.parameters]
        test_time_invariant_grid_features = [
            [] for _ in testset.grid_time_invariant_parameters
        ]
        test_station_features = [[] for _ in testset.station_parameters]
        test_labels = []
        # loop over complete train set
        for i, data in enumerate(testloader, 0):
            try:
                # get training batch, e.g. label, cosmo-1 output and external features
                Blabel, Bip2d, StationTimeInv = data
            except ValueError:
                # when the batch size is small, it could happen, that all labels have been corrupted and therefore
                # collate_fn would return an empty list
                print('Value error...')
                continue

            test_labels += list(Blabel.numpy().flatten())
            for feature_idx, _ in enumerate(testset.parameters):
                test_features[feature_idx] += list(
                    Bip2d[:, feature_idx, :, :].numpy().flatten())
            for ti_feature_idx, _ in enumerate(
                    testset.grid_time_invariant_parameters):
                test_time_invariant_grid_features[ti_feature_idx] += list(
                    Bip2d[:, testset.n_parameters +
                          ti_feature_idx, :, :].numpy().flatten())
            for station_feature_idx, _ in enumerate(
                    testset.station_parameters):
                test_station_features[station_feature_idx] += list(
                    StationTimeInv[:, station_feature_idx].numpy().flatten())

        plot_config['run'] = run
        PlotUtils.plotFeatureDistribution(
            output_path=experiment_path,
            config=plot_config,
            train_features=train_features,
            train_time_invariant_grid_features=
            train_time_invariant_grid_features,
            train_station_features=train_station_features,
            train_labels=train_labels,
            test_features=test_features,
            test_time_invariant_grid_features=test_time_invariant_grid_features,
            test_station_features=test_station_features,
            test_labels=test_labels)