예제 #1
0
    def __init__(self, args, dataset_or_loader):
        super(CGATNet, self).__init__()
        self.args = args

        num_input_features = getattr_d(dataset_or_loader, "num_node_features")
        num_classes = getattr_d(dataset_or_loader, "num_classes")

        kwargs = {"use_topk_softmax": args.use_topk_softmax}
        if args.use_topk_softmax:
            kwargs["aggr_k"] = args.aggr_k
        else:
            kwargs["dropout"] = args.dropout

        self.conv1 = CGATConv(
            num_input_features,
            args.num_hidden_features,
            heads=args.heads,
            concat=True,
            margin_graph=args.margin_graph,
            margin_boundary=args.margin_boundary,
            num_neg_samples_per_edge=args.num_neg_samples_per_edge,
            **kwargs,
        )

        self.conv2 = CGATConv(
            args.num_hidden_features * args.heads,
            num_classes,
            heads=(args.out_heads or args.heads),
            concat=False,
            margin_graph=args.margin_graph,
            margin_boundary=args.margin_boundary,
            num_neg_samples_per_edge=args.num_neg_samples_per_edge,
            **kwargs)

        pprint(next(self.modules()))
예제 #2
0
파일: model.py 프로젝트: zyq2016/SuperGAT
    def __init__(self, args, dataset_or_loader):
        super().__init__()
        self.args = args
        self.num_layers = self.args.num_layers

        gat_cls = _get_gat_cls(self.args.model_name)

        num_input_features = getattr_d(dataset_or_loader, "num_node_features")
        num_classes = getattr_d(dataset_or_loader, "num_classes")

        conv_common_kwargs = dict(
            dropout=args.dropout,
            is_super_gat=args.is_super_gat,
            attention_type=args.attention_type,
            super_gat_criterion=args.super_gat_criterion,
            neg_sample_ratio=args.neg_sample_ratio,
            edge_sample_ratio=args.edge_sampling_ratio,
            pretraining_noise_ratio=args.pretraining_noise_ratio,
            use_pretraining=args.use_pretraining,
            to_undirected_at_neg=args.to_undirected_at_neg,
            scaling_factor=args.scaling_factor,
        )
        self.conv_list = []
        self.bn_list = []
        for conv_id in range(1, self.num_layers + 1):
            if conv_id == 1:  # first layer
                in_channels, out_channels = num_input_features, args.num_hidden_features
                heads, concat = args.heads, True
            elif conv_id == self.num_layers:  # last layer
                in_channels, out_channels = args.num_hidden_features * args.heads, num_classes
                heads, concat = args.out_heads or args.heads, False
            else:
                in_channels, out_channels = args.num_hidden_features * args.heads, args.num_hidden_features
                heads, concat = args.heads, True
            # conv
            conv = gat_cls(in_channels,
                           out_channels,
                           heads=heads,
                           concat=concat,
                           **conv_common_kwargs)
            conv_name = "conv{}".format(conv_id)
            self.conv_list.append(conv)
            setattr(self, conv_name, conv)
            self.add_module(conv_name, conv)
            # bn
            if args.use_bn and conv_id != self.num_layers:  # not last layer
                bn = nn.BatchNorm1d(out_channels * heads)
                bn_name = "bn{}".format(conv_id)
                self.bn_list.append(bn)
                setattr(self, bn_name, bn)
                self.add_module(bn_name, bn)

        pprint(next(self.modules()))
예제 #3
0
    def __init__(self, args, dataset_or_loader):
        super(MLPNet, self).__init__()
        self.args = args

        num_input_features = getattr_d(dataset_or_loader, "num_node_features")
        num_classes = getattr_d(dataset_or_loader, "num_classes")

        self.fc = nn.Sequential(
            nn.Dropout(p=args.dropout),
            nn.Linear(num_input_features, args.num_hidden_features),
            nn.ELU(),
            nn.Dropout(p=args.dropout),
            nn.Linear(args.num_hidden_features, num_classes),
        )
        pprint(next(self.modules()))
