Exemplo n.º 1
0
    def decode(self, tokens, constraints=[], train_mode=False):
        loss = 0
        errs = []

        fr_vecs = [self.special[0]] + [t.vecs[self.vec_key] for t in tokens]
        to_vecs = [self.special[1]] + [t.vecs[self.vec_key] for t in tokens]
        score_mat = self.biaffine.attend(fr_vecs, to_vecs)
        scores = score_mat.npvalue()

        if train_mode:
            oids = [0] + [t['original_id'] for t in tokens]
            gold_path = np.argsort(oids).tolist() + [0]
            trans_mat = dy.transpose(score_mat)
            for i, j in zip(gold_path, gold_path[1:]):
                errs.append(dy.hinge(score_mat[i], j))
                errs.append(dy.hinge(trans_mat[j], i))
            if errs:
                loss = dy.average(errs)

        costs = (1000 * (scores.max() - scores)).astype(int).tolist()
        solution = solve_tsp(costs, constraints,
                             self.args.guided_local_search)  # first is best
        if not solution:
            # self.log('no solution, remove constraints')
            solution = solve_tsp(costs, [], self.args.guided_local_search)

        assert solution != []
        seq = [tokens[i - 1] for i in solution[1:-1]]

        return {'loss': loss, 'seq': seq}
Exemplo n.º 2
0
    def _doc_loss(self, doc, y):
        y_node = self.prop_encoder_.transform(y.nodes)
        y_link = self.link_encoder_.transform(y.links)

        props, links, _, _, _, _ = self.build_cg(doc)

        obj_prop = [dy.hinge(prop, y_) for prop, y_ in zip(props, y_node)]
        obj_link = [dy.hinge(link, y_) for link, y_ in zip(links, y_link)]

        obj = dy.esum(obj_prop) + dy.esum(obj_link)

        correct = sum(1 for val in obj_prop + obj_link
                      if val.scalar_value() == 0)

        max_acc = len(obj_prop + obj_link)
        return obj, max_acc - correct, max_acc, 'n/a'
Exemplo n.º 3
0
 def Update(self, pos, neg):
     loss = []
     for p in pos:
         l = []
         l.append(p)
         for n in neg:
             l.append(n)
         loss.append(dy.hinge(dy.concatenate(l), 0))
     if len(loss) > 0:
         sum_loss = dy.esum(loss)
         sum_loss.scalar_value()
         sum_loss.backward()
         self.trainer.update()
     dy.renew_cg()
Exemplo n.º 4
0
 def PairAppendToLoss(self, pos, neg, loss):
   if len(pos) != 1:
     print('ERROR IN POS EXAMPLE SIZE')
     print(len(pos))
   loss.append(dy.hinge(dy.concatenate(pos + neg), 0))
Exemplo n.º 5
0
# A function to calculate scores for one value
def calc_scores(tree):
    dy.renew_cg()
    emb = builder.expr_for_tree(tree)
    W_sm_exp = dy.parameter(W_sm)
    b_sm_exp = dy.parameter(b_sm)
    return W_sm_exp * emb + b_sm_exp


for ITER in range(100):
    # Perform training
    random.shuffle(train)
    train_loss = 0.0
    start = time.time()
    for tree in train:
        my_loss = dy.hinge(calc_scores(tree), l2i[tree.label])
        # my_loss = dy.pickneglogsoftmax(calc_scores(tree), l2i[tree.label])
        train_loss += my_loss.value()
        my_loss.backward()
        trainer.update()
    print("iter %r: train loss/sent=%.4f, time=%.2fs" %
          (ITER, train_loss / len(train), time.time() - start))
    # Perform testing
    test_correct = 0.0
    for tree in dev:
        scores = calc_scores(tree).npvalue()
        predict = np.argmax(scores)
        if predict == l2i[tree.label]:
            test_correct += 1
    print("iter %r: test acc=%.4f" % (ITER, test_correct / len(dev)))
