def getAllMentions(dataset, window_size, word_filter, concept_filter, log=log): data = parser.parse() samples, skipped = [], 0 log.track( message=' >> Extracted features for {0}/%d mentions (skipped {1})...' % len(data), writeInterval=1) for (tokens, mentions) in data: for mention in mentions: (start, end, targets, candidates, correct_ix) = mention valid, keep_me = False, None for c in targets: if concept_filter(c): valid = True keep_me = c break if valid: candidates[correct_ix] = keep_me corrected_mention = (start, end, keep_me, candidates) samples.append( getMention(tokens, corrected_mention, window_size, word_filter)) else: skipped += 1 log.tick(skipped) log.flushTracker(skipped) return samples
def _collate(keys, n_batches, nn_q, outf): with codecs.open(outf, 'w', 'utf-8') as stream: result = nn_q.get() log.track(message=' >> Processed {0}/%d batches' % n_batches, writeInterval=5) while result != _SIGNALS.HALT: (batch_ixes, batch_nns) = result for i in range(len(batch_ixes)): stream.write('%s\t%s\n' % ( keys[batch_ixes[i]], ','.join([keys[ix] for ix in batch_nns[i]]) )) log.tick() result = nn_q.get() log.flushTracker()
def getAllMentions(dataset, window_size, word_filter, concept_filter, log=log): samples = [] log.track(message=' >> Extracted features from {0}/%d documents...' % len(dataset), writeInterval=1) for ambig in dataset: for instance in ambig.instances: if concept_filter(instance.CUI): samples.append( getSingleMention(instance, window_size, word_filter, ambig.labels)) log.tick() log.writeln() return samples
def completeAnalogySet(str_analogies, setting, embeddings, vocab, vocab_indexer, grph, report_top_k=5, log=log): analogies, embedded_analogies, kept_str_analogies = [], [], [] # if using multi_answer, find the maximum number of answers for any analogy in this set if setting in [eval_mode.ALL_INFO, eval_mode.MULTI_ANSWER]: max_answers = 0 for analogy in str_analogies: max_answers = max(max_answers, len(analogy[3])) else: max_answers = 1 # convert analogies to a matrix of indices and a matrix of embeddings log.track(message=' >> Preprocessing: {0:,}/%s' % ('{0:,}'.format(len(str_analogies))), writeInterval=50) for analogy in str_analogies: valid, analogy_ixes, analogy_embeds = convertAnalogyToMatrices( analogy, setting, embeddings, vocab_indexer, max_answers) if valid: analogies.append(analogy_ixes) embedded_analogies.append(analogy_embeds) kept_str_analogies.append(analogy) log.tick() log.flushTracker() correct, MAP, MRR, total, skipped, predictions = grph.eval( analogies, embedded_analogies, report_top_k=report_top_k, log=log) log.flushTracker(len(analogies)) str_predictions = [] for i in range(len(predictions)): (is_correct, num_candidates, predicted_ixes) = predictions[i] if len(predicted_ixes) > 1: predicted_strings = [vocab[ix] for ix in predicted_ixes] elif predicted_ixes[0] == -1: predicted_strings = ['>>> SKIPPED <<<'] str_predictions.append((kept_str_analogies[i], is_correct, num_candidates, predicted_strings)) return (correct, MAP, MRR, total, skipped, str_predictions)
def _nn_writer(neighborf, node_IDs, nn_q): stream = open(neighborf, 'w') stream.write('# File format is:\n# <word vocab index>,<NN 1>,<NN 2>,...\n') result = nn_q.get() log.track(message=' >> Processed {0}/{1:,} samples'.format( '{0:,}', len(node_IDs)), writeInterval=50) while result != _SIGNALS.HALT: (ix, neighbors) = result stream.write('%s\n' % ','.join([ str(d) for d in [node_IDs[ix], *[node_IDs[nbr] for nbr in neighbors]] ])) log.tick() result = nn_q.get() log.flushTracker()
def getELMoRepresentations(sentences_words, sentences_instances, semcor_labels, unique_sense_IDs, bilm_params): sense_embeddings = {} for sense_ID in unique_sense_IDs: sense_embeddings[sense_ID] = [] with tf.Session() as sess: log.writeln(' (1) Setting up ELMo') elmo = ELMoRunner(sess, bilm_params) # batch up the data sentence_ids = elmo.preprocess(sentences_words) batch_size = 25 num_batches = math.ceil(sentence_ids.shape[0] / batch_size) batch_start = 0 log.writeln(' (2) Extracting sense embeddings from sentences') log.track(message=' >> Processed {0}/{1:,} batches'.format('{0:,}',num_batches), writeInterval=5) while batch_start < sentence_ids.shape[0]: batch_sentence_ids = sentence_ids[batch_start:batch_start + batch_size] elmo_sentence_input_ = elmo(batch_sentence_ids) for i in range(elmo_sentence_input_.shape[0]): sentence_indices = sentences_instances[batch_start+i] for (instance_ID, ix) in sentence_indices: senses = semcor_labels[instance_ID] for sense in senses: sense_embeddings[sense].append( elmo_sentence_input_[i][ix] ) log.tick() batch_start += batch_size log.flushTracker() log.writeln(' (3) Calculating mean per-sense embeddings') mean_sense_embeddings = pyemblib.Embeddings() for (sense_ID, embedding_list) in sense_embeddings.items(): if len(embedding_list) > 0: mean_sense_embeddings[sense_ID] = np.mean(embedding_list, axis=0) else: log.writeln('[WARNING] Sense ID "%s" found no embeddings' % sense_ID) return mean_sense_embeddings
def ELMoBaseline(mentions, mention_map, backoff_preds, training_lemmas, semcor_embeddings, output_predsf): log.writeln('Running ELMo baseline\n') # pre-norm the semcor embeddings log.writeln('Norming SemCor embeddings...') normed_semcor_embeddings = pyemblib.Embeddings() for (k, v) in semcor_embeddings.items(): normed_semcor_embeddings[k] = (v / np.linalg.norm(v)) #semcor_embeddings = normed_semcor_embeddings ordered_vocab, semcor_embeddings = normed_semcor_embeddings.toarray() semcor_embeddings = np.transpose(semcor_embeddings) log.writeln('Done.\n') predictions, correct = [], 0 num_elmo, num_backoff = 0, 0 log.track( message=' >> Processed {0:,}/%s samples ({1:,} ELMo, {2:,} backoff)' % ('{0:,}'.format(len(mentions))), writeInterval=5) for m in mentions: (ds, instance_ID, lemma) = mention_map[m.ID] if lemma in training_lemmas: #prediction = getNearestNeighborKey(m.context_repr, semcor_embeddings) prediction = getNearestNeighborKey2(m.context_repr, semcor_embeddings, ordered_vocab) num_elmo += 1 else: prediction = backoff_preds[predictionID(ds, instance_ID)] num_backoff += 1 predictions.append((m.ID, prediction)) if prediction == m.CUI: correct += 1 log.tick(num_elmo, num_backoff) log.flushTracker(num_elmo, num_backoff) writeWSDFrameworkPredictions(predictions, mention_map, output_predsf) log.writeln('\n-- ELMo baseline --') log.writeln('Accuracy: {0:.4f} ({1:,}/{2:,})\n'.format( float(correct) / len(predictions), correct, len(predictions))) log.writeln('# ELMo: {0:,}\n# backoff: {1:,}\n'.format( num_elmo, num_backoff))
def enumerateWordNetPairs(vocab, outf, write_lemma=False): data = [] in_vocab = lambda synset: synset.lemmas()[0].name() in vocab for pos in ['n', 'v', 'a', 'r']: n_pairs = 0 log.writeln('Processing POS "%s"' % pos) log.track(message=' >> Processed {0:,} source synsets ({1:,} pairs)', writeInterval=100) for synset in wn.all_synsets(pos): if in_vocab(synset): for (getter, lbl) in [ (synset.hyponyms, dataset.Hyponym), (synset.hypernyms, dataset.Hypernym), (synset.member_holonyms, dataset.Holonym), (synset.substance_holonyms, dataset.Holonym), (synset.part_holonyms, dataset.Holonym), (synset.member_meronyms, dataset.Meronym), (synset.substance_meronyms, dataset.Meronym), (synset.part_meronyms, dataset.Meronym), ]: for sink in getter(): if in_vocab(sink): if write_lemma: src = synset.lemmas()[0].name() snk = sink.lemmas()[0].name() else: src = synset.name() snk = sink.name() data.append(( len(data), src, snk, lbl )) n_pairs += 1 log.tick(n_pairs) log.flushTracker(n_pairs) log.writeln('') dataset.write(data, outf)
def _getELMoMentions(sentences_words, sentences_instances, labels, ds_name, samples, ds_map, elmo, batch_size=25): sentence_ids = elmo.preprocess(sentences_words) batch_start = 0 num_batches = math.ceil(sentence_ids.shape[0] / batch_size) log.track(message=' >> Processed {0}/{1:,} batches'.format( '{0:,}', num_batches), writeInterval=5) while batch_start < sentence_ids.shape[0]: batch_sentence_ids = sentence_ids[batch_start:batch_start + batch_size] elmo_sentence_input_ = elmo(batch_sentence_ids) for i in range(elmo_sentence_input_.shape[0]): sentence_indices = sentences_instances[batch_start + i] for (instance_ID, ix, lemma) in sentence_indices: # instance may have multiple correct senses; # since this is just for training and we need one hot label, # just take the first one label = labels[instance_ID][0] context_repr = elmo_sentence_input_[i][ix] samples.append( getEmbeddedSingleMention( #sentences_words[batch_start+i][ix], lemma, context_repr, label, ID=len(samples))) ds_map[samples[-1].ID] = (ds_name, instance_ID, lemma) log.tick() batch_start += batch_size log.flushTracker()
def buildGraph(neighbor_files, k): log.writeln('Building neighborhood graph...') graph = {} # construct frequency-weighted edges log.track(message=' >> Loaded {0}/%d neighborhood files' % len(neighbor_files), writeInterval=1) for neighbor_file in neighbor_files: neighborhoods = readNeighbors(neighbor_file, k) for (source, neighbors) in neighborhoods.items(): if graph.get(source, None) is None: graph[source] = {} for nbr in neighbors: graph[source][nbr] = graph[source].get(nbr, 0) + 1 log.tick() log.flushTracker() log.writeln(' >> Normalizing edge weights...') max_count = float(len(neighbor_files)) for (source, neighborhood) in graph.items(): for (nbr, freq) in neighborhood.items(): graph[source][nbr] = freq/max_count log.writeln('Graph complete!') return graph
def runModel(mentions, entity_embeds, ctx_embeds, minibatch_size, preds_file, debug=False, secondary_entity_embeds=None, entity_combo_method=None, using_mention=False, preds_file_detailed=None, preferred_strings=None, preds_file_polysemy=None, polysemy=None): entity_vocab, entity_arr = entity_embeds.toarray() ctx_vocab, ctx_arr = ctx_embeds.toarray() if secondary_entity_embeds: secondary_entity_vocab, secondary_entity_arr = secondary_entity_embeds.toarray( ) secondary_entity_arr_2 = [] for v in secondary_entity_vocab: secondary_entity_arr_2.append(np.array(secondary_entity_embeds[v])) secondary_entity_arr_2 = np.array(secondary_entity_arr_2) else: secondary_entity_vocab, secondary_entity_arr = None, None ent_ixer = Indexer(entity_vocab) ctx_ixer = Indexer(ctx_vocab) if secondary_entity_embeds: secondary_ent_ixer = Indexer(secondary_entity_vocab) else: secondary_ent_ixer = None max_num_entities = 0 for m in mentions: if len(m.candidates) > max_num_entities: max_num_entities = len(m.candidates) max_mention_size = 0 for m in mentions: n_tokens = len(m.mention_text.split()) if n_tokens > max_mention_size: max_mention_size = n_tokens window_size = 5 params = LLParams( ctx_vocab_size=len(ctx_vocab), ctx_dim=ctx_embeds.size, entity_vocab_size=len(entity_vocab), entity_dim=entity_embeds.size, secondary_entity_vocab_size=(0 if not secondary_entity_embeds else len(secondary_entity_vocab)), secondary_entity_dim=(0 if not secondary_entity_embeds else secondary_entity_embeds.size), window_size=window_size, max_num_entities=max_num_entities, max_mention_size=max_mention_size, entity_combo_method=entity_combo_method, using_mention=using_mention) session = tf.Session() lll = LinearSabbirLinkerC( session, np.array(ctx_arr), np.array(entity_arr), params, debug=debug, secondary_entity_embed_arr=np.array(secondary_entity_arr)) log.track(message=' >>> Processed {0} batches', writeInterval=10) if secondary_entity_embeds: ent_vs_sec = McNemars() ent_vs_joint = McNemars() sec_vs_joint = McNemars() joint_vs_oracle = McNemars() correct, total = 0., 0 batch_start = 0 oracle = {} while (batch_start < len(mentions)): next_batch_mentions = mentions[batch_start:batch_start + minibatch_size] next_batch = [ prepSample(mention, ent_ixer, ctx_ixer, window_size, max_mention_size, max_num_entities, secondary_ent_ixer=secondary_ent_ixer) for mention in next_batch_mentions ] batch_ctx_window_ixes = [ next_batch[i][0] for i in range(len(next_batch)) ] batch_ctx_window_masks = [ next_batch[i][1] for i in range(len(next_batch)) ] batch_mention_ixes = [next_batch[i][2] for i in range(len(next_batch))] batch_mention_masks = [ next_batch[i][3] for i in range(len(next_batch)) ] batch_entity_ixes = [next_batch[i][4] for i in range(len(next_batch))] batch_entity_masks = [next_batch[i][5] for i in range(len(next_batch))] if secondary_entity_embeds: batch_secondary_entity_ixes = [ next_batch[i][6] for i in range(len(next_batch)) ] else: batch_secondary_entity_ixes = None results = lll.getPredictions( batch_ctx_window_ixes, batch_ctx_window_masks, batch_entity_ixes, batch_entity_masks, batch_secondary_entity_ixes=batch_secondary_entity_ixes, batch_mention_ixes=batch_mention_ixes, batch_mention_masks=batch_mention_masks, oracle=True) if secondary_entity_embeds: (preds, probs, ent_preds, secondary_ent_preds) = results else: (preds, probs, ent_preds) = results for i in range(len(next_batch)): (_, _, _, _, ent_ixes, _, _, correct_candidate, mention) = next_batch[i] # base accuracy eval predicted_ix = ent_ixes[preds[i]] if predicted_ix == correct_candidate: correct += 1 total += 1 # oracle eval joint_correct, entity_correct, secondary_correct, oracle_correct = False, False, False, False if ent_ixes[ent_preds[i]] == correct_candidate: entity_correct = True oracle['entity_correct'] = oracle.get('entity_correct', 0) + 1 if secondary_entity_embeds and ent_ixes[ preds[i]] == correct_candidate: joint_correct = True oracle['joint_correct'] = oracle.get('joint_correct', 0) + 1 if secondary_entity_embeds and ent_ixes[ secondary_ent_preds[i]] == correct_candidate: secondary_correct = True oracle['secondary_correct'] = oracle.get( 'secondary_correct', 0) + 1 if entity_correct or secondary_correct: oracle_correct = True oracle['oracle_correct'] = oracle.get('oracle_correct', 0) + 1 # significance tracking if secondary_entity_embeds: # entity vs secondary if entity_correct and secondary_correct: ent_vs_sec.a += 1 elif entity_correct and (not secondary_correct): ent_vs_sec.b += 1 elif (not entity_correct) and secondary_correct: ent_vs_sec.c += 1 else: ent_vs_sec.d += 1 # entity vs joint if entity_correct and joint_correct: ent_vs_joint.a += 1 elif entity_correct and (not joint_correct): ent_vs_joint.b += 1 elif (not entity_correct) and joint_correct: ent_vs_joint.c += 1 else: ent_vs_joint.d += 1 # secondary vs joint if secondary_correct and joint_correct: sec_vs_joint.a += 1 elif secondary_correct and (not joint_correct): sec_vs_joint.b += 1 elif (not secondary_correct) and joint_correct: sec_vs_joint.c += 1 else: sec_vs_joint.d += 1 # joint vs oracle if joint_correct and oracle_correct: joint_vs_oracle.a += 1 elif joint_correct and (not oracle_correct): joint_vs_oracle.b += 1 elif (not joint_correct) and oracle_correct: joint_vs_oracle.c += 1 else: joint_vs_oracle.d += 1 # predictions + scores if preds_file: preds_file.write('Probs: [ %s ] Pred: %d -> %d Gold: %d\n' % (' '.join([str(p) for p in probs[i]]), preds[i], ent_ixes[preds[i]], correct_candidate)) # predictions + corpus polysemy of correct entity if preds_file_polysemy: try: line = '%d\t%f\n' % ( (1 if predicted_ix == correct_candidate else 0), polysemy[ent_ixer[predicted_ix]]) preds_file_polysemy.write(line) except KeyError: pass # predictions, in detail if preds_file_detailed: keys = ['all'] if secondary_entity_embeds: pred_ixes = [('Pred (Joint)', ent_ixes[preds[i]]), ('Pred (Ent)', ent_ixes[ent_preds[i]]), ('Pred (Defn)', ent_ixes[secondary_ent_preds[i]])] if entity_correct and secondary_correct: comp_stream_key = 'both_correct' elif entity_correct and (not secondary_correct): comp_stream_key = 'entity_only_correct' elif (not entity_correct) and secondary_correct: comp_stream_key = 'secondary_only_correct' else: comp_stream_key = 'both_wrong' keys.append(comp_stream_key) #if entity_correct and secondary_correct and joint_correct: # joint_stream_key = None #if entity_correct and secondary_correct and (not joint_correct): # joint_stream_key = 'ent_sec_no-joint' #elif entity_correct and joint_correct and (not secondary_correct): # joint_stream_key = 'ent_and_joint' #elif (not entity_correct) and joint_correct and secondary_correct: # joint_stream_key = 'sec_and_joint' #elif joint_correct and (not entity_correct) and (not secondary_correct): # joint_stream_key = 'joint_only' #elif entity_correct and (not joint_correct) and (not secondary_correct): # joint_stream_key = 'ent_no-joint' #elif (not entity_correct) and (not joint_correct) and secondary_correct: # joint_stream_key = 'sec_no-joint' #elif (not entity_correct) and (not joint_correct) and (not secondary_correct): # joint_stream_key = None #keys.append(joint_stream_key) if (not entity_correct) and joint_correct: keys.append('ent_joint_help') elif entity_correct and (not joint_correct): keys.append('ent_joint_hurt') if (not secondary_correct) and joint_correct: keys.append('sec_joint_help') if secondary_correct and (not joint_correct): keys.append('sec_joint_hurt') else: pred_ixes = [('Pred', predicted_ix)] if entity_correct: stream_key = 'entity_correct' else: stream_key = 'entity_wrong' keys.append(stream_key) for k in keys: _writeDetailedOutcome(preds_file_detailed[k], mention, probs, batch_entity_ixes, batch_entity_masks, ent_ixer, preferred_strings, correct_candidate, pred_ixes, i) batch_start += minibatch_size log.tick() log.flushTracker() for (msg, mcn) in [('Entity vs Defn', ent_vs_sec), ('Entity vs Joint', ent_vs_joint), ('Defn vs Joint', sec_vs_joint), ('Joint vs Oracle', joint_vs_oracle)]: chi2, pval = mcn.run() log.writeln('\n%s\n' ' | a = %5d | b = %5d |\n' ' | c = %5d | d = %5d |\n' ' Chi^2 = %f P-value = %f\n' % (msg, mcn.a, mcn.b, mcn.c, mcn.d, chi2, pval)) return correct, total, oracle
def crossfoldTrain(src_embs, trg_embs, pivot_keys, nfold, activation, num_layers, batch_size=5, checkpoint_file='checkpoint', random_seed=None): project_batch_size = batch_size * 10 pivot_keys = list(pivot_keys) if random_seed: random.seed(random_seed) random.shuffle(pivot_keys) fold_size = int(np.ceil(len(pivot_keys) / nfold)) mapped_embs = {} src_keys = list(src_embs.keys()) for k in src_keys: mapped_embs[k] = np.zeros([trg_embs.size]) session = tf.Session() params = MapperParams(src_dim=src_embs.size, trg_dim=trg_embs.size, map_dim=trg_embs.size, activation=activation, num_layers=num_layers, checkpoint_file=checkpoint_file) for i in range(nfold): log.writeln(' Starting fold %d/%d' % (i + 1, nfold)) if random_seed: this_random = random_seed + i else: this_random = None model = ManifoldMapper(session, params, random_seed=this_random) fold_start, fold_end = (i * fold_size), ((i + 1) * fold_size) train_keys = pivot_keys[:fold_start] dev_keys = pivot_keys[fold_start:fold_end] train_keys.extend(pivot_keys[fold_end:]) train(model, src_embs, trg_embs, train_keys, dev_keys, batch_size=batch_size) # get projections from this fold log.writeln(' Getting trained projections for fold %d' % (i + 1)) log.track(message=' >> Projected {0}/%d keys' % len(src_keys), writeInterval=10000) batch_start = 0 while batch_start < len(src_keys): batch_keys = src_keys[batch_start:batch_start + project_batch_size] batch_src = np.array([src_embs[k] for k in batch_keys]) batch_mapped = model.project_batch(batch_src) for i in range(batch_mapped.shape[0]): key = batch_keys[i] mapped_embs[key] += batch_mapped[i] log.tick() batch_start += project_batch_size log.flushTracker() # mean projections for k in src_keys: mapped_embs[k] /= nfold # get final MSE over full pivot set final_errors = [] for k in pivot_keys: diff = mapped_embs[k] - trg_embs[k] final_errors.append(np.sum(diff**2) / 2) log.writeln('\nPivot error in final projections: %f' % np.mean(final_errors)) return mapped_embs
def trainModel(session, model, train, dev, test, embeds, batch_size=5, patience=3, early_stopping=0., max_epochs=-1, fold=None, preds_file=None, verbose=False, debug=False): epoch, best_epoch = 0, -1 epoch_dev_metrics, current_best = [], 1e9 training = True patience_so_far = 0 log.writeln('Starting training') while training: epoch += 1 batch = 1 log.writeln('\n== Epoch %d ==\n\n Training...' % epoch) np.random.shuffle(train) correct, total = 0., 0 batch_start = 0 n_batches = int(np.ceil(len(train) / batch_size)) if not verbose: log.track(message=' >> Processed {0:,}/%s batches' % ('{0:,}'.format(n_batches)), writeInterval=10) while (batch_start < len(train)): if debug: sys.stderr.write('------------------------------\nTRAINING\n') sys.stderr.flush() next_batch_samples = train[batch_start:batch_start + batch_size] next_batch = np.array([[embeds[src], embeds[snk]] for (_, src, snk, _) in next_batch_samples]) batch_labels = np.array( [lbl for (_, _, _, lbl) in next_batch_samples]) batch_loss = model.trainStep(next_batch, batch_labels) if verbose: log.writeln(' Batch {0:,}/{1:,} loss -- {2:.8f}'.format( batch, n_batches, batch_loss)) else: log.tick() batch += 1 batch_start += len(next_batch) if not verbose: log.flushTracker() log.writeln('\n Training complete.\n Evaluating on dev...') dev_metrics = testModel( session, model, dev, embeds, batch_size=batch_size, preds_file=None, training=True, ) log.writeln(' Dev loss -- %f (Best: %f) [Accuracy: %f (%d/%d)]' % (dev_metrics.loss, current_best, dev_metrics.accuracy, dev_metrics.correct, dev_metrics.total)) epoch_dev_metrics.append(dev_metrics) # patience/early stopping handling if dev_metrics.loss < (current_best - early_stopping): patience_so_far = 0 else: patience_so_far += 1 log.writeln(' >>> Impatience building... (%d/%d) <<<' % (patience_so_far, patience)) if patience_so_far >= patience: log.writeln(" >>> Ran out of patience! <<<") log.writeln(" (╯'-')╯︵ ┻━┻ ") break if dev_metrics.loss < current_best: log.writeln(' >>> Improvement! Saving model state. <<<') model.save(fold) current_best = dev_metrics.loss best_epoch = epoch if (max_epochs > 0) and epoch >= max_epochs: log.writeln(" >>> Hit maximum epoch threshold! <<<") log.writeln(" ¯\(°_o)/¯") break log.writeln('\nTraining halted.') model.restore(fold) log.writeln('\nEvaluating best model on train:') train_metrics = testModel(session, model, train, embeds, batch_size=batch_size, preds_file=None, training=True) log.writeln( ' Accuracy: %f (%d/%d)' % (train_metrics.accuracy, train_metrics.correct, train_metrics.total)) log.writeln('\nEvaluating best model on test:') test_metrics = testModel(session, model, test, embeds, batch_size=batch_size, preds_file=preds_file, training=True) log.writeln( ' Accuracy: %f (%d/%d)' % (test_metrics.accuracy, test_metrics.correct, test_metrics.total)) return best_epoch, epoch_dev_metrics, test_metrics, train_metrics