예제 #4
0
    def __init__(self, args, dataset_or_loader):
        super(LinkGNN, self).__init__()
        self.args = args

        gn_layer = _get_gn_cls(self.args.model_name)

        num_input_features = getattr_d(dataset_or_loader, "num_node_features")
        num_classes = getattr_d(dataset_or_loader, "num_classes")
        self.neg_sample_ratio = args.neg_sample_ratio

        self.conv1 = gn_layer(
            num_input_features,
            args.num_hidden_features,
            **_get_gn_kwargs(args.model_name, args, concat=True),
        )
        self.conv2 = gn_layer(
            _get_last_features(args.model_name, args),
            num_classes,
            **_get_gn_kwargs(args.model_name,
                             args,
                             heads=(args.out_heads or args.heads),
                             concat=False),
        )

        if args.is_link_gnn:
            self.r_scaling_11, self.r_bias_11 = Parameter(
                torch.Tensor(1)), Parameter(torch.Tensor(1))
            self.r_scaling_12, self.r_bias_12 = Parameter(
                torch.Tensor(1)), Parameter(torch.Tensor(1))
            self.r_scaling_21, self.r_bias_21 = Parameter(
                torch.Tensor(1)), Parameter(torch.Tensor(1))
            self.r_scaling_22, self.r_bias_22 = Parameter(
                torch.Tensor(1)), Parameter(torch.Tensor(1))

        self.cache = {
            "num_updated": 0,
            "batch": None,
            "x_conv1": None,
            "x_conv2": None,
            "label": None
        }

        self.reset_parameters()
        pprint(next(self.modules()))
예제 #5
0
파일: model.py 프로젝트: zyq2016/SuperGAT
    def __init__(self, args, dataset_or_loader):
        super().__init__()
        self.args = args

        gat_cls = _get_gat_cls(self.args.model_name)

        num_input_features = getattr_d(dataset_or_loader, "num_node_features")
        num_classes = getattr_d(dataset_or_loader, "num_classes")

        self.conv1 = gat_cls(
            num_input_features,
            args.num_hidden_features,
            heads=args.heads,
            dropout=args.dropout,
            concat=True,
            is_super_gat=args.is_super_gat,
            attention_type=args.attention_type,
            super_gat_criterion=args.super_gat_criterion,
            neg_sample_ratio=args.neg_sample_ratio,
            edge_sample_ratio=args.edge_sampling_ratio,
            pretraining_noise_ratio=args.pretraining_noise_ratio,
            use_pretraining=args.use_pretraining,
            to_undirected_at_neg=args.to_undirected_at_neg,
            scaling_factor=args.scaling_factor,
        )

        self.conv2 = gat_cls(
            args.num_hidden_features * args.heads,
            num_classes,
            heads=(args.out_heads or args.heads),
            dropout=args.dropout,
            concat=False,
            is_super_gat=args.is_super_gat,
            attention_type=args.attention_type,
            super_gat_criterion=args.super_gat_criterion,
            neg_sample_ratio=args.neg_sample_ratio,
            edge_sample_ratio=args.edge_sampling_ratio,
            pretraining_noise_ratio=args.pretraining_noise_ratio,
            use_pretraining=args.use_pretraining,
            to_undirected_at_neg=args.to_undirected_at_neg,
            scaling_factor=args.scaling_factor,
        )

        pprint(next(self.modules()))
