Пример #1
0
 def generate_target():
     target_matrix = generate_target_matrix(data[TARGET_REL_ID],
                                            args.target_n_samples,
                                            args.target_pos_rate, device)
     data_target = SparseMatrixData(target_schema)
     data_target[TARGET_REL_ID] = target_matrix
     data_target.to(device)
     return data_target
Пример #2
0
def get_data_and_targets(schema, neg_data, data_indices, paper, cites,
                         content):
    n_papers = schema.entities[0].n_instances
    n_words = schema.entities[1].n_instances

    train_targets = torch.LongTensor(paper[1])

    # Randomly fill in values and coalesce to remove duplicates
    n_cites_neg = int(neg_data * cites.shape[1])
    #cites_neg = np.random.choice(data_indices, (2, n_cites_neg))
    cites_neg = np.random.randint(0, n_papers, (2, n_cites_neg))
    cites_matrix = SparseMatrix(
        indices=torch.LongTensor(np.concatenate((cites, cites_neg), axis=1)),
        values=torch.cat((torch.ones(cites.shape[1],
                                     1), torch.zeros(n_cites_neg, 1))),
        shape=(n_papers, n_papers, 1)).coalesce()

    # For each paper, randomly fill in values and coalesce to remove duplicates
    n_content_neg = int(neg_data * content.shape[1])
    #content_neg = np.stack((np.random.choice(data_indices, (n_content_neg,)),
    #                np.random.randint(0, n_words, (n_content_neg,))))
    content_neg = np.stack((np.random.randint(0, n_papers, (n_content_neg, )),
                            np.random.randint(0, n_words, (n_content_neg, ))))
    content_matrix = SparseMatrix(
        indices=torch.LongTensor(np.concatenate((content, content_neg),
                                                axis=1)),
        values=torch.cat((torch.ones(content.shape[1],
                                     1), torch.zeros(n_content_neg, 1))),
        shape=(n_papers, n_words, 1)).coalesce()

    data = SparseMatrixData(schema)

    data[0] = cites_matrix
    data[1] = content_matrix
    return data, train_targets
    def get_node_classification_data(self):
        entities = self.schema.entities

        self.schema_out = DataSchema([entities[TARGET_NODE_TYPE]], [
            Relation(0,
                     [entities[TARGET_NODE_TYPE], entities[TARGET_NODE_TYPE]],
                     is_set=True)
        ])
        target_indices = []
        targets = []
        with open(LABEL_FILE_STR, 'r') as label_file:
            lines = label_file.readlines()
            for line in lines:
                node_id, node_name, node_type, node_label = line.rstrip(
                ).split('\t')
                node_type = int(node_type)
                node_id = self.node_id_to_idx[node_type][int(node_id)]
                node_label = int(node_label)
                target_indices.append(node_id)
                targets.append(node_label)

        self.target_indices = torch.LongTensor(target_indices)
        self.targets = torch.LongTensor(targets)
        self.n_outputs = self.schema.entities[TARGET_NODE_TYPE].n_instances
        self.data_target = SparseMatrixData(self.schema_out)
Пример #4
0
 def forward(self, X):
     X_out = SparseMatrixData(X.schema)
     for rel in self.schema.relations.values():
         if self.node_only and rel.is_set:
             X_out[rel.id] = X[rel.id]
         else:
             X_out[rel.id] = self.linear[str(rel.id)](X[rel.id])
     return X_out
