def __init__(self, entity_counts, sparsity=0.5, embedding_dims=2, tucker=False, batch_dim=True): self.n_student = entity_counts[0] self.n_course = entity_counts[1] self.n_professor = entity_counts[2] self.sparsity = sparsity self.embedding_dims = embedding_dims self.tucker = tucker # Wether to include a batch dimension self.batch_dim = batch_dim ent_students = Entity(0, self.n_student) ent_courses = Entity(1, self.n_course) ent_professors = Entity(2, self.n_professor) entities = [ent_students, ent_courses, ent_professors] relations = [] relations.append(Relation(0, [ent_students, ent_courses], 1)) relations.append(Relation(1, [ent_students, ent_professors], 1)) relations.append(Relation(2, [ent_professors, ent_courses], 1)) relations.append(Relation(3, [ent_courses, ent_courses], 1)) relations.append(Relation(4, [ent_courses, ent_courses], 1)) self.schema = DataSchema(entities, relations) self.embedding_dims = embedding_dims np.random.seed(0) self.embeddings = self.make_embeddings(self.embedding_dims) self.data = self.make_data(self.tucker) self.observed = self.make_observed(self.sparsity)
def get_node_classification_data(self): entities = self.schema.entities self.schema_out = DataSchema([entities[TARGET_NODE_TYPE]], [ Relation(0, [entities[TARGET_NODE_TYPE], entities[TARGET_NODE_TYPE]], is_set=True) ]) target_indices = [] targets = [] with open(LABEL_FILE_STR, 'r') as label_file: lines = label_file.readlines() for line in lines: node_id, node_name, node_type, node_label = line.rstrip( ).split('\t') node_type = int(node_type) node_id = self.node_id_to_idx[node_type][int(node_id)] node_label = int(node_label) target_indices.append(node_id) targets.append(node_label) self.target_indices = torch.LongTensor(target_indices) self.targets = torch.LongTensor(targets) self.n_outputs = self.schema.entities[TARGET_NODE_TYPE].n_instances self.data_target = SparseMatrixData(self.schema_out)
def __init__(self, n_student, n_course, n_professor): self.n_student = n_student self.n_course = n_course self.n_professor = n_professor ent_students = Entity(0, self.n_student) ent_courses = Entity(1, self.n_course) ent_professors = Entity(2, self.n_professor) entities = [ent_students, ent_courses, ent_professors] #TODO: Fix student self-relation to have two channels relations = [] #Takes relations.append(Relation(0, [ent_students, ent_courses], 1)) #Reference relations.append(Relation(1, [ent_students, ent_professors], 1)) #Teaches relations.append(Relation(2, [ent_professors, ent_courses], 1)) #Prereq relations.append(Relation(3, [ent_courses, ent_courses], 1)) #Student relations.append(Relation(4, [ent_students], 1)) #Course relations.append(Relation(5, [ent_courses], 1)) #Professor relations.append(Relation(6, [ent_professors], 1)) # pick n dimensions # Draw from n-dimensional normal dist to get encodings for each entity self.schema = DataSchema(entities, relations) self.embeddings = None
def __init__(self, schema, input_channels=1, activation=F.relu, layers=[64, 64, 64], embedding_dim=50, dropout=0, pool_op='mean', norm_affine=False, norm_embed=False, final_activation=nn.Identity(), embedding_entities=None, output_rels=None, in_fc_layer=True, decode='dot', out_dim=1): super(EquivLinkPredictor, self).__init__() self.output_rels = output_rels if output_rels == None: self.schema_out = schema else: self.schema_out = DataSchema(schema.entities, output_rels) self.out_dim = out_dim self.encoder = EquivEncoder(schema, input_channels, activation, layers, embedding_dim, dropout, pool_op, norm_affine, embedding_entities, in_fc_layer) self.norm_embed = norm_embed self.decode = decode if self.decode == 'dot': self.decoder = Dot() elif self.decode == 'distmult': self.decoder = DistMult(len(schema.relations), embedding_dim) elif decode == 'equiv': self.decoder = EquivDecoder(self.schema_out, activation, layers, embedding_dim, dropout, pool_op, norm_affine, embedding_entities, out_fc_layer=in_fc_layer, out_dim=self.out_dim) elif self.decode == 'broadcast': self.decoder = SparseMatrixEntityBroadcastingLayer( self.schema_out, embedding_dim, input_channels, entities=embedding_entities, pool_op=pool_op) self.final_activation = Activation(self.schema_out, final_activation, is_sparse=True)
def __init__(self, schema, dims): super(EntityPooling, self).__init__() self.schema = schema self.dims = dims self.out_shape = [e.n_instances for e in self.schema.entities] # Make a "schema" for the encodings enc_relations = { i: Relation(i, [self.schema.entities[i]]) for i in range(len(self.schema.entities)) } self.enc_schema = DataSchema(self.schema.entities, enc_relations)
def __init__(self, schema, input_channels=1, activation=F.relu, layers=[64, 64, 64], embedding_dim=50, dropout=0, norm=True, pool_op='mean', norm_affine=False, final_activation=nn.Identity(), embedding_entities = None, output_relations = None): super(SparseMatrixAutoEncoder, self).__init__() self.schema = schema if output_relations == None: self.schema_out = schema else: self.schema_out = DataSchema(schema.entities, output_relations) self.input_channels = input_channels self.activation = activation self.rel_activation = Activation(schema, self.activation, is_sparse=True) self.dropout = Dropout(p=dropout) self.rel_dropout = Activation(schema, self.dropout, is_sparse=True) self.n_equiv_layers = len(layers) self.equiv_layers = nn.ModuleList([]) self.equiv_layers.append(SparseMatrixEquivariantLayer( schema, input_channels, layers[0], pool_op=pool_op)) self.equiv_layers.extend([ SparseMatrixEquivariantLayer(schema, layers[i-1], layers[i], pool_op=pool_op) for i in range(1, len(layers))]) if norm: self.norms = nn.ModuleList() for channels in layers: norm_dict = nn.ModuleDict() for rel_id in self.schema.relations: norm_dict[str(rel_id)] = nn.BatchNorm1d(channels, affine=norm_affine, track_running_stats=False) norm_activation = Activation(schema, norm_dict, is_dict=True, is_sparse=True) self.norms.append(norm_activation) else: self.norms = nn.ModuleList([Activation(schema, nn.Identity(), is_sparse=True) for _ in layers]) # Entity embeddings self.pooling = SparseMatrixEntityPoolingLayer(schema, layers[-1], embedding_dim, entities=embedding_entities, pool_op=pool_op) self.broadcasting = SparseMatrixEntityBroadcastingLayer(self.schema_out, embedding_dim, input_channels, entities=embedding_entities, pool_op=pool_op) self.final_activation = Activation(schema, final_activation, is_sparse=True)
def __init__(self, schema, input_dim=1, output_dim=1, entities=None, pool_op='mean'): ''' input_dim: either a rel_id: dimension dict, or an integer for all relations output_dim: either a rel_id: dimension dict, or an integer for all relations ''' if entities == None: entities = schema.entities enc_relations = { entity.id: Relation(entity.id, [entity, entity], is_set=True) for entity in entities } encodings_schema = DataSchema(entities, enc_relations) super().__init__(schema, input_dim, output_dim, schema_out=encodings_schema, pool_op=pool_op)
def __init__(self, entity_counts, sparsity=0.5, n_channels=1): self.n_student = entity_counts[0] self.n_course = entity_counts[1] self.n_professor = entity_counts[2] # Upper estimate of sparsity self.sparsity = sparsity ent_students = Entity(0, self.n_student) ent_courses = Entity(1, self.n_course) ent_professors = Entity(2, self.n_professor) entities = [ent_students, ent_courses, ent_professors] relations = [] relations.append(Relation(0, [ent_students, ent_courses], 1)) relations.append(Relation(1, [ent_students, ent_professors], 1)) relations.append(Relation(2, [ent_professors, ent_courses], 1)) relations.append( Relation(3, [ent_students, ent_professors, ent_courses], 1)) relations.append(Relation(4, [ent_courses, ent_courses], 1)) relations.append(Relation(5, [ent_students], 1)) #relations.append(Relation(6, [ent_students, ent_students, ent_students, ent_courses], 1)) self.schema = DataSchema(entities, relations) self.observed = self.make_observed(self.sparsity, n_channels)
def __init__(self, schema, input_dim=1, output_dim=1, entities=None, pool_op='mean'): ''' schema: schema to broadcast to input_dim: either a rel_id: dimension dict, or an integer for all relations output_dim: either a rel_id: dimension dict, or an integer for all relations entities: if specified, these are the input entities for the encodings ''' if entities == None: entities = schema.entities enc_relations = { entity.id: Relation(entity.id, [entity, entity], is_set=True) for entity in entities } encodings_schema = DataSchema(entities, enc_relations) super().__init__(encodings_schema, input_dim, output_dim, schema_out=schema, pool_op=pool_op)
def __init__(self): self.target_relation = 'advisedBy' data_raw = { rel_name: {key: list() for key in schema_dict[rel_name].keys()} for rel_name in schema_dict.keys() } for relation_name in relation_names: with open(csv_file_str.format(relation_name)) as file: reader = csv.reader(file) keys = schema_dict[relation_name].keys() for cols in reader: for key, col in zip(keys, cols): data_raw[relation_name][key].append(col) ent_person = Entity(0, len(data_raw['person']['p_id'])) ent_course = Entity(1, len(data_raw['course']['course_id'])) entities = [ent_person, ent_course] rel_person_matrix = Relation(0, [ent_person, ent_person], is_set=True) rel_person = Relation(0, [ent_person]) rel_course_matrix = Relation(1, [ent_course, ent_course], is_set=True) rel_course = Relation(1, [ent_course]) rel_advisedBy = Relation(2, [ent_person, ent_person]) rel_taughtBy = Relation(3, [ent_course, ent_person]) relations_matrix = [ rel_person_matrix, rel_course_matrix, rel_advisedBy, rel_taughtBy ] relations = [rel_person, rel_course, rel_taughtBy] self.target_rel_id = 2 self.schema = DataSchema(entities, relations) schema_matrix = DataSchema(entities, relations_matrix) matrix_data = SparseMatrixData(schema_matrix) ent_id_to_idx_dict = { 'person': self.id_to_idx(data_raw['person']['p_id']), 'course': self.id_to_idx(data_raw['course']['course_id']) } for relation in relations_matrix: relation_name = relation_names[relation.id] print(relation_name) if relation.is_set: data_matrix = self.set_relation_to_matrix( relation, schema_dict[relation_name], data_raw[relation_name]) else: if relation_name == 'advisedBy': ent_n_id_str = 'p_id' ent_m_id_str = 'p_id_dummy' elif relation_name == 'taughtBy': ent_n_id_str = 'course_id' ent_m_id_str = 'p_id' data_matrix = self.binary_relation_to_matrix( relation, schema_dict[relation_name], data_raw[relation_name], ent_id_to_idx_dict, ent_n_id_str, ent_m_id_str) matrix_data[relation.id] = data_matrix rel_out = Relation(2, [ent_person, ent_person]) self.schema_out = DataSchema([ent_person], [rel_out]) self.output_dim = 1 data = Data(self.schema) for rel_matrix in schema_matrix.relations: for rel in self.schema.relations: if rel_matrix.id == rel.id: data_matrix = matrix_data[rel_matrix.id] if rel_matrix.is_set: dense_data = torch.diagonal(data_matrix.to_dense(), 0, 1, 2).unsqueeze(0) else: dense_data = data_matrix.to_dense().unsqueeze(0) data[rel.id] = dense_data self.data = data self.target = matrix_data[self.target_rel_id].to_dense().squeeze()
def run_model(args): device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') use_equiv = args.decoder == 'equiv' # Collect data and schema schema, data_original, dl = load_data(args.dataset, use_edge_data=args.use_edge_data, use_other_edges=args.use_other_edges, use_node_attrs=args.use_node_attrs, node_val=args.node_val) data, in_dims = select_features(data_original, schema, args.feats_type) data = data.to(device) # Precompute data indices indices_identity, indices_transpose = data.calculate_indices() # Get target relations and create data structure for embeddings target_rel_ids = dl.links_test['data'].keys() target_rels = [schema.relations[rel_id] for rel_id in target_rel_ids] target_ents = schema.entities # Get relations used by decoder if use_equiv: output_rels = schema.relations else: output_rels = {rel.id: rel for rel in target_rels} data_embedding = SparseMatrixData.make_entity_embeddings( target_ents, args.embedding_dim) data_embedding.to(device) # Get training and validation positive samples now train_pos_heads, train_pos_tails = dict(), dict() val_pos_heads, val_pos_tails = dict(), dict() for target_rel_id in target_rel_ids: train_val_pos = get_train_valid_pos(dl, target_rel_id) train_pos_heads[target_rel_id], train_pos_tails[target_rel_id], \ val_pos_heads[target_rel_id], val_pos_tails[target_rel_id] = train_val_pos # Get additional indices to be used when making predictions pred_idx_matrices = {} for target_rel in target_rels: if args.pred_indices == 'train': train_neg_head, train_neg_tail = get_train_neg( dl, target_rel.id, tail_weighted=args.tail_weighted) pred_idx_matrices[target_rel.id] = make_target_matrix( target_rel, train_pos_heads[target_rel.id], train_pos_tails[target_rel.id], train_neg_head, train_neg_tail, device) elif args.pred_indices == 'train_neg': # Get negative samples twice train_neg_head1, train_neg_tail1 = get_train_neg( dl, target_rel.id, tail_weighted=args.tail_weighted) train_neg_head2, train_neg_tail2 = get_train_neg( dl, target_rel.id, tail_weighted=args.tail_weighted) pred_idx_matrices[target_rel.id] = make_target_matrix( target_rel, train_neg_head1, train_neg_tail1, train_neg_head2, train_neg_tail2, device) elif args.pred_indices == 'none': pred_idx_matrices[target_rel.id] = None # Create network and optimizer net = EquivLinkPredictor(schema, in_dims, layers=args.layers, embedding_dim=args.embedding_dim, embedding_entities=target_ents, output_rels=output_rels, activation=eval('nn.%s()' % args.act_fn), final_activation=nn.Identity(), dropout=args.dropout, pool_op=args.pool_op, norm_affine=args.norm_affine, norm_embed=args.norm_embed, in_fc_layer=args.in_fc_layer, decode=args.decoder) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) # Set up logging and checkpointing if args.wandb_log_run: wandb.init(config=args, settings=wandb.Settings(start_method='fork'), project="EquivariantHGN_LP", entity='danieltlevy') wandb.watch(net, log='all', log_freq=args.wandb_log_param_freq) print(args) print("Number of parameters: {}".format(count_parameters(net))) run_name = args.dataset + '_' + str(args.run) if args.wandb_log_run and wandb.run.name is not None: run_name = run_name + '_' + str(wandb.run.name) if args.checkpoint_path != '': checkpoint_path = args.checkpoint_path else: checkpoint_path = f"checkpoint/checkpoint_{run_name}.pt" print("Checkpoint Path: " + checkpoint_path) val_metric_best = -1e10 # training loss_func = nn.BCELoss() progress = tqdm(range(args.epoch), desc="Epoch 0", position=0, leave=True) for epoch in progress: net.train() # Make target matrix and labels to train on if use_equiv: # Target is same as input target_schema = schema data_target = data.clone() else: # Target is just target relation target_schema = DataSchema(schema.entities, target_rels) data_target = SparseMatrixData(target_schema) labels_train = torch.Tensor([]).to(device) for target_rel in target_rels: train_neg_head, train_neg_tail = get_train_neg( dl, target_rel.id, tail_weighted=args.tail_weighted) train_matrix = make_target_matrix(target_rel, train_pos_heads[target_rel.id], train_pos_tails[target_rel.id], train_neg_head, train_neg_tail, device) data_target[target_rel.id] = train_matrix labels_train_rel = train_matrix.values.squeeze() labels_train = torch.cat([labels_train, labels_train_rel]) # Make prediction if use_equiv: idx_id_tgt, idx_trans_tgt = data_target.calculate_indices() output_data = net(data, indices_identity, indices_transpose, data_embedding, data_target, idx_id_tgt, idx_trans_tgt) else: output_data = net(data, indices_identity, indices_transpose, data_embedding, data_target) logits_combined = torch.Tensor([]).to(device) for target_rel in target_rels: logits_rel = output_data[target_rel.id].values.squeeze() logits_combined = torch.cat([logits_combined, logits_rel]) logp = torch.sigmoid(logits_combined) train_loss = loss_func(logp, labels_train) # autograd optimizer.zero_grad() train_loss.backward() optimizer.step() # Update logging progress.set_description(f"Epoch {epoch}") progress.set_postfix(loss=train_loss.item()) wandb_log = {'Train Loss': train_loss.item(), 'epoch': epoch} # Evaluate on validation set net.eval() if epoch % args.val_every == 0: with torch.no_grad(): net.eval() left = torch.Tensor([]).to(device) right = torch.Tensor([]).to(device) labels_val = torch.Tensor([]).to(device) valid_masks = {} for target_rel in target_rels: if args.val_neg == '2hop': valid_neg_head, valid_neg_tail = get_valid_neg_2hop( dl, target_rel.id) elif args.val_neg == 'randomtw': valid_neg_head, valid_neg_tail = get_valid_neg( dl, target_rel.id, tail_weighted=True) else: valid_neg_head, valid_neg_tail = get_valid_neg( dl, target_rel.id) valid_matrix_full = make_target_matrix( target_rel, val_pos_heads[target_rel.id], val_pos_tails[target_rel.id], valid_neg_head, valid_neg_tail, device) valid_matrix, left_rel, right_rel, labels_val_rel = coalesce_matrix( valid_matrix_full) left = torch.cat([left, left_rel]) right = torch.cat([right, right_rel]) labels_val = torch.cat([labels_val, labels_val_rel]) if use_equiv: # Add in additional prediction indices pred_idx_matrix = pred_idx_matrices[target_rel.id] if pred_idx_matrix is None: valid_combined_matrix = valid_matrix valid_mask = torch.arange( valid_matrix.nnz()).to(device) else: valid_combined_matrix, valid_mask = combine_matrices( valid_matrix, pred_idx_matrix) valid_masks[target_rel.id] = valid_mask data_target[target_rel.id] = valid_combined_matrix else: data_target[target_rel.id] = valid_matrix if use_equiv: data_target.zero_() idx_id_val, idx_trans_val = data_target.calculate_indices() output_data = net(data, indices_identity, indices_transpose, data_embedding, data_target, idx_id_val, idx_trans_val) else: output_data = net(data, indices_identity, indices_transpose, data_embedding, data_target) logits_combined = torch.Tensor([]).to(device) for target_rel in target_rels: logits_rel_full = output_data[ target_rel.id].values.squeeze() if use_equiv: logits_rel = logits_rel_full[valid_masks[ target_rel.id]] else: logits_rel = logits_rel_full logits_combined = torch.cat([logits_combined, logits_rel]) logp = torch.sigmoid(logits_combined) val_loss = loss_func(logp, labels_val).item() wandb_log.update({'val_loss': val_loss}) left = left.cpu().numpy() right = right.cpu().numpy() edge_list = np.concatenate( [left.reshape((1, -1)), right.reshape((1, -1))], axis=0) res = dl.evaluate(edge_list, logp.cpu().numpy(), labels_val.cpu().numpy()) val_roc_auc = res['roc_auc'] val_mrr = res['MRR'] wandb_log.update(res) print("\nVal Loss: {:.3f} Val ROC AUC: {:.3f} Val MRR: {:.3f}". format(val_loss, val_roc_auc, val_mrr)) if args.val_metric == 'loss': val_metric = -val_loss elif args.val_metric == 'roc_auc': val_metric = val_roc_auc elif args.val_metric == 'mrr': val_metric = val_mrr if val_metric > val_metric_best: val_metric_best = val_metric print("New best, saving") torch.save( { 'epoch': epoch, 'net_state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss.item(), 'val_loss': val_loss, 'val_roc_auc': val_roc_auc, 'val_mrr': val_mrr }, checkpoint_path) if args.wandb_log_run: wandb.summary["val_roc_auc_best"] = val_roc_auc wandb.summary["val_mrr_best"] = val_mrr wandb.summary["val_loss_best"] = val_loss wandb.summary["epoch_best"] = epoch wandb.summary["train_loss_best"] = train_loss.item() wandb.save(checkpoint_path) if args.wandb_log_run: wandb.log(wandb_log) # Evaluate on test set if args.evaluate: for target_rel in target_rels: print("Evaluating Target Rel " + str(target_rel.id)) checkpoint = torch.load(checkpoint_path, map_location=device) net.load_state_dict(checkpoint['net_state_dict']) net.eval() # Target is same as input data_target = data.clone() with torch.no_grad(): left_full, right_full, test_labels_full = get_test_neigh_from_file( dl, args.dataset, target_rel.id) test_matrix_full = make_target_matrix_test( target_rel, left_full, right_full, test_labels_full, device) test_matrix, left, right, test_labels = coalesce_matrix( test_matrix_full) if use_equiv: test_combined_matrix, test_mask = combine_matrices( test_matrix, train_matrix) data_target[target_rel.id] = test_combined_matrix data_target.zero_() idx_id_tst, idx_trans_tst = data_target.calculate_indices() data_out = net(data, indices_identity, indices_transpose, data_embedding, data_target, idx_id_tst, idx_trans_tst) logits_full = data_out[target_rel.id].values.squeeze() logits = logits_full[test_mask] else: data_target[target_rel.id] = test_matrix data_out = net(data, indices_identity, indices_transpose, data_embedding, data_target) logits_full = data_out[target_rel.id].values.squeeze() logits = logits_full pred = torch.sigmoid(logits).cpu().numpy() left = left.cpu().numpy() right = right.cpu().numpy() edge_list = np.vstack((left, right)) edge_list_full = np.vstack((left_full, right_full)) file_path = f"test_out/{run_name}.txt" gen_file_for_evaluate(dl, edge_list_full, edge_list, pred, target_rel.id, file_path=file_path)
argv = sys.argv[1:] args = get_hyperparams(argv) print(args) set_seed(args.seed) dataloader = PubMedData(args.node_labels) schema = dataloader.schema data = dataloader.data.to(device) indices_identity, indices_transpose = data.calculate_indices() embedding_entity = schema.entities[TARGET_NODE_TYPE] input_channels = { rel.id: data[rel.id].n_channels for rel in schema.relations } embedding_schema = DataSchema( schema.entities, Relation(0, [embedding_entity, embedding_entity], is_set=True)) n_instances = embedding_entity.n_instances data_embedding = SparseMatrixData(embedding_schema) data_embedding[0] = SparseMatrix( indices=torch.arange(n_instances, dtype=torch.int64).repeat(2, 1), values=torch.zeros([n_instances, args.embedding_dim]), shape=(n_instances, n_instances, args.embedding_dim), is_set=True) data_embedding.to(device) target_schema = DataSchema(schema.entities, schema.relations[TARGET_REL_ID]) target_node_idx_to_id = dataloader.target_node_idx_to_id #%% net = SparseMatrixAutoEncoder(schema, input_channels,
return torch.Tensor(features.todense()) ent_movie = Entity(0, raw_data['movie_feature'].shape[0]) ent_actor = Entity(1, raw_data['movie_actor'].shape[1]) ent_director = Entity(2, raw_data['movie_director'].shape[1]) ent_keyword = Entity(3, raw_data['movie_keyword'].shape[1]) entities = [ent_movie, ent_actor, ent_director, ent_keyword] relations = [] rel_movie_actor = Relation(0, [ent_movie, ent_actor]) rel_movie_director = Relation(1, [ent_movie, ent_director]) rel_movie_keyword = Relation(2, [ent_movie, ent_keyword]) rel_movie_feature = Relation(3, [ent_movie, ent_movie], is_set=True) relations = [rel_movie_actor, rel_movie_director, rel_movie_keyword, rel_movie_feature] schema = DataSchema(entities, relations) schema_out = DataSchema([ent_movie], [Relation(0, [ent_movie, ent_movie], is_set=True)]) data = SparseMatrixData(schema) for rel_i, rel_name in enumerate(relation_names): if rel_name == 'movie_feature': values = preprocess_features(raw_data[rel_name]) data[rel_i] = SparseMatrix.from_embed_diag(values) else: data[rel_i] = SparseMatrix.from_scipy_sparse(raw_data[rel_name]) data = data.to(device) indices_identity, indices_transpose = data.calculate_indices() input_channels = {rel.id: data[rel.id].n_channels for rel in relations} data_target = Data(schema_out) n_movies = ent_movie.n_instances labels = []
def load_data_flat(prefix, use_node_attrs=True, use_edge_data=True, node_val='one'): ''' Load data into one matrix with all relations, reproducing Maron 2019 The first [# relation types] channels are adjacency matrices, while the next [sum of feature dimensions per entity type] channels have node attributes on the relevant segment of their diagonals if use_node_attrs=True. If node features aren't included, then ndoe_val is used instead. ''' dl = data_loader(DATA_FILE_DIR + prefix) total_n_nodes = dl.nodes['total'] entities = [Entity(0, total_n_nodes)] relations = {0: Relation(0, [entities[0], entities[0]])} schema = DataSchema(entities, relations) # Sparse Matrix containing all data data_full = sum(dl.links['data'].values()).tocoo() data_diag = scipy.sparse.coo_matrix( (np.ones(total_n_nodes), (np.arange(total_n_nodes), np.arange(total_n_nodes))), (total_n_nodes, total_n_nodes)) data_full += data_diag data_full = SparseMatrix.from_scipy_sparse(data_full.tocoo()).zero_() data_out = SparseMatrix.from_other_sparse_matrix(data_full, 0) # Load up all edge data for rel_id in sorted(dl.links['data'].keys()): data_matrix = dl.links['data'][rel_id] data_rel = SparseMatrix.from_scipy_sparse(data_matrix.tocoo()) if not use_edge_data: # Use only adjacency information data_rel.values = torch.ones(data_rel.values.shape) data_rel_full = SparseMatrix.from_other_sparse_matrix(data_full, 1) + data_rel data_out.values = torch.cat([data_out.values, data_rel_full.values], 1) data_out.n_channels += 1 if use_node_attrs: for ent_id, attr_matrix in dl.nodes['attr'].items(): start_i = dl.nodes['shift'][ent_id] n_instances = dl.nodes['count'][ent_id] if attr_matrix is None: if node_val == 'zero': attr_matrix = np.zeros((n_instances, 1)) elif node_val == 'rand': attr_matrix = np.random.randn(n_instances, 1) else: attr_matrix = np.ones((n_instances, 1)) n_channels = attr_matrix.shape[1] indices = torch.arange(start_i, start_i + n_instances).unsqueeze(0).repeat( 2, 1) data_rel = SparseMatrix( indices=indices, values=torch.FloatTensor(attr_matrix), shape=np.array([total_n_nodes, total_n_nodes, n_channels]), is_set=True) data_rel_full = SparseMatrix.from_other_sparse_matrix( data_full, n_channels) + data_rel data_out.values = torch.cat( [data_out.values, data_rel_full.values], 1) data_out.n_channels += n_channels data = SparseMatrixData(schema) data[0] = data_out return schema,\ data, \ dl
def load_data(): paper_names = [] classes = [] word_names = ['word'+str(i+1) for i in range(1433)] with open(csv_file_str.format('paper')) as paperfile: reader = csv.reader(paperfile) for paper_name, class_name in reader: paper_names.append(paper_name) classes.append(class_name) class_names = list(np.unique(classes)) class_name_to_idx = {class_name : i for i, class_name in enumerate(class_names)} paper_name_to_idx = {paper_name: i for i, paper_name in enumerate(paper_names)} paper = np.array([[paper_name_to_idx[paper_name] for paper_name in paper_names], [class_name_to_idx[class_name] for class_name in classes]]) cites = [] with open(csv_file_str.format('cites')) as citesfile: reader = csv.reader(citesfile) for citer, citee in reader: cites.append([paper_name_to_idx[citer], paper_name_to_idx[citee]]) cites = np.array(cites).T content = [] def word_to_idx(word): ''' words all formatted like: "word1328" ''' return int(word[4:]) - 1 with open(csv_file_str.format('content')) as contentfile: reader = csv.reader(contentfile) for paper_name, word_name in reader: content.append([paper_name_to_idx[paper_name], word_to_idx(word_name)]) content = np.array(content).T n_papers = len(paper_names) n_classes = len(class_names) n_words = len(word_names) ent_papers = Entity(0, n_papers) ent_classes = Entity(1, n_classes) ent_words = Entity(2, n_words) entities = [ent_papers, ent_classes, ent_words] rel_paper = Relation(0, [ent_papers, ent_classes]) rel_cites = Relation(1, [ent_papers, ent_papers]) rel_content = Relation(2, [ent_papers, ent_words]) relations = [rel_paper, rel_cites, rel_content] schema = DataSchema(entities, relations) class_targets = torch.LongTensor(paper[1]) paper_matrix = torch.zeros(n_papers, n_classes) paper_matrix[paper] = 1 cites_matrix = torch.zeros(n_papers, n_papers) cites_matrix[cites] = 1 content_matrix = torch.zeros(n_papers, n_words) content_matrix[content] = 1 data = Data(schema) data[0] = paper_matrix.unsqueeze(0).unsqueeze(0) data[1] = cites_matrix.unsqueeze(0).unsqueeze(0) data[2] = content_matrix.unsqueeze(0).unsqueeze(0) return data, schema, class_targets
def load_data(prefix, use_node_attrs=True, use_edge_data=True, use_other_edges=True, node_val='one'): dl = data_loader(DATA_FILE_DIR + prefix) all_entities = [ Entity(entity_id, n_instances) for entity_id, n_instances in sorted(dl.nodes['count'].items()) ] relations = {} test_types = dl.test_types if use_other_edges: for rel_id, (entity_i, entity_j) in sorted(dl.links['meta'].items()): relations[rel_id] = Relation( rel_id, [all_entities[entity_i], all_entities[entity_j]]) else: for rel_id in test_types: entity_i, entity_j = dl.links['meta'][rel_id] relations[rel_id] = Relation( rel_id, [all_entities[entity_i], all_entities[entity_j]]) if use_other_edges: entities = all_entities else: entities = list(np.unique(relations[test_types[0]].entities)) max_relation = max(relations) + 1 if use_node_attrs: # Create fake relations to represent node attributes for entity in entities: rel_id = max_relation + entity.id relations[rel_id] = Relation(rel_id, [entity, entity], is_set=True) schema = DataSchema(entities, relations) data = SparseMatrixData(schema) for rel_id, data_matrix in dl.links['data'].items(): if use_other_edges or rel_id in test_types: # Get subset belonging to entities in relation relation = relations[rel_id] start_i = dl.nodes['shift'][relation.entities[0].id] end_i = start_i + dl.nodes['count'][relation.entities[0].id] start_j = dl.nodes['shift'][relation.entities[1].id] end_j = start_j + dl.nodes['count'][relation.entities[1].id] rel_matrix = data_matrix[start_i:end_i, start_j:end_j] data[rel_id] = SparseMatrix.from_scipy_sparse(rel_matrix.tocoo()) if not use_edge_data: # Use only adjacency information data[rel_id].values = torch.ones(data[rel_id].values.shape) if use_node_attrs: for ent in entities: ent_id = ent.id attr_matrix = dl.nodes['attr'][ent_id] n_instances = dl.nodes['count'][ent_id] if attr_matrix is None: if node_val == 'zero': attr_matrix = np.zeros((n_instances, 1)) elif node_val == 'rand': attr_matrix = np.random.randn(n_instances, 1) else: attr_matrix = np.ones((n_instances, 1)) n_channels = attr_matrix.shape[1] rel_id = ent_id + max_relation indices = torch.arange(n_instances).unsqueeze(0).repeat(2, 1) data[rel_id] = SparseMatrix( indices=indices, values=torch.FloatTensor(attr_matrix), shape=np.array([n_instances, n_instances, n_channels]), is_set=True) return schema,\ data, \ dl
def load_data(prefix='DBLP', use_node_attrs=True, use_edge_data=True, feats_type=0): dl = data_loader(DATA_FILE_DIR + prefix) # Create Schema entities = [ Entity(entity_id, n_instances) for entity_id, n_instances in sorted(dl.nodes['count'].items()) ] relations = { rel_id: Relation(rel_id, [entities[entity_i], entities[entity_j]]) for rel_id, (entity_i, entity_j) in sorted(dl.links['meta'].items()) } num_relations = len(relations) if use_node_attrs: # Create fake relations to represent node attributes for entity in entities: rel_id = num_relations + entity.id relations[rel_id] = Relation(rel_id, [entity, entity], is_set=True) schema = DataSchema(entities, relations) # Collect data data = SparseMatrixData(schema) for rel_id, data_matrix in dl.links['data'].items(): # Get subset belonging to entities in relation start_i = dl.nodes['shift'][relations[rel_id].entities[0].id] end_i = start_i + dl.nodes['count'][relations[rel_id].entities[0].id] start_j = dl.nodes['shift'][relations[rel_id].entities[1].id] end_j = start_j + dl.nodes['count'][relations[rel_id].entities[1].id] rel_matrix = data_matrix[start_i:end_i, start_j:end_j] data[rel_id] = SparseMatrix.from_scipy_sparse(rel_matrix.tocoo()) if not use_edge_data: # Use only adjacency information data[rel_id].values = torch.ones(data[rel_id].values.shape) target_entity = 0 if use_node_attrs: for ent_id, attr_matrix in dl.nodes['attr'].items(): if attr_matrix is None: # Attribute for each node is a single 1 attr_matrix = np.ones(dl.nodes['count'][ent_id])[:, None] n_channels = attr_matrix.shape[1] rel_id = ent_id + num_relations n_instances = dl.nodes['count'][ent_id] indices = torch.arange(n_instances).unsqueeze(0).repeat(2, 1) data[rel_id] = SparseMatrix( indices=indices, values=torch.FloatTensor(attr_matrix), shape=np.array([n_instances, n_instances, n_channels]), is_set=True) n_outputs = dl.nodes['count'][target_entity] n_output_classes = dl.labels_train['num_classes'] schema_out = DataSchema([entities[target_entity]], [ Relation(0, [entities[target_entity], entities[target_entity]], is_set=True) ]) data_target = SparseMatrixData(schema_out) data_target[0] = SparseMatrix( indices=torch.arange(n_outputs, dtype=torch.int64).repeat(2, 1), values=torch.zeros([n_outputs, n_output_classes]), shape=(n_outputs, n_outputs, n_output_classes), is_set=True) labels = np.zeros((dl.nodes['count'][0], dl.labels_train['num_classes']), dtype=int) val_ratio = 0.2 train_idx = np.nonzero(dl.labels_train['mask'])[0] np.random.shuffle(train_idx) split = int(train_idx.shape[0] * val_ratio) val_idx = train_idx[:split] train_idx = train_idx[split:] train_idx = np.sort(train_idx) val_idx = np.sort(val_idx) test_idx = np.nonzero(dl.labels_test['mask'])[0] labels[train_idx] = dl.labels_train['data'][train_idx] labels[val_idx] = dl.labels_train['data'][val_idx] if prefix != 'IMDB': labels = labels.argmax(axis=1) train_val_test_idx = {} train_val_test_idx['train_idx'] = train_idx train_val_test_idx['val_idx'] = val_idx train_val_test_idx['test_idx'] = test_idx return schema,\ schema_out, \ data, \ data_target, \ labels,\ train_val_test_idx,\ dl
def __init__(self): data_raw = { rel_name: {key: list() for key in schema_dict[rel_name].keys()} for rel_name in schema_dict.keys() } for relation_name in relation_names: with open(csv_file_str.format(relation_name)) as file: reader = csv.reader(file) keys = schema_dict[relation_name].keys() for cols in reader: for key, col in zip(keys, cols): data_raw[relation_name][key].append(col) ent_person = Entity(0, len(data_raw['person']['p_id'])) ent_course = Entity(1, len(data_raw['course']['course_id'])) entities = [ent_person, ent_course] rel_person = Relation(0, [ent_person, ent_person], is_set=True) rel_course = Relation(1, [ent_course, ent_course], is_set=True) rel_advisedBy = Relation(2, [ent_person, ent_person]) rel_taughtBy = Relation(3, [ent_course, ent_person]) relations = [rel_person, rel_course, rel_advisedBy, rel_taughtBy] self.schema = DataSchema(entities, relations) self.data = SparseMatrixData(self.schema) ent_id_to_idx_dict = { 'person': self.id_to_idx(data_raw['person']['p_id']), 'course': self.id_to_idx(data_raw['course']['course_id']) } for relation in relations: relation_name = relation_names[relation.id] print(relation_name) if relation.is_set: data_matrix = self.set_relation_to_matrix( relation, schema_dict[relation_name], data_raw[relation_name]) else: if relation_name == 'advisedBy': ent_n_id_str = 'p_id' ent_m_id_str = 'p_id_dummy' elif relation_name == 'taughtBy': ent_n_id_str = 'course_id' ent_m_id_str = 'p_id' data_matrix = self.binary_relation_to_matrix( relation, schema_dict[relation_name], data_raw[relation_name], ent_id_to_idx_dict, ent_n_id_str, ent_m_id_str) self.data[relation.id] = data_matrix self.target = self.get_targets( data_raw[self.TARGET_RELATION][self.TARGET_KEY], schema_dict[self.TARGET_RELATION][self.TARGET_KEY]) self.target_rel_id = 0 rel_out = Relation(self.target_rel_id, [ent_person, ent_person], is_set=True) self.schema_out = DataSchema([ent_person], [rel_out]) self.data_target = Data(self.schema_out) n_output_classes = len( np.unique(data_raw[self.TARGET_RELATION][self.TARGET_KEY])) self.output_dim = n_output_classes n_person = ent_person.n_instances self.data_target[self.target_rel_id] = SparseMatrix( indices=torch.arange(n_person, dtype=torch.int64).repeat(2, 1), values=torch.zeros([n_person, n_output_classes]), shape=(n_person, n_person, n_output_classes))
def __init__(self, use_node_attrs=True): entities = [ Entity(entity_id, n_instances) for entity_id, n_instances in ENTITY_N_INSTANCES.items() ] relations = [ Relation(rel_id, [entities[entity_i], entities[entity_j]]) for rel_id, (entity_i, entity_j) in RELATION_IDX.items() ] if use_node_attrs: for entity_id in ENTITY_N_INSTANCES.keys(): rel = Relation(10 + entity_id, [entities[entity_id], entities[entity_id]], is_set=True) relations.append(rel) self.schema = DataSchema(entities, relations) self.node_id_to_idx = {ent_i: {} for ent_i in range(len(entities))} with open(NODE_FILE_STR, 'r') as node_file: lines = node_file.readlines() node_counter = {ent_i: 0 for ent_i in range(len(entities))} for line in lines: node_id, node_name, node_type, values = line.rstrip().split( '\t') node_id = int(node_id) node_type = int(node_type) node_idx = node_counter[node_type] self.node_id_to_idx[node_type][node_id] = node_idx node_counter[node_type] += 1 target_node_id_to_idx = self.node_id_to_idx[TARGET_NODE_TYPE] self.target_node_idx_to_id = { idx: id for id, idx in target_node_id_to_idx.items() } raw_data_indices = {rel_id: [] for rel_id in range(len(relations))} raw_data_values = {rel_id: [] for rel_id in range(len(relations))} if use_node_attrs: with open(NODE_FILE_STR, 'r') as node_file: lines = node_file.readlines() for line in lines: node_id, node_name, node_type, values = line.rstrip( ).split('\t') node_type = int(node_type) node_id = self.node_id_to_idx[node_type][int(node_id)] values = list(map(float, values.split(','))) raw_data_indices[10 + node_type].append([node_id, node_id]) raw_data_values[10 + node_type].append(values) with open(LINK_FILE_STR, 'r') as link_file: lines = link_file.readlines() for line in lines: node_i, node_j, rel_num, val = line.rstrip().split('\t') rel_num = int(rel_num) node_i_type, node_j_type = RELATION_IDX[rel_num] node_i = self.node_id_to_idx[node_i_type][int(node_i)] node_j = self.node_id_to_idx[node_j_type][int(node_j)] val = float(val) raw_data_indices[rel_num].append([node_i, node_j]) raw_data_values[rel_num].append([val]) self.data = SparseMatrixData(self.schema) for rel in relations: indices = torch.LongTensor(raw_data_indices[rel.id]).T values = torch.Tensor(raw_data_values[rel.id]) n = rel.entities[0].n_instances m = rel.entities[1].n_instances n_channels = values.shape[1] data_matrix = SparseMatrix(indices=indices, values=values, shape=np.array([n, m, n_channels]), is_set=rel.is_set) del raw_data_indices[rel.id] del raw_data_values[rel.id] self.data[rel.id] = data_matrix
np.random.randint(0, n_papers, (n_content)), np.random.randint(0, n_words, (n_content)) ]) content = np.unique(cites, axis=1) content_matrix = SparseMatrix(indices=torch.LongTensor(content), values=value_dist.sample( (content.shape[1], 1)), shape=(n_papers, n_words, 1)).coalesce() ent_papers = Entity(0, n_papers) #ent_classes = Entity(1, n_classes) ent_words = Entity(1, n_words) rel_paper = Relation(0, [ent_papers, ent_papers], is_set=True) rel_cites = Relation(0, [ent_papers, ent_papers]) rel_content = Relation(1, [ent_papers, ent_words]) schema = DataSchema([ent_papers, ent_words], [rel_cites, rel_content]) schema_out = DataSchema([ent_papers], [rel_paper]) targets = torch.LongTensor(paper[1]) data = SparseMatrixData(schema) data[0] = cites_matrix data[1] = content_matrix indices_identity, indices_transpose = data.calculate_indices() data_target = Data(schema_out) data_target[0] = SparseMatrix(indices=torch.arange( n_papers, dtype=torch.int64).repeat(2, 1), values=torch.zeros([n_papers, n_classes]), shape=(n_papers, n_papers, n_classes)) data_target = data_target.to(device)
def load_data(): paper_names = [] classes = [] word_names = ['word' + str(i + 1) for i in range(1433)] with open(csv_file_str.format('paper')) as paperfile: reader = csv.reader(paperfile) for paper_name, class_name in reader: paper_names.append(paper_name) classes.append(class_name) class_names = list(np.unique(classes)) class_name_to_idx = { class_name: i for i, class_name in enumerate(class_names) } paper_name_to_idx = { paper_name: i for i, paper_name in enumerate(paper_names) } paper = np.array( [[paper_name_to_idx[paper_name] for paper_name in paper_names], [class_name_to_idx[class_name] for class_name in classes]]) cites = [] with open(csv_file_str.format('cites')) as citesfile: reader = csv.reader(citesfile) for citer, citee in reader: cites.append([paper_name_to_idx[citer], paper_name_to_idx[citee]]) cites = np.array(cites).T content = [] def word_to_idx(word): ''' words all formatted like: "word1328" ''' return int(word[4:]) - 1 with open(csv_file_str.format('content')) as contentfile: reader = csv.reader(contentfile) for paper_name, word_name in reader: content.append( [paper_name_to_idx[paper_name], word_to_idx(word_name)]) content = np.array(content).T n_papers = len(paper_names) n_classes = len(class_names) n_words = len(word_names) ent_papers = Entity(0, n_papers) ent_classes = Entity(1, n_classes) ent_words = Entity(2, n_words) entities = [ent_papers, ent_classes, ent_words] rel_paper = Relation(0, [ent_papers, ent_classes]) rel_cites = Relation(1, [ent_papers, ent_papers]) rel_content = Relation(2, [ent_papers, ent_words]) relations = [rel_paper, rel_cites, rel_content] schema = DataSchema(entities, relations) # For each paper, get a random negative sample random_class_offset = np.random.randint(1, n_classes, (n_papers, )) paper_neg = np.stack( (paper[0], (paper[1] + random_class_offset) % n_classes)) paper_matrix = SparseTensor( indices=torch.LongTensor(np.concatenate((paper, paper_neg), axis=1)), values=torch.cat((torch.ones( 1, paper.shape[1]), torch.zeros(1, paper_neg.shape[1])), 1), shape=np.array([n_papers, n_classes])).coalesce() class_targets = torch.LongTensor(paper[1]) # Randomly fill in values and coalesce to remove duplicates cites_neg = np.random.randint(0, n_papers, cites.shape) cites_matrix = SparseTensor( indices=torch.LongTensor(np.concatenate((cites, cites_neg), axis=1)), values=torch.cat((torch.ones( 1, cites.shape[1]), torch.zeros(1, cites_neg.shape[1])), 1), shape=np.array([n_papers, n_papers])).coalesce() # For each paper, randomly fill in values and coalesce to remove duplicates content_neg = np.stack( (np.random.randint(0, n_papers, (content.shape[1], )), np.random.randint(0, n_words, (content.shape[1], )))) content_matrix = SparseTensor( indices=torch.LongTensor(np.concatenate((content, content_neg), axis=1)), values=torch.cat((torch.ones( 1, content.shape[1]), torch.zeros(1, content_neg.shape[1])), 1), shape=np.array([n_papers, n_words])).coalesce() paper_dense_indices = np.array([ np.tile(range(n_papers), n_classes), np.repeat(range(n_classes), n_papers) ]) paper_dense_values = torch.zeros(paper_dense_indices.shape[1]) for paper_i, class_name in enumerate(classes): class_i = class_name_to_idx[class_name] paper_dense_values[paper_i * n_classes + class_i] = 1 paper_dense_matrix = SparseTensor( indices=torch.LongTensor(paper_dense_indices), values=torch.Tensor(paper_dense_values).unsqueeze(0), shape=np.array([n_papers, n_classes])) data = SparseTensorData(schema) data[0] = paper_dense_matrix data[1] = cites_matrix data[2] = content_matrix return data, schema, class_targets
def load_data_flat(prefix, use_node_attrs=True, use_edge_data=True, node_val='zero', feats_type=0): ''' Load data into one matrix with all relations, reproducing Maron 2019 The first [# relation types] channels are adjacency matrices, while the next [sum of feature dimensions per entity type] channels have node attributes on the relevant segment of their diagonals if use_node_attrs=True. If node features aren't included, then ndoe_val is used instead. ''' dl = data_loader(DATA_FILE_DIR + prefix) total_n_nodes = dl.nodes['total'] entities = [Entity(0, total_n_nodes)] relations = {0: Relation(0, [entities[0], entities[0]])} schema = DataSchema(entities, relations) # Sparse Matrix containing all data data_full = sum(dl.links['data'].values()).tocoo() data_diag = scipy.sparse.coo_matrix( (np.ones(total_n_nodes), (np.arange(total_n_nodes), np.arange(total_n_nodes))), (total_n_nodes, total_n_nodes)) data_full += data_diag data_full = SparseMatrix.from_scipy_sparse(data_full.tocoo()).zero_() data_out = SparseMatrix.from_other_sparse_matrix(data_full, 0) # Load up all edge data for rel_id in sorted(dl.links['data'].keys()): data_matrix = dl.links['data'][rel_id] data_rel = SparseMatrix.from_scipy_sparse(data_matrix.tocoo()) if not use_edge_data: # Use only adjacency information data_rel.values = torch.ones(data_rel.values.shape) data_rel_full = SparseMatrix.from_other_sparse_matrix(data_full, 1) + data_rel data_out.values = torch.cat([data_out.values, data_rel_full.values], 1) data_out.n_channels += 1 target_entity = 0 if use_node_attrs: for ent_id, attr_matrix in dl.nodes['attr'].items(): start_i = dl.nodes['shift'][ent_id] n_instances = dl.nodes['count'][ent_id] if attr_matrix is None: if node_val == 'zero': attr_matrix = np.zeros((n_instances, 1)) elif node_val == 'rand': attr_matrix = np.random.randn(n_instances, 1) else: attr_matrix = np.ones((n_instances, 1)) if feats_type == 1 and ent_id != target_entity: # To keep same behaviour as non-LGNN model, use 10 dimensions attr_matrix = np.zeros((n_instances, 10)) n_channels = attr_matrix.shape[1] indices = torch.arange(start_i, start_i + n_instances).unsqueeze(0).repeat( 2, 1) data_rel = SparseMatrix( indices=indices, values=torch.FloatTensor(attr_matrix), shape=np.array([total_n_nodes, total_n_nodes, n_channels]), is_set=True) data_rel_full = SparseMatrix.from_other_sparse_matrix( data_full, n_channels) + data_rel data_out.values = torch.cat( [data_out.values, data_rel_full.values], 1) data_out.n_channels += n_channels data = SparseMatrixData(schema) data[0] = data_out n_outputs = total_n_nodes n_output_classes = dl.labels_train['num_classes'] schema_out = DataSchema([entities[target_entity]], [ Relation(0, [entities[target_entity], entities[target_entity]], is_set=True) ]) data_target = SparseMatrixData(schema_out) data_target[0] = SparseMatrix( indices=torch.arange(n_outputs, dtype=torch.int64).repeat(2, 1), values=torch.zeros([n_outputs, n_output_classes]), shape=(n_outputs, n_outputs, n_output_classes), is_set=True) labels = np.zeros((dl.nodes['count'][0], dl.labels_train['num_classes']), dtype=int) val_ratio = 0.2 train_idx = np.nonzero(dl.labels_train['mask'])[0] np.random.shuffle(train_idx) split = int(train_idx.shape[0] * val_ratio) val_idx = train_idx[:split] train_idx = train_idx[split:] train_idx = np.sort(train_idx) val_idx = np.sort(val_idx) test_idx = np.nonzero(dl.labels_test['mask'])[0] labels[train_idx] = dl.labels_train['data'][train_idx] labels[val_idx] = dl.labels_train['data'][val_idx] if prefix != 'IMDB': labels = labels.argmax(axis=1) train_val_test_idx = {} train_val_test_idx['train_idx'] = train_idx train_val_test_idx['val_idx'] = val_idx train_val_test_idx['test_idx'] = test_idx return schema,\ schema_out, \ data, \ data_target, \ labels,\ train_val_test_idx,\ dl