Exemplo n.º 6
0
 def __allElemLoss(self, output, correctId):
     return dynet.hinge(output, correctId, self.__trainMargin)
                #print(i)
                q_dpos_hist = train_data['pos_docs'][i]
                q_dneg_hist = train_data['neg_docs'][i]
                query_idf = train_data['queries_idf'][i]
                pos_bm25 = train_data['pos_docs_normBM25'][i]
                neg_bm25 = train_data['neg_docs_normBM25'][i]
                pos_overlap = train_data['pos_docs_overlap'][i]
                neg_overlap = train_data['neg_docs_overlap'][i]

                #print(query_idf)

                preds = drmm_model.predict_pos_neg_scores(
                    q_dpos_hist, q_dneg_hist, query_idf, pos_bm25, neg_bm25,
                    pos_overlap, neg_overlap)
                batch_preds.append(preds)
                loss = dy.hinge(preds, 0)
                batch_losses.append(loss)
            batch_loss = dy.esum(batch_losses) / len(batch)
            sum_of_losses += batch_loss.npvalue()[
                0]  # this calls forward on the batch
            for p in batch_preds:
                p_v = p.value()
                if p_v[0] > p_v[1]:
                    hits += 1
            batch_loss.backward()
            drmm_model.trainer.update()
            pbar.update(train_batch_size)

        logs['acc'] = hits / len(train_data['queries'])
        logs['loss'] = sum_of_losses / len(train_batches)
Exemplo n.º 8
0
b_sm = model.add_parameters((ntags))                  # Softmax bias

# A function to calculate scores for one value
def calc_scores(tree):
  dy.renew_cg()
  emb = builder.expr_for_tree(tree)
  W_sm_exp = dy.parameter(W_sm)
  b_sm_exp = dy.parameter(b_sm)
  return W_sm_exp * emb + b_sm_exp

for ITER in range(100):
  # Perform training
  random.shuffle(train)
  train_loss = 0.0
  start = time.time()
  for tree in train:
    my_loss = dy.hinge(calc_scores(tree), l2i[tree.label])
    # my_loss = dy.pickneglogsoftmax(calc_scores(tree), l2i[tree.label])
    train_loss += my_loss.value()
    my_loss.backward()
    trainer.update()
  print("iter %r: train loss/sent=%.4f, time=%.2fs" % (ITER, train_loss/len(train), time.time()-start))
  # Perform testing
  test_correct = 0.0
  for tree in dev:
    scores = calc_scores(tree).npvalue()
    predict = np.argmax(scores)
    if predict == l2i[tree.label]:
      test_correct += 1
  print("iter %r: test acc=%.4f" % (ITER, test_correct/len(dev)))
