def set_up_predictor(fp_out_dim, net_hidden_dims, class_num, sim_method='mlp', symmetric=None): sim_method_dict = { 'mlp': 'multi-layered perceptron', 'ntn': 'bilinear transform', 'symmlp': 'symmetric perceptron', 'hole': 'holographic embedding', 'dist-mult': 'dist-mult', } logging.info('Link Prediction: {}'.format( sim_method_dict.get(sim_method, None))) lp = None if sim_method == 'mlp': lp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'ntn': ntn_out_dim = 8 lp = NTN(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, ntn_out_dim=ntn_out_dim, hidden_dims=net_hidden_dims) elif sim_method == 'symmlp': lp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'hole': lp = HolE(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'dist-mult': dm_out_dim = 8 lp = DistMult(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, dm_out_dim=dm_out_dim, hidden_dims=net_hidden_dims) else: raise ValueError( '[ERROR] Invalid link prediction model: {}'.format(sim_method)) predictor = DDIPredictor(lp, symmetric=symmetric) return predictor
help="decay rate for the first moment estimate in Adam" ) parser.add_argument( '--decay2', default=0.999, type=float, help="decay rate for second moment estimate in Adam" ) args = parser.parse_args() dataset = Dataset(args.dataset) examples = torch.from_numpy(dataset.get_train().astype('int64')) print(dataset.get_shape()) model = { 'CP': lambda: CP(dataset.get_shape(), args.rank, args.init), 'ComplEx': lambda: ComplEx(dataset.get_shape(), args.rank, args.init), 'DistMult': lambda: DistMult(dataset.get_shape(), args.rank, args.init) }[args.model]() regularizer = { 'F2': F2(args.reg), 'N3': N3(args.reg), }[args.regularizer] has_cuda = torch.cuda.is_available() if has_cuda: device = 'cuda' else: device = 'cpu' model.to(device) optim_method = {
def set_up_predictor(method, fp_hidden_dim, fp_out_dim, conv_layers, concat_hidden, layer_aggregator, fp_dropout_rate, fp_batch_normalization, net_hidden_dims, class_num, sim_method='mlp', fp_attention=False, weight_typing=True, attention_tying=True, update_attention=False, context=False, context_layers=1, context_dropout=0., message_function='matrix_multiply', readout_function='graph_level', num_timesteps=3, num_output_hidden_layers=0, output_hidden_dim=16, output_activation=functions.relu, symmetric=None, ): if sim_method == 'mlp': logging.info('Using multi-layer perceptron for the learning of composite representation') mlp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'ntn': logging.info('Using neural tensor network for the learning of composite representation') ntn_out_dim = fp_out_dim logging.info('NTN out dim is {}'.format(ntn_out_dim)) mlp = NTN(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, ntn_out_dim=ntn_out_dim, hidden_dims=net_hidden_dims) elif sim_method == 'symmlp': logging.info('Using symmetric multi-layer perceptron for the learning of composite representation') mlp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'hole': logging.info('Using holegraphic embedding for the learning of composite representation') mlp = HolE(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'dist-mult': logging.info('Using DistMult embedding for the learning of composite representation') dm_out_dim = 8 mlp = DistMult(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, dm_out_dim=dm_out_dim, hidden_dims=net_hidden_dims) else: raise ValueError('[ERROR] Invalid similarity method: {}'.format(method)) logging.info('Using {} as similarity predictor with hidden_dims {}...'.format(sim_method, net_hidden_dims)) if method == 'ggnn': logging.info('Training a GGNN predictor...') if fp_attention: logging.info('Self-attention mechanism is utilized...') if attention_tying: logging.info('Self-attention is tying...') else: logging.info('Self-attention is not tying...') if update_attention: logging.info('Self-attention mechanism is utilized in update...') if attention_tying: logging.info('Self-attention is tying...') else: logging.info('Self-attention is not tying...') if not weight_typing: logging.info('Weight is not tying...') if fp_dropout_rate != 0.0: logging.info('Using dropout whose rate is {}...'.format(fp_dropout_rate)) if fp_batch_normalization: logging.info('Using batch normalization in dynamic fingerprint...') if concat_hidden: logging.info('Incorporating layer aggregation via concatenation after readout...') if layer_aggregator: logging.info('Incorporating layer aggregation via {} before readout...'.format(layer_aggregator)) if context: logging.info('Context embedding is utilized...') logging.info('Number of context layers is {}...'.format(context_layers)) logging.info('Dropout rate of context layers is {:.2f}'.format(context_dropout)) logging.info('Message function is {}'.format(message_function)) logging.info('Readout function is {}'.format(readout_function)) logging.info('Num_timesteps = {}, num_output_hidden_layers={}, output_hidden_dim={}'.format( num_timesteps, num_output_hidden_layers, output_hidden_dim )) # num_timesteps=3, num_output_hidden_layers=0, output_hidden_dim=16, output_activation=functions.relu, # ggnn = GGNN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, # layer_aggregator=layer_aggregator, # dropout_rate=fp_dropout_rate, batch_normalization=fp_batch_normalization, # use_attention=fp_attention, weight_tying=weight_typing, attention_tying=attention_tying, # context=context, message_function=message_function, readout_function=readout_function, # num_timesteps=num_timesteps, num_output_hidden_layers=num_output_hidden_layers, # output_hidden_dim=output_hidden_dim, output_activation=output_activation) ggnn = GGNN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, # layer_aggregator=layer_aggregator, dropout_rate=fp_dropout_rate, batch_normalization=fp_batch_normalization, # use_attention=fp_attention, weight_tying=weight_typing, # attention_tying=attention_tying, # context=context, # mmessage_function=message_function, readout_function=readout_function, # num_timesteps=num_timesteps, # num_output_hidden_layers=num_output_hidden_layers, # output_hidden_dim=output_hidden_dim, # output_activation=output_activation ) if symmetric is not None: logging.info('Symmetric is {}'.format(symmetric)) else: logging.info('Symmetric is None') predictor = GraphConvPredictorForPair(ggnn, mlp, symmetric=symmetric) return predictor
def train_and_eval(self): logger.info(f'Training the {model_name} model ...') self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))} self.relation_idxs = { d.relations[i]: i for i in range(len(d.relations)) } train_triple_idxs = self.get_data_idxs(d.train_data) train_triple_size = len(train_triple_idxs) logger.info(f'Number of training data points: {train_triple_size}') if model_name.lower() == "hype": model = HypE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "hyper": model = HypER(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "distmult": model = DistMult(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "conve": model = ConvE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "complex": model = ComplEx(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) logger.debug('model parameters: {}'.format( {name: value.numel() for name, value in model.named_parameters()})) if self.cuda: model.cuda() model.init() opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate) if self.decay_rate: scheduler = ExponentialLR(opt, self.decay_rate) er_vocab = self.get_er_vocab(train_triple_idxs) er_vocab_pairs = list(er_vocab.keys()) er_vocab_pairs_size = len(er_vocab_pairs) logger.info( f'Number of entity-relational pairs: {er_vocab_pairs_size}') logger.info('Starting Training ...') for epoch in range(1, self.epochs + 1): logger.info(f'Epoch: {epoch}') model.train() costs = [] np.random.shuffle(er_vocab_pairs) for j in range(0, er_vocab_pairs_size, self.batch_size): if j % (128 * 100) == 0: logger.info(f'Batch: {j + 1} ...') triples, targets = self.get_batch(er_vocab, er_vocab_pairs, er_vocab_pairs_size, j) opt.zero_grad() e1_idx = torch.tensor(triples[:, 0]) r_idx = torch.tensor(triples[:, 1]) if self.cuda: e1_idx = e1_idx.cuda() r_idx = r_idx.cuda() predictions = model.forward(e1_idx, r_idx) if self.label_smoothing: targets = ((1.0 - self.label_smoothing) * targets) + (1.0 / targets.size(1)) cost = model.loss(predictions, targets) cost.backward() opt.step() costs.append(cost.item()) if self.decay_rate: scheduler.step() logger.info(f'Mean training cost: {np.mean(costs)}') if epoch % 2 == 0: model.eval() with torch.no_grad(): train_data = np.array(d.train_data) train_data_map = { 'WN18': 10000, 'FB15k': 100000, 'WN18RR': 6068, 'FB15k-237': 35070 } train_data_sample_size = train_data_map[dataset] train_data = train_data[ np.random.choice(train_data.shape[0], train_data_sample_size, replace=False), :] self.evaluate(model, train_data, epoch, 'training') logger.info(f'Starting Validation ...') self.evaluate(model, d.valid_data, epoch, 'validation') logger.info(f'Starting Test ...') self.evaluate(model, d.test_data, epoch, 'testing')
def main(): #config_path = join(path_dir, 'data', Config.dataset, 'data.npy') if Config.process: preprocess(Config.dataset, delete_data=True) input_keys = ['e1', 'rel', 'rel_eval', 'e2', 'e2_multi1', 'e2_multi2'] p = Pipeline(Config.dataset, keys=input_keys) p.load_vocabs() vocab = p.state['vocab'] node_list = p.state['vocab']['e1'] rel_list = p.state['vocab']['rel'] num_entities = vocab['e1'].num_token num_relations = vocab['rel'].num_token train_batcher = StreamBatcher(Config.dataset, 'train', Config.batch_size, randomize=True, keys=input_keys) dev_rank_batcher = StreamBatcher(Config.dataset, 'dev_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys) test_rank_batcher = StreamBatcher(Config.dataset, 'test_ranking', Config.batch_size, randomize=False, loader_threads=4, keys=input_keys) train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary')) def normalize(mx): """Row-normalize sparse matrix""" rowsum = np.array(mx.sum(1)) r_inv = np.power(rowsum, -1).flatten() r_inv[np.isinf(r_inv)] = 0. r_mat_inv = sp.diags(r_inv) mx = r_mat_inv.dot(mx) return mx data = [] rows = [] columns = [] for i, str2var in enumerate(train_batcher): if i % 10 == 0: print("batch number:", i) for j in range(str2var['e1'].shape[0]): for k in range(str2var['e2_multi1'][j].shape[0]): if str2var['e2_multi1'][j][k] != 0: a = str2var['rel'][j].cpu() data.append(str2var['rel'][j].cpu()) rows.append(str2var['e1'][j].cpu().tolist()[0]) columns.append(str2var['e2_multi1'][j][k].cpu()) else: break rows = rows + [i for i in range(num_entities)] columns = columns + [i for i in range(num_entities)] data = data + [num_relations for i in range(num_entities)] indices = torch.LongTensor([rows, columns]).cuda() v = torch.LongTensor(data).cuda() adjacencies = [indices, v, num_entities] #filename = join(path_dir, 'data', Config.dataset, 'adj.pkl') #file = open(filename, 'wb+') #pkl.dump(adjacencies, file) #file.close() print('Finished the preprocessing') ############ X = torch.LongTensor([i for i in range(num_entities)]) if Config.model_name is None: model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token) elif Config.model_name == 'SACN': model = SACN(vocab['e1'].num_token, vocab['rel'].num_token) elif Config.model_name == 'ConvTransE': model = ConvTransE(vocab['e1'].num_token, vocab['rel'].num_token) elif Config.model_name == 'ConvE': model = ConvE(vocab['e1'].num_token, vocab['rel'].num_token) elif Config.model_name == 'DistMult': model = DistMult(vocab['e1'].num_token, vocab['rel'].num_token) elif Config.model_name == 'ComplEx': model = Complex(vocab['e1'].num_token, vocab['rel'].num_token) else: log.info('Unknown model: {0}', Config.model_name) raise Exception("Unknown model!") #train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary')) train_batcher = StreamBatcher(Config.dataset, 'train', Config.batch_size, randomize=True, keys=input_keys) eta = ETAHook('train', print_every_x_batches=100) train_batcher.subscribe_to_events(eta) train_batcher.subscribe_to_start_of_epoch_event(eta) train_batcher.subscribe_to_events(LossHook('train', print_every_x_batches=100)) train_batcher.at_batch_prepared_observers.insert(1,TargetIdx2MultiTarget(num_entities, 'e2_multi1', 'e2_multi1_binary')) if Config.cuda: model.cuda() X = X.cuda() if load: model_params = torch.load(model_path) print(model) total_param_size = [] params = [(key, value.size(), value.numel()) for key, value in model_params.items()] for key, size, count in params: total_param_size.append(count) print(key, size, count) print(np.array(total_param_size).sum()) model.load_state_dict(model_params) model.eval() ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation',X, adjacencies) ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation',X, adjacencies) else: model.init() total_param_size = [] params = [value.numel() for value in model.parameters()] print(params) print(np.sum(params)) opt = torch.optim.Adam(model.parameters(), lr=Config.learning_rate, weight_decay=Config.L2) for epoch in range(epochs): model.train() for i, str2var in enumerate(train_batcher): opt.zero_grad() e1 = str2var['e1'].cuda() rel = str2var['rel'].cuda() e2_multi = str2var['e2_multi1_binary'].float().cuda() # label smoothing e2_multi = ((1.0-Config.label_smoothing_epsilon)*e2_multi) + (1.0/e2_multi.size(1)) pred = model.forward(e1, rel, X, adjacencies) loss = model.loss(pred, e2_multi) loss.backward() opt.step() train_batcher.state.loss = loss.cpu() print('saving to {0}'.format(model_path)) torch.save(model.state_dict(), model_path) model.eval() with torch.no_grad(): ranking_and_hits(model, dev_rank_batcher, vocab, 'dev_evaluation', X, adjacencies) if epoch % 3 == 0: if epoch > 0: ranking_and_hits(model, test_rank_batcher, vocab, 'test_evaluation', X, adjacencies)
filtered = False train_ranker = RankingEvaluation(train_triples[:5000], num_nodes, triples_to_filter=all_triples if filtered else None, device=device, show_progress=True) val_ranker = RankingEvaluation(val_triples, num_nodes, triples_to_filter=all_triples if filtered else None, device=device, show_progress=True) #test_ranker = RankingEvaluation(test_triples, num_nodes, filter_triples=all_triples if filtered else None, show_progress=True) history = utils.History() #node_features = load_image_features(num_nodes, entity_map) node_features = None utils.seed_all(0) # TODO: Make device parameter obsolete by moving everything to the device once .to(device) is called. # net = UnsupervisedRGCN(num_nodes, num_relations, train_triples, embedding_size=200, dropout=0, # embedding_size=500, dropout=0.5 # num_sample_train=10, num_sample_eval=10, activation=F.elu, # node_features=node_features, device=device) net = DistMult(500, num_nodes, num_relations, 0) net.to(device) optimizer = torch.optim.Adam(filter(lambda parameter: parameter.requires_grad, net.parameters()), lr=0.001) train_via_classification(net, train_triples, val_triples, optimizer, num_nodes, train_ranker, val_ranker, num_epochs=35, batch_size=64, batch_size_eval=512, device=device, history=history, dry_run=False, ranking_eval=True) #to_plot = ['loss', 'acc', 'median_diff', 'mean_rank', 'mean_rec_rank', 'hits_1', 'hits_3', 'hits_10'] #figsize = (8, 20) #history.plot(*to_plot, figsize=figsize)#, xlim=(0, 10))
def set_up_predictor(method, fp_hidden_dim, fp_out_dim, conv_layers, concat_hidden, layer_aggregator, fp_dropout_rate, fp_batch_normalization, net_hidden_dims, class_num, sim_method='mlp', fp_attention=False, weight_typing=True, attention_tying=True, update_attention=False, context=False, context_layers=1, context_dropout=0., message_function='matrix_multiply', readout_function='graph_level', num_timesteps=3, num_output_hidden_layers=0, output_hidden_dim=16, output_activation=functions.relu, symmetric=None, ): if sim_method == 'mlp': logging.info('Using multi-layer perceptron for the learning of composite representation') mlp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'ntn': logging.info('Using neural tensor network for the learning of composite representation') ntn_out_dim = 8 logging.info('NTN out dim is {}'.format(ntn_out_dim)) mlp = NTN(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, ntn_out_dim=ntn_out_dim, hidden_dims=net_hidden_dims) elif sim_method == 'symmlp': logging.info('Using symmetric multi-layer perceptron for the learning of composite representation') mlp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'hole': logging.info('Using holegraphic embedding for the learning of composite representation') mlp = HolE(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'dist-mult': logging.info('Using DistMult embedding for the learning of composite representation') dm_out_dim = 8 mlp = DistMult(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, dm_out_dim=dm_out_dim, hidden_dims=net_hidden_dims) else: raise ValueError('[ERROR] Invalid similarity method: {}'.format(method)) logging.info('Using {} as similarity predictor with hidden_dims {}...'.format(sim_method, net_hidden_dims)) encoder = None if method == 'ggnn': logging.info('Training a GGNN predictor...') if fp_attention: logging.info('Self-attention mechanism is utilized...') if attention_tying: logging.info('Self-attention is tying...') else: logging.info('Self-attention is not tying...') if update_attention: logging.info('Self-attention mechanism is utilized in update...') if attention_tying: logging.info('Self-attention is tying...') else: logging.info('Self-attention is not tying...') if not weight_typing: logging.info('Weight is not tying...') if fp_dropout_rate != 0.0: logging.info('Using dropout whose rate is {}...'.format(fp_dropout_rate)) if fp_batch_normalization: logging.info('Using batch normalization in dynamic fingerprint...') if concat_hidden: logging.info('Incorporating layer aggregation via concatenation after readout...') if layer_aggregator: logging.info('Incorporating layer aggregation via {} before readout...'.format(layer_aggregator)) if context: logging.info('Context embedding is utilized...') logging.info('Number of context layers is {}...'.format(context_layers)) logging.info('Dropout rate of context layers is {:.2f}'.format(context_dropout)) logging.info('Message function is {}'.format(message_function)) logging.info('Readout function is {}'.format(readout_function)) logging.info('Num_timesteps = {}, num_output_hidden_layers={}, output_hidden_dim={}'.format( num_timesteps, num_output_hidden_layers, output_hidden_dim )) # num_timesteps=3, num_output_hidden_layers=0, output_hidden_dim=16, output_activation=functions.relu, encoder = GGNN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, layer_aggregator=layer_aggregator, dropout_rate=fp_dropout_rate, batch_normalization=fp_batch_normalization, use_attention=fp_attention, weight_tying=weight_typing, attention_tying=attention_tying, context=context, message_function=message_function, readout_function=readout_function, num_timesteps=num_timesteps, num_output_hidden_layers=num_output_hidden_layers, output_hidden_dim=output_hidden_dim, output_activation=output_activation) elif method == 'nfp': fp_max_degree = 6 logging.info('Training an NFP predictor...') logging.info('Max degree is {}'.format(fp_max_degree)) encoder = NFP(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, max_degree=fp_max_degree) elif method == 'schnet': logging.info('Training an SchNet predictor...') schnet = SchNet(out_dim=class_num, hidden_dim=fp_hidden_dim, n_layers=conv_layers) encoder = schnet elif method == 'weavenet': logging.info('Training a WeaveNet predictor...') n_atom = 20 n_sub_layer = 1 weave_channels = [50] * conv_layers encoder = WeaveNet(weave_channels=weave_channels, hidden_dim=fp_hidden_dim, n_sub_layer=n_sub_layer, n_atom=n_atom) elif method == 'rsgcn': logging.info('Training an RSGCN predictor...') use_batch_norm = True dropout_ratio = 0.5 if use_batch_norm: logging.info('Using batch normalization...') logging.info('Dropout ratio is {:.1f}'.format(dropout_ratio)) encoder = RSGCN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, use_batch_norm=use_batch_norm, dropout_ratio=dropout_ratio) elif method == 'gin': encoder = GIN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, dropout_ratio=0.5, concat_hidden=True) else: raise ValueError('[ERROR] Invalid method: {}'.format(method)) predictor = GraphConvPredictorForPair(encoder, mlp, symmetric=symmetric) return predictor
def set_up_predictor(method, fp_hidden_dim, fp_out_dim, conv_layers, concat_hidden, fp_dropout_rate, fp_batch_normalization, net_hidden_dims, class_num, weight_typing=True, sim_method='mlp', symmetric=None, ): sim_method_dict = { 'mlp': 'multi-layered perceptron', 'ntn': 'bilinear transform', 'symmlp': 'symmetric perceptron', 'hole': 'holographic embedding', 'dist-mult': 'dist-mult', } method_dict = { 'ggnn': 'GGNN', 'relgcn': 'RelGCN', } logging.info('Graph Embedding: {}'.format(method_dict.get(method, None))) logging.info('Link Prediction: {}'.format(sim_method_dict.get(sim_method, None))) lp = None if sim_method == 'mlp': lp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'ntn': ntn_out_dim = 8 lp = NTN(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, ntn_out_dim=ntn_out_dim, hidden_dims=net_hidden_dims) elif sim_method == 'symmlp': lp = MLP(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'hole': lp = HolE(out_dim=class_num, hidden_dims=net_hidden_dims) elif sim_method == 'dist-mult': dm_out_dim = 8 lp = DistMult(left_dim=fp_out_dim, right_dim=fp_out_dim, out_dim=class_num, dm_out_dim=dm_out_dim, hidden_dims=net_hidden_dims) else: raise ValueError('[ERROR] Invalid link prediction model: {}'.format(method)) encoder = None if method == 'ggnn': if not weight_typing: logging.info('Weight is not tying') if fp_dropout_rate != 0.0: logging.info('Forward propagation dropout rate is {:.1f}'.format(fp_dropout_rate)) if fp_batch_normalization: logging.info('Using batch normalization') if concat_hidden: logging.info('Using concatenation between layers') encoder = GGNN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, weight_tying=weight_typing) elif method == 'relgcn': encoder = RelGCN(out_channels=fp_out_dim, scale_adj=True) elif method == 'mpnn': if not weight_typing: logging.info('Weight is not tying') if concat_hidden: logging.info('Using concatenation between layers') encoder = MPNN(out_dim=fp_out_dim, hidden_dim=fp_hidden_dim, n_layers=conv_layers, concat_hidden=concat_hidden, weight_tying=weight_typing, readout_func='ggnn') else: raise ValueError('[ERROR] Invalid graph embedding encoder.') predictor = GraphConvPredictorForPair(encoder, lp, symmetric=symmetric) return predictor
def train_and_eval(self, entity2idx, language_model): logger.info("Training the %s model..." % model_name) self.entity_idxs = {d.entities[i]: i for i in range(len(d.entities))} matrix_entity_len = len(d.entities) print(f'matrix_entity_len: {matrix_entity_len}') weights_entity_matrix = np.zeros((matrix_entity_len, 300)) entities_found = 0 for entity_idx in self.entity_idxs.keys(): i = self.entity_idxs[entity_idx] entity_found = False embedding = [] try: entity_string = entity2idx[str(entity_idx)] for entity in entity_string: embedding.append(language_model[entity]) entity_found = True weights_entity_matrix[i] = np.array(embedding).mean(axis=0) except KeyError: if not embedding: weights_entity_matrix[i] = np.random.randn( 300, ) * np.sqrt(1 / (300 - 1)) else: weights_entity_matrix[i] = np.array(embedding).mean(axis=0) finally: if entity_found: entities_found += 1 logger.info(f'number of entities_found: {entities_found}') logger.info( f'number of unique entities found: {len(self.entity_idxs.keys())}') logger.info( f'entity pre trained vector coverage: {(entities_found / len(self.entity_idxs.keys()) * 100):.2f}%' ) self.entity_weights = weights_entity_matrix logger.debug( f'weights_entity_matrix size: {weights_entity_matrix.size}') self.relation_idxs = { d.relations[i]: i for i in range(len(d.relations)) } relation2idx = { str(i): re.sub(r'[/]', ' ', d.relations[i]).strip().split() for i in range(len(d.relations)) } for idx in relation2idx: relation2idx[idx] = [item.split('_') for item in relation2idx[idx]] matrix_relation_len = len(d.relations) logger.debug(f'matrix_relation_len: {matrix_relation_len}') weights_relation_matrix = np.zeros((matrix_relation_len, 300)) relations_found = 0 for relation_idx in self.relation_idxs.keys(): i = self.relation_idxs[relation_idx] document = [] embedding = [] relation_found = False try: document_string = relation2idx[str(i)] for relation_string in document_string: logger.debug(f'relation_string: {relation_string}') for relation in relation_string: logger.debug(f'relation: {relation}') document.append(language_model[relation]) embedding.append(np.array(document).mean(axis=0)) weights_relation_matrix[i] = np.array(embedding).mean(axis=0) relation_found = True except KeyError: if not embedding: weights_entity_matrix[i] = np.random.randn( 300, ) * np.sqrt(1 / (300 - 1)) else: weights_relation_matrix[i] = np.array(embedding).mean( axis=0) finally: if relation_found: relations_found += 1 logger.info(f'number of relations found: {relations_found}') logger.info( f'number of unique relations found: {len(self.relation_idxs.keys())}' ) logger.info( f'relations pre-trained vector coverage: {(relations_found / len(self.relation_idxs.keys()) * 100):.2f}%' ) self.relation_weights = weights_relation_matrix logger.debug( f'weights_relation_matrix size: {weights_relation_matrix.size}') train_data_idxs = self.get_data_idxs(d.train_data) logger.info("Number of training data points: %d" % len(train_data_idxs)) if model_name.lower() == "hype": model = HypE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "hyper": model = HypER(d, self.ent_vec_dim, self.rel_vec_dim, weights_entity_matrix, weights_relation_matrix, **self.kwargs) elif model_name.lower() == "distmult": model = DistMult(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "conve": model = ConvE(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) elif model_name.lower() == "complex": model = ComplEx(d, self.ent_vec_dim, self.rel_vec_dim, **self.kwargs) print([value.numel() for value in model.parameters()]) if self.cuda: model.cuda() model.init() opt = torch.optim.Adam(model.parameters(), lr=self.learning_rate) if self.decay_rate: scheduler = ExponentialLR(opt, self.decay_rate) er_vocab = self.get_er_vocab(train_data_idxs) er_vocab_pairs = list(er_vocab.keys()) print(len(er_vocab_pairs)) logger.info("Starting training...") for epoch in range(1, self.num_iterations + 1): logger.info(f'EPOCH: {epoch}') model.train() losses = [] np.random.shuffle(er_vocab_pairs) iteration = 0 for j in range(0, len(er_vocab_pairs), self.batch_size): data_batch, targets = self.get_batch(er_vocab, er_vocab_pairs, j) opt.zero_grad() e1_idx = torch.tensor(data_batch[:, 0]) r_idx = torch.tensor(data_batch[:, 1]) logger.debug(f'targets size: {targets.size()}') if self.cuda: e1_idx = e1_idx.cuda() r_idx = r_idx.cuda() predictions = model.forward(e1_idx, r_idx) logger.debug(f'logits size: {predictions.size()}') if self.label_smoothing: targets = ((1.0 - self.label_smoothing) * targets) + (1.0 / targets.size(1)) loss = self.loss(predictions, targets) loss.backward() opt.step() iteration += 1 if self.decay_rate: scheduler.step() losses.append(loss.item()) logger.info(f'EPOCH: {epoch}') logger.info(f'Loss: {np.mean(losses)}') model.eval() with torch.no_grad(): logger.info("Validation:") self.evaluate(model, d.valid_data) if (epoch % 10) == 0: logger.info("Test:") self.evaluate(model, d.test_data) model.eval() print("Test:") self.evaluate(model, d.test_data)