def compute(
        self,
        model: Model,
        graph: Graph,
        trajectories: Trajectories,
        pairwise_features: torch.Tensor,
    ):
        """Update the metrics for all trajectories in `trajectories`"""
        self.init_metrics()
        config = self.config

        with torch.no_grad():
            for trajectory_idx in tqdm(range(len(trajectories))):
                observations = trajectories[trajectory_idx]

                number_steps = None
                if config.rw_edge_weight_see_number_step or config.rw_expected_steps:
                    if config.use_shortest_path_distance:
                        number_steps = (trajectories.leg_shortest_lengths(
                            trajectory_idx).float() * 1.1).long()
                    else:
                        number_steps = trajectories.leg_lengths(trajectory_idx)

                observed, starts, targets = generate_masks(
                    trajectory_length=observations.shape[0],
                    number_observations=config.number_observations,
                    predict=config.target_prediction,
                    with_interpolation=config.with_interpolation,
                    device=config.device,
                )

                diffusion_graph = (graph if not config.diffusion_self_loops
                                   else graph.add_self_loops())

                predictions, _, rw_weights = model(
                    observations,
                    graph,
                    diffusion_graph,
                    observed=observed,
                    starts=starts,
                    targets=targets,
                    pairwise_node_features=pairwise_features,
                    number_steps=number_steps,
                )

                self.update_metrics(
                    trajectories,
                    graph,
                    observations,
                    observed,
                    starts,
                    targets,
                    predictions,
                    rw_weights,
                    trajectory_idx,
                    model.rw_non_backtracking,
                )
Example #2
0
def load_data():
    from features import Model
    from trajectories import Trajectories
    from remote import *  # FIXME
    global model, phi, trajectories, get_actions, regressor
    model = Model(model_file)  # bivec model
    phi = model.paragraph_vector  # feature function: State -> vector
    trajectories = Trajectories(trajectory_file)
    get_actions = trajectories.get_actions
def compute_loss(
    typ: str,
    trajectories: Trajectories,
    observations: torch.Tensor,
    predictions: torch.Tensor,
    starts: torch.Tensor,
    targets: torch.Tensor,
    rw_weights: torch.Tensor,
    trajectory_idx: int,
):
    """Compute the †raining loss

    Args:
        typ (str): loss flag from configuration, can be RMSE, dot_loss, log_dot_loss, target_only or nll_loss
        trajectories (Trajectories): full trajectories dataset evaluated
        observations (torch.Tensor): current trajectory observation [traj_length, n_node]
        predictions (torch.Tensor): output prediction of the model [n_pred, n_node]
        starts (torch.Tensor): indexes of starts extrapolation in observations [n_pred,]
        targets (torch.Tensor): indexes of targets extrapolation in observations [n_pred,]
        rw_weights (torch.Tensor): random walk weights output of model [n_pred, n_edge]
        trajectory_idx (int): index of evaluated trajectory

    Returns:
        torch.Tensor(): loss for this prediction
    """

    if typ == "RMSE":
        return ((predictions - observations[targets])**2).sum()
    elif typ == "dot_loss":
        return -1.0 * (predictions * observations[targets]).sum()
    elif typ == "log_dot_loss":
        return -1.0 * ((predictions * observations[targets]).sum(dim=1) +
                       1e-30).log().sum()
    elif typ == "target_only":
        return -predictions[observations[targets] > 0].sum()
    elif typ == "nll_loss":
        loss = torch.tensor(0.0, device=trajectories.device)
        log_rw_weights = -(rw_weights + 1e-20).log()
        for pred_id in range(len(starts)):
            for jump_id in range(starts[pred_id], targets[pred_id]):
                traversed_edges = trajectories.traversed_edges(
                    trajectory_idx, jump_id)
                loss += log_rw_weights[pred_id, traversed_edges].sum()
        return loss
    else:
        raise Exception(f'Unknown loss "{typ}"')