Exemplo n.º 9
0
    def _train(self,
               sentences,
               transition_system,
               evaluate,
               relations,
               triggers=None):
        start_chunk = time.time()
        start_all = time.time()
        loss_chunk = 0
        loss_all = 0
        total_chunk = 0
        total_all = 0
        losses = []
        self.set_empty_vector()

        for i, sentence in enumerate(sentences):
            if i != 0 and i % 100 == 0:
                end = time.time()
                print(
                    f'count: {i}\tloss: {loss_chunk/total_chunk:.4f}\ttime: {end-start_chunk:,.2f} secs'
                )
                start_chunk = end
                loss_chunk = 0
                total_chunk = 0
            if len(sentence) > 2:
                for e in sentence:
                    e.children = []
                # assign embedding to each word
                features = self.extract_features(sentence, drop_word=True)
                # initialize sentence parse
                state = transition_system(sentence)
                # parse sentence
                while not state.is_terminal():
                    outputs = evaluate(state.stack, state.buffer, features)

                    if triggers:
                        dy_op_scores, dy_lbl_scores, dy_tg_scores = outputs
                        np_tg_scores = dy_tg_scores.npvalue()
                    else:
                        dy_op_scores, dy_lbl_scores = outputs

                    # get scores in numpy arrays
                    np_op_scores = dy_op_scores.npvalue()
                    np_lbl_scores = dy_lbl_scores.npvalue()

                    # collect all legal transitions
                    legal_transitions = []
                    if triggers:
                        for lt in state.all_legal():
                            ix = state.t2i[lt]
                            if lt == "shift":
                                for j, tg in enumerate(triggers[1:], start=2):
                                    if (hasattr(state.buffer[0], 'is_parent')
                                            and state.buffer[0].is_parent
                                            and j == 1):
                                        continue
                                    t = new_Transition(
                                        lt, None, tg, np_op_scores[ix] +
                                        np_lbl_scores[0] + np_tg_scores[j],
                                        dy_op_scores[ix] + dy_lbl_scores[0] +
                                        dy_tg_scores[j])
                                    legal_transitions.append(t)
                            if lt == "drop":
                                t = new_Transition(
                                    lt, None, "O", np_op_scores[ix] +
                                    np_lbl_scores[0] + np_tg_scores[1],
                                    dy_op_scores[ix] + dy_lbl_scores[0] +
                                    dy_tg_scores[1])
                                legal_transitions.append(t)
                                t = new_Transition(
                                    lt, None, "Protein", np_op_scores[ix] +
                                    np_lbl_scores[0] + np_tg_scores[4],
                                    dy_op_scores[ix] + dy_lbl_scores[0] +
                                    dy_tg_scores[4])
                                legal_transitions.append(t)
                            if lt in ['left_reduce', 'left_attach']:
                                for j, r in enumerate(relations):
                                    k = 1 + 2 * j
                                    t = new_Transition(
                                        lt, r, None, np_op_scores[ix] +
                                        np_lbl_scores[k] + np_tg_scores[0],
                                        dy_op_scores[ix] + dy_lbl_scores[k] +
                                        dy_tg_scores[0])
                                    legal_transitions.append(t)
                            if lt in ['right_reduce', 'right_attach']:
                                for j, r in enumerate(relations):
                                    k = 2 + 2 * j
                                    t = new_Transition(
                                        lt, r, None, np_op_scores[ix] +
                                        np_lbl_scores[k] + np_tg_scores[0],
                                        dy_op_scores[ix] + dy_lbl_scores[k] +
                                        dy_tg_scores[0])
                                    legal_transitions.append(t)
                            if lt == "swap":
                                t = new_Transition(
                                    lt, None, None, np_op_scores[ix] +
                                    np_lbl_scores[0] + np_tg_scores[0],
                                    dy_op_scores[ix] + dy_lbl_scores[0] +
                                    dy_tg_scores[0])
                                legal_transitions.append(t)
                        # collect all correct transitions
                        correct_transitions = []
                        for t in legal_transitions:
                            if state.is_correct(t[0]):
                                relation = state.get_arc_label_for_transition(
                                    t[0])
                                label = state.get_token_label_for_transition(
                                    t[0])
                                if t[1] == relation and t[2] == label:
                                    correct_transitions.append(t)

                    else:
                        if state.is_legal('shift'):
                            ix = state.t2i['shift']
                            t = Transition('shift', None, None,
                                           np_op_scores[ix] + np_lbl_scores[0],
                                           dy_op_scores[ix] + dy_lbl_scores[0])
                            legal_transitions.append(t)
                        if state.is_legal('left_arc'):
                            ix = state.t2i['left_arc']
                            for j, r in enumerate(relations):
                                k = 1 + 2 * j
                                t = Transition(
                                    'left_arc', r, None,
                                    np_op_scores[ix] + np_lbl_scores[k],
                                    dy_op_scores[ix] + dy_lbl_scores[k])
                                legal_transitions.append(t)
                        if state.is_legal('right_arc'):
                            ix = state.t2i['right_arc']
                            for j, r in enumerate(relations):
                                k = 2 + 2 * j
                                t = Transition(
                                    'right_arc', r, None,
                                    np_op_scores[ix] + np_lbl_scores[k],
                                    dy_op_scores[ix] + dy_lbl_scores[k])
                                legal_transitions.append(t)
                        if state.is_legal('drop'):
                            ix = state.t2i['drop']
                            t = Transition('drop', None, None,
                                           np_op_scores[ix] + np_lbl_scores[0],
                                           dy_op_scores[ix] + dy_lbl_scores[0])
                            legal_transitions.append(t)
                        # collect all correct transitions
                        correct_transitions = []
                        for t in legal_transitions:
                            if state.is_correct(t):
                                if t.op in [
                                        'shift', 'drop'
                                ] or t.label in state.stack[-1].relation:
                                    correct_transitions.append(t)

                    # select transition
                    best_correct = max(correct_transitions,
                                       key=attrgetter('score'))

                    i_correct = legal_transitions.index(best_correct)
                    legal_scores = dy.concatenate(
                        [t.dy_score for t in legal_transitions])
                    loss = dy.hinge(legal_scores, i_correct)
                    # loss = dy.pickneglogsoftmax(legal_scores, i_correct)
                    losses.append(loss)

                    # perform transition
                    selected = best_correct
                    state.perform_transition(selected.op, selected.label,
                                             selected.trigger)

            # process losses in chunks
            if len(losses) > 50:
                try:
                    loss = dy.esum(losses)
                    l = loss.scalar_value()
                    loss.backward()
                    self.trainer.update()
                except:
                    pass
                dy.renew_cg()
                self.set_empty_vector()
                losses = []
                loss_chunk += l
                loss_all += l
                total_chunk += 1
                total_all += 1

        # consider any remaining losses
        if len(losses) > 0:
            try:
                loss = dy.esum(losses)
                loss.scalar_value()
                loss.backward()
                self.trainer.update()
            except:
                pass
            dy.renew_cg()
            self.set_empty_vector()

        end = time.time()
        print('\nend of epoch')
        print(
            f'count: {i}\tloss: {loss_all/total_all:.4f}\ttime: {end-start_all:,.2f} secs'
        )
