def test_remove_isolated_nodes():
    assert RemoveIsolatedNodes().__repr__() == 'RemoveIsolatedNodes()'

    edge_index = torch.tensor([[0, 2, 1, 0], [2, 0, 1, 0]])
    edge_attr = torch.tensor([1, 2, 3, 4])
    x = torch.tensor([[1], [2], [3]])
    data = Data(edge_index=edge_index, edge_attr=edge_attr, x=x)
    data = RemoveIsolatedNodes()(data)
    assert len(data) == 3
    assert data.edge_index.tolist() == [[0, 1, 0], [1, 0, 0]]
    assert data.edge_attr.tolist() == [1, 2, 4]
    assert data.x.tolist() == [[1], [3]]
Пример #2
0
def test_remove_isolated_nodes():
    assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()'

    data = Data()
    data.x = torch.arange(3)
    data.edge_index = torch.tensor([[0, 2], [2, 0]])
    data.edge_attr = torch.arange(2)

    data = RemoveIsolatedNodes()(data)

    assert len(data) == 3
    assert data.x.tolist() == [0, 2]
    assert data.edge_index.tolist() == [[0, 1], [1, 0]]
    assert data.edge_attr.tolist() == [0, 1]
Пример #3
0
def test_remove_isolated_nodes_in_hetero_data():
    data = HeteroData()

    data['p'].x = torch.arange(6)
    data['a'].x = torch.arange(6)
    data['i'].num_nodes = 4

    # isolated paper nodes: {4}
    # isolated author nodes: {3, 4, 5}
    # isolated institution nodes: {0, 1, 2, 3}
    data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]])
    data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]])
    data['p', '2', 'a'].edge_attr = torch.arange(3)
    data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]])

    data = RemoveIsolatedNodes()(data)

    assert len(data) == 4
    assert data['p'].num_nodes == 5
    assert data['a'].num_nodes == 3
    assert data['i'].num_nodes == 0

    assert data['p'].x.tolist() == [0, 1, 2, 3, 5]
    assert data['a'].x.tolist() == [0, 1, 2]

    assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]]
    assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]]
    assert data['2'].edge_attr.tolist() == [0, 1, 2]
    assert data['3'].edge_index.tolist() == [[4], [2]]
Пример #4
0
        help="Multiplicative learning rate factor. Defaults to 0.5.")

    parser.add_argument(
        '--sc_type',
        action="store",
        type=str,
        default="last",
        choices=["first", "last"],
        help="How to apply skip-connections for the ADD model. Choices are:"
        "['first', 'last']. Defaults to 'last'.")

    args = parser.parse_args()
    return args


remove_isolated_nodes = RemoveIsolatedNodes()


def concat_x_pos(data: Data) -> Data:
    data.x = torch.cat([data.x, data.pos], dim=-1)
    data.edge_attr = data.edge_attr.unsqueeze(dim=1)
    #data = remove_isolated_nodes(data)
    del data.pos
    return data


multicls_criterion = torch.nn.CrossEntropyLoss()


class Evaluator(object):
    r""" Minimal Evaluator Class as implemented in OGBG to use here. """
