def execute(self, query): # self.sparql.setQuery(query) # self.sparql.setReturnFormat(JSON) # self.sparql.setMethod('GET') # results = self.sparql.queryAndConvert() # # target_vars=[var[1:] for var in target_vars]+['c'] # vars = results["head"]["vars"] params = { 'query': query, 'format': 'json', 'timeout': 0 # run = +Run + Query + } results = self.session.get(self.endpoint, params=params) if results.status_code != 200: logger.error('sparql query returned an error %r' % results) raise Exception(results) results = results.json() vars = results["head"]["vars"] results_formatted = [[result[var]["value"] for var in vars] for result in results["results"]["bindings"]] logger.debug(results) logger.debug(results_formatted) return results_formatted
def _filter_output_descriptions(self, description: Description): if description.preds[-1] in self.categorical_relations: if is_var(description.get_dangling_arg()): logger.debug("Avoid Desc: %s" % description.str_readable()) return False return True
def construct_count_query(self, head, description, not_head=False): target_var = description.get_target_var() # select part select_part = 'select count(distinct ' + str(target_var) + ') as ?c ' from_part = self.construct_from_part() # where part query_conditions = description.as_tuples() # print(query_conditions) # print(head) # prevent ->?xi<- filter_part = [] if not_head: filter_part = [ 'FILTER(' + str('?y') + '!=' + self._safe_repr(str(head[2])) + ').' ] where_part = 'WHERE { ' + str(target_var) + ' ' \ + self._safe_repr(head[1]) + ' ' + \ ('?y' if not_head else self._safe_repr(head[2])) + '. ' + \ ' '.join([(' '.join(map(self._safe_repr, x))) + '. ' for x in query_conditions]) + ' '.join( filter_part) + '} ' query = select_part + from_part + where_part logger.debug(query) return query
def get_predicates(self, description: Description2): var_predicates = description.get_var_predicates() query_conditions = description.body query = 'SELECT DISTINCT %s FROM <%s> Where {%s}' % ( ' '.join(var_predicates), self.identifiers[0], ' '.join( map(lambda a: a.tuple_sparql_repr(), query_conditions))) logger.debug('Query: %s' % query) predicates = list(chain.from_iterable(self.execute(query))) logger.debug('Predicates: %i' % len(predicates)) return predicates
def get_predicates(self, description: Description): var_predicates = description.get_var_predicates() query_conditions = description.as_tuples() query = 'SELECT DISTINCT %s FROM <%s> Where {%s}' % ( ' '.join(var_predicates), self.identifiers[0], ' '.join( [(' '.join(map(self._safe_repr, x))) + '. ' for x in query_conditions])) logger.debug('Query: %s' % query) predicates = list(chain.from_iterable(self.execute(query))) logger.debug('Predicates: %i' % len(predicates)) return predicates
def get_entities(self, description, limit=100): var_args = description.get_var_args() query_conditions = description.as_tuples() query = 'SELECT DISTINCT %s FROM <%s> Where {%s} limit %i' % ( ' '.join(var_args), self.identifiers[0], ' '.join( [(' '.join(map(self._safe_repr, x))) + '. ' for x in query_conditions]), limit) logger.debug('Query: %s' % query) entities = list(chain.from_iterable(self.execute(query))) logger.debug('Entities: %i' % len(entities)) return entities
def execute(self, query): self.sparql.setQuery(query) self.sparql.setReturnFormat(JSON) results = self.sparql.query().convert() # target_vars=[var[1:] for var in target_vars]+['c'] vars = results["head"]["vars"] results_formated = [[result[var]["value"] for var in vars] for result in results["results"]["bindings"]] logger.debug(results) logger.debug(results_formated) return results_formated
def get_entities(self, description: Description2, limit=100): # var_args = description.get_var_args() # TODO verify! var_args = description.get_uniq_var_args() query_conditions = description.body query = 'SELECT DISTINCT %s FROM <%s> Where {%s} limit %i' % ( ' '.join(var_args), self.identifiers[0], ' '.join( map(lambda a: a.tuple_sparql_repr(), query_conditions)), limit) logger.debug('Query: %s' % query) entities = list(chain.from_iterable(self.execute(query))) logger.debug('Entities: %i' % len(entities)) return entities
def init_miner(self): logger.debug("Data Statistics") stats = np.array(self.query_interface.get_data_stats(), dtype=object) print(stats.dtype) stats[:, [0, 2]] = stats[:, [0, 2]].astype(int) relative_stats = np.column_stack([stats[:, 0] / stats[:, 2], stats[:, 1], stats[:, 2] / stats[:, 0]]) self.relations_with_const_object = self.relations_with_const_object + list(relative_stats[relative_stats[:, 0] > 10, 1]) self.categorical_relations = self.categorical_relations + \ list(relative_stats[relative_stats[:, 0] > 50, 1]) # print(relations_with_const_object) # print(categorical_relations) logger.debug("Data Statistics")
def construct_query(self, description: Description2, min_coverage, per_pattern_limit): """ Construct a binding query for the variable predicates or arguments in the descriptipn rule :param description: :param min_coverage: :param per_pattern_limit: :return: """ # query_predicates = description.preds # query_predicates = description.get_predicates() logger.debug("get_pattern_bindings: %r" % description) target_var = description.get_target_var( ) # can be called anchor var or counting vars # predict_directions = description.get_predicates_directions() # Check if it should bind predicates or arguments bind_vars = description.get_var_predicates( ) # list(filter(is_var, query_predicates)) if not bind_vars: bind_vars = [description.get_dangling_arg()] # bind_vars = description.get_bind_vars() # select part select_part = 'select distinct ' + ' '.join( bind_vars) + ' (count(distinct ' + target_var + ') as ?c)' from_part = self.construct_from_part() # where part head + body filter_part = self.create_filter_part(description) # where_part = 'WHERE {' + ' '.join( map(tuple_sparql_repr, query_conditions)) + ' '.join(filter_part) + '} ' query_conditions = [description.head] + description.body where_part = self._create_where_str(query_conditions, filter_part) # group by group_by = self._create_groupby_str(bind_vars) having = '' if min_coverage <= 0 else ' HAVING (count(distinct' + target_var + ') >' + str( min_coverage) + ')' limit = '' if per_pattern_limit <= 0 else 'LIMIT ' + str( per_pattern_limit) suffix = ' ORDER BY desc(?c) ' + limit query = select_part + from_part + where_part + group_by + having + suffix logger.debug(query) return query
def consruct_confusion_matrix(y_true_str, y_pred): ids_true = set(y_true_str) ids_true_map = {e: i for e, i in zip(ids_true, range(0, len(ids_true)))} ids_predict = set(y_pred) ids_predict_map = {e: i for e, i in zip(ids_predict, range(0, len(ids_predict)))} logger.debug("Ground Truth ids: " + str(ids_true_map)) logger.debug("predicted Labels: " + str(ids_predict_map)) y_true = y_true_str assert y_pred.size == y_true.size D = max(len(ids_true), len(ids_predict)) # + 1 w = np.zeros((D, D), dtype=np.int64) for i in range(y_pred.size): w[ids_predict_map[y_pred[i]], ids_true_map[y_true[i]]] += 1 return w
def _bind_patterns(self, level_query_patterns, min_support): level_descriptions = [] for query_pattern in level_query_patterns: res = self.query_interface.get_pattern_bindings(query_pattern, min_support, self.per_pattern_binding_limit) for r in res: description = deepcopy(query_pattern) if len(description.get_var_predicates()) > 0: # it is a predicate description.get_last_atom().predicate = str(r[0]) else: description.set_dangling_arg(str(r[0])) logger.debug("** After binding: %s" % str(description)) description.target_head_support = int(r[1]) level_descriptions.append(description) return level_descriptions
def micro_acc_triples(gt_triples, predict_triples): gt_labels = defaultdict(lambda: len(gt_labels)) predict_labels = defaultdict(lambda: len(predict_labels)) gt_dict = {t[0]: gt_labels[t[2]] for t in gt_triples} predict_dict = {t[0]: predict_labels[t[2]] for t in predict_triples} logger.debug("GT Triples: %i Predict Triples: %i" % (len(gt_dict), len(predict_dict))) assert len(gt_dict) == len(predict_dict) D = max(len(gt_labels), len(predict_labels)) # + 1 w = np.zeros((D, D), dtype=np.int64) for k in predict_dict: w[predict_dict[k], gt_dict[k]] += 1 row_ind, col_ind = linear_sum_assignment(w.max() - w) return (w[row_ind, col_ind].sum() + 1.0) / len(predict_dict)
def cluster(self): logger.info("Start clustering") entity_vectors = self.current_itr.target_entities_embeddings logger.debug(entity_vectors) logger.info("size of the data " + str(entity_vectors.shape)) y_pred = self.clustering_method.cluster(entity_vectors, clustering_params=self.clustering_params, output_folder=self.get_current_itr_directory()) triples = EntityLabelsToTriples(np.column_stack((self.target_entities.get_entities(), y_pred.reshape(-1, 1))), iter_id=self.current_itr.id) if self.save_steps: output_file = os.path.join(self.get_current_itr_directory(), 'clustering.tsv') output_vecs_file = os.path.join(self.get_current_itr_directory(), 'embeddings_vecs.tsv') write_triples(triples, output_file) write_vectors(entity_vectors, output_vecs_file) output_labels_file = os.path.join(self.get_current_itr_directory(), 'clustering_labels_only.tsv') write_vectors(y_pred.reshape(-1,1), output_labels_file) return triples
def load_dict(self, filename, flipped=False): if os.path.exists(filename): # logger.info("Exists!"+ str(os.path.exists(filename))) logger.debug("Loading %s into dictionary! (Flipped: %r)" % (filename, flipped)) with open(filename) as fh: # fh.readline() lines = (line.split(None, 1) if '\t' in line else ['q', -1] for line in fh) if flipped: mapping = dict( (int(number), word) for word, number in lines) else: mapping = dict( (word, int(number)) for word, number in lines) logger.debug("Dictionary Size: %i" % len(mapping)) # print(mapping[68]) return mapping return dict()
def get_arguments_bindings(self, description: Description2, restriction_pattern: Description2 = None): """ do the inference and generate binging for the head variable. :param description: :param restriction_pattern: tuple restricting the variables to :return: """ # print(description) # logger.debug("Get bindings for Description: %r" %description) restriction_pattern = restriction_pattern if restriction_pattern else Description2( ) query = self.construct_argument_bind_query(description, restriction_pattern) res = self.execute(query) res = [r[0] for r in res] logger.debug("results size: %i", len(res)) return res
def _expand_pattern(self, pattern_to_expand, bind_only_const, i): # print('Pattern\n%s' % pattern_to_expand.str_readable()) level_query_patterns = [] # Do not extend the pattern if it is constant binding iteration if not bind_only_const: # in_edge and out_edge level_query_patterns += [ deepcopy(pattern_to_expand), deepcopy(pattern_to_expand) ] for d in level_query_patterns: # add predicate variable can be fixed predicate d.preds.append('?p' + str(i)) # add variable (can be changed to repeat the variables) d.args.append('?x' + str(i)) # add edge direction list( map(lambda d, pred_d: d.pred_direct.append(pred_d), level_query_patterns, [True, False])) # if last predicate was in relation that has interesting constant in check constants # TODO the pattern size > 1 is a adhoc solution to avoid having simple explanations if pattern_to_expand.size() > 1 and pattern_to_expand.pred_direct[ -1] and pattern_to_expand.preds[ -1] in self.relations_with_const_object: # print('Extend to bind constants if not already \n%s' %pattern_to_expand.str_readable()) # if not yet expanded if is_var(pattern_to_expand.get_dangling_arg()): level_query_patterns += [deepcopy(pattern_to_expand)] logger.debug(str(level_query_patterns[-1])) # print('Extended to bind constants\n%s' % pattern_to_expand.str_readable()) return level_query_patterns
def construct_argument_bind_query(self, description: Description2, restriction_pattern=Description2()): target_var = description.get_target_var() # select part select_part = 'select distinct ' + target_var from_part = self.construct_from_part() # where part query_conditions = description.body + restriction_pattern.body filter_part = self.create_filter_part(description) # where_part = 'WHERE { ' + ' '.join(map(lambda a: a.tuple_sparql_repr(), query_conditions)) + '} ' where_part = self._create_where_str(query_conditions, filter_part) query = select_part + from_part + where_part logger.debug(query) return query
def construct_argument_bind_query(self, description, restriction_pattern=Description()): target_var = description.get_target_var() # select part select_part = 'select distinct ' + target_var from_part = self.construct_from_part() # where part query_conditions = description.as_tuples( ) + restriction_pattern.as_tuples() where_part = 'WHERE { ' + ' '.join( [(' '.join(map(self._safe_repr, x))) + '. ' for x in query_conditions]) + '} ' query = select_part + from_part + where_part logger.debug(query) return query
def construct_count_query(self, description: Description2, alternative_head=None, not_head=False): head = alternative_head if alternative_head else description.head target_var = description.get_target_var() # select part select_part = 'select count(distinct ' + str(target_var) + ') as ?c ' from_part = self.construct_from_part() # where part query_conditions = description.body # print(query_conditions) # print(head) # prevent ->?xi<- filter_part = [str(target_var) + ' ' + _sparql_repr(head.predicate) + ' ' + \ ('?y' if not_head else _sparql_repr(head.object)) + '. '] if not_head: filter_part = [ 'FILTER(' + str('?y') + '!=' + _sparql_repr(head.object) + ').' ] filter_part += self.create_filter_part(description) # where_part = 'WHERE { '+ # ' '.join(map(lambda a: a.tuple_sparql_repr(), query_conditions)) + ' '.join(filter_part) + '} ' where_part = self._create_where_str(query_conditions, filter_part) query = select_part + from_part + where_part logger.debug(query) return query
def restore_model(model_name_path=None, module_name="ampligraph.latent_features"): """Restore a saved model from disk. See also :meth:`save_model`. Parameters ---------- model_name_path: string The name of saved model to be restored. If not specified, the library will try to find the default model in the working directory. Returns ------- model: EmbeddingModel the neural knowledge graph embedding model restored from disk. """ if model_name_path is None: logger.warning("There is no model name specified. \ We will try to lookup \ the latest default saved model...") default_models = glob.glob("*.model.pkl") if len(default_models) == 0: raise Exception("No default model found. Please specify \ model_name_path...") else: model_name_path = default_models[len(default_models) - 1] logger.info("Will will load the model: {0} in your \ current dir...".format(model_name_path)) model = None logger.info('Will load model {}.'.format(model_name_path)) try: with open(model_name_path, 'rb') as fr: restored_obj = pickle.load(fr) logger.debug('Restoring model ...') module = importlib.import_module(module_name) class_ = getattr(module, restored_obj['class_name'].replace('Continue', '')) model = class_(**restored_obj['hyperparams']) model.is_fitted = restored_obj['is_fitted'] model.ent_to_idx = restored_obj['ent_to_idx'] model.rel_to_idx = restored_obj['rel_to_idx'] try: model.is_calibrated = restored_obj['is_calibrated'] except KeyError: model.is_calibrated = False model.restore_model_params(restored_obj) except pickle.UnpicklingError as e: msg = 'Error unpickling model {} : {}.'.format(model_name_path, e) logger.debug(msg) raise Exception(msg) except (IOError, FileNotFoundError): msg = 'No model found: {}.'.format(model_name_path) logger.debug(msg) raise FileNotFoundError(msg) return model
def mine_with_constants(self, head, max_length=2, min_coverage=-1.0, negative_heads=None): if isinstance(head, tuple): head = Atom(head[0], head[1], head[2]) negative_heads = negative_heads if negative_heads else [] logger.info('Learn descriptions for ' + str(head)) # start_var = head.subject if head.subject else '?x' descriptions = [] # for evaluation target_head_size = self.query_interface.count(Description2(head=head)) # logger.info('Taget head size %i' % target_head_size) min_support = int(min_coverage * target_head_size) # print(negative_heads) negative_heads_sizes = [self.query_interface.count(Description2(head=neg_head)) for neg_head in negative_heads] # logger.info('Neagtive head sizes %r' % negative_heads_sizes) base_description = Description2(head=head) previous_level_descriptions = [base_description] # TODO the last iteration will be only to bind constants in the last predicate (better way to be implemented) # const_iteration = max_length + 1 for i in range(1, max_length + 1): logger.info("Discovering Level: %i" % (i)) level_descriptions = [] for cur_pattern in previous_level_descriptions: logger.debug('Expand Description Pattern: %r', cur_pattern) # expand() description_extended_patterns = self._expand_pattern(cur_pattern, i) logger.debug('Expanded patterns Size: %i' % len(description_extended_patterns)) # bind predicates query_bind_descriptions = self._bind_patterns(description_extended_patterns, min_support) # bind args if required descriptions_with_constants = self._get_patterns_with_bindable_args(query_bind_descriptions) query_bind_descriptions += self._bind_patterns(descriptions_with_constants, min_support) # Prune bind descriptions query_bind_descriptions = list(filter(self._filter_level_descriptions, query_bind_descriptions)) # Add Quality Scores to binede descriptions self._add_quality_to_descriptions(query_bind_descriptions, target_head_size, negative_heads, negative_heads_sizes) level_descriptions += query_bind_descriptions # Remove identical but different order body atoms # WARN: may not work well becasue of the trivial implementation of __eq__ and __hash__ of Description2 level_descriptions = set(level_descriptions) # TODO make the filter global descriptions += list(filter(self._filter_output_descriptions, level_descriptions)) previous_level_descriptions = level_descriptions logger.info("Done level: " + str(i) + ' level descriptions: ' + str( len(level_descriptions)) + ' total descriptions: ' + str(len(descriptions))) return descriptions
def _prepare_data(self, clusters_as_triples): logger.debug("Indexing clustering results! into %s" % self.indexer.identifier) self.indexer.index_triples(clusters_as_triples, drop_old=True) logger.debug("Indexing clustering results!")
def fit(self, X, early_stopping=False, early_stopping_params={}, continue_training=False): """Train an EmbeddingModel (with optional early stopping). The model is trained on a training set X using the training protocol described in :cite:`trouillon2016complex`. Parameters ---------- X : ndarray (shape [n, 3]) or object of AmpligraphDatasetAdapter Numpy array of training triples OR handle of Dataset adapter which would help retrieve data. early_stopping: bool Flag to enable early stopping (default:``False``) early_stopping_params: dictionary Dictionary of hyperparameters for the early stopping heuristics. The following string keys are supported: - **'x_valid'**: ndarray (shape [n, 3]) or object of AmpligraphDatasetAdapter : Numpy array of validation triples OR handle of Dataset adapter which would help retrieve data. - **'criteria'**: string : criteria for early stopping 'hits10', 'hits3', 'hits1' or 'mrr'(default). - **'x_filter'**: ndarray, shape [n, 3] : Positive triples to use as filter if a 'filtered' early stopping criteria is desired (i.e. filtered-MRR if 'criteria':'mrr'). Note this will affect training time (no filter by default). If the filter has already been set in the adapter, pass True - **'burn_in'**: int : Number of epochs to pass before kicking in early stopping (default: 100). - **check_interval'**: int : Early stopping interval after burn-in (default:10). - **'stop_interval'**: int : Stop if criteria is performing worse over n consecutive checks (default: 3) - **'corruption_entities'**: List of entities to be used for corruptions. If 'all', it uses all entities (default: 'all') - **'corrupt_side'**: Specifies which side to corrupt. 's', 'o', 's+o' (default) Example: ``early_stopping_params={x_valid=X['valid'], 'criteria': 'mrr'}`` """ self.train_dataset_handle = None # try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook # TODO change 0: Update the mapping if there are new entities. if continue_training: self.update_mapping(X) try: if isinstance(X, np.ndarray): # Adapt the numpy data in the internal format - to generalize self.train_dataset_handle = NumpyDatasetAdapter() self.train_dataset_handle.set_data(X, "train") elif isinstance(X, AmpligraphDatasetAdapter): self.train_dataset_handle = X else: msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, got {}'.format( type(X)) logger.error(msg) raise ValueError(msg) # create internal IDs mappings # TODO Change 1: fist change to reuse the existing mappings rel_to_idx and ent_to_idx if not continue_training: self.rel_to_idx, self.ent_to_idx = self.train_dataset_handle.generate_mappings( ) else: self.train_dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx) prefetch_batches = 1 if len(self.ent_to_idx) > ENTITY_THRESHOLD: self.dealing_with_large_graphs = True logger.warning( 'Your graph has a large number of distinct entities. ' 'Found {} distinct entities'.format(len(self.ent_to_idx))) logger.warning( 'Changing the variable initialization strategy.') logger.warning( 'Changing the strategy to use lazy loading of variables...' ) if early_stopping: raise Exception( 'Early stopping not supported for large graphs') if not isinstance(self.optimizer, SGDOptimizer): raise Exception( "This mode works well only with SGD optimizer with decay (read docs for details).\ Kindly change the optimizer and restart the experiment") if self.dealing_with_large_graphs: prefetch_batches = 0 # CPU matrix of embeddings # TODO Change 2.1: do not intialize if continue training if not continue_training: self.ent_emb_cpu = self.initializer.get_np_initializer( len(self.ent_to_idx), self.internal_k) self.train_dataset_handle.map_data() # This is useful when we re-fit the same model (e.g. retraining in model selection) if self.is_fitted: tf.reset_default_graph() self.rnd = check_random_state(self.seed) tf.random.set_random_seed(self.seed) self.sess_train = tf.Session(config=self.tf_config) # change 2.2 : Do not change batch size with new training data, just use the old (for large KGs) # if not continue_training: batch_size = int( np.ceil( self.train_dataset_handle.get_size("train") / self.batches_count)) # else: # batch_size = self.batch_size logger.info("Batch Size: %i" % batch_size) # dataset = tf.data.Dataset.from_tensor_slices(X).repeat().batch(batch_size).prefetch(2) if len(self.ent_to_idx) > ENTITY_THRESHOLD: logger.warning( 'Only {} embeddings would be loaded in memory per batch...' .format(batch_size * 2)) self.batch_size = batch_size # TODO change 3: load model from trained params if continue instead of re_initialize the ent_emb and rel_emb if not continue_training: self._initialize_parameters() else: self._load_model_from_trained_params() dataset = tf.data.Dataset.from_generator( self._training_data_generator, output_types=(tf.int32, tf.int32, tf.float32), output_shapes=((None, 3), (None, 1), (None, self.internal_k))) dataset = dataset.repeat().prefetch(prefetch_batches) dataset_iterator = tf.data.make_one_shot_iterator(dataset) # init tf graph/dataflow for training # init variables (model parameters to be learned - i.e. the embeddings) if self.loss.get_state('require_same_size_pos_neg'): batch_size = batch_size * self.eta loss = self._get_model_loss(dataset_iterator) train = self.optimizer.minimize(loss) # Entity embeddings normalization normalize_ent_emb_op = self.ent_emb.assign( tf.clip_by_norm(self.ent_emb, clip_norm=1, axes=1)) self.early_stopping_params = early_stopping_params # early stopping if early_stopping: self._initialize_early_stopping() self.sess_train.run(tf.tables_initializer()) self.sess_train.run(tf.global_variables_initializer()) try: self.sess_train.run(self.set_training_true) except AttributeError: pass normalize_rel_emb_op = self.rel_emb.assign( tf.clip_by_norm(self.rel_emb, clip_norm=1, axes=1)) if self.embedding_model_params.get( 'normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS): self.sess_train.run(normalize_rel_emb_op) self.sess_train.run(normalize_ent_emb_op) epoch_iterator_with_progress = tqdm(range(1, self.epochs + 1), disable=(not self.verbose), unit='epoch') # print("before epochs!") # print(self.sess_train.run(self.ent_emb)) # print(self.sess_train.run(self.rel_emb)) for epoch in epoch_iterator_with_progress: losses = [] for batch in range(1, self.batches_count + 1): feed_dict = {} self.optimizer.update_feed_dict(feed_dict, batch, epoch) if self.dealing_with_large_graphs: loss_batch, unique_entities, _ = self.sess_train.run( [loss, self.unique_entities, train], feed_dict=feed_dict) self.ent_emb_cpu[np.squeeze(unique_entities), :] = \ self.sess_train.run(self.ent_emb)[:unique_entities.shape[0], :] else: loss_batch, _ = self.sess_train.run( [loss, train], feed_dict=feed_dict) if np.isnan(loss_batch) or np.isinf(loss_batch): msg = 'Loss is {}. Please change the hyperparameters.'.format( loss_batch) logger.error(msg) raise ValueError(msg) losses.append(loss_batch) if self.embedding_model_params.get( 'normalize_ent_emb', constants.DEFAULT_NORMALIZE_EMBEDDINGS): self.sess_train.run(normalize_ent_emb_op) if self.verbose: msg = 'Average Loss: {:10f}'.format( sum(losses) / (batch_size * self.batches_count)) if early_stopping and self.early_stopping_best_value is not None: msg += ' — Best validation ({}): {:5f}'.format( self.early_stopping_criteria, self.early_stopping_best_value) logger.debug(msg) epoch_iterator_with_progress.set_description(msg) if early_stopping: try: self.sess_train.run(self.set_training_false) except AttributeError: pass if self._perform_early_stopping_test(epoch): self._end_training() return try: self.sess_train.run(self.set_training_true) except AttributeError: pass self._save_trained_params() self._end_training() except BaseException as e: self._end_training() raise e
def construct_query(self, head, description, min_coverage, per_pattern_limit): # query_predicates = description.preds query_predicates = description.get_predicates() target_var = description.get_target_var() predict_directions = description.get_predicates_directions() # Check if it should bind predicates or arguments bind_vars = description.get_var_predicates( ) # list(filter(is_var, query_predicates)) if not bind_vars: bind_vars = description.get_bind_args() # bind_vars = description.get_bind_vars() # select part select_part = 'select distinct ' + ' '.join( bind_vars) + ' (count(distinct ' + target_var + ') as ?c)' from_part = self.construct_from_part() # where part query_conditions = description.as_tuples() # remove filter predicate from patterns # filter_part=['FILTER(' + con[1] + '!=' + str(entities_filter[1]) + ').' if con[0] == start_var else '' for con in query_conditions] # print(query_conditions) filter_part = [ 'FILTER(' + p + '!=' + self._safe_repr(str(head[1])) + ').' for p in bind_vars ] # prevent ->?xi<- if len(predict_directions) > 1: for i in range(1, len(predict_directions)): if predict_directions[i] != predict_directions[ i - 1] and is_var(query_predicates[i]): filter_part.append( 'FILTER(' + self._safe_repr(str(query_predicates[i])) + '!=' + self._safe_repr(str(query_predicates[i - 1])) + ').') where_part = 'WHERE {' + ' '.join(map( self._safe_repr, head)) + '. ' + ' '.join( [(' '.join(map(self._safe_repr, x))) + '. ' for x in query_conditions]) + ' '.join(filter_part) + '} ' # group by group_by = ' GROUP BY ' + ' '.join(bind_vars) having = '' if min_coverage > 0: having = ' HAVING (count(distinct' + target_var + ') >' + str( min_coverage) + ')' limit = '' if per_pattern_limit > 0: limit = 'LIMIT ' + str(per_pattern_limit) suffix = ' ORDER BY desc(?c) ' + limit query = select_part + from_part + where_part + group_by + having + suffix logger.debug(query) return query