Exemplo n.º 10
0
    def train_one_step(self, sent):
        domain_total = domain_correct = loss_value = 0
        t0 = time()
        errs = []

        self.encode(sent)
        sent_agenda = [SentSequence(sent)]

        for token in traverse_topdown(sent.root):
            all_agendas = []
            # training left-to-right
            if 'l2r' in self.args.lin_decoders:
                gold_seq = self.l2r_linearizer.init_seq(token)
                while not self.l2r_linearizer.finished(gold_seq):
                    agenda, gold_seq = self.l2r_linearizer.decode(
                        gold_seq, True)
                    all_agendas.append(agenda)

                    if gold_seq is not agenda[0]:
                        scores = [gold_seq.score_expr] + [
                            seq.score_expr
                            for seq in agenda if seq is not gold_seq
                        ]
                        errs.append(dy.hinge(dy.concatenate(scores), 0))
            # right-to-left
            if 'r2l' in self.args.lin_decoders:
                gold_seq = self.r2l_linearizer.init_seq(token)
                while not self.r2l_linearizer.finished(gold_seq):
                    agenda, gold_seq = self.r2l_linearizer.decode(
                        gold_seq, True)
                    all_agendas.append(agenda)
                    if gold_seq is not agenda[0]:
                        scores = [gold_seq.score_expr] + [
                            seq.score_expr
                            for seq in agenda if seq is not gold_seq
                        ]
                        errs.append(dy.hinge(dy.concatenate(scores), 0))
            # head-to-dep
            if 'h2d' in self.args.lin_decoders:
                gold_seq = self.h2d_linearizer.init_seq(token)
                agenda = [gold_seq]

                if self.h2d_linearizer.finished(gold_seq):
                    all_agendas.append(agenda)
                else:
                    while not self.h2d_linearizer.finished(gold_seq):
                        agenda, gold_seq = self.h2d_linearizer.decode(
                            gold_seq, True)
                        all_agendas.append(agenda)
                        # update only against all incorrect sequences (exclude lower scoring gold seq)
                        if gold_seq is not agenda[0]:
                            scores = [gold_seq.score_expr] + [
                                seq.score_expr
                                for seq in agenda if not seq.correct
                            ]
                            errs.append(dy.hinge(dy.concatenate(scores), 0))

            new_agenda = []
            best_seqs = self.vote_best_seq(sent, all_agendas,
                                           self.args.beam_size)
            for sent_seq in sent_agenda:
                for seq in best_seqs:
                    new_seq = sent_seq.append(seq)
                    new_agenda.append(new_seq)
            new_agenda.sort(key=lambda x: -x.score)
            sent_agenda = new_agenda[:self.args.beam_size]

            if token['deps']:
                domain_total += 1
                domain_correct += agenda[0].correct

        sent['nbest_linearized_tokens'] = [
            seq.get_sorted_tokens() for seq in sent_agenda
        ]
        # random sequence from the beam to give the downstream training set more realistic input
        sent['linearized_tokens'] = random.choice(
            sent['nbest_linearized_tokens'])

        loss = dy.esum(errs) if errs else 0
        loss_value = loss.value() if loss else 0

        return {
            'time': time() - t0,
            'loss': loss_value,
            'loss_expr': loss,
            'total': domain_total,
            'correct': domain_correct
        }