Пример #5
0
def main():
    args = get_parser()
    # get some argparse arguments that are parsed a bool string
    naive_encoder = not str2bool(args.full_encoder)
    pin_memory = str2bool(args.pin_memory)
    use_bias = str2bool(args.bias)
    downstream_bn = str(args.d_bn)
    same_dropout = str2bool(args.same_dropout)
    mlp_mp = str2bool(args.mlp_mp)

    phm_dim = args.phm_dim
    learn_phm = str2bool(args.learn_phm)

    base_dir = "pcba/"
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)

    if base_dir not in args.save_dir:
        args.save_dir = os.path.join(base_dir, args.save_dir)

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    set_logging(save_dir=args.save_dir)
    logging.info(f"Creating log directory at {args.save_dir}.")
    with open(os.path.join(args.save_dir, "params.json"), 'w') as fp:
        json.dump(args.__dict__, fp)

    mp_layers = [int(item) for item in args.mp_units.split(',')]
    downstream_layers = [int(item) for item in args.d_units.split(',')]
    mp_dropout = [float(item) for item in args.dropout_mpnn.split(',')]
    dn_dropout = [float(item) for item in args.dropout_dn.split(',')]
    logging.info(
        f'Initialising model with {mp_layers} hidden units with dropout {mp_dropout} '
        f'and downstream units: {downstream_layers} with dropout {dn_dropout}.'
    )

    if args.pooling == "globalsum":
        logging.info("Using GlobalSum Pooling")
    else:
        logging.info("Using SoftAttention Pooling")

    logging.info(
        f"Using Adam optimizer with weight_decay ({args.weightdecay}) and regularization "
        f"norm ({args.regularization})")
    logging.info(
        f"Weight init: {args.w_init} \n Contribution init: {args.c_init}")

    # data
    dname = "ogbg-molpcba"
    transform = RemoveIsolatedNodes()
    # pre-transform doesnt work somehow..
    dataset = PygGraphPropPredDataset(
        name=dname,
        root="dataset")  #, pre_transform=transform, transform=None)
    evaluator = Evaluator(name=dname)
    split_idx = dataset.get_idx_split()
    train_data = dataset[split_idx["train"]]
    valid_data = dataset[split_idx["valid"]]
    test_data = dataset[split_idx["test"]]

    if PRE_TRAFO:
        # pre-transform in memory to overcome computations when training
        logging.info(
            "Pre-transforming graphs, to overcome computation in batching.")
        train_data_list = []
        valid_data_list = []
        test_data_list = []
        for data in train_data:
            train_data_list.append(transform(data))
        for data in valid_data:
            valid_data_list.append(transform(data))
        for data in test_data:
            test_data_list.append(transform(data))

        logging.info("finised. Initiliasing dataloaders")

        train_loader = DataLoader(train_data_list,
                                  batch_size=args.batch_size,
                                  drop_last=False,
                                  shuffle=True,
                                  num_workers=args.nworkers,
                                  pin_memory=pin_memory)
        valid_loader = DataLoader(valid_data_list,
                                  batch_size=args.batch_size,
                                  drop_last=False,
                                  shuffle=False,
                                  num_workers=args.nworkers,
                                  pin_memory=pin_memory)
        test_loader = DataLoader(test_data_list,
                                 batch_size=args.batch_size,
                                 drop_last=False,
                                 shuffle=False,
                                 num_workers=args.nworkers,
                                 pin_memory=pin_memory)
    else:
        train_loader = DataLoader(train_data,
                                  batch_size=args.batch_size,
                                  drop_last=False,
                                  shuffle=True,
                                  num_workers=args.nworkers,
                                  pin_memory=pin_memory)
        valid_loader = DataLoader(valid_data,
                                  batch_size=args.batch_size,
                                  drop_last=False,
                                  shuffle=False,
                                  num_workers=args.nworkers,
                                  pin_memory=pin_memory)
        test_loader = DataLoader(test_data,
                                 batch_size=args.batch_size,
                                 drop_last=False,
                                 shuffle=False,
                                 num_workers=args.nworkers,
                                 pin_memory=pin_memory)

    device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
    FULL_ATOM_FEATURE_DIMS = get_atom_feature_dims()
    FULL_BOND_FEATURE_DIMS = get_bond_feature_dims()

    # for hypercomplex model
    unique_phm = str2bool(args.unique_phm)
    if unique_phm:
        phm_rule = get_multiplication_matrices(phm_dim=args.phm_dim,
                                               type="phm")
        phm_rule = torch.nn.ParameterList(
            [torch.nn.Parameter(a, requires_grad=learn_phm) for a in phm_rule])
    else:
        phm_rule = None

    if args.aggr_msg == "pna" or args.aggr_node == "pna":
        # if PNA is used
        # Compute in-degree histogram over training data.
        deg = torch.zeros(6, dtype=torch.long)
        for data in dataset[split_idx['train']]:
            d = degree(data.edge_index[1],
                       num_nodes=data.num_nodes,
                       dtype=torch.long)
            deg += torch.bincount(d, minlength=deg.numel())
    else:
        deg = None

    aggr_kwargs = {
        "aggregators": ['mean', 'min', 'max', 'std'],
        "scalers": ['identity', 'amplification', 'attenuation'],
        "deg": deg,
        "post_layers": 1,
        "msg_scalers":
        str2bool(args.msg_scale
                 ),  # this key is for directional messagepassing layers.
        "initial_beta": 1.0,  # Softmax
        "learn_beta": True
    }

    if "quaternion" in args.type:
        if args.aggr_msg == "pna" or args.aggr_msg == "pna":
            logging.info("PNA not implemented for quaternion models.")
            raise NotImplementedError

    if args.type == "undirectional-quaternion-sc-add":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Addition"
        )
        model = UQ_SC_ADD(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=dataset.num_tasks,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-quaternion-sc-cat":
        logging.info(
            "Using Quaternion Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UQ_SC_CAT(atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                          atom_encoded_dim=args.input_embed_dim,
                          bond_input_dims=FULL_BOND_FEATURE_DIMS,
                          naive_encoder=naive_encoder,
                          mp_layers=mp_layers,
                          dropout_mpnn=mp_dropout,
                          init=args.w_init,
                          same_dropout=same_dropout,
                          norm_mp=args.mp_norm,
                          add_self_loops=True,
                          msg_aggr=args.aggr_msg,
                          node_aggr=args.aggr_node,
                          mlp=mlp_mp,
                          pooling=args.pooling,
                          activation=args.activation,
                          real_trafo=args.real_trafo,
                          downstream_layers=downstream_layers,
                          target_dim=dataset.num_tasks,
                          dropout_dn=dn_dropout,
                          norm_dn=downstream_bn,
                          msg_encoder=args.msg_encoder,
                          **aggr_kwargs)
    elif args.type == "undirectional-phm-sc-add":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Addition"
        )
        model = UPH_SC_ADD(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=dataset.num_tasks,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           sc_type=args.sc_type,
                           **aggr_kwargs)

    elif args.type == "undirectional-phm-sc-cat":
        logging.info(
            "Using PHM Undirectional MPNN with Skip Connection through Concatenation"
        )
        model = UPH_SC_CAT(phm_dim=phm_dim,
                           learn_phm=learn_phm,
                           phm_rule=phm_rule,
                           atom_input_dims=FULL_ATOM_FEATURE_DIMS,
                           atom_encoded_dim=args.input_embed_dim,
                           bond_input_dims=FULL_BOND_FEATURE_DIMS,
                           naive_encoder=naive_encoder,
                           mp_layers=mp_layers,
                           dropout_mpnn=mp_dropout,
                           w_init=args.w_init,
                           c_init=args.c_init,
                           same_dropout=same_dropout,
                           norm_mp=args.mp_norm,
                           add_self_loops=True,
                           msg_aggr=args.aggr_msg,
                           node_aggr=args.aggr_node,
                           mlp=mlp_mp,
                           pooling=args.pooling,
                           activation=args.activation,
                           real_trafo=args.real_trafo,
                           downstream_layers=downstream_layers,
                           target_dim=dataset.num_tasks,
                           dropout_dn=dn_dropout,
                           norm_dn=downstream_bn,
                           msg_encoder=args.msg_encoder,
                           **aggr_kwargs)

    else:
        raise ModuleNotFoundError

    logging.info(
        f"Model consists of {model.get_number_of_params_()} trainable parameters"
    )
    # do runs
    test_best_epoch_metrics_arr = []
    test_last_epoch_metrics_arr = []
    val_metrics_arr = []

    for i in range(1, args.n_runs + 1):
        ogb_bestEpoch_test_metrics, ogb_lastEpoch_test_metric, ogb_val_metrics = do_run(
            i, model, args, transform, train_loader, valid_loader, test_loader,
            device, evaluator)

        test_best_epoch_metrics_arr.append(ogb_bestEpoch_test_metrics)
        test_last_epoch_metrics_arr.append(ogb_lastEpoch_test_metric)
        val_metrics_arr.append(ogb_val_metrics)

    logging.info(f"Performance of model across {args.n_runs} runs:")
    test_bestEpoch_perf = torch.tensor(test_best_epoch_metrics_arr)
    test_lastEpoch_perf = torch.tensor(test_last_epoch_metrics_arr)
    valid_perf = torch.tensor(val_metrics_arr)
    logging.info('===========================')
    logging.info(
        f'Final Test (best val-epoch) '
        f'"{evaluator.eval_metric}": {test_bestEpoch_perf.mean():.4f} ± {test_bestEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final Test (last-epoch) '
        f'"{evaluator.eval_metric}": {test_lastEpoch_perf.mean():.4f} ± {test_lastEpoch_perf.std():.4f}'
    )
    logging.info(
        f'Final (best) Valid "{evaluator.eval_metric}": {valid_perf.mean():.4f} ± {valid_perf.std():.4f}'
    )
