Ejemplo n.º 1
0
def run_model(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_equiv = args.decoder == 'equiv'

    # Collect data and schema
    schema, data_original, dl = load_data(args.dataset,
                                          use_edge_data=args.use_edge_data,
                                          use_other_edges=args.use_other_edges,
                                          use_node_attrs=args.use_node_attrs,
                                          node_val=args.node_val)
    data, in_dims = select_features(data_original, schema, args.feats_type)
    data = data.to(device)

    # Precompute data indices
    indices_identity, indices_transpose = data.calculate_indices()
    # Get target relations and create data structure for embeddings
    target_rel_ids = dl.links_test['data'].keys()
    target_rels = [schema.relations[rel_id] for rel_id in target_rel_ids]
    target_ents = schema.entities
    # Get relations used by decoder
    if use_equiv:
        output_rels = schema.relations
    else:
        output_rels = {rel.id: rel for rel in target_rels}
    data_embedding = SparseMatrixData.make_entity_embeddings(
        target_ents, args.embedding_dim)
    data_embedding.to(device)

    # Get training and validation positive samples now
    train_pos_heads, train_pos_tails = dict(), dict()
    val_pos_heads, val_pos_tails = dict(), dict()
    for target_rel_id in target_rel_ids:
        train_val_pos = get_train_valid_pos(dl, target_rel_id)
        train_pos_heads[target_rel_id], train_pos_tails[target_rel_id], \
            val_pos_heads[target_rel_id], val_pos_tails[target_rel_id] = train_val_pos

    # Get additional indices to be used when making predictions
    pred_idx_matrices = {}
    for target_rel in target_rels:
        if args.pred_indices == 'train':
            train_neg_head, train_neg_tail = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            pred_idx_matrices[target_rel.id] = make_target_matrix(
                target_rel, train_pos_heads[target_rel.id],
                train_pos_tails[target_rel.id], train_neg_head, train_neg_tail,
                device)
        elif args.pred_indices == 'train_neg':
            # Get negative samples twice
            train_neg_head1, train_neg_tail1 = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            train_neg_head2, train_neg_tail2 = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            pred_idx_matrices[target_rel.id] = make_target_matrix(
                target_rel, train_neg_head1, train_neg_tail1, train_neg_head2,
                train_neg_tail2, device)
        elif args.pred_indices == 'none':
            pred_idx_matrices[target_rel.id] = None

    # Create network and optimizer
    net = EquivLinkPredictor(schema,
                             in_dims,
                             layers=args.layers,
                             embedding_dim=args.embedding_dim,
                             embedding_entities=target_ents,
                             output_rels=output_rels,
                             activation=eval('nn.%s()' % args.act_fn),
                             final_activation=nn.Identity(),
                             dropout=args.dropout,
                             pool_op=args.pool_op,
                             norm_affine=args.norm_affine,
                             norm_embed=args.norm_embed,
                             in_fc_layer=args.in_fc_layer,
                             decode=args.decoder)
    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # Set up logging and checkpointing
    if args.wandb_log_run:
        wandb.init(config=args,
                   settings=wandb.Settings(start_method='fork'),
                   project="EquivariantHGN_LP",
                   entity='danieltlevy')
        wandb.watch(net, log='all', log_freq=args.wandb_log_param_freq)
    print(args)
    print("Number of parameters: {}".format(count_parameters(net)))
    run_name = args.dataset + '_' + str(args.run)
    if args.wandb_log_run and wandb.run.name is not None:
        run_name = run_name + '_' + str(wandb.run.name)
    if args.checkpoint_path != '':
        checkpoint_path = args.checkpoint_path
    else:
        checkpoint_path = f"checkpoint/checkpoint_{run_name}.pt"
    print("Checkpoint Path: " + checkpoint_path)
    val_metric_best = -1e10

    # training
    loss_func = nn.BCELoss()
    progress = tqdm(range(args.epoch), desc="Epoch 0", position=0, leave=True)
    for epoch in progress:
        net.train()
        # Make target matrix and labels to train on
        if use_equiv:
            # Target is same as input
            target_schema = schema
            data_target = data.clone()
        else:
            # Target is just target relation
            target_schema = DataSchema(schema.entities, target_rels)
            data_target = SparseMatrixData(target_schema)
        labels_train = torch.Tensor([]).to(device)
        for target_rel in target_rels:
            train_neg_head, train_neg_tail = get_train_neg(
                dl, target_rel.id, tail_weighted=args.tail_weighted)
            train_matrix = make_target_matrix(target_rel,
                                              train_pos_heads[target_rel.id],
                                              train_pos_tails[target_rel.id],
                                              train_neg_head, train_neg_tail,
                                              device)
            data_target[target_rel.id] = train_matrix
            labels_train_rel = train_matrix.values.squeeze()
            labels_train = torch.cat([labels_train, labels_train_rel])

        # Make prediction
        if use_equiv:
            idx_id_tgt, idx_trans_tgt = data_target.calculate_indices()
            output_data = net(data, indices_identity, indices_transpose,
                              data_embedding, data_target, idx_id_tgt,
                              idx_trans_tgt)
        else:
            output_data = net(data, indices_identity, indices_transpose,
                              data_embedding, data_target)
        logits_combined = torch.Tensor([]).to(device)
        for target_rel in target_rels:
            logits_rel = output_data[target_rel.id].values.squeeze()
            logits_combined = torch.cat([logits_combined, logits_rel])

        logp = torch.sigmoid(logits_combined)
        train_loss = loss_func(logp, labels_train)

        # autograd
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        # Update logging
        progress.set_description(f"Epoch {epoch}")
        progress.set_postfix(loss=train_loss.item())
        wandb_log = {'Train Loss': train_loss.item(), 'epoch': epoch}

        # Evaluate on validation set
        net.eval()
        if epoch % args.val_every == 0:
            with torch.no_grad():
                net.eval()
                left = torch.Tensor([]).to(device)
                right = torch.Tensor([]).to(device)
                labels_val = torch.Tensor([]).to(device)
                valid_masks = {}
                for target_rel in target_rels:
                    if args.val_neg == '2hop':
                        valid_neg_head, valid_neg_tail = get_valid_neg_2hop(
                            dl, target_rel.id)
                    elif args.val_neg == 'randomtw':
                        valid_neg_head, valid_neg_tail = get_valid_neg(
                            dl, target_rel.id, tail_weighted=True)
                    else:
                        valid_neg_head, valid_neg_tail = get_valid_neg(
                            dl, target_rel.id)
                    valid_matrix_full = make_target_matrix(
                        target_rel, val_pos_heads[target_rel.id],
                        val_pos_tails[target_rel.id], valid_neg_head,
                        valid_neg_tail, device)
                    valid_matrix, left_rel, right_rel, labels_val_rel = coalesce_matrix(
                        valid_matrix_full)
                    left = torch.cat([left, left_rel])
                    right = torch.cat([right, right_rel])
                    labels_val = torch.cat([labels_val, labels_val_rel])
                    if use_equiv:
                        # Add in additional prediction indices
                        pred_idx_matrix = pred_idx_matrices[target_rel.id]
                        if pred_idx_matrix is None:
                            valid_combined_matrix = valid_matrix
                            valid_mask = torch.arange(
                                valid_matrix.nnz()).to(device)
                        else:
                            valid_combined_matrix, valid_mask = combine_matrices(
                                valid_matrix, pred_idx_matrix)
                        valid_masks[target_rel.id] = valid_mask
                        data_target[target_rel.id] = valid_combined_matrix
                    else:
                        data_target[target_rel.id] = valid_matrix

                if use_equiv:
                    data_target.zero_()
                    idx_id_val, idx_trans_val = data_target.calculate_indices()
                    output_data = net(data, indices_identity,
                                      indices_transpose, data_embedding,
                                      data_target, idx_id_val, idx_trans_val)
                else:
                    output_data = net(data, indices_identity,
                                      indices_transpose, data_embedding,
                                      data_target)
                logits_combined = torch.Tensor([]).to(device)
                for target_rel in target_rels:
                    logits_rel_full = output_data[
                        target_rel.id].values.squeeze()
                    if use_equiv:
                        logits_rel = logits_rel_full[valid_masks[
                            target_rel.id]]
                    else:
                        logits_rel = logits_rel_full
                    logits_combined = torch.cat([logits_combined, logits_rel])

                logp = torch.sigmoid(logits_combined)
                val_loss = loss_func(logp, labels_val).item()

                wandb_log.update({'val_loss': val_loss})
                left = left.cpu().numpy()
                right = right.cpu().numpy()
                edge_list = np.concatenate(
                    [left.reshape((1, -1)),
                     right.reshape((1, -1))], axis=0)
                res = dl.evaluate(edge_list,
                                  logp.cpu().numpy(),
                                  labels_val.cpu().numpy())
                val_roc_auc = res['roc_auc']
                val_mrr = res['MRR']
                wandb_log.update(res)
                print("\nVal Loss: {:.3f} Val ROC AUC: {:.3f} Val MRR: {:.3f}".
                      format(val_loss, val_roc_auc, val_mrr))
                if args.val_metric == 'loss':
                    val_metric = -val_loss
                elif args.val_metric == 'roc_auc':
                    val_metric = val_roc_auc
                elif args.val_metric == 'mrr':
                    val_metric = val_mrr

                if val_metric > val_metric_best:
                    val_metric_best = val_metric
                    print("New best, saving")
                    torch.save(
                        {
                            'epoch': epoch,
                            'net_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss.item(),
                            'val_loss': val_loss,
                            'val_roc_auc': val_roc_auc,
                            'val_mrr': val_mrr
                        }, checkpoint_path)
                    if args.wandb_log_run:
                        wandb.summary["val_roc_auc_best"] = val_roc_auc
                        wandb.summary["val_mrr_best"] = val_mrr
                        wandb.summary["val_loss_best"] = val_loss
                        wandb.summary["epoch_best"] = epoch
                        wandb.summary["train_loss_best"] = train_loss.item()
                        wandb.save(checkpoint_path)
        if args.wandb_log_run:
            wandb.log(wandb_log)

    # Evaluate on test set
    if args.evaluate:
        for target_rel in target_rels:
            print("Evaluating Target Rel " + str(target_rel.id))
            checkpoint = torch.load(checkpoint_path, map_location=device)
            net.load_state_dict(checkpoint['net_state_dict'])
            net.eval()

            # Target is same as input
            data_target = data.clone()
            with torch.no_grad():
                left_full, right_full, test_labels_full = get_test_neigh_from_file(
                    dl, args.dataset, target_rel.id)
                test_matrix_full = make_target_matrix_test(
                    target_rel, left_full, right_full, test_labels_full,
                    device)
                test_matrix, left, right, test_labels = coalesce_matrix(
                    test_matrix_full)
                if use_equiv:
                    test_combined_matrix, test_mask = combine_matrices(
                        test_matrix, train_matrix)
                    data_target[target_rel.id] = test_combined_matrix
                    data_target.zero_()
                    idx_id_tst, idx_trans_tst = data_target.calculate_indices()
                    data_out = net(data, indices_identity, indices_transpose,
                                   data_embedding, data_target, idx_id_tst,
                                   idx_trans_tst)
                    logits_full = data_out[target_rel.id].values.squeeze()
                    logits = logits_full[test_mask]
                else:
                    data_target[target_rel.id] = test_matrix
                    data_out = net(data, indices_identity, indices_transpose,
                                   data_embedding, data_target)
                    logits_full = data_out[target_rel.id].values.squeeze()
                    logits = logits_full

                pred = torch.sigmoid(logits).cpu().numpy()
                left = left.cpu().numpy()
                right = right.cpu().numpy()
                edge_list = np.vstack((left, right))
                edge_list_full = np.vstack((left_full, right_full))
                file_path = f"test_out/{run_name}.txt"
                gen_file_for_evaluate(dl,
                                      edge_list_full,
                                      edge_list,
                                      pred,
                                      target_rel.id,
                                      file_path=file_path)
def run_model(args):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    if args.lgnn:
        load_data_fn = load_data_flat
    else:
        load_data_fn = load_data
    schema, schema_out, data, data_target, labels, \
        train_val_test_idx, dl = load_data_fn(args.dataset,
                       use_edge_data=args.use_edge_data,
                       use_node_attrs=args.use_node_attr,
                       feats_type=args.feats_type)
    target_entity_id = 0  # True for all current NC datasets
    data, in_dims = select_features(data, schema, args.feats_type,
                                    target_entity_id)
    if args.multi_label:
        labels = torch.FloatTensor(labels).to(device)
    else:
        labels = torch.LongTensor(labels).to(device)
    train_idx = train_val_test_idx['train_idx']
    train_idx = np.sort(train_idx)
    val_idx = train_val_test_idx['val_idx']
    val_idx = np.sort(val_idx)
    test_idx = train_val_test_idx['test_idx']
    test_idx = np.sort(test_idx)

    data = data.to(device)

    data_embedding = SparseMatrixData.make_entity_embeddings(
        schema.entities, args.embedding_dim)
    data_embedding.to(device)
    indices_identity, indices_transpose = data.calculate_indices()
    data_target = data_target.to(device)

    num_classes = dl.labels_train['num_classes']

    net = AlternatingHGN(schema,
                         in_dims,
                         width=args.width,
                         depth=args.depth,
                         embedding_dim=args.embedding_dim,
                         activation=eval('nn.%s()' % args.act_fn),
                         final_activation=nn.Identity(),
                         dropout=args.dropout,
                         output_dim=num_classes,
                         norm=args.norm,
                         pool_op=args.pool_op,
                         norm_affine=args.norm_affine,
                         norm_out=args.norm_out)

    net.to(device)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    if args.wandb_log_run:
        wandb.init(config=args,
                   settings=wandb.Settings(start_method='fork'),
                   project="EquivariantHGN_NC",
                   entity='danieltlevy')
        wandb.watch(net, log='all', log_freq=args.wandb_log_param_freq)
    print(args)
    print("Number of parameters: {}".format(count_parameters(net)))

    run_name = args.dataset + '_' + str(args.run)
    if args.wandb_log_run and wandb.run.name is not None:
        run_name = run_name + '_' + str(wandb.run.name)

    if args.checkpoint_path != '':
        checkpoint_path = args.checkpoint_path
    else:
        checkpoint_path = f"checkpoint/checkpoint_{run_name}.pt"

    print("Checkpoint Path: " + checkpoint_path)
    progress = tqdm(range(args.epoch), desc="Epoch 0", position=0, leave=True)
    # training loop
    net.train()
    val_micro_best = 0
    for epoch in progress:
        # training
        net.train()
        optimizer.zero_grad()
        logits = net(data, data_embedding).squeeze()
        logp = regr_fcn(logits, args.multi_label)
        train_loss = loss_fcn(logp[train_idx], labels[train_idx],
                              args.multi_label)
        train_loss.backward()
        optimizer.step()
        if args.multi_label:
            train_micro, train_macro = f1_scores_multi(
                logits[train_idx], dl.labels_train['data'][train_idx])
        else:
            train_micro, train_macro = f1_scores(logits[train_idx],
                                                 labels[train_idx])
        with torch.no_grad():
            progress.set_description(f"Epoch {epoch}")
            progress.set_postfix(loss=train_loss.item(), micr=train_micro)
            wandb_log = {
                'Train Loss': train_loss.item(),
                'Train Micro': train_micro,
                'Train Macro': train_macro
            }
            if epoch % args.val_every == 0:
                # validation
                net.eval()
                logits = net(data, data_embedding).squeeze()
                logp = regr_fcn(logits, args.multi_label)
                val_loss = loss_fcn(logp[val_idx], labels[val_idx],
                                    args.multi_label)
                if args.multi_label:
                    val_micro, val_macro = f1_scores_multi(
                        logits[val_idx], dl.labels_train['data'][val_idx])
                else:
                    val_micro, val_macro = f1_scores(logits[val_idx],
                                                     labels[val_idx])
                print("\nVal Loss: {:.3f} Val Micro-F1: {:.3f} \
Val Macro-F1: {:.3f}".format(val_loss, val_micro, val_macro))
                wandb_log.update({
                    'Val Loss': val_loss.item(),
                    'Val Micro-F1': val_micro,
                    'Val Macro-F1': val_macro
                })
                if val_micro > val_micro_best:

                    val_micro_best = val_micro
                    print("New best, saving")
                    torch.save(
                        {
                            'epoch': epoch,
                            'net_state_dict': net.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': train_loss.item(),
                            'train_micro': train_micro,
                            'train_macro': train_macro,
                            'val_loss': val_loss.item(),
                            'val_micro': val_micro,
                            'val_macro': val_macro
                        }, checkpoint_path)
                    if args.wandb_log_run:
                        wandb.run.summary["val_micro_best"] = val_micro
                        wandb.run.summary["val_macro_best"] = val_macro
                        wandb.run.summary["val_loss_best"] = val_loss.item()
                        wandb.run.summary["epoch_best"] = epoch
                        wandb.run.summary["train_loss_best"] = train_loss.item(
                        )
                        wandb.run.summary['train_micro_best'] = train_micro
                        wandb.run.summary['train_macro_best'] = train_macro
                        wandb.save(checkpoint_path)

            if epoch % args.wandb_log_loss_freq == 0:
                if args.wandb_log_run:
                    wandb.log(wandb_log, step=epoch)

    # testing with evaluate_results_nc
    if args.evaluate:

        checkpoint = torch.load(checkpoint_path)
        net.load_state_dict(checkpoint['net_state_dict'])
        net.eval()
        test_logits = []
        with torch.no_grad():
            logits = net(data, data_embedding).squeeze()
            test_logits = logits[test_idx]
            if args.multi_label:
                pred = (test_logits.cpu().numpy() > 0).astype(int)
            else:
                pred = test_logits.cpu().numpy().argmax(axis=1)
                onehot = np.eye(num_classes, dtype=np.int32)

            file_path = f"test_out/{run_name}.txt"
            dl.gen_file_for_evaluate(test_idx=test_idx,
                                     label=pred,
                                     file_path=file_path,
                                     multi_label=args.multi_label)
            if not args.multi_label:
                pred = onehot[pred]
            print(dl.evaluate(pred))