Exemplo n.º 11
0
def train_DRMM(train_pairs,
               dev_data_pairs,
               dev_data,
               jobs_df,
               excluding_set,
               w2v_model,
               train_batch_size,
               n_epochs,
               mlp_layers=5,
               hidden_size=10,
               p=1):
    """
    This function trains the mlp of DRMM by
    :param train_pairs: triplets of (query, positive doc, negative doc) used for training the model
    :param dev_data_pairs:  triplets of (query, positive doc, negative doc) used to evaluate models performance
    :param dev_data: dict containing query_tokens, query idfs, 50 positive document ids, 50 negative document ids, 50 positive documents BM25 scores, 50 negative docs BM25 scores
    :param jobs_df: pandas dataframe containing jobs dataset
    :param excluding_set: tokens with small idf to exclude
    :param w2v_model:  gensim word2vec
    :param train_batch_size: batch size of data to backpropagate
    :param p: probability used to apply dropout
    :return: dictionary containing data for learning curve
    """

    #create an object of classs DRMM
    drmm_mod = drmm_model.DRMM(mlp_layers, hidden_size)

    #load pretrained weights of the MLP layer
    drmm_mod.load_weights("dataset/results/res_no_bm25/no_unigrams")

    train_size_list = np.arange(0, len(train_pairs) + 1, 9000)

    #dictionary containing data needed to construct a learning curve
    learning_curve_data = {
        "num_of_data": train_size_list[1:],
        "train_accuracy": [],
        "dev_pairs_accuracy": [],
        "map_on_test_set": []
    }

    query_doc_df = {}
    flag = False

    for t in range(len(train_size_list) - 1):
        print(train_size_list[t], train_size_list[t + 1])

        train_subset = train_pairs.iloc[train_size_list[t]:train_size_list[t +
                                                                           1]]
        print(len(train_subset))
        best_map = -1
        dev_accuracy_prev = 0.0
        for epoch in range(1, n_epochs + 1):
            print('\nEpoch: {0}/{1}'.format(epoch, n_epochs))
            sum_of_losses = 0
            train_subset = train_subset.sample(frac=1)
            train_batches = chunks(
                range(train_size_list[t], train_size_list[t + 1]),
                train_batch_size)
            hits = 0
            for batch in train_batches:
                dy.renew_cg()  # new computation graph
                batch_losses = []
                batch_preds = []
                for i in batch:
                    q_dpos_hist = train_subset.loc[i, 'pos_histogram']
                    q_dneg_hist = train_subset.loc[i, 'neg_histogram']
                    query_idf = train_subset.loc[i, 'query_idf']
                    pos_bm25 = train_subset.loc[i, 'pos_normBM25'][0]
                    neg_bm25 = train_subset.loc[i, 'neg_normBM25'][0]
                    pos_uni_overlap = train_subset.loc[
                        i, 'overlapping_unigrams_pos']
                    pos_bi_overlap = train_subset.loc[
                        i, 'overlapping_bigrams_pos']
                    pos_overlap_features = [pos_uni_overlap, pos_bi_overlap]
                    neg_uni_overlap = train_subset.loc[
                        i, 'overlapping_unigrams_neg']
                    neg_bi_overlap = train_subset.loc[
                        i, 'overlapping_bigrams_neg']
                    neg_overlap_features = [neg_uni_overlap, neg_bi_overlap]
                    preds = drmm_mod.predict_pos_neg_scores(
                        q_dpos_hist, q_dneg_hist, query_idf, pos_bm25,
                        neg_bm25, pos_overlap_features, neg_overlap_features,
                        p)
                    batch_preds.append(preds)
                    loss = dy.hinge(preds, 0)
                    batch_losses.append(loss)
                batch_loss = dy.esum(batch_losses) / len(batch)
                #print(float(batch_loss.npvalue())/len(batch))
                sum_of_losses += float(batch_loss.npvalue()[0])
                for p in batch_preds:
                    p_v = p.value()
                    if p_v[0] > p_v[1]:
                        hits += 1
                batch_loss.backward()
                drmm_mod.trainer.update()  # this calls forward on the batch

            train_acc = hits / train_subset.shape[0]

            val_preds = []
            val_losses = []
            hits = 0
            dy.renew_cg()
            for i, row in dev_data_pairs.iterrows():
                q_dpos_hist = dev_data_pairs.loc[i, 'pos_histogram']
                q_dneg_hist = dev_data_pairs.loc[i, 'neg_histogram']
                query_idf = dev_data_pairs.loc[i, 'query_idf']
                pos_bm25 = dev_data_pairs.loc[i, 'pos_normBM25'][0]
                neg_bm25 = dev_data_pairs.loc[i, 'neg_normBM25'][0]
                pos_uni_overlap = dev_data_pairs.loc[
                    i, 'overlapping_unigrams_pos']
                pos_bi_overlap = dev_data_pairs.loc[i,
                                                    'overlapping_bigrams_pos']
                pos_overlap_features = [pos_uni_overlap, pos_bi_overlap]
                neg_uni_overlap = dev_data_pairs.loc[
                    i, 'overlapping_unigrams_neg']
                neg_bi_overlap = dev_data_pairs.loc[i,
                                                    'overlapping_bigrams_neg']
                neg_overlap_features = [neg_uni_overlap, neg_bi_overlap]
                preds_dev = drmm_mod.predict_pos_neg_scores(
                    q_dpos_hist,
                    q_dneg_hist,
                    query_idf,
                    pos_bm25,
                    neg_bm25,
                    pos_overlap_features,
                    neg_overlap_features,
                    p=1)
                val_preds.append(preds_dev)
                loss = dy.hinge(preds_dev, 0)
                val_losses.append(loss)
            val_loss = dy.esum(val_losses)
            sum_of_losses += val_loss.npvalue()[
                0]  # this calls forward on the batch
            for p in val_preds:
                p_v = p.value()
                if p_v[0] > p_v[1]:
                    hits += 1

            dev_accuracy = hits / dev_data_pairs.shape[0]

            print('\nTraining acc: {0}'.format(train_acc))
            print('Dev acc: {0}'.format(dev_accuracy))

            if dev_accuracy - dev_accuracy_prev <= 0.008 or epoch == 50 or dev_accuracy > 0.9:
                print(dev_accuracy - dev_accuracy_prev)
                map_dev, query_doc_df, flag, check_manually = rerank(
                    dev_data, drmm_mod, jobs_df, excluding_set, w2v_model,
                    query_doc_df, flag)
                print("map_dev", map_dev)
                best_map = map_dev
                best_train_accuracy = train_acc
                best_dev_accuracy = dev_accuracy
                drmm_mod.dump_weights("dataset/results/res_dropout")
                break

            #if dev_accuracy >= 0.85: #early stop
            #    break

            dev_accuracy_prev = dev_accuracy
        metrics_results = [best_map, best_train_accuracy, best_dev_accuracy]

        learning_curve_data["train_accuracy"].append(best_train_accuracy)
        learning_curve_data["dev_pairs_accuracy"].append(best_dev_accuracy)
        learning_curve_data["map_on_test_set"].append(best_map)

        save_dataframe(check_manually,
                       path="dataset/results/res_dropout/check_manually_dev" +
                       str(train_size_list[t + 1]))
    #save_dataframe(metrics_results, path = "dataset/results/res_dropout_mlp/metrics_results_dev")
    save_dataframe(learning_curve_data,
                   path="dataset/results/res_dropout/learning_curve_data_dict")

    return learning_curve_data