Пример #6
0
def get_small_dataset(dataset_name,
                      normalize_attributes=False,
                      add_self_loops=False,
                      remove_isolated_nodes=False,
                      make_undirected=False,
                      graph_availability=None,
                      seed=0,
                      create_adjacency_lists=True):
    """
    Get the pytorch_geometric.data.Data object associated with the specified dataset name.
    :param dataset_name: str => One of the datasets mentioned below.
    :param normalize_attributes: Whether the attributes for each node should be normalized to sum to 1.
    :param add_self_loops: Add self loops to the input Graph.
    :param remove_isolated_nodes: Remove isolated nodes.
    :param make_undirected: Make the Graph undirected.
    :param graph_availability: Either inductive and transductive. If transductive, all the graph nodes are available
                               during training. Otherwise, only training split nodes are available.
    :param seed: The random seed to use while splitting into train/val/test splits.
    :param create_adjacency_lists: Whether to process and store adjacency lists that can be used for efficient
                                   r-radius neighborhood sampling.
    :return: A pytorch_geometric.data.Data object for that dataset.
    """
    assert dataset_name in {
        'amazon-computers', 'amazon-photo', 'citeseer', 'coauthor-cs',
        'coauthor-physics', 'cora', 'cora-full', 'ppi', 'pubmed', 'reddit'
    }
    assert graph_availability in {'inductive', 'transductive'}

    # Compose transforms that should be applied.
    transforms = []
    if normalize_attributes:
        transforms.append(NormalizeFeatures())
    if remove_isolated_nodes:
        transforms.append(RemoveIsolatedNodes())
    if add_self_loops:
        transforms.append(AddSelfLoops())
    transforms = Compose(transforms) if transforms else None

    # Load the specified dataset and apply transforms.
    root_dir = '/tmp/{dir}'.format(dir=dataset_name)
    processed_dir = os.path.join(root_dir, dataset_name, 'processed')
    # Remove any previously pre-processed data, so pytorch_geometric can pre-process it again.
    if os.path.exists(processed_dir) and os.path.isdir(processed_dir):
        shutil.rmtree(processed_dir)

    data = None

    def split_function(y):
        return _get_train_val_test_masks(y.shape[0], y, 0.2, 0.2, seed)

    if dataset_name in ['citeseer', 'cora', 'pubmed']:
        data = Planetoid(root=root_dir,
                         name=dataset_name,
                         pre_transform=transforms,
                         split='full').data
        if seed != 0:
            data.train_mask, data.val_mask, data.test_mask = split_function(
                data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'cora-full':
        data = CoraFull(root=root_dir, pre_transform=transforms).data
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'amazon-computers':
        data = Amazon(root=root_dir,
                      name='Computers',
                      pre_transform=transforms).data
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'amazon-photo':
        data = Amazon(root=root_dir, name='Photo',
                      pre_transform=transforms).data
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'coauthor-cs':
        data = Coauthor(root=root_dir, name='CS',
                        pre_transform=transforms).data
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'coauthor-physics':
        data = Coauthor(root=root_dir,
                        name='Physics',
                        pre_transform=transforms).data
        data.train_mask, data.val_mask, data.test_mask = split_function(
            data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'reddit':
        data = Reddit(root=root_dir, pre_transform=transforms).data
        if seed != 0:
            data.train_mask, data.val_mask, data.test_mask = split_function(
                data.y.numpy())
        data.graphs = [data]
    elif dataset_name == 'ppi':
        data = SimpleNamespace()
        data.graphs = []
        for split in ['train', 'val', 'test']:
            split_data = PPI(root=root_dir,
                             split=split,
                             pre_transform=transforms)
            x_idxs = split_data.slices['x'].numpy()
            edge_idxs = split_data.slices['edge_index'].numpy()
            split_data = split_data.data
            for x_start, x_end, e_start, e_end in zip(x_idxs, x_idxs[1:],
                                                      edge_idxs,
                                                      edge_idxs[1:]):
                graph = Data(split_data.x[x_start:x_end],
                             split_data.edge_index[:, e_start:e_end],
                             y=split_data.y[x_start:x_end])
                graph.num_nodes = int(x_end - x_start)
                graph.split = split
                all_true = torch.ones(graph.num_nodes).bool()
                all_false = torch.zeros(graph.num_nodes).bool()
                graph.train_mask = all_true if split == 'train' else all_false
                graph.val_mask = all_true if split == 'val' else all_false
                graph.test_mask = all_true if split == 'test' else all_false
                data.graphs.append(graph)
        if seed != 0:
            temp_random = random.Random(seed)
            val_graphs = temp_random.sample(range(len(data.graphs)), 2)
            test_candidates = [
                graph_idx for graph_idx in range(len(data.graphs))
                if graph_idx not in val_graphs
            ]
            test_graphs = temp_random.sample(test_candidates, 2)
            for graph_idx, graph in enumerate(data.graphs):
                all_true = torch.ones(graph.num_nodes).bool()
                all_false = torch.zeros(graph.num_nodes).bool()
                graph.split = 'test' if graph_idx in test_graphs else 'val' if graph_idx in val_graphs else 'train'
                graph.train_mask = all_true if graph.split == 'train' else all_false
                graph.val_mask = all_true if graph.split == 'val' else all_false
                graph.test_mask = all_true if graph.split == 'test' else all_false

    if make_undirected:
        for graph in data.graphs:
            graph.edge_index = to_undirected(graph.edge_index, graph.num_nodes)

    LOG.info(f'Downloaded and transformed {len(data.graphs)} graph(s).')

    # Populate adjacency lists for efficient k-neighborhood sampling.
    # Only retain edges coming into a node and reverse the edges for the purpose of adjacency lists.
    LOG.info('Processing adjacency lists and degree information.')

    for graph in data.graphs:
        train_in_degrees = np.zeros(graph.num_nodes, dtype=np.int64)
        val_in_degrees = np.zeros(graph.num_nodes, dtype=np.int64)
        test_in_degrees = np.zeros(graph.num_nodes, dtype=np.int64)
        adjacency_lists = defaultdict(list)
        not_val_test_mask = (~graph.val_mask & ~graph.test_mask).numpy()
        val_mask = graph.val_mask.numpy()
        test_mask = graph.test_mask.numpy()

        if create_adjacency_lists:
            num_edges = graph.edge_index[0].shape[0]
            sources, dests = graph.edge_index[0].numpy(
            ), graph.edge_index[1].numpy()
            for source, dest in tqdm(zip(sources, dests),
                                     total=num_edges,
                                     leave=False):
                if not_val_test_mask[dest] and not_val_test_mask[source]:
                    train_in_degrees[dest] += 1
                    val_in_degrees[dest] += 1
                elif val_mask[dest] and not test_mask[source]:
                    val_in_degrees[dest] += 1
                test_in_degrees[dest] += 1
                adjacency_lists[dest].append(source)

        graph.adjacency_lists = dict(adjacency_lists)
        graph.train_in_degrees = torch.from_numpy(train_in_degrees).long()
        graph.val_in_degrees = torch.from_numpy(val_in_degrees).long()
        graph.test_in_degrees = torch.from_numpy(test_in_degrees).long()
        if graph_availability == 'transductive':
            graph.train_in_degrees = data.test_in_degrees
            graph.val_in_degrees = data.test_in_degrees

        graph.graph_availability = graph_availability

        # To accumulate any neighborhood perturbations to the graph.
        graph.perturbed_neighborhoods = defaultdict(set)
        graph.added_nodes = defaultdict(set)
        graph.modified_degrees = {}

        # For small datasets, cache the neighborhoods for all nodes for at least 3 different radii queries.
        graph.use_cache = True
        graph.neighborhood_cache = NeighborhoodCache(graph.num_nodes * 3)

        graph.train_mask_original = graph.train_mask
        graph.val_mask_original = graph.val_mask
        graph.test_mask_original = graph.test_mask

        graph.train_mask = torch.ones(
            graph.num_nodes).bool() & ~graph.val_mask & ~graph.test_mask

    return data