Пример #5
0
def run_model(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_equiv = args.decoder == 'equiv'

    # Collect data and schema
    schema, data_original, dl = load_data(args.dataset,
                                          use_edge_data=args.use_edge_data,
                                          use_other_edges=args.use_other_edges,
                                          use_node_attrs=args.use_node_attrs,
                                          node_val=args.node_val)
    data, in_dims = select_features(data_original, schema, args.feats_type)
    data = data.to(device)

    # Precompute data indices
    indices_identity, indices_transpose = data.calculate_indices()
    # Get target relations and create data structure for embeddings
    target_rel_ids = dl.links_test['data'].keys()
    target_rels = [schema.relations[rel_id] for rel_id in target_rel_ids]
    target_ents = schema.entities
    # Get relations used by decoder
    if use_equiv:
        output_rels = schema.relations
    else:
        output_rels = {rel.id: rel for rel in target_rels}
    data_embedding = SparseMatrixData.make_entity_embeddings(
        target_ents, args.embedding_dim)
    data_embedding.to(device)

    # Get training and validation positive samples now
    train_pos_heads, train_pos_tails = dict(), dict()
    val_pos_heads, val_pos_tails = dict(), dict()
    for target_rel_id in target_rel_ids:
        train_val_pos = get_train_valid_pos(dl, target_rel_id)
        train_pos_heads[target_rel_id], train_pos_tails[target_rel_id], \
            val_pos_heads[target_rel_id], val_pos_tails[target_rel_id] = train_val_pos

    # Get additional indices to be used when making predictions
    pred_idx_matrices = {}
    for target_rel in target_rels:
        if args.pred_indices == 'train':
            train_neg_head, train_neg_tail = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            pred_idx_matrices[target_rel.id] = make_target_matrix(
                target_rel, train_pos_heads[target_rel.id],
                train_pos_tails[target_rel.id], train_neg_head, train_neg_tail,
                device)
        elif args.pred_indices == 'train_neg':
            # Get negative samples twice
            train_neg_head1, train_neg_tail1 = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            train_neg_head2, train_neg_tail2 = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            pred_idx_matrices[target_rel.id] = make_target_matrix(
                target_rel, train_neg_head1, train_neg_tail1, train_neg_head2,
                train_neg_tail2, device)
        elif args.pred_indices == 'none':
            pred_idx_matrices[target_rel.id] = None

    # Create network and optimizer
    net = EquivLinkPredictor(schema,
                             in_dims,
                             layers=args.layers,
                             embedding_dim=args.embedding_dim,
                             embedding_entities=target_ents,
                             output_rels=output_rels,
                             activation=eval('nn.%s()' % args.act_fn),
                             final_activation=nn.Identity(),
                             dropout=args.dropout,
                             pool_op=args.pool_op,
                             norm_affine=args.norm_affine,
                             norm_embed=args.norm_embed,
                             in_fc_layer=args.in_fc_layer,
                             decode=args.decoder)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Set up logging and checkpointing
    if args.wandb_log_run:
        wandb.init(config=args,
                   settings=wandb.Settings(start_method='fork'),
                   project="EquivariantHGN_LP",
                   entity='danieltlevy')
        wandb.watch(net, log='all', log_freq=args.wandb_log_param_freq)
    print(args)
    print("Number of parameters: {}".format(count_parameters(net)))
    run_name = args.dataset + '_' + str(args.run)
    if args.wandb_log_run and wandb.run.name is not None:
        run_name = run_name + '_' + str(wandb.run.name)
    if args.checkpoint_path != '':
        checkpoint_path = args.checkpoint_path
    else:
        checkpoint_path = f"checkpoint/checkpoint_{run_name}.pt"
    print("Checkpoint Path: " + checkpoint_path)
    val_metric_best = -1e10

    # training
    loss_func = nn.BCELoss()
    progress = tqdm(range(args.epoch), desc="Epoch 0", position=0, leave=True)
    for epoch in progress:
        net.train()
        # Make target matrix and labels to train on
        if use_equiv:
            # Target is same as input
            target_schema = schema
            data_target = data.clone()
        else:
            # Target is just target relation
            target_schema = DataSchema(schema.entities, target_rels)
            data_target = SparseMatrixData(target_schema)
        labels_train = torch.Tensor([]).to(device)
        for target_rel in target_rels:
            train_neg_head, train_neg_tail = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            train_matrix = make_target_matrix(target_rel,
                                              train_pos_heads[target_rel.id],
                                              train_pos_tails[target_rel.id],
                                              train_neg_head, train_neg_tail,
                                              device)
            data_target[target_rel.id] = train_matrix
            labels_train_rel = train_matrix.values.squeeze()
            labels_train = torch.cat([labels_train, labels_train_rel])

        # Make prediction
        if use_equiv:
            idx_id_tgt, idx_trans_tgt = data_target.calculate_indices()
            output_data = net(data, indices_identity, indices_transpose,
                              data_embedding, data_target, idx_id_tgt,
                              idx_trans_tgt)
        else:
            output_data = net(data, indices_identity, indices_transpose,
                              data_embedding, data_target)
        logits_combined = torch.Tensor([]).to(device)
        for target_rel in target_rels:
            logits_rel = output_data[target_rel.id].values.squeeze()
            logits_combined = torch.cat([logits_combined, logits_rel])

        logp = torch.sigmoid(logits_combined)
        train_loss = loss_func(logp, labels_train)

        # autograd
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Update logging
        progress.set_description(f"Epoch {epoch}")
        progress.set_postfix(loss=train_loss.item())
        wandb_log = {'Train Loss': train_loss.item(), 'epoch': epoch}

        # Evaluate on validation set
        net.eval()
        if epoch % args.val_every == 0:
            with torch.no_grad():
                net.eval()
                left = torch.Tensor([]).to(device)
                right = torch.Tensor([]).to(device)
                labels_val = torch.Tensor([]).to(device)
                valid_masks = {}
                for target_rel in target_rels:
                    if args.val_neg == '2hop':
                        valid_neg_head, valid_neg_tail = get_valid_neg_2hop(
                            dl, target_rel.id)
                    elif args.val_neg == 'randomtw':
                        valid_neg_head, valid_neg_tail = get_valid_neg(
                            dl, target_rel.id, tail_weighted=True)
                    else:
                        valid_neg_head, valid_neg_tail = get_valid_neg(
                            dl, target_rel.id)
                    valid_matrix_full = make_target_matrix(
                        target_rel, val_pos_heads[target_rel.id],
                        val_pos_tails[target_rel.id], valid_neg_head,
                        valid_neg_tail, device)
                    valid_matrix, left_rel, right_rel, labels_val_rel = coalesce_matrix(
                        valid_matrix_full)
                    left = torch.cat([left, left_rel])
                    right = torch.cat([right, right_rel])
                    labels_val = torch.cat([labels_val, labels_val_rel])
                    if use_equiv:
                        # Add in additional prediction indices
                        pred_idx_matrix = pred_idx_matrices[target_rel.id]
                        if pred_idx_matrix is None:
                            valid_combined_matrix = valid_matrix
                            valid_mask = torch.arange(
                                valid_matrix.nnz()).to(device)
                        else:
                            valid_combined_matrix, valid_mask = combine_matrices(
                                valid_matrix, pred_idx_matrix)
                        valid_masks[target_rel.id] = valid_mask
                        data_target[target_rel.id] = valid_combined_matrix
                    else:
                        data_target[target_rel.id] = valid_matrix

                if use_equiv:
                    data_target.zero_()
                    idx_id_val, idx_trans_val = data_target.calculate_indices()
                    output_data = net(data, indices_identity,
                                      indices_transpose, data_embedding,
                                      data_target, idx_id_val, idx_trans_val)
                else:
                    output_data = net(data, indices_identity,
                                      indices_transpose, data_embedding,
                                      data_target)
                logits_combined = torch.Tensor([]).to(device)
                for target_rel in target_rels:
                    logits_rel_full = output_data[
                        target_rel.id].values.squeeze()
                    if use_equiv:
                        logits_rel = logits_rel_full[valid_masks[
                            target_rel.id]]
                    else:
                        logits_rel = logits_rel_full
                    logits_combined = torch.cat([logits_combined, logits_rel])

                logp = torch.sigmoid(logits_combined)
                val_loss = loss_func(logp, labels_val).item()

                wandb_log.update({'val_loss': val_loss})
                left = left.cpu().numpy()
                right = right.cpu().numpy()
                edge_list = np.concatenate(
                    [left.reshape((1, -1)),
                     right.reshape((1, -1))], axis=0)
                res = dl.evaluate(edge_list,
                                  logp.cpu().numpy(),
                                  labels_val.cpu().numpy())
                val_roc_auc = res['roc_auc']
                val_mrr = res['MRR']
                wandb_log.update(res)
                print("\nVal Loss: {:.3f} Val ROC AUC: {:.3f} Val MRR: {:.3f}".
                      format(val_loss, val_roc_auc, val_mrr))
                if args.val_metric == 'loss':
                    val_metric = -val_loss
                elif args.val_metric == 'roc_auc':
                    val_metric = val_roc_auc
                elif args.val_metric == 'mrr':
                    val_metric = val_mrr

                if val_metric > val_metric_best:
                    val_metric_best = val_metric
                    print("New best, saving")
                    torch.save(
                        {
                            'epoch': epoch,
                            'net_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss.item(),
                            'val_loss': val_loss,
                            'val_roc_auc': val_roc_auc,
                            'val_mrr': val_mrr
                        }, checkpoint_path)
                    if args.wandb_log_run:
                        wandb.summary["val_roc_auc_best"] = val_roc_auc
                        wandb.summary["val_mrr_best"] = val_mrr
                        wandb.summary["val_loss_best"] = val_loss
                        wandb.summary["epoch_best"] = epoch
                        wandb.summary["train_loss_best"] = train_loss.item()
                        wandb.save(checkpoint_path)
        if args.wandb_log_run:
            wandb.log(wandb_log)

    # Evaluate on test set
    if args.evaluate:
        for target_rel in target_rels:
            print("Evaluating Target Rel " + str(target_rel.id))
            checkpoint = torch.load(checkpoint_path, map_location=device)
            net.load_state_dict(checkpoint['net_state_dict'])
            net.eval()

            # Target is same as input
            data_target = data.clone()
            with torch.no_grad():
                left_full, right_full, test_labels_full = get_test_neigh_from_file(
                    dl, args.dataset, target_rel.id)
                test_matrix_full = make_target_matrix_test(
                    target_rel, left_full, right_full, test_labels_full,
                    device)
                test_matrix, left, right, test_labels = coalesce_matrix(
                    test_matrix_full)
                if use_equiv:
                    test_combined_matrix, test_mask = combine_matrices(
                        test_matrix, train_matrix)
                    data_target[target_rel.id] = test_combined_matrix
                    data_target.zero_()
                    idx_id_tst, idx_trans_tst = data_target.calculate_indices()
                    data_out = net(data, indices_identity, indices_transpose,
                                   data_embedding, data_target, idx_id_tst,
                                   idx_trans_tst)
                    logits_full = data_out[target_rel.id].values.squeeze()
                    logits = logits_full[test_mask]
                else:
                    data_target[target_rel.id] = test_matrix
                    data_out = net(data, indices_identity, indices_transpose,
                                   data_embedding, data_target)
                    logits_full = data_out[target_rel.id].values.squeeze()
                    logits = logits_full

                pred = torch.sigmoid(logits).cpu().numpy()
                left = left.cpu().numpy()
                right = right.cpu().numpy()
                edge_list = np.vstack((left, right))
                edge_list_full = np.vstack((left_full, right_full))
                file_path = f"test_out/{run_name}.txt"
                gen_file_for_evaluate(dl,
                                      edge_list_full,
                                      edge_list,
                                      pred,
                                      target_rel.id,
                                      file_path=file_path)
Пример #6
0
    set_seed(args.seed)

    dataloader = PubMedData(args.node_labels)
    schema = dataloader.schema
    data = dataloader.data.to(device)
    indices_identity, indices_transpose = data.calculate_indices()
    embedding_entity = schema.entities[TARGET_NODE_TYPE]
    input_channels = {
        rel.id: data[rel.id].n_channels
        for rel in schema.relations
    }
    embedding_schema = DataSchema(
        schema.entities,
        Relation(0, [embedding_entity, embedding_entity], is_set=True))
    n_instances = embedding_entity.n_instances
    data_embedding = SparseMatrixData(embedding_schema)
    data_embedding[0] = SparseMatrix(
        indices=torch.arange(n_instances, dtype=torch.int64).repeat(2, 1),
        values=torch.zeros([n_instances, args.embedding_dim]),
        shape=(n_instances, n_instances, args.embedding_dim),
        is_set=True)
    data_embedding.to(device)
    target_schema = DataSchema(schema.entities,
                               schema.relations[TARGET_REL_ID])
    target_node_idx_to_id = dataloader.target_node_idx_to_id
    #%%
    net = SparseMatrixAutoEncoder(schema,
                                  input_channels,
                                  layers=args.layers,
                                  embedding_dim=args.embedding_dim,
                                  embedding_entities=[embedding_entity],
Пример #7
0
 def forward(self, data, data_target=None):
     data_out = SparseMatrixData(self.schema_out)
     data_out = self.multiply_matrices(data, data_out, data_target)
     data_out = self.add_bias(data_out)
     return data_out
Пример #8
0
 def forward(self, data, indices_identity=None, indices_transpose=None):
     data_out = SparseMatrixData(self.schema_out)
     data_out = self.multiply_matrices(data, data_out, indices_identity,
                                       indices_transpose)
     data_out = self.add_bias(data_out)
     return data_out
Пример #9
0
def load_data(prefix='DBLP',
              use_node_attrs=True,
              use_edge_data=True,
              feats_type=0):
    dl = data_loader(DATA_FILE_DIR + prefix)

    # Create Schema
    entities = [
        Entity(entity_id, n_instances)
        for entity_id, n_instances in sorted(dl.nodes['count'].items())
    ]
    relations = {
        rel_id: Relation(rel_id, [entities[entity_i], entities[entity_j]])
        for rel_id, (entity_i, entity_j) in sorted(dl.links['meta'].items())
    }
    num_relations = len(relations)
    if use_node_attrs:
        # Create fake relations to represent node attributes
        for entity in entities:
            rel_id = num_relations + entity.id
            relations[rel_id] = Relation(rel_id, [entity, entity], is_set=True)
    schema = DataSchema(entities, relations)

    # Collect data
    data = SparseMatrixData(schema)
    for rel_id, data_matrix in dl.links['data'].items():
        # Get subset belonging to entities in relation
        start_i = dl.nodes['shift'][relations[rel_id].entities[0].id]
        end_i = start_i + dl.nodes['count'][relations[rel_id].entities[0].id]
        start_j = dl.nodes['shift'][relations[rel_id].entities[1].id]
        end_j = start_j + dl.nodes['count'][relations[rel_id].entities[1].id]
        rel_matrix = data_matrix[start_i:end_i, start_j:end_j]
        data[rel_id] = SparseMatrix.from_scipy_sparse(rel_matrix.tocoo())
        if not use_edge_data:
            # Use only adjacency information
            data[rel_id].values = torch.ones(data[rel_id].values.shape)

    target_entity = 0

    if use_node_attrs:
        for ent_id, attr_matrix in dl.nodes['attr'].items():
            if attr_matrix is None:
                # Attribute for each node is a single 1
                attr_matrix = np.ones(dl.nodes['count'][ent_id])[:, None]
            n_channels = attr_matrix.shape[1]
            rel_id = ent_id + num_relations
            n_instances = dl.nodes['count'][ent_id]
            indices = torch.arange(n_instances).unsqueeze(0).repeat(2, 1)
            data[rel_id] = SparseMatrix(
                indices=indices,
                values=torch.FloatTensor(attr_matrix),
                shape=np.array([n_instances, n_instances, n_channels]),
                is_set=True)

    n_outputs = dl.nodes['count'][target_entity]
    n_output_classes = dl.labels_train['num_classes']
    schema_out = DataSchema([entities[target_entity]], [
        Relation(0, [entities[target_entity], entities[target_entity]],
                 is_set=True)
    ])
    data_target = SparseMatrixData(schema_out)
    data_target[0] = SparseMatrix(
        indices=torch.arange(n_outputs, dtype=torch.int64).repeat(2, 1),
        values=torch.zeros([n_outputs, n_output_classes]),
        shape=(n_outputs, n_outputs, n_output_classes),
        is_set=True)
    labels = np.zeros((dl.nodes['count'][0], dl.labels_train['num_classes']),
                      dtype=int)
    val_ratio = 0.2
    train_idx = np.nonzero(dl.labels_train['mask'])[0]
    np.random.shuffle(train_idx)
    split = int(train_idx.shape[0] * val_ratio)
    val_idx = train_idx[:split]
    train_idx = train_idx[split:]
    train_idx = np.sort(train_idx)
    val_idx = np.sort(val_idx)
    test_idx = np.nonzero(dl.labels_test['mask'])[0]
    labels[train_idx] = dl.labels_train['data'][train_idx]
    labels[val_idx] = dl.labels_train['data'][val_idx]
    if prefix != 'IMDB':
        labels = labels.argmax(axis=1)
    train_val_test_idx = {}
    train_val_test_idx['train_idx'] = train_idx
    train_val_test_idx['val_idx'] = val_idx
    train_val_test_idx['test_idx'] = test_idx
    return schema,\
           schema_out, \
           data, \
           data_target, \
           labels,\
           train_val_test_idx,\
           dl
Пример #10
0
def load_data_flat(prefix,
                   use_node_attrs=True,
                   use_edge_data=True,
                   node_val='one'):
    '''
    Load data into one matrix with all relations, reproducing Maron 2019
    The first [# relation types] channels are adjacency matrices,
    while the next [sum of feature dimensions per entity type] channels have
    node attributes on the relevant segment of their diagonals if use_node_attrs=True.
    If node features aren't included, then ndoe_val is used instead.
    '''
    dl = data_loader(DATA_FILE_DIR + prefix)
    total_n_nodes = dl.nodes['total']
    entities = [Entity(0, total_n_nodes)]
    relations = {0: Relation(0, [entities[0], entities[0]])}
    schema = DataSchema(entities, relations)

    # Sparse Matrix containing all data
    data_full = sum(dl.links['data'].values()).tocoo()
    data_diag = scipy.sparse.coo_matrix(
        (np.ones(total_n_nodes),
         (np.arange(total_n_nodes), np.arange(total_n_nodes))),
        (total_n_nodes, total_n_nodes))
    data_full += data_diag
    data_full = SparseMatrix.from_scipy_sparse(data_full.tocoo()).zero_()
    data_out = SparseMatrix.from_other_sparse_matrix(data_full, 0)
    # Load up all edge data
    for rel_id in sorted(dl.links['data'].keys()):
        data_matrix = dl.links['data'][rel_id]
        data_rel = SparseMatrix.from_scipy_sparse(data_matrix.tocoo())
        if not use_edge_data:
            # Use only adjacency information
            data_rel.values = torch.ones(data_rel.values.shape)
        data_rel_full = SparseMatrix.from_other_sparse_matrix(data_full,
                                                              1) + data_rel
        data_out.values = torch.cat([data_out.values, data_rel_full.values], 1)
        data_out.n_channels += 1

    if use_node_attrs:
        for ent_id, attr_matrix in dl.nodes['attr'].items():
            start_i = dl.nodes['shift'][ent_id]
            n_instances = dl.nodes['count'][ent_id]
            if attr_matrix is None:
                if node_val == 'zero':
                    attr_matrix = np.zeros((n_instances, 1))
                elif node_val == 'rand':
                    attr_matrix = np.random.randn(n_instances, 1)
                else:
                    attr_matrix = np.ones((n_instances, 1))
            n_channels = attr_matrix.shape[1]
            indices = torch.arange(start_i,
                                   start_i + n_instances).unsqueeze(0).repeat(
                                       2, 1)
            data_rel = SparseMatrix(
                indices=indices,
                values=torch.FloatTensor(attr_matrix),
                shape=np.array([total_n_nodes, total_n_nodes, n_channels]),
                is_set=True)
            data_rel_full = SparseMatrix.from_other_sparse_matrix(
                data_full, n_channels) + data_rel
            data_out.values = torch.cat(
                [data_out.values, data_rel_full.values], 1)
            data_out.n_channels += n_channels

    data = SparseMatrixData(schema)
    data[0] = data_out

    return schema,\
           data, \
           dl
Пример #11
0
def load_data(prefix,
              use_node_attrs=True,
              use_edge_data=True,
              use_other_edges=True,
              node_val='one'):
    dl = data_loader(DATA_FILE_DIR + prefix)

    all_entities = [
        Entity(entity_id, n_instances)
        for entity_id, n_instances in sorted(dl.nodes['count'].items())
    ]

    relations = {}
    test_types = dl.test_types
    if use_other_edges:
        for rel_id, (entity_i, entity_j) in sorted(dl.links['meta'].items()):
            relations[rel_id] = Relation(
                rel_id, [all_entities[entity_i], all_entities[entity_j]])

    else:
        for rel_id in test_types:
            entity_i, entity_j = dl.links['meta'][rel_id]
            relations[rel_id] = Relation(
                rel_id, [all_entities[entity_i], all_entities[entity_j]])

    if use_other_edges:
        entities = all_entities
    else:
        entities = list(np.unique(relations[test_types[0]].entities))

    max_relation = max(relations) + 1
    if use_node_attrs:
        # Create fake relations to represent node attributes
        for entity in entities:
            rel_id = max_relation + entity.id
            relations[rel_id] = Relation(rel_id, [entity, entity], is_set=True)
    schema = DataSchema(entities, relations)

    data = SparseMatrixData(schema)
    for rel_id, data_matrix in dl.links['data'].items():
        if use_other_edges or rel_id in test_types:
            # Get subset belonging to entities in relation
            relation = relations[rel_id]
            start_i = dl.nodes['shift'][relation.entities[0].id]
            end_i = start_i + dl.nodes['count'][relation.entities[0].id]
            start_j = dl.nodes['shift'][relation.entities[1].id]
            end_j = start_j + dl.nodes['count'][relation.entities[1].id]
            rel_matrix = data_matrix[start_i:end_i, start_j:end_j]
            data[rel_id] = SparseMatrix.from_scipy_sparse(rel_matrix.tocoo())
            if not use_edge_data:
                # Use only adjacency information
                data[rel_id].values = torch.ones(data[rel_id].values.shape)

    if use_node_attrs:
        for ent in entities:
            ent_id = ent.id
            attr_matrix = dl.nodes['attr'][ent_id]
            n_instances = dl.nodes['count'][ent_id]
            if attr_matrix is None:
                if node_val == 'zero':
                    attr_matrix = np.zeros((n_instances, 1))
                elif node_val == 'rand':
                    attr_matrix = np.random.randn(n_instances, 1)
                else:
                    attr_matrix = np.ones((n_instances, 1))
            n_channels = attr_matrix.shape[1]
            rel_id = ent_id + max_relation
            indices = torch.arange(n_instances).unsqueeze(0).repeat(2, 1)
            data[rel_id] = SparseMatrix(
                indices=indices,
                values=torch.FloatTensor(attr_matrix),
                shape=np.array([n_instances, n_instances, n_channels]),
                is_set=True)

    return schema,\
           data, \
           dl
Пример #12
0
    content_matrix = SparseMatrix(indices=torch.LongTensor(content),
                                  values=value_dist.sample(
                                      (content.shape[1], 1)),
                                  shape=(n_papers, n_words, 1)).coalesce()

    ent_papers = Entity(0, n_papers)
    #ent_classes = Entity(1, n_classes)
    ent_words = Entity(1, n_words)
    rel_paper = Relation(0, [ent_papers, ent_papers], is_set=True)
    rel_cites = Relation(0, [ent_papers, ent_papers])
    rel_content = Relation(1, [ent_papers, ent_words])
    schema = DataSchema([ent_papers, ent_words], [rel_cites, rel_content])
    schema_out = DataSchema([ent_papers], [rel_paper])
    targets = torch.LongTensor(paper[1])

    data = SparseMatrixData(schema)
    data[0] = cites_matrix
    data[1] = content_matrix

    indices_identity, indices_transpose = data.calculate_indices()

    data_target = Data(schema_out)
    data_target[0] = SparseMatrix(indices=torch.arange(
        n_papers, dtype=torch.int64).repeat(2, 1),
                                  values=torch.zeros([n_papers, n_classes]),
                                  shape=(n_papers, n_papers, n_classes))
    data_target = data_target.to(device)

    #%%

    # Loss function:
    def __init__(self, use_node_attrs=True):
        entities = [
            Entity(entity_id, n_instances)
            for entity_id, n_instances in ENTITY_N_INSTANCES.items()
        ]
        relations = [
            Relation(rel_id, [entities[entity_i], entities[entity_j]])
            for rel_id, (entity_i, entity_j) in RELATION_IDX.items()
        ]
        if use_node_attrs:
            for entity_id in ENTITY_N_INSTANCES.keys():
                rel = Relation(10 + entity_id,
                               [entities[entity_id], entities[entity_id]],
                               is_set=True)
                relations.append(rel)
        self.schema = DataSchema(entities, relations)

        self.node_id_to_idx = {ent_i: {} for ent_i in range(len(entities))}
        with open(NODE_FILE_STR, 'r') as node_file:
            lines = node_file.readlines()
            node_counter = {ent_i: 0 for ent_i in range(len(entities))}
            for line in lines:
                node_id, node_name, node_type, values = line.rstrip().split(
                    '\t')
                node_id = int(node_id)
                node_type = int(node_type)
                node_idx = node_counter[node_type]
                self.node_id_to_idx[node_type][node_id] = node_idx
                node_counter[node_type] += 1
        target_node_id_to_idx = self.node_id_to_idx[TARGET_NODE_TYPE]
        self.target_node_idx_to_id = {
            idx: id
            for id, idx in target_node_id_to_idx.items()
        }

        raw_data_indices = {rel_id: [] for rel_id in range(len(relations))}
        raw_data_values = {rel_id: [] for rel_id in range(len(relations))}
        if use_node_attrs:
            with open(NODE_FILE_STR, 'r') as node_file:
                lines = node_file.readlines()
                for line in lines:
                    node_id, node_name, node_type, values = line.rstrip(
                    ).split('\t')
                    node_type = int(node_type)
                    node_id = self.node_id_to_idx[node_type][int(node_id)]
                    values = list(map(float, values.split(',')))
                    raw_data_indices[10 + node_type].append([node_id, node_id])
                    raw_data_values[10 + node_type].append(values)

        with open(LINK_FILE_STR, 'r') as link_file:
            lines = link_file.readlines()
            for line in lines:
                node_i, node_j, rel_num, val = line.rstrip().split('\t')
                rel_num = int(rel_num)
                node_i_type, node_j_type = RELATION_IDX[rel_num]
                node_i = self.node_id_to_idx[node_i_type][int(node_i)]
                node_j = self.node_id_to_idx[node_j_type][int(node_j)]
                val = float(val)
                raw_data_indices[rel_num].append([node_i, node_j])
                raw_data_values[rel_num].append([val])

        self.data = SparseMatrixData(self.schema)
        for rel in relations:
            indices = torch.LongTensor(raw_data_indices[rel.id]).T
            values = torch.Tensor(raw_data_values[rel.id])
            n = rel.entities[0].n_instances
            m = rel.entities[1].n_instances
            n_channels = values.shape[1]
            data_matrix = SparseMatrix(indices=indices,
                                       values=values,
                                       shape=np.array([n, m, n_channels]),
                                       is_set=rel.is_set)
            del raw_data_indices[rel.id]
            del raw_data_values[rel.id]
            self.data[rel.id] = data_matrix
Пример #14
0
    def __init__(self):
        self.target_relation = 'advisedBy'

        data_raw = {
            rel_name: {key: list()
                       for key in schema_dict[rel_name].keys()}
            for rel_name in schema_dict.keys()
        }

        for relation_name in relation_names:
            with open(csv_file_str.format(relation_name)) as file:
                reader = csv.reader(file)
                keys = schema_dict[relation_name].keys()
                for cols in reader:
                    for key, col in zip(keys, cols):
                        data_raw[relation_name][key].append(col)

        ent_person = Entity(0, len(data_raw['person']['p_id']))
        ent_course = Entity(1, len(data_raw['course']['course_id']))
        entities = [ent_person, ent_course]

        rel_person_matrix = Relation(0, [ent_person, ent_person], is_set=True)
        rel_person = Relation(0, [ent_person])
        rel_course_matrix = Relation(1, [ent_course, ent_course], is_set=True)
        rel_course = Relation(1, [ent_course])
        rel_advisedBy = Relation(2, [ent_person, ent_person])
        rel_taughtBy = Relation(3, [ent_course, ent_person])
        relations_matrix = [
            rel_person_matrix, rel_course_matrix, rel_advisedBy, rel_taughtBy
        ]
        relations = [rel_person, rel_course, rel_taughtBy]

        self.target_rel_id = 2
        self.schema = DataSchema(entities, relations)
        schema_matrix = DataSchema(entities, relations_matrix)
        matrix_data = SparseMatrixData(schema_matrix)

        ent_id_to_idx_dict = {
            'person': self.id_to_idx(data_raw['person']['p_id']),
            'course': self.id_to_idx(data_raw['course']['course_id'])
        }

        for relation in relations_matrix:
            relation_name = relation_names[relation.id]
            print(relation_name)
            if relation.is_set:
                data_matrix = self.set_relation_to_matrix(
                    relation, schema_dict[relation_name],
                    data_raw[relation_name])
            else:
                if relation_name == 'advisedBy':
                    ent_n_id_str = 'p_id'
                    ent_m_id_str = 'p_id_dummy'
                elif relation_name == 'taughtBy':
                    ent_n_id_str = 'course_id'
                    ent_m_id_str = 'p_id'
                data_matrix = self.binary_relation_to_matrix(
                    relation, schema_dict[relation_name],
                    data_raw[relation_name], ent_id_to_idx_dict, ent_n_id_str,
                    ent_m_id_str)
            matrix_data[relation.id] = data_matrix

        rel_out = Relation(2, [ent_person, ent_person])
        self.schema_out = DataSchema([ent_person], [rel_out])

        self.output_dim = 1
        data = Data(self.schema)
        for rel_matrix in schema_matrix.relations:
            for rel in self.schema.relations:
                if rel_matrix.id == rel.id:
                    data_matrix = matrix_data[rel_matrix.id]
                    if rel_matrix.is_set:
                        dense_data = torch.diagonal(data_matrix.to_dense(), 0,
                                                    1, 2).unsqueeze(0)
                    else:
                        dense_data = data_matrix.to_dense().unsqueeze(0)
                    data[rel.id] = dense_data
        self.data = data

        self.target = matrix_data[self.target_rel_id].to_dense().squeeze()
def run_model(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    if args.lgnn:
        load_data_fn = load_data_flat
    else:
        load_data_fn = load_data
    schema, schema_out, data, data_target, labels, \
        train_val_test_idx, dl = load_data_fn(args.dataset,
                       use_edge_data=args.use_edge_data,
                       use_node_attrs=args.use_node_attr,
                       feats_type=args.feats_type)
    target_entity_id = 0  # True for all current NC datasets
    data, in_dims = select_features(data, schema, args.feats_type,
                                    target_entity_id)
    if args.multi_label:
        labels = torch.FloatTensor(labels).to(device)
    else:
        labels = torch.LongTensor(labels).to(device)
    train_idx = train_val_test_idx['train_idx']
    train_idx = np.sort(train_idx)
    val_idx = train_val_test_idx['val_idx']
    val_idx = np.sort(val_idx)
    test_idx = train_val_test_idx['test_idx']
    test_idx = np.sort(test_idx)

    data = data.to(device)

    data_embedding = SparseMatrixData.make_entity_embeddings(
        schema.entities, args.embedding_dim)
    data_embedding.to(device)
    indices_identity, indices_transpose = data.calculate_indices()
    data_target = data_target.to(device)

    num_classes = dl.labels_train['num_classes']

    net = AlternatingHGN(schema,
                         in_dims,
                         width=args.width,
                         depth=args.depth,
                         embedding_dim=args.embedding_dim,
                         activation=eval('nn.%s()' % args.act_fn),
                         final_activation=nn.Identity(),
                         dropout=args.dropout,
                         output_dim=num_classes,
                         norm=args.norm,
                         pool_op=args.pool_op,
                         norm_affine=args.norm_affine,
                         norm_out=args.norm_out)

    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.wandb_log_run:
        wandb.init(config=args,
                   settings=wandb.Settings(start_method='fork'),
                   project="EquivariantHGN_NC",
                   entity='danieltlevy')
        wandb.watch(net, log='all', log_freq=args.wandb_log_param_freq)
    print(args)
    print("Number of parameters: {}".format(count_parameters(net)))

    run_name = args.dataset + '_' + str(args.run)
    if args.wandb_log_run and wandb.run.name is not None:
        run_name = run_name + '_' + str(wandb.run.name)

    if args.checkpoint_path != '':
        checkpoint_path = args.checkpoint_path
    else:
        checkpoint_path = f"checkpoint/checkpoint_{run_name}.pt"

    print("Checkpoint Path: " + checkpoint_path)
    progress = tqdm(range(args.epoch), desc="Epoch 0", position=0, leave=True)
    # training loop
    net.train()
    val_micro_best = 0
    for epoch in progress:
        # training
        net.train()
        optimizer.zero_grad()
        logits = net(data, data_embedding).squeeze()
        logp = regr_fcn(logits, args.multi_label)
        train_loss = loss_fcn(logp[train_idx], labels[train_idx],
                              args.multi_label)
        train_loss.backward()
        optimizer.step()
        if args.multi_label:
            train_micro, train_macro = f1_scores_multi(
                logits[train_idx], dl.labels_train['data'][train_idx])
        else:
            train_micro, train_macro = f1_scores(logits[train_idx],
                                                 labels[train_idx])
        with torch.no_grad():
            progress.set_description(f"Epoch {epoch}")
            progress.set_postfix(loss=train_loss.item(), micr=train_micro)
            wandb_log = {
                'Train Loss': train_loss.item(),
                'Train Micro': train_micro,
                'Train Macro': train_macro
            }
            if epoch % args.val_every == 0:
                # validation
                net.eval()
                logits = net(data, data_embedding).squeeze()
                logp = regr_fcn(logits, args.multi_label)
                val_loss = loss_fcn(logp[val_idx], labels[val_idx],
                                    args.multi_label)
                if args.multi_label:
                    val_micro, val_macro = f1_scores_multi(
                        logits[val_idx], dl.labels_train['data'][val_idx])
                else:
                    val_micro, val_macro = f1_scores(logits[val_idx],
                                                     labels[val_idx])
                print("\nVal Loss: {:.3f} Val Micro-F1: {:.3f} \
Val Macro-F1: {:.3f}".format(val_loss, val_micro, val_macro))
                wandb_log.update({
                    'Val Loss': val_loss.item(),
                    'Val Micro-F1': val_micro,
                    'Val Macro-F1': val_macro
                })
                if val_micro > val_micro_best:

                    val_micro_best = val_micro
                    print("New best, saving")
                    torch.save(
                        {
                            'epoch': epoch,
                            'net_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss.item(),
                            'train_micro': train_micro,
                            'train_macro': train_macro,
                            'val_loss': val_loss.item(),
                            'val_micro': val_micro,
                            'val_macro': val_macro
                        }, checkpoint_path)
                    if args.wandb_log_run:
                        wandb.run.summary["val_micro_best"] = val_micro
                        wandb.run.summary["val_macro_best"] = val_macro
                        wandb.run.summary["val_loss_best"] = val_loss.item()
                        wandb.run.summary["epoch_best"] = epoch
                        wandb.run.summary["train_loss_best"] = train_loss.item(
                        )
                        wandb.run.summary['train_micro_best'] = train_micro
                        wandb.run.summary['train_macro_best'] = train_macro
                        wandb.save(checkpoint_path)

            if epoch % args.wandb_log_loss_freq == 0:
                if args.wandb_log_run:
                    wandb.log(wandb_log, step=epoch)

    # testing with evaluate_results_nc
    if args.evaluate:

        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint['net_state_dict'])
        net.eval()
        test_logits = []
        with torch.no_grad():
            logits = net(data, data_embedding).squeeze()
            test_logits = logits[test_idx]
            if args.multi_label:
                pred = (test_logits.cpu().numpy() > 0).astype(int)
            else:
                pred = test_logits.cpu().numpy().argmax(axis=1)
                onehot = np.eye(num_classes, dtype=np.int32)

            file_path = f"test_out/{run_name}.txt"
            dl.gen_file_for_evaluate(test_idx=test_idx,
                                     label=pred,
                                     file_path=file_path,
                                     multi_label=args.multi_label)
            if not args.multi_label:
                pred = onehot[pred]
            print(dl.evaluate(pred))
Пример #16
0
    def __init__(self):
        data_raw = {
            rel_name: {key: list()
                       for key in schema_dict[rel_name].keys()}
            for rel_name in schema_dict.keys()
        }

        for relation_name in relation_names:
            with open(csv_file_str.format(relation_name)) as file:
                reader = csv.reader(file)
                keys = schema_dict[relation_name].keys()
                for cols in reader:
                    for key, col in zip(keys, cols):
                        data_raw[relation_name][key].append(col)

        ent_person = Entity(0, len(data_raw['person']['p_id']))
        ent_course = Entity(1, len(data_raw['course']['course_id']))
        entities = [ent_person, ent_course]

        rel_person = Relation(0, [ent_person, ent_person], is_set=True)
        rel_course = Relation(1, [ent_course, ent_course], is_set=True)
        rel_advisedBy = Relation(2, [ent_person, ent_person])
        rel_taughtBy = Relation(3, [ent_course, ent_person])
        relations = [rel_person, rel_course, rel_advisedBy, rel_taughtBy]

        self.schema = DataSchema(entities, relations)
        self.data = SparseMatrixData(self.schema)

        ent_id_to_idx_dict = {
            'person': self.id_to_idx(data_raw['person']['p_id']),
            'course': self.id_to_idx(data_raw['course']['course_id'])
        }

        for relation in relations:
            relation_name = relation_names[relation.id]
            print(relation_name)
            if relation.is_set:
                data_matrix = self.set_relation_to_matrix(
                    relation, schema_dict[relation_name],
                    data_raw[relation_name])
            else:
                if relation_name == 'advisedBy':
                    ent_n_id_str = 'p_id'
                    ent_m_id_str = 'p_id_dummy'
                elif relation_name == 'taughtBy':
                    ent_n_id_str = 'course_id'
                    ent_m_id_str = 'p_id'
                data_matrix = self.binary_relation_to_matrix(
                    relation, schema_dict[relation_name],
                    data_raw[relation_name], ent_id_to_idx_dict, ent_n_id_str,
                    ent_m_id_str)
            self.data[relation.id] = data_matrix

        self.target = self.get_targets(
            data_raw[self.TARGET_RELATION][self.TARGET_KEY],
            schema_dict[self.TARGET_RELATION][self.TARGET_KEY])
        self.target_rel_id = 0
        rel_out = Relation(self.target_rel_id, [ent_person, ent_person],
                           is_set=True)
        self.schema_out = DataSchema([ent_person], [rel_out])
        self.data_target = Data(self.schema_out)
        n_output_classes = len(
            np.unique(data_raw[self.TARGET_RELATION][self.TARGET_KEY]))
        self.output_dim = n_output_classes
        n_person = ent_person.n_instances
        self.data_target[self.target_rel_id] = SparseMatrix(
            indices=torch.arange(n_person, dtype=torch.int64).repeat(2, 1),
            values=torch.zeros([n_person, n_output_classes]),
            shape=(n_person, n_person, n_output_classes))
Пример #17
0
def load_data_flat(prefix,
                   use_node_attrs=True,
                   use_edge_data=True,
                   node_val='zero',
                   feats_type=0):
    '''
    Load data into one matrix with all relations, reproducing Maron 2019
    The first [# relation types] channels are adjacency matrices,
    while the next [sum of feature dimensions per entity type] channels have
    node attributes on the relevant segment of their diagonals if use_node_attrs=True.
    If node features aren't included, then ndoe_val is used instead.
    '''
    dl = data_loader(DATA_FILE_DIR + prefix)
    total_n_nodes = dl.nodes['total']
    entities = [Entity(0, total_n_nodes)]
    relations = {0: Relation(0, [entities[0], entities[0]])}
    schema = DataSchema(entities, relations)

    # Sparse Matrix containing all data
    data_full = sum(dl.links['data'].values()).tocoo()
    data_diag = scipy.sparse.coo_matrix(
        (np.ones(total_n_nodes),
         (np.arange(total_n_nodes), np.arange(total_n_nodes))),
        (total_n_nodes, total_n_nodes))
    data_full += data_diag
    data_full = SparseMatrix.from_scipy_sparse(data_full.tocoo()).zero_()
    data_out = SparseMatrix.from_other_sparse_matrix(data_full, 0)
    # Load up all edge data
    for rel_id in sorted(dl.links['data'].keys()):
        data_matrix = dl.links['data'][rel_id]
        data_rel = SparseMatrix.from_scipy_sparse(data_matrix.tocoo())
        if not use_edge_data:
            # Use only adjacency information
            data_rel.values = torch.ones(data_rel.values.shape)
        data_rel_full = SparseMatrix.from_other_sparse_matrix(data_full,
                                                              1) + data_rel
        data_out.values = torch.cat([data_out.values, data_rel_full.values], 1)
        data_out.n_channels += 1

    target_entity = 0

    if use_node_attrs:
        for ent_id, attr_matrix in dl.nodes['attr'].items():
            start_i = dl.nodes['shift'][ent_id]
            n_instances = dl.nodes['count'][ent_id]
            if attr_matrix is None:
                if node_val == 'zero':
                    attr_matrix = np.zeros((n_instances, 1))
                elif node_val == 'rand':
                    attr_matrix = np.random.randn(n_instances, 1)
                else:
                    attr_matrix = np.ones((n_instances, 1))
            if feats_type == 1 and ent_id != target_entity:
                # To keep same behaviour as non-LGNN model, use 10 dimensions
                attr_matrix = np.zeros((n_instances, 10))
            n_channels = attr_matrix.shape[1]
            indices = torch.arange(start_i,
                                   start_i + n_instances).unsqueeze(0).repeat(
                                       2, 1)
            data_rel = SparseMatrix(
                indices=indices,
                values=torch.FloatTensor(attr_matrix),
                shape=np.array([total_n_nodes, total_n_nodes, n_channels]),
                is_set=True)
            data_rel_full = SparseMatrix.from_other_sparse_matrix(
                data_full, n_channels) + data_rel
            data_out.values = torch.cat(
                [data_out.values, data_rel_full.values], 1)
            data_out.n_channels += n_channels

    data = SparseMatrixData(schema)
    data[0] = data_out

    n_outputs = total_n_nodes
    n_output_classes = dl.labels_train['num_classes']
    schema_out = DataSchema([entities[target_entity]], [
        Relation(0, [entities[target_entity], entities[target_entity]],
                 is_set=True)
    ])
    data_target = SparseMatrixData(schema_out)
    data_target[0] = SparseMatrix(
        indices=torch.arange(n_outputs, dtype=torch.int64).repeat(2, 1),
        values=torch.zeros([n_outputs, n_output_classes]),
        shape=(n_outputs, n_outputs, n_output_classes),
        is_set=True)
    labels = np.zeros((dl.nodes['count'][0], dl.labels_train['num_classes']),
                      dtype=int)
    val_ratio = 0.2
    train_idx = np.nonzero(dl.labels_train['mask'])[0]
    np.random.shuffle(train_idx)
    split = int(train_idx.shape[0] * val_ratio)
    val_idx = train_idx[:split]
    train_idx = train_idx[split:]
    train_idx = np.sort(train_idx)
    val_idx = np.sort(val_idx)
    test_idx = np.nonzero(dl.labels_test['mask'])[0]
    labels[train_idx] = dl.labels_train['data'][train_idx]
    labels[val_idx] = dl.labels_train['data'][val_idx]
    if prefix != 'IMDB':
        labels = labels.argmax(axis=1)
    train_val_test_idx = {}
    train_val_test_idx['train_idx'] = train_idx
    train_val_test_idx['val_idx'] = val_idx
    train_val_test_idx['test_idx'] = test_idx

    return schema,\
           schema_out, \
           data, \
           data_target, \
           labels,\
           train_val_test_idx,\
           dl