Exemplo n.º 12
0
def tune_DRMM(train_pairs,
              tuning_data_pairs,
              tuning_data,
              jobs_df,
              excluding_set,
              w2v_model,
              train_batch_size=128,
              n_epochs=10,
              mlp_layers=5,
              hidden_size=10):
    """
    This function is used to tune the hyperparameters hidden size and number of layers using a "tuning dataset"
    :return best number of mlp layers, best number of units per layer
    """

    mlp_layers_list = [3, 5, 8]
    nodes_per_layer = [10, 20]
    metrics_dict = {
        "hidden_size": [],
        "mlp_layers": [],
        "train_accuracy": [],
        "dev_pairs_accuracy": [],
        "best_map": []
    }

    for layer in tqdm(mlp_layers_list, position=1):
        for hidden_size in nodes_per_layer:
            drmm_mod = drmm_model.DRMM(mlp_layers, hidden_size)
            print("mlp_layers:", layer, "\n", "hidden_size:", hidden_size)
            metrics_dict["mlp_layers"].append(layer)
            metrics_dict["hidden_size"].append(hidden_size)
            dev_accuracy_prev = 0.0
            train_shuffled = train_pairs.copy()
            best_map = -1
            for epoch in range(1, n_epochs + 1):
                print('\nEpoch: {0}/{1}'.format(epoch, n_epochs))
                sum_of_losses = 0
                train_shuffled = train_shuffled.sample(frac=1)
                train_batches = chunks(range(len(train_shuffled['cand_id'])),
                                       train_batch_size)
                hits = 0
                for batch in train_batches:
                    dy.renew_cg()  # new computation graph
                    batch_losses = []
                    batch_preds = []
                    for i in batch:
                        q_dpos_hist = train_shuffled.loc[i, 'pos_histogram']
                        q_dneg_hist = train_shuffled.loc[i, 'neg_histogram']
                        query_idf = train_shuffled.loc[i, 'query_idf']
                        pos_bm25 = train_shuffled.loc[i, 'pos_normBM25'][0]
                        neg_bm25 = train_shuffled.loc[i, 'neg_normBM25'][0]
                        pos_uni_overlap = train_shuffled.loc[
                            i, 'overlapping_unigrams_pos']
                        pos_bi_overlap = train_shuffled.loc[
                            i, 'overlapping_bigrams_pos']
                        pos_overlap_features = [
                            pos_uni_overlap, pos_bi_overlap
                        ]
                        neg_uni_overlap = train_shuffled.loc[
                            i, 'overlapping_unigrams_neg']
                        neg_bi_overlap = train_shuffled.loc[
                            i, 'overlapping_bigrams_neg']
                        neg_overlap_features = [
                            neg_uni_overlap, neg_bi_overlap
                        ]
                        preds = drmm_mod.predict_pos_neg_scores(
                            q_dpos_hist, q_dneg_hist, query_idf, pos_bm25,
                            neg_bm25, pos_overlap_features,
                            neg_overlap_features)
                        batch_preds.append(preds)
                        loss = dy.hinge(preds, 0)
                        batch_losses.append(loss)
                    batch_loss = dy.esum(batch_losses) / len(batch)
                    #print(float(batch_loss.npvalue())/len(batch))
                    sum_of_losses += float(batch_loss.npvalue()[0])
                    for p in batch_preds:
                        p_v = p.value()
                        if p_v[0] > p_v[1]:
                            hits += 1
                    batch_loss.backward()
                    drmm_mod.trainer.update(
                    )  # this calls forward on the batch

                train_acc = hits / train_shuffled.shape[0]

                val_preds = []
                val_losses = []
                hits = 0
                dy.renew_cg()
                for i, row in tuning_data_pairs.iterrows():
                    q_dpos_hist = tuning_data_pairs.loc[i, 'pos_histogram']
                    q_dneg_hist = tuning_data_pairs.loc[i, 'neg_histogram']
                    query_idf = tuning_data_pairs.loc[i, 'query_idf']
                    pos_bm25 = tuning_data_pairs.loc[i, 'pos_normBM25'][0]
                    neg_bm25 = tuning_data_pairs.loc[i, 'neg_normBM25'][0]
                    pos_uni_overlap = tuning_data_pairs.loc[
                        i, 'overlapping_unigrams_pos']
                    pos_bi_overlap = tuning_data_pairs.loc[
                        i, 'overlapping_bigrams_pos']
                    pos_overlap_features = [pos_uni_overlap, pos_bi_overlap]
                    neg_uni_overlap = tuning_data_pairs.loc[
                        i, 'overlapping_unigrams_neg']
                    neg_bi_overlap = tuning_data_pairs.loc[
                        i, 'overlapping_bigrams_neg']
                    neg_overlap_features = [neg_uni_overlap, neg_bi_overlap]
                    preds_dev = drmm_mod.predict_pos_neg_scores(
                        q_dpos_hist, q_dneg_hist, query_idf, pos_bm25,
                        neg_bm25, pos_overlap_features, neg_overlap_features)
                    val_preds.append(preds_dev)
                    loss = dy.hinge(preds_dev, 0)
                    val_losses.append(loss)
                val_loss = dy.esum(val_losses)
                sum_of_losses += val_loss.npvalue()[
                    0]  # this calls forward on the batch
                for p in val_preds:
                    p_v = p.value()
                    if p_v[0] > p_v[1]:
                        hits += 1

                dev_accuracy = hits / tuning_data_pairs.shape[0]

                print('Training acc: {0}'.format(train_acc))
                print('Dev acc: {0}'.format(dev_accuracy))

                if train_acc < 0.6:
                    continue

                map_dev, query_doc_data, flag = rerank(tuning_data, drmm_mod,
                                                       jobs_df, excluding_set,
                                                       w2v_model,
                                                       query_doc_data, flag)
                print("map_dev", map_dev)

                if map_dev > best_map:
                    print('===== Best epoch so far =====')
                    best_map = map_dev
                    best_epoch = epoch
                    best_train_accuracy = train_acc
                    best_dev_accuracy = dev_accuracy
                    #drmm_mod.dump_weights("C:/Users/nataz/Downloads/job_prop_data/dataset2")

            if dev_accuracy - dev_accuracy_prev <= 0.005:
                break

            dev_accuracy_prev = dev_accuracy

            metrics_dict["train_accuracy"].append(train_acc)
            metrics_dict["dev_pairs_accuracy"].append(dev_accuracy)
            metrics_dict["best_map"].append(best_map)

    metrics_df = pd.DataFrame.from_dict(metrics_dict)
    max_map = metrics_df["best_map"].max()

    best_mlp_layer = metrics_df["mlp_layers"][metrics_df["best_map"] ==
                                              max_map].tolist()[0]
    best_hidden_size = metrics_df["hidden_size"][metrics_df["best_map"] ==
                                                 max_map].tolist()[0]
    save_dataframe(metrics_df, path="dataset/results/metrics_tuning")

    return best_mlp_layer, best_hidden_size