def train_epoch(
    model: Model,
    graph: Graph,
    optimizer: torch.optim.Optimizer,
    config: Config,
    train_trajectories: Trajectories,
    pairwise_node_features: torch.Tensor,
):
    """One epoch of training"""
    model.train()

    print_cum_loss = 0.0
    print_num_preds = 0
    print_time = time.time()
    print_every = len(
        train_trajectories) // config.batch_size // config.print_per_epoch

    trajectories_shuffle_indices = np.arange(len(train_trajectories))
    if config.shuffle_samples:
        np.random.shuffle(trajectories_shuffle_indices)

    for iteration, batch_start in enumerate(
            range(0,
                  len(trajectories_shuffle_indices) - config.batch_size + 1,
                  config.batch_size)):
        optimizer.zero_grad()
        loss = torch.tensor(0.0, device=config.device)

        for i in range(batch_start, batch_start + config.batch_size):
            trajectory_idx = trajectories_shuffle_indices[i]
            observations = train_trajectories[trajectory_idx]
            length = train_trajectories.lengths[trajectory_idx]

            number_steps = None
            if config.rw_edge_weight_see_number_step or config.rw_expected_steps:
                if config.use_shortest_path_distance:
                    number_steps = (train_trajectories.leg_shortest_lengths(
                        trajectory_idx).float() * 1.1).long()
                else:
                    number_steps = train_trajectories.leg_lengths(
                        trajectory_idx)

            observed, starts, targets = generate_masks(
                trajectory_length=observations.shape[0],
                number_observations=config.number_observations,
                predict=config.target_prediction,
                with_interpolation=config.with_interpolation,
                device=config.device,
            )

            diffusion_graph = graph if not config.diffusion_self_loops else graph.add_self_loops(
            )

            predictions, potentials, rw_weights = model(
                observations,
                graph,
                diffusion_graph,
                observed=observed,
                starts=starts,
                targets=targets,
                pairwise_node_features=pairwise_node_features,
                number_steps=number_steps,
            )

            print_num_preds += starts.shape[0]

            l = (compute_loss(
                config.loss,
                train_trajectories,
                observations,
                predictions,
                starts,
                targets,
                rw_weights,
                trajectory_idx,
            ) / starts.shape[0])
            loss += l

        loss /= config.batch_size
        print_cum_loss += loss.item()
        loss.backward()
        optimizer.step()

        if (iteration + 1) % print_every == 0:
            print_loss = print_cum_loss / print_every
            print_loss /= print_num_preds
            pred_per_second = 1.0 * print_num_preds / \
                (time.time() - print_time)

            print_cum_loss = 0.0
            print_num_preds = 0
            print_time = time.time()

            progress_percent = int(100.0 * ((iteration + 1) // print_every) /
                                   config.print_per_epoch)

            print(
                f"Progress {progress_percent}% | iter {iteration} | {pred_per_second:.1f} pred/s | loss {config.loss} {print_loss}"
            )
def load_data(
    config: Config
) -> Tuple[Graph, List[Trajectories], Optional[torch.Tensor],
           Optional[torch.Tensor]]:
    """Read data in config.workspace / config.input_directory

    Args:
        config (Config): configuration

    Returns:
        (Graph, List[Trajectories], torch.Tensor, torch.Tensor):
            graph, (train, valid, test)_trajectories, pairwise_node_features, pairwise_distances
    """

    input_dir = os.path.join(config.workspace, config.input_directory)

    graph = Graph.read_from_files(
        nodes_filename=os.path.join(input_dir, "nodes.txt"),
        edges_filename=os.path.join(input_dir, "edges.txt"),
    )

    trajectories = Trajectories.read_from_files(
        lengths_filename=os.path.join(input_dir, "lengths.txt"),
        observations_filename=os.path.join(input_dir, "observations.txt"),
        paths_filename=os.path.join(input_dir, "paths.txt"),
        num_nodes=graph.n_node,
    )

    pairwise_node_features = load_tensor(config.device, input_dir,
                                         "pairwise_node_features.pt")
    pairwise_distances = load_tensor(config.device, input_dir,
                                     "shortest-path-distance-matrix.pt")

    trajectories.pairwise_node_distances = pairwise_distances

    if config.extract_coord_features:
        print("Node coordinates are removed from node features")
        graph.extract_coords_from_features(keep_in_features=False)

    valid_trajectories_mask = trajectories.lengths >= config.min_trajectory_length
    valid_trajectories_mask &= trajectories.lengths <= config.max_trajectory_length
    valid_trajectories_idx = valid_trajectories_mask.nonzero()[:, 0]
    valid_lengths = trajectories.lengths[valid_trajectories_idx]

    #print("number of trajectories: ", len(trajectories))
    # print(
    #    f"number of valid trajectories (length in [{config.min_trajectory_length}, {config.max_trajectory_length}]): {len(valid_trajectories_idx)}"
    # )
    # print(
    #    f"trajectories length: min {valid_lengths.min()} | max {valid_lengths.max()} | mean {valid_lengths.float().mean():.2f}"
    # )

    trajectories = trajectories.to(config.device)

    if config.overfit1:
        config.batch_size = 1
        id_ = (trajectories.lengths == config.number_observations +
               1).nonzero()[0]
        print(
            f"Overfit on trajectory {id_.item()} of length {trajectories.lengths[id_].item()}"
        )
        train_mask = torch.zeros_like(valid_trajectories_mask)
        train_mask[id_] = 1
        test_mask = valid_mask = train_mask

    else:
        #print(f"split train/(valid)?/test {config.train_test_ratio}")
        proportions = list(map(float, config.train_test_ratio.split("/")))
        if len(proportions) == 2:
            train_prop, test_prop = proportions
            valid_prop = 0.0
        elif len(proportions) == 3:
            train_prop, valid_prop, test_prop = proportions

        n_train = int(train_prop * len(valid_trajectories_idx))
        n_valid = int(valid_prop * len(valid_trajectories_idx))
        n_test = int(test_prop * len(valid_trajectories_idx))

        train_idx = valid_trajectories_idx[:n_train]
        train_mask = torch.zeros_like(valid_trajectories_mask)
        train_mask[train_idx] = 1

        valid_idx = valid_trajectories_idx[n_train:n_train + n_valid]
        valid_mask = torch.zeros_like(valid_trajectories_mask)
        valid_mask[valid_idx] = 1

        test_idx = valid_trajectories_idx[n_train + n_valid:n_train + n_valid +
                                          n_test]
        test_mask = torch.zeros_like(valid_trajectories_mask)
        test_mask[test_idx] = 1

    train_trajectories = trajectories.with_mask(train_mask)
    valid_trajectories = trajectories.with_mask(valid_mask)
    test_trajectories = trajectories.with_mask(test_mask)
    trajectories = (train_trajectories, valid_trajectories, test_trajectories)

    return (graph, trajectories, pairwise_node_features, pairwise_distances)
Example #6
0
def main(cfg):
    model_dir = os.path.abspath(cfg["model_dir"])
    X_test, Y_test = get_data(cfg)
    print(f"Data loaded. X_test shape: {X_test.shape}, Y_test shape: " \
            f"{Y_test.shape}")
    # Binarize outcome if need be
    Y_test[Y_test >= 0.5] = 1
    Y_test[Y_test < 0.5] = 0

    model = load_model(model_dir)
    model.summary()
    print("Model loaded")

    if cfg["task"].startswith("dpsom"):
        probas_test = model.predict(X_test)
    else:
        probas_test = model.predict([X_test[:, :, 7:], X_test[:, :, :7]])
    ix_pred_a = (probas_test < 0.5).flatten()
    ix_pred_d = (probas_test >= 0.5).flatten()
    ix_a = (Y_test == 0).flatten()
    ix_d = (Y_test == 1).flatten()
    ix_tn = ix_a & ix_pred_a
    ix_fp = ix_a & ix_pred_d
    ix_fn = ix_d & ix_pred_a
    ix_tp = ix_d & ix_pred_d
    X_anl, Y_anl = get_analysis_subsets(X_test, Y_test,
                                        cfg["num_for_analysis"])

    if cfg["write_out"]:
        pickle.dump(X_test, open(pj(bm_config.output_dir, "X_test.pkl"), "wb"))
        pickle.dump(Y_test, open(pj(bm_config.output_dir, "Y_test.pkl"), "wb"))
        # Note, data are *right-padded*, i.e. padded with zeros to the right
        # if there < 200 actual data samples
        # Y_test is {0,1}, 1 = death, about 12% mortality

    if cfg["cluster"]:
        bilstm_name = "bilstm_2"
        bilstm_layer = model.get_layer(bilstm_name)
        bilstm_layer.return_sequences = True
        bilstm_model = Model(inputs=model.input, outputs=bilstm_layer.output)
        if cfg["task"].startswith("dpsom"):
            bilstm_seqs = bilstm_model.predict(X_test)
        else:
            bilstm_seqs = bilstm_model.predict(
                [X_test[:, :, 7:], X_test[:, :, :7]])
        print("Shape of BiLSTM output:", bilstm_seqs.shape)
        bilstm_seqs = np.concatenate(
            [bilstm_seqs[:, :, :64], bilstm_seqs[:, ::-1, 64:]], axis=2)

        reducer = cfg["reducer"]
        if reducer == "tsne":
            reducer_model = TSNE(n_components=2)
        elif reducer == "isomap":
            reducer_model = Isomap(n_components=2,
                                   n_neighbors=cfg["n_neighbors"])
        else:
            raise NotImplementedError(reducer)
        probas_out = bilstm_seqs[:, -1, :]
        print("Shape of final probas matrix:", probas_out.shape)
        print(f"Fitting {reducer} model...")
        proj_X = reducer_model.fit_transform(probas_out)
        # Should really be training tsne with training data but oh well
        print("...Done")

        plt.figure(figsize=(16, 16))
        plt.scatter(proj_X[ix_tn, 0], proj_X[ix_tn, 1], s=12, c="r")
        plt.scatter(proj_X[ix_fn, 0], proj_X[ix_fn, 1], s=12, c="g")
        plt.scatter(proj_X[ix_fp, 0], proj_X[ix_fp, 1], s=12, c="y")
        plt.scatter(proj_X[ix_tp, 0], proj_X[ix_tp, 1], s=12, c="b")
        plt.savefig(pj(model_dir, f"{reducer}.png"))
        plt.close()

        inc = cfg["plot_every_nth"]
        slices_dir = pj(model_dir, f"{reducer}_slices")
        if not pe(slices_dir):
            os.makedirs(slices_dir)
        seq_len = bilstm_seqs.shape[1]
        start_idx = seq_len - cfg["plot_last_n"]

        bilstm_seqs = bilstm_seqs[::inc, start_idx:]
        print("Creating sequence projections...")
        data_mat = np.zeros((bilstm_seqs.shape[0], bilstm_seqs.shape[1], 2))
        for j in range(seq_len - start_idx):
            slice_j = bilstm_seqs[:, j, :]
            data_mat[:, j, :] = reducer_model.transform(slice_j)
        print("...Done")
        color_d = {
            "r": (ix_tn[::inc], 12),
            "g": (ix_fn[::inc], 24),
            "y": (ix_fp[::inc], 12),
            "b": (ix_tp[::inc], 24)
        }
        trajectories = Trajectories(data_mat,
                                    color_dict=color_d,
                                    final_extra=20)
        trajectories.save(pj(model_dir, f"{reducer}_{len(data_mat)}.gif"))
        plt.show()

    # Uses all subjects
    if cfg["confusion_matrix"]:
        print(f"X_test shape: {X_test.shape}, Y_test shape: {Y_test.shape}")
        print(f"Inferred probabilities, output shape {probas_test.shape}")

        fpr_mort, tpr_mort, thresholds = roc_curve(Y_test, probas_test)
        roc_auc_mort = auc(fpr_mort, tpr_mort)
        TN, FP, FN, TP = confusion_matrix(Y_test, probas_test.round()).ravel()
        PPV = TP / (TP + FP)
        NPV = TN / (TN + FN)

        cm = np.array([[TN, FP], [FN, TP]])
        save_path = pj(cfg["model_dir"], "confusion_matrix.png")
        classes = ["False", "True"]
        plot_confusion_matrix(cm,
                              save_path,
                              classes,
                              normalize=False,
                              title='Confusion matrix')

        print("Inference:")
        print(f"PPV: {PPV:0.4f}, NPV: {NPV:0.4f}, roc_auc: " \
                "{roc_auc_mort:0.4f}")
    def update_metrics(
        self,
        trajectories: Trajectories,
        graph: Graph,
        observations,
        observed,
        starts,
        targets,
        predictions,
        rw_weights,
        trajectory_idx,
        rw_non_backtracking,
    ):
        n_pred = len(starts)
        # remove added self loops
        rw_weights = rw_weights[:, :graph.n_edge]

        target_distributions = observations[targets]

        target_probabilities = compute_target_probability(
            target_distributions, predictions)
        self.metrics["target_probability"].add_all(target_probabilities)

        top1_contains_target = compute_topk_contains_target(
            target_distributions, predictions, k=1)
        self.metrics["precision_top1"].add_all(top1_contains_target)
        top5_contains_target = compute_topk_contains_target(
            target_distributions, predictions, k=5)
        self.metrics["precision_top5"].add_all(top5_contains_target)

        assert trajectories.has_traversed_edges
        noise_level = 1e-6  # very small noise is added to break the uniform cases

        # [n_pred, n_node]
        _, chosen_edge_at_each_node = scatter_max(
            rw_weights + torch.rand_like(rw_weights) * noise_level,
            graph.senders,
            fill_value=-1)
        if rw_non_backtracking:
            nb_rw_graph = graph.update(edges=rw_weights.transpose(
                0, 1)).non_backtracking_random_walk_graph
            # [n_edge, n_pred]
            _, chosen_hyperedge_at_each_edge = scatter_max(
                nb_rw_graph.edges +
                torch.rand_like(nb_rw_graph.edges) * noise_level,
                nb_rw_graph.senders,
                dim=0,
                fill_value=-1000,
            )
            chosen_edge_at_each_edge = nb_rw_graph.receivers[
                chosen_hyperedge_at_each_edge]
            # [n_pred, n_edge]
            chosen_edge_at_each_edge = chosen_edge_at_each_edge.transpose(0, 1)

        for pred_id in range(n_pred):
            # concat all edges traversed between start and target
            traversed_edges = torch.cat([
                trajectories.traversed_edges(trajectory_idx, i)
                for i in range(starts[pred_id], targets[pred_id])
            ])
            # remove consecutive duplicate
            duplicate_mask = torch.zeros_like(traversed_edges,
                                              dtype=torch.uint8)
            duplicate_mask[1:] = traversed_edges[:-1] == traversed_edges[1:]
            traversed_edges = traversed_edges[~duplicate_mask]

            nodes_where_decide = graph.senders[traversed_edges]
            """ choice accuracy """

            if rw_non_backtracking:
                chosen_edges = torch.zeros_like(traversed_edges,
                                                dtype=torch.long)
                first_node = nodes_where_decide[0]
                chosen_edges[0] = chosen_edge_at_each_node[pred_id, first_node]
                chosen_edges[1:] = chosen_edge_at_each_edge[
                    pred_id, traversed_edges[:-1]]
            else:
                chosen_edges = chosen_edge_at_each_node[pred_id,
                                                        nodes_where_decide]

            correct_choices = (traversed_edges == chosen_edges).float()
            self.metrics["choice_accuracy"].add_all(correct_choices)

            deg3_mask = graph.out_degree_counts[nodes_where_decide] > 2
            deg3_mask[0] = 1
            self.metrics["choice_accuracy_deg3"].add_all(
                correct_choices[deg3_mask])
            """NLL computation"""

            if not rw_non_backtracking:
                traversed_edges_weights = rw_weights[pred_id, traversed_edges]
            else:
                rw_graph = graph.update(edges=rw_weights[pred_id])
                nb_rw_graph = rw_graph.non_backtracking_random_walk_graph

                traversed_edges_weights = torch.zeros(len(traversed_edges))
                traversed_edges_weights[0] = rw_weights[pred_id,
                                                        traversed_edges[0]]
                for i, (s, r) in enumerate(
                        zip(traversed_edges[:-1], traversed_edges[1:])):
                    traversed_edges_weights[i + 1] = nb_rw_graph.edge(s, r)

            neg_log_weights = -(traversed_edges_weights + 1e-20).log()
            self.metrics["path_nll"].add(neg_log_weights.sum().item())
            deg3_mask = graph.out_degree_counts[
                graph.senders[traversed_edges]] > 2
            deg3_mask[0] = 1
            self.metrics["path_nll_deg3"].add(
                neg_log_weights[deg3_mask].sum().item())

        if self.config.dataset == "wikispeedia":
            jump_lengths = targets - starts
            """top k by jump"""
            self.update_metrics_by_keys("precision_top1", jump_lengths,
                                        top1_contains_target)
            self.update_metrics_by_keys("precision_top5", jump_lengths,
                                        top5_contains_target)
            """cumulative reciprocal rank"""
            # assumes only one target per observations
            target_nodes = observations[targets].nonzero()[:, 1]
            target_ranks = compute_rank(predictions, target_nodes)
            self.update_metrics_by_keys(
                "target_rank", jump_lengths,
                self.harmonic_numbers[target_ranks - 1])
            """West target accuracy"""
            start_nodes = observations[starts].nonzero()[:, 1]
            target2_acc = target2_accuracy(
                start_nodes,
                target_nodes,
                predictions,
                self.given_as_target,
                trajectories.pairwise_node_distances,
            )

            self.update_metrics_by_keys("target2_acc", jump_lengths,
                                        target2_acc)