def decodeNonDuplicatePredictions( self, context: TacticContext, all_idxs: List[Tuple[float, int, int]], k: int) -> List[Prediction]: assert self.training_args num_stem_poss = get_num_tokens(self.metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss) if self.training_args.lemma_args: all_hyps = context.hypotheses + context.relevant_lemmas else: all_hyps = context.hypotheses prediction_strs: List[str] = [] prediction_probs: List[float] = [] next_i = 0 num_valid_probs = (1 + len(all_hyps) + len(get_fpa_words(context.goal))) * stem_width while len(prediction_strs) < k and next_i < num_valid_probs: next_pred_str = decode_fpa_result( extract_dataloader_args(self.training_args), self.metadata, all_hyps, context.goal, all_idxs[next_i][1], all_idxs[next_i][2]) # next_pred_str = "" if next_pred_str not in prediction_strs: prediction_strs.append(next_pred_str) prediction_probs.append(math.exp(all_idxs[next_i][0])) next_i += 1 predictions = [Prediction(s, prob) for s, prob in zip(prediction_strs, prediction_probs)] return predictions
def getAllPredictionIdxs_batch(self, contexts: List[TacticContext], verbosity:int = 0) -> List[List[Tuple[float, int, int]]]: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self.metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss) tokenized_premises_batch, premise_features_batch, \ nhyps_batch, tokenized_goal_batch, \ goal_mask, \ word_features, vec_features = \ sample_fpa_batch(extract_dataloader_args(self.training_args), self.metadata, [context_py2r(context) for context in contexts]) stem_certainties_batch, stem_idxs_batch = self.predict_stems( stem_width, word_features, vec_features) goal_arg_values_batch = self.goal_token_scores( stem_idxs_batch, tokenized_goal_batch, goal_mask) idxs_batch = [] for (stem_certainties, stem_idxs, goal_arg_values, tokenized_goal, tokenized_premises, premise_features) in \ tqdm(zip(stem_certainties_batch, stem_idxs_batch, goal_arg_values_batch, tokenized_goal_batch, tokenized_premises_batch, premise_features_batch), desc="Assessing hyp args and decoding indices", total=len(contexts), disable=verbosity <= 1): if len(tokenized_premises) > 0: premise_arg_values = self.hyp_name_scores( stem_idxs, tokenized_goal, tokenized_premises, premise_features) total_scores = torch.cat((goal_arg_values.unsqueeze(0), premise_arg_values), dim=2) else: total_scores = goal_arg_values.unsqueeze(0) probs, stems, args = self.predict_args( total_scores, stem_certainties, stem_idxs) idxs_batch.append(list(zip(list(probs), list(stems), list(args)))) return idxs_batch
def load_saved_state(self, args: Namespace, unparsed_args: List[str], metadata: Any, state: NeuralPredictorState) -> None: model = maybe_cuda( self._get_model(args, get_word_feature_vocab_sizes(metadata), get_vec_features_size(metadata), get_num_indices(metadata), get_num_tokens(metadata))) model.load_state_dict(state.weights) self._model = model self.training_loss = state.loss self.num_epochs = state.epoch self.training_args = args self.unparsed_args = unparsed_args self._metadata = metadata
def getAllPredictionIdxs(self, context: TacticContext ) -> List[Tuple[float, int, int]]: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self.metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss) tokenized_premises, hyp_features, \ nhyps_batch, tokenized_goal, \ goal_mask, \ word_features, vec_features = \ sample_fpa(extract_dataloader_args(self.training_args), self.metadata, context.relevant_lemmas, context.prev_tactics, context.hypotheses, context.goal) stem_certainties, stem_idxs = self.predict_stems( stem_width, word_features, vec_features) goal_arg_values = self.goal_token_scores( stem_idxs, tokenized_goal, goal_mask) if len(tokenized_premises[0]) > 0: hyp_arg_values = self.hyp_name_scores( stem_idxs[0], tokenized_goal[0], tokenized_premises[0], hyp_features[0]) total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2) else: total_scores = goal_arg_values final_probs, predicted_stem_idxs, predicted_arg_idxs = \ self.predict_args(total_scores, stem_certainties, stem_idxs) result = list(zip(list(final_probs), list(predicted_stem_idxs), list(predicted_arg_idxs))) return result
def _optimize_model( self, arg_values: Namespace) -> Iterable[FeaturesPolyargState]: with print_time("Loading data", guard=arg_values.verbose): if arg_values.start_from: _, (old_arg_values, unparsed_args, metadata, state) = torch.load(arg_values.start_from) _, data_lists, \ (word_features_size, vec_features_size) = \ features_polyarg_tensors_with_meta( extract_dataloader_args(arg_values), str(arg_values.scrape_file), metadata) else: metadata, data_lists, \ (word_features_size, vec_features_size) = \ features_polyarg_tensors( extract_dataloader_args(arg_values), str(arg_values.scrape_file)) with print_time("Converting data to tensors", guard=arg_values.verbose): unpadded_tokenized_hyp_types, \ unpadded_hyp_features, \ num_hyps, \ tokenized_goals, \ goal_masks, \ word_features, \ vec_features, \ tactic_stem_indices, \ arg_indices = data_lists tensors = [ pad_sequence([ torch.LongTensor(tokenized_hyps_list) for tokenized_hyps_list in unpadded_tokenized_hyp_types ], batch_first=True), pad_sequence([ torch.FloatTensor(hyp_features_vec) for hyp_features_vec in unpadded_hyp_features ], batch_first=True), torch.LongTensor(num_hyps), torch.LongTensor(tokenized_goals), torch.ByteTensor(goal_masks), torch.LongTensor(word_features), torch.FloatTensor(vec_features), torch.LongTensor(tactic_stem_indices), torch.LongTensor(arg_indices) ] with open("tensors.pickle", 'wb') as f: torch.save(tensors, f) eprint(tensors, guard=arg_values.print_tensors) with print_time("Building the model", guard=arg_values.verbose): if arg_values.start_from: self.load_saved_state(arg_values, unparsed_args, metadata, state) model = self._model epoch_start = self.num_epochs else: model = self._get_model(arg_values, word_features_size, vec_features_size, get_num_indices(metadata), get_num_tokens(metadata)) epoch_start = 1 assert model assert epoch_start return ((metadata, state) for state in optimize_checkpoints( tensors, arg_values, model, lambda batch_tensors, model: self._getBatchPredictionLoss( arg_values, batch_tensors, model), epoch_start))
def predictKTactics(self, context: TacticContext, k: int) -> List[Prediction]: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self._metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss, k**2) tokenized_premises, hyp_features, \ nhyps_batch, tokenized_goal, \ goal_mask, \ word_features, vec_features = \ sample_fpa(extract_dataloader_args(self.training_args), self._metadata, context.relevant_lemmas, context.prev_tactics, context.hypotheses, context.goal) num_hyps = nhyps_batch[0] stem_distribution = self._model.stem_classifier( LongTensor(word_features), FloatTensor(vec_features)) stem_certainties, stem_idxs = stem_distribution.topk(stem_width) goals_batch = LongTensor(tokenized_goal) goal_arg_values = self._model.goal_args_model( stem_idxs.view(1 * stem_width), goals_batch.view(1, 1, self.training_args.max_length) .expand(-1, stem_width, -1).contiguous() .view(1 * stem_width, self.training_args.max_length))\ .view(1, stem_width, self.training_args.max_length + 1) goal_arg_values = torch.where( torch.ByteTensor(goal_mask).view(1, 1, self.training_args.max_length + 1).expand(-1, stem_width, -1), goal_arg_values, torch.full_like(goal_arg_values, -float("Inf"))) assert goal_arg_values.size() == torch.Size([1, stem_width, self.training_args.max_length + 1]),\ "goal_arg_values.size(): {}; stem_width: {}".format(goal_arg_values.size(), stem_width) num_probs = 1 + num_hyps + self.training_args.max_length goal_symbols = get_fpa_words(context.goal) num_valid_probs = (1 + num_hyps + len(goal_symbols)) * stem_width if num_hyps > 0: encoded_goals = self._model.goal_encoder(goals_batch)\ .view(1, 1, self.training_args.hidden_size) hyps_batch = LongTensor(tokenized_premises) assert hyps_batch.size() == torch.Size([1, num_hyps, self.training_args.max_length]), \ (hyps_batch.size(), num_hyps, self.training_args.max_length) hypfeatures_batch = FloatTensor(hyp_features) assert hypfeatures_batch.size() == \ torch.Size([1, num_hyps, hypFeaturesSize()]), \ (hypfeatures_batch.size(), num_hyps, hypFeaturesSize()) hyp_arg_values = self.runHypModel(stem_idxs, encoded_goals, hyps_batch, hypfeatures_batch) assert hyp_arg_values.size() == \ torch.Size([1, stem_width, num_hyps]) total_values = torch.cat((goal_arg_values, hyp_arg_values), dim=2) else: total_values = goal_arg_values all_prob_batches = self._softmax((total_values + stem_certainties.view(1, stem_width, 1) .expand(-1, -1, num_probs)) .contiguous() .view(1, stem_width * num_probs))\ .view(stem_width * num_probs) final_probs, final_idxs = all_prob_batches.topk(k) assert not torch.isnan(final_probs).any() assert final_probs.size() == torch.Size([k]) row_length = self.training_args.max_length + num_hyps + 1 stem_keys = final_idxs // row_length assert stem_keys.size() == torch.Size([k]) assert stem_idxs.size() == torch.Size([1, stem_width]), stem_idxs.size() prediction_stem_idxs = stem_idxs.view(stem_width).index_select( 0, stem_keys) assert prediction_stem_idxs.size() == torch.Size([k]), \ prediction_stem_idxs.size() arg_idxs = final_idxs % row_length assert arg_idxs.size() == torch.Size([k]) if self.training_args.lemma_args: all_hyps = context.hypotheses + context.relevant_lemmas else: all_hyps = context.hypotheses return [ Prediction( decode_fpa_result(extract_dataloader_args(self.training_args), self._metadata, all_hyps, context.goal, stem_idx.item(), arg_idx.item()), math.exp(prob)) for stem_idx, arg_idx, prob in islice( zip(prediction_stem_idxs, arg_idxs, final_probs), min(k, num_valid_probs)) ]
def predictKTactics_batch(self, context_batch: List[TacticContext], k: int) \ -> List[List[Prediction]]: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self._metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss, k**2) batch_size = len(context_batch) tprems_batch, pfeat_batch, \ nhyps_batch, tgoals_batch, \ goal_masks_batch, \ wfeats_batch, vfeats_batch = \ sample_fpa_batch(extract_dataloader_args(self.training_args), self._metadata, [context_py2r(context) for context in context_batch]) for tprem, pfeat, nhyp, tgoal, masks, wfeat, vfeat, context \ in zip(tprems_batch, pfeat_batch, nhyps_batch, tgoals_batch, goal_masks_batch, wfeats_batch, vfeats_batch, context_batch): s_tprem, s_pfeat, s_nhyp, s_tgoal, s_masks, s_wfeat, s_vfeat = \ sample_fpa(extract_dataloader_args(self.training_args), self._metadata, context.relevant_lemmas, context.prev_tactics, context.hypotheses, context.goal) assert len(s_tprem) == 1 assert len(tprem) == len(s_tprem[0]) for p1, p2 in zip(tprem, s_tprem[0]): assert p1 == p2, (p1, p2) assert len(s_pfeat) == 1 assert len(pfeat) == len(s_pfeat[0]) for f1, f2 in zip(pfeat, s_pfeat[0]): assert f1 == f2, (f1, f2) assert s_nhyp[0] == nhyp, (s_nhyp[0], nhyp) assert s_tgoal[0] == tgoal, (s_tgoal[0], tgoal) assert s_masks[0] == masks, (s_masks[0], masks) assert s_wfeat[0] == wfeat, (s_wfeat[0], wfeat) assert s_vfeat[0] == vfeat, (s_vfeat[0], vfeat) stem_distribution = self._model.stem_classifier( LongTensor(wfeats_batch), FloatTensor(vfeats_batch)) stem_certainties_batch, stem_idxs_batch = stem_distribution.topk( stem_width) goals_batch = LongTensor(tgoals_batch) goal_arg_values = self._model.goal_args_model( stem_idxs_batch.view(batch_size * stem_width), goals_batch.view(batch_size, 1, self.training_args.max_length) .expand(-1, stem_width, -1).contiguous() .view(batch_size * stem_width, self.training_args.max_length))\ .view(batch_size, stem_width, self.training_args.max_length + 1) goal_arg_values = torch.where( torch.ByteTensor(goal_masks_batch).view( batch_size, 1, self.training_args.max_length + 1).expand(-1, stem_width, -1), goal_arg_values, torch.full_like(goal_arg_values, -float("Inf"))) encoded_goals_batch = self._model.goal_encoder(goals_batch) stems_expanded_batch = torch.cat([ stem_idxs.view(1, stem_width).expand( num_hyps, stem_width).contiguous().view(num_hyps * stem_width) for stem_idxs, num_hyps, in zip(stem_idxs_batch, nhyps_batch) ]) egoals_expanded_batch = torch.cat([ encoded_goal.view(1, self.training_args.hidden_size).expand( num_hyps * stem_width, -1).contiguous() for encoded_goal, num_hyps in zip(encoded_goals_batch, nhyps_batch) ]) tprems_expanded_batch = torch.cat([ LongTensor(tpremises).view( 1, -1, self.training_args.max_length).expand( stem_width, -1, -1).contiguous().view(-1, self.training_args.max_length) for tpremises in tprems_batch ]) pfeat_expanded_batch = torch.cat([ FloatTensor(premise_features).view(1, num_hyps, 2).expand( stem_width, -1, -1).contiguous().view(num_hyps * stem_width, 2) for premise_features, num_hyps in zip(pfeat_batch, nhyps_batch) ]) prem_arg_values = self._model.hyp_model(stems_expanded_batch, egoals_expanded_batch, tprems_expanded_batch, pfeat_expanded_batch) prem_arg_values_split = prem_arg_values.split( [num_hyps * stem_width for num_hyps in nhyps_batch]) total_values_list = [ torch.cat((goal_values, prem_values.view(stem_width, -1)), dim=1) for goal_values, prem_values in zip( goal_arg_values, prem_arg_values_split) ] all_probs_list = [ self._softmax((total_values + stem_certainties.view( stem_width, 1).expand_as(total_values)).contiguous().view( 1, -1)).view(-1) for total_values, stem_certainties in zip( total_values_list, stem_certainties_batch) ] final_probs_list, final_idxs_list = zip( *[probs.topk(k) for probs in all_probs_list]) stem_keys_list = [ final_idxs // (self.training_args.max_length + num_hyps + 1) for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch) ] stem_idxs_list = [ stem_idxs.view(stem_width).index_select(0, stem_keys) for stem_idxs, stem_keys in zip(stem_idxs_batch, stem_keys_list) ] arg_idxs_list = [ final_idxs % (self.training_args.max_length + num_hyps + 1) for final_idxs, num_hyps in zip(final_idxs_list, nhyps_batch) ] predictions = [[ Prediction( decode_fpa_result(extract_dataloader_args(self.training_args), self._metadata, context.hypotheses + context.relevant_lemmas, context.goal, stem_idx.item(), arg_idx.item()), math.exp(prob)) for stem_idx, arg_idx, prob in islice( zip(stem_idxs, arg_idxs, final_probs), min(k, 1 + num_hyps + len(get_fpa_words(context.goal)))) ] for stem_idxs, arg_idxs, final_probs, context, num_hyps in zip( stem_idxs_list, arg_idxs_list, final_probs_list, context_batch, nhyps_batch)] for context, pred_list in zip(context_batch, predictions): for batch_pred, single_pred in zip( pred_list, self.predictKTactics(context, k)): assert batch_pred.prediction == single_pred.prediction, \ (batch_pred, single_pred) return predictions
def predictionCertainty(self, context: TacticContext, prediction: str) -> float: assert self.training_args assert self._model num_stem_poss = get_num_tokens(self.metadata) stem_width = min(self.training_args.max_beam_width, num_stem_poss) tokenized_premises, hyp_features, \ nhyps_batch, tokenized_goal, \ goal_mask, \ word_features, vec_features = \ sample_fpa(extract_dataloader_args(self.training_args), self.metadata, context.relevant_lemmas, context.prev_tactics, context.hypotheses, context.goal) prediction_stem, prediction_args = \ serapi_instance.split_tactic(prediction) prediction_stem_idx = encode_fpa_stem(extract_dataloader_args(self.training_args), self.metadata, prediction_stem) stem_distributions = self._model.stem_classifier( maybe_cuda(torch.LongTensor(word_features)), maybe_cuda(torch.FloatTensor(vec_features))) stem_certainties, stem_idxs = stem_distributions.topk(stem_width) if prediction_stem_idx in stem_idxs[0]: merged_stem_idxs = stem_idxs merged_stem_certainties = stem_certainties else: merged_stem_idxs = torch.cat( (maybe_cuda(torch.LongTensor([[prediction_stem_idx]])), stem_idxs[:, :stem_width-1]), dim=1) cother = stem_certainties[:, :stem_width-1] val = stem_distributions[0][prediction_stem_idx] merged_stem_certainties = \ torch.cat((val.view(1, 1), cother),dim=1) prediction_stem_idx_idx = list(merged_stem_idxs[0]).index( prediction_stem_idx) prediction_arg_idx = encode_fpa_arg( extract_dataloader_args(self.training_args), self.metadata, context.hypotheses + context.relevant_lemmas, context.goal, prediction_args) goal_arg_values = self.goal_token_scores( merged_stem_idxs, tokenized_goal, goal_mask) if len(tokenized_premises[0]) > 0: hyp_arg_values = self.hyp_name_scores( merged_stem_idxs[0], tokenized_goal[0], tokenized_premises[0], hyp_features[0]) total_scores = torch.cat((goal_arg_values, hyp_arg_values), dim=2) else: total_scores = goal_arg_values final_probs, predicted_stem_idxs, predicted_arg_idxs = \ self.predict_args(total_scores, merged_stem_certainties, merged_stem_idxs) for prob, stem_idx_idx, arg_idx in zip(final_probs, predicted_stem_idxs, predicted_arg_idxs): if stem_idx_idx == prediction_stem_idx and \ arg_idx == prediction_arg_idx: return math.exp(prob.item()) assert False, "Shouldn't be able to get here"