예제 #6
0
파일: main.py 프로젝트: zyq2016/SuperGAT
def test_model(device,
               model,
               dataset_or_loader,
               criterion,
               _args,
               val_or_test="val",
               verbose=0,
               **kwargs):
    model.eval()
    try:
        model.set_layer_attrs("cache_attention",
                              _args.task_type == "Attention_Dist")
    except AttributeError:
        pass
    try:
        dataset_or_loader.eval()
    except AttributeError:
        pass

    num_classes = getattr_d(dataset_or_loader, "num_classes")

    total_loss = 0.
    outputs_list, ys_list, batch = [], [], None
    with torch.no_grad():
        for batch in dataset_or_loader:
            batch = batch.to(device)

            # Forward
            outputs = model(batch.x,
                            batch.edge_index,
                            batch=getattr(batch, "batch", None),
                            attention_edge_index=getattr(
                                batch, "{}_edge_index".format(val_or_test),
                                None))

            # Loss
            if "train_mask" in batch.__dict__:
                val_or_test_mask = batch.val_mask if val_or_test == "val" else batch.test_mask
                loss = criterion(outputs[val_or_test_mask],
                                 batch.y[val_or_test_mask])
                outputs_ndarray = outputs[val_or_test_mask].cpu().numpy()
                ys_ndarray = to_one_hot(batch.y[val_or_test_mask], num_classes)
            elif _args.dataset_name == "PPI":  # PPI task
                loss = criterion(outputs, batch.y)
                outputs_ndarray, ys_ndarray = outputs.cpu().numpy(
                ), batch.y.cpu().numpy()
            else:
                loss = criterion(outputs, batch.y)
                outputs_ndarray, ys_ndarray = outputs.cpu().numpy(
                ), to_one_hot(batch.y, num_classes)
            total_loss += loss.item()

            outputs_list.append(outputs_ndarray)
            ys_list.append(ys_ndarray)

    outputs_total, ys_total = np.concatenate(outputs_list), np.concatenate(
        ys_list)

    if _args.task_type == "Link_Prediction":
        if "run_link_prediction" in kwargs and kwargs["run_link_prediction"]:
            val_or_test_edge_y = batch.val_edge_y if val_or_test == "val" else batch.test_edge_y
            layer_idx_for_lp = kwargs["layer_idx_for_link_prediction"] \
                if "layer_idx_for_link_prediction" in kwargs else -1
            perfs = SuperGAT.get_link_pred_perfs_by_attention(
                model=model,
                edge_y=val_or_test_edge_y,
                layer_idx=layer_idx_for_lp)
        else:
            perfs = get_accuracy(outputs_total, ys_total)
    elif _args.perf_type == "micro-f1" and _args.dataset_name == "PPI":
        preds = (outputs_total > 0).astype(int)
        perfs = f1_score(ys_total, preds,
                         average="micro") if preds.sum() > 0 else 0
    elif _args.perf_type == "accuracy" or _args.task_type == "Attention_Dist":
        perfs = get_accuracy(outputs_total, ys_total)
    else:
        raise ValueError

    if verbose >= 2:
        full_name = "Validation" if val_or_test == "val" else "Test"
        cprint("\n[{} of {}]".format(full_name, model.__class__.__name__),
               "yellow")
        cprint("\t- Perfs: {}".format(perfs), "yellow")

    return perfs, total_loss
예제 #7
0
def test_model(device,
               model,
               dataset_or_loader,
               criterion,
               _args,
               val_or_test="val",
               verbose=0,
               **kwargs):
    model.eval()

    if _args.dataset_name == "Reddit":
        dataset, _loader = dataset_or_loader
        data = dataset[0]

    elif _args.dataset_name == "MyReddit":
        dataset, _loader = dataset_or_loader
        data = dataset.data_xy
    else:
        raise TypeError

    try:
        if val_or_test == "val":
            loader = _loader(data.val_mask)
        else:
            loader = _loader(data.test_mask)
    except TypeError:
        loader: DisjointGraphSAINTRandomWalkSampler = _loader
        if val_or_test == "val":
            loader.set_mask(data.val_mask)
        else:
            loader.set_mask(data.test_mask)

    num_classes = getattr_d(dataset, "num_classes")

    total_loss = 0.
    outputs_list, ys_list = [], []

    for batch_id, batch in enumerate(loader):
        # Neighbor sampling
        # n_id: original ID of nodes in the whole sub-graph.
        # b_id: original ID of nodes in the training graph.
        # sub_b_id: sampled ID of nodes in the training graph.

        # RW sampling
        # x, y, mask, edge_index
        try:
            edge_index = dataset.get_edge_index(batch).to(device)
        except AttributeError:
            edge_index = batch.edge_index.to(device)

        try:
            x = data.x[batch.n_id].to(device)
        except AttributeError:
            x = batch.x.to(device)

        try:
            out_mask = batch.sub_b_id.to(device)
        except AttributeError:
            if val_or_test == "val":
                out_mask = batch.val_mask.to(device)
            else:
                out_mask = batch.test_mask.to(device)

        try:
            y_masked = data.y.squeeze()[batch.b_id].to(device)
        except AttributeError:
            y_masked = batch.y[out_mask].to(device)

        outputs = model(x, edge_index)  # [#(n_id), #class]

        batch_node_out = outputs[out_mask]

        loss = criterion(batch_node_out, y_masked)
        total_loss += loss.item() / y_masked.size(0)

        outputs_ndarray = batch_node_out.cpu().numpy()
        ys_ndarray = to_one_hot(y_masked, num_classes)
        outputs_list.append(outputs_ndarray)
        ys_list.append(ys_ndarray)

    outputs_total, ys_total = np.concatenate(outputs_list), np.concatenate(
        ys_list)
    perfs = get_accuracy(outputs_total, ys_total)

    if verbose >= 2:
        full_name = "Validation" if val_or_test == "val" else "Test"
        cprint("\n[{} of {}]".format(full_name, model.__class__.__name__),
               "yellow")
        cprint("\t- Perfs: {}".format(perfs), "yellow")

    return perfs, total_loss