def evaluate(sess, textRNN, x, y):
    batches = utils.batch_iter(list(zip(x, y)), FLAGS.batch_size, 1)
    num_batches_per_epoch = int((len(x) - 1) / FLAGS.batch_size) + 1
    loss, acc, curr_pre, curr_recall, counter = 0.0, 0.0, 0.0, 0.0, 0
    for batch in batches:
        x_batch, y_batch = zip(*batch)
        curr_loss, curr_acc, pred, _ = sess.run(
            [
                textRNN.loss, textRNN.accuracy, textRNN.predictions,
                textRNN.train
            ],
            feed_dict={
                textRNN.input_x: x_batch,
                textRNN.input_y: y_batch,
                textRNN.dropout_keep_prob: 0.1
            })
        if FLAGS.class_weight == 0:
            y_true = [int(i) for i in y_batch]
        else:
            y_true = np.argmax(np.array([list(i) for i in y_batch]), axis=1)
        pre = precision_score(y_true, pred)
        recall = recall_score(y_true, pred)
        loss, counter, acc, pre, recall = loss + curr_loss, counter + 1, acc + curr_acc, pre + curr_pre, recall + curr_recall
        if counter % num_batches_per_epoch == 0:
            log.info(
                "\t-\tBatch_size %d\t-\tTest Loss:%.3f\t-\tTest Accuracy:%.3f\t-\t"
                "Test Precision:%.3f\t-\tTest Recall:%.3f " %
                (counter, loss / float(counter), acc / float(counter),
                 pre / float(counter), recall / float(counter)))
    def train_one_epoch(self):
        sum_loss = 0.0
        mrr = 0.0

        # train process
        batches = batch_iter(self.L, self.batch_size, 0, self.lookup, 'f', 'g')
        batch_id = 0
        for batch in batches:
            pos, neg = batch
            if not len(pos['f']) == len(pos['g']) and not len(neg['f']) == len(
                    neg['g']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos['f'])
            feed_dict = {
                self.pos_inputs['f']: self.X[pos['f'], :],
                self.pos_inputs['g']: self.Y[pos['g'], :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            batch_id += 1

        # valid process
        valid_size = 0
        if self.valid:
            valid = valid_iter(self.L, self.valid_sample_size, self.lookup,
                               'f', 'g')
            if not len(valid['f']) == len(valid['g']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                return
            valid_size = len(valid['f'])
            feed_dict = {
                self.valid_inputs['f']: self.X[valid['f'], :],
                self.valid_inputs['g']: self.Y[valid['g'], :]
            }
            valid_dist = self.sess.run(self.dot_dist, feed_dict)

            mrr = .0
            for i in range(valid_size):
                fst_dist = valid_dist[i][0]
                pos = 1
                for k in range(1, len(valid_dist[i])):
                    if fst_dist >= valid_dist[i][k]:
                        pos += 1
                mrr += 1. / pos
            self.logger.info(
                'Epoch={}, sum of loss={!s}, mrr in validation={}'.format(
                    self.cur_epoch, sum_loss / (batch_id + 1e-8),
                    mrr / (valid_size + 1e-8)))
        else:
            self.logger.info('Epoch={}, sum of loss={!s}'.format(
                self.cur_epoch, sum_loss / batch_id))
        self.cur_epoch += 1

        # print(batch_id,valid_size)
        return sum_loss / (batch_id + 1e-8), mrr / (valid_size + 1e-8)
Esempio n. 3
0
    def train_one_epoch(self):
        sum_loss = 0.0

        # train process
        # with tf.device(self.device):
        batches = batch_iter(self.L, self.batch_size, 0\
                , self.lookup_f, self.lookup_g, 'f', 'g')
        batch_id = 0
        for batch in batches:
            pos_f, pos_g, neg_f, neg_g = batch
            if not len(pos_f) == len(pos_g):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos_f)
            feed_dict = {
                self.pos_f_inputs: self.X[pos_f, :],
                self.pos_g_inputs: self.Y[pos_g, :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            # self.logger.info('Finish processing batch {} and cur_loss={}'
            #                        .format(batch_id, cur_loss))
            batch_id += 1
        # valid process
        valid_f, valid_g = valid_iter(self.L, self.valid_sample_size,
                                      self.lookup_f, self.lookup_g, 'f', 'g')
        # print valid_f,valid_g
        if not len(valid_f) == len(valid_g):
            self.logger.info(
                'The input label file goes wrong as the file format.')
            return
        valid_size = len(valid_f)
        feed_dict = {
            self.valid_f_inputs: self.X[valid_f, :],
            self.valid_g_inputs: self.Y[valid_g, :]
        }
        valid_dist = self.sess.run(self.dot_dist, feed_dict)
        # valid_dist = self.sess.run(self.hamming_dist,feed_dict)
        mrr = .0
        for i in range(valid_size):
            fst_dist = valid_dist[i][0]
            pos = 1
            for k in range(1, len(valid_dist[i])):
                if fst_dist >= valid_dist[i][k]:
                    pos += 1
            # print pos
            # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
            # print valid_dist[i]
            mrr += 1. / pos
        self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
            self.cur_epoch, sum_loss / batch_id, mrr / valid_size))
        # print 'mrr:',mrr/valid_size
        # self.logger.info('Epoch={}, sum of loss={!s}, valid_loss={}'
        #                     .format(self.cur_epoch, sum_loss/batch_id, valid_loss))
        self.cur_epoch += 1
Esempio n. 4
0
    def train(self, x_train, y_train, x_val, y_val):
        x_train = np.array(x_train)
        x_val = np.array(x_val)
        y_train = np.array(y_train)
        y_val = np.array(y_val)
        # logits = self.logits
        # keep_prob = self.keep_prob
        # fingerprint_input = self.fingerprint_input
        # ground_truth_input = self.ground_truth_input
        # evaluation_step = self.evaluation_step

        control_dependencies = []
        # Create the back propagation and training evaluation machinery in the graph.

        with tf.name_scope('train'), tf.control_dependencies(
                control_dependencies):
            learning_rate_input = tf.placeholder(tf.float32, [],
                                                 name='learning_rate_input')
            train_step = tf.train.AdamOptimizer(learning_rate_input,
                                                epsilon=1e-6).minimize(
                                                    self.cross_entropy_mean)

        global_step = tf.contrib.framework.get_or_create_global_step()
        increment_global_step = tf.assign(global_step, global_step + 1)

        best_acc_val = 0
        cur_step = 0
        last_improved_step = 0
        data_len = len(y_train)
        step_sum = (int(
            (data_len - 1) / config.batch_size) + 1) * config.epochs_num
        adjust_num = 0
        flag = True
        saver = tf.train.Saver(max_to_keep=1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for epoch in range(config.epochs_num):
                for batch_x, batch_y in batch_iter(x_train, y_train,
                                                   config.batch_size):
                    # print(len(batch_y))
                    # print(len(y_train))
                    feed_dict = {
                        self.fingerprint_input: batch_x,
                        self.ground_truth_input: batch_y,
                        learning_rate_input: 0.001,
                        self.keep_prob: 0.5
                    }
                    fetches = [
                        self.cross_entropy_mean, self.evaluation_step,
                        train_step, increment_global_step
                    ]
                    loss_train, acc_train, _, _ = sess.run(fetches,
                                                           feed_dict=feed_dict)
                    cur_step += 1
                    if cur_step % config.print_per_batch == 0:
                        acc_val = self.evaluate(sess, x_val, y_val)
                        if acc_val >= best_acc_val:
                            best_acc_val = acc_val
                            last_improved_step = cur_step
                            saver.save(sess=sess,
                                       save_path=model_config.model_save_path)
                            improved_str = '*'
                        else:
                            improved_str = ''
                        cur_step_str = str(cur_step) + "/" + str(step_sum)
                        msg = 'The current step: {0:>6}, Train Loss: {1:>6.2}, Train Acc: {2:>7.2%},' \
                              + '  Val Acc: {3:>7.2%} {4}'
                        print(
                            msg.format(cur_step_str, loss_train, acc_train,
                                       acc_val, improved_str))
                    if cur_step - last_improved_step >= config.improvement_step:
                        last_improved_step = cur_step
                        print(
                            "No optimization for a long time, auto adjust learning_rate..."
                        )
                        adjust_num += 1
                        if adjust_num > 3:
                            print(
                                "No optimization for a long time, auto-stopping..."
                            )
                            flag = False
                    if not flag:
                        break
                if not flag:
                    break
Esempio n. 5
0
    def train_one_epoch(self):
        sum_loss = 0.0
        mrr = 0.0

        # train process
        # print 'start training...'
        batches = batch_iter(self.L, self.batch_size, self.neg_ratio\
                                        , self.lookup, 'src', 'end')

        batch_id = 0
        for batch in batches:
            # training the process from source network to end network
            pos, neg = batch
            if not len(pos['src']) == len(pos['end']) and not len(
                    neg['src']) == len(neg['end']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size = len(pos['src'])
            feed_dict = {
                self.inputs_pos['src']: self.F[pos['src'], :],
                self.inputs_pos['end']: self.G[pos['end'], :],
                self.inputs_neg['src']: self.F[neg['src'], :],
                self.inputs_neg['end']: self.G[neg['end'], :],
                self.cur_batch_size: batch_size
            }
            _, cur_loss = self.sess.run([self.train_op, self.loss], feed_dict)

            sum_loss += cur_loss
            batch_id += 1

        if self.valid:
            # valid process
            valid = valid_iter(self.L, self.valid_sample_size, self.lookup,
                               'src', 'end')
            # print valid_f,valid_g
            if not len(valid['src']) == len(valid['end']):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                return
            valid_size = len(valid['src'])
            feed_dict = {
                self.inputs_val['src']: self.F[valid['src'], :],
                self.inputs_val['end']: self.G[valid['end'], :],
            }
            # valid_dist = self.sess.run(self.dot_dist,feed_dict)
            valid_dist = self.sess.run(self.hamming_dist, feed_dict)
            mrr = .0
            for i in range(valid_size):
                fst_dist = valid_dist[i][0]
                pos = 1
                for k in range(1, len(valid_dist[i])):
                    if fst_dist >= valid_dist[i][k]:
                        pos += 1
                # print pos
                # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
                # print valid_dist[i]
                mrr += 1. / pos
            self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
                self.cur_epoch, sum_loss / batch_id / 2, mrr / valid_size))
        else:
            self.logger.info('Epoch={}, sum of loss={!s}'.format(
                self.cur_epoch, sum_loss / batch_id / 2))

        self.cur_epoch += 1

        # print(sum_loss/(batch_id+1e-8), mrr/(valid_size+1e-8))
        return sum_loss / (batch_id + 1e-8), mrr / (valid_size + 1e-8)
def train(x_train, y_train, embedding_weights, x_dev, y_dev):
    # Training
    # ==================================================
    # Create session
    session_conf = tf.ConfigProto()
    sess = tf.Session(config=session_conf)
    with sess.as_default():
        textRNN = TextRNN(num_classes=FLAGS.num_classes,
                          learning_rate=FLAGS.learning_rate,
                          decay_steps=FLAGS.decay_steps,
                          decay_rate=FLAGS.decay_rate,
                          sequence_length=FLAGS.sequence_length,
                          vocab_size=embedding_weights.shape[0],
                          embed_size=FLAGS.embed_size,
                          is_trace_train=FLAGS.is_trace_train,
                          class_weight=FLAGS.class_weight)

        #  is or not keep trace
        saver = tf.train.Saver(max_to_keep=FLAGS.num_epochs)
        if FLAGS.is_trace_train:
            if not os.path.exists(os.path.join(FLAGS.ckpt_dir, "tf_log")):
                os.makedirs(os.path.join(FLAGS.ckpt_dir, "tf_log"))
            train_summary_dir = os.path.join(FLAGS.ckpt_dir, "tf_log")
            train_summary_writer = tf.summary.FileWriter(
                train_summary_dir, sess.graph)
        checkpoint_dir = os.path.abspath(
            os.path.join(FLAGS.ckpt_dir, "checkpoints"))
        if os.path.exists(checkpoint_dir + "checkpoint"):
            print("Restoring Variables from Checkpoint for rnn model.")
            saver.restore(
                sess,
                tf.train.latest_checkpoint(
                    os.path.join(FLAGS.ckpt_dir, "checkpoint")))
        #Initialize
        print('Initializing Variables')
        sess.run(tf.global_variables_initializer())
        #embeding layer
        assign_pretrained_word_embedding(sess, embedding_weights, textRNN)
        # 3.feed data & training & Generate batches
        batches = utils.batch_iter(list(zip(x_train, y_train)),
                                   FLAGS.batch_size, FLAGS.num_epochs)
        # Training loop. For each batch...
        time_str = datetime.datetime.now().isoformat()
        num_batches_per_epoch = int((len(x_train) - 1) / FLAGS.batch_size) + 1
        loss, acc, curr_pre, curr_recall, counter = 0.0, 0.0, 0.0, 0.0, 0
        num = int(x_train.shape[0])
        print('num_batches_per_epoch:', num_batches_per_epoch)
        for batch in batches:
            x_batch, y_batch = zip(*batch)
            x_batch, y_batch = np.array([list(i) for i in x_batch]), np.array(
                [list(i) for i in y_batch])
            if counter == 0:
                print("trainX:", y_batch.shape)

            if FLAGS.is_trace_train:
                curr_loss, curr_acc, pred, _, summaries, step = sess.run(
                    [
                        textRNN.loss, textRNN.accuracy, textRNN.predictions,
                        textRNN.train, textRNN.train_summary_op,
                        textRNN.global_step
                    ],
                    feed_dict={
                        textRNN.input_x: x_batch,
                        textRNN.input_y: y_batch,
                        textRNN.dropout_keep_prob: 0.1
                    })
                train_summary_writer.add_summary(summaries, step)
                print('global_step:', step)
            else:
                curr_loss, curr_acc, pred, _, step = sess.run(
                    [
                        textRNN.loss, textRNN.accuracy, textRNN.predictions,
                        textRNN.train, textRNN.global_step
                    ],
                    feed_dict={
                        textRNN.input_x: x_batch,
                        textRNN.input_y: y_batch,
                        textRNN.dropout_keep_prob: 0.1
                    })
                print('global_step:', step)

            if FLAGS.class_weight == 0:
                y_true = [int(i) for i in y_batch]
            else:
                y_true = np.argmax(np.array([list(i) for i in y_batch]),
                                   axis=1)
            pre = precision_score(y_true, pred)
            recall = recall_score(y_true, pred)
            loss, counter, acc, pre, recall = loss + curr_loss, counter + 1, acc + curr_acc, \
                                              pre + curr_pre, recall + curr_recall
            if counter % 10 == 0:
                print(
                    time_str,
                    " \t-\tBatch_size %d/%d \t-\tTrain Loss:%.3f \t-\tTrain Accuracy:%.3f"
                    "\t-\tTrain Precision:%.3f \t-\tTrain Recall:%.3f" %
                    (counter, num, loss / float(counter), acc / float(counter),
                     pre / float(counter), recall / float(counter)))
            if counter % num_batches_per_epoch == 0:
                # 4.在测试集上做测试,并报告测试准确率 Test
                evaluate(sess, textRNN, x_dev, y_dev)
                # 5.save model to checkpoint
                epoch = int(counter / num_batches_per_epoch)
                log.info('Epoch: {:d}'.format(epoch))
                print('save model to checkpoint')
                save_path = FLAGS.ckpt_dir + 'checkpoint/' + str(
                    epoch) + "-model.ckpt"
                saver.save(sess, save_path, global_step=counter)

                # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
                constant_graph = graph_util.convert_variables_to_constants(
                    sess, sess.graph_def, ['input_x', 'predictions'])
                # 此处务必和前面的输入输出对应上,其他的不用管
                # 写入序列化的 PB 文件
                # 模型的名字是model.pb
                with tf.gfile.FastGFile(FLAGS.ckpt_dir + 'checkpoint/' +
                                        str(epoch) + '-model.pb',
                                        mode='wb') as f:
                    f.write(constant_graph.SerializeToString())
Esempio n. 7
0
    def train(self, x_train, y_train, x_val, y_val):
        x_train = np.array(x_train)
        x_val = np.array(x_val)
        y_train = np.array(y_train)
        y_val = np.array(y_val)
        control_dependencies = []
        # Create the back propagation and training evaluation machinery in the graph.

        with tf.name_scope('train'), tf.control_dependencies(
                control_dependencies):
            learning_rate_input = tf.placeholder(tf.float32, [],
                                                 name='learning_rate_input')
            train_step = tf.train.AdamOptimizer(learning_rate_input,
                                                epsilon=1e-6).minimize(
                                                    self.cross_entropy_mean)

        global_step = tf.contrib.framework.get_or_create_global_step()
        increment_global_step = tf.assign(global_step, global_step + 1)

        best_val_score = 0
        last_improved_step = 0
        data_len = len(y_train)
        each_epoch_step_sum = int((data_len - 1) / config.batch_size) + 1
        saver = tf.train.Saver(max_to_keep=1)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            for cur_epoch in range(config.epochs_num):
                cur_step = 0
                train_score = 0
                start_time = int(time.time())
                for batch_x, batch_y in batch_iter(x_train, y_train,
                                                   config.batch_size):
                    print(batch_x.shape)
                    feed_dict = {
                        self.fingerprint_input: batch_x,
                        self.ground_truth_input: batch_y,
                        learning_rate_input: 0.001,
                        self.keep_prob: 0.5
                    }
                    fetches = [
                        self.cross_entropy_mean, self.evaluation_step,
                        train_step, increment_global_step
                    ]
                    train_loss, _train_score, _, _ = sess.run(
                        fetches, feed_dict=feed_dict)
                    cur_step += 1
                    train_score += _train_score
                    if cur_step % config.train_print_step == 0:
                        msg = 'the current step: {0}/{1}, train score: {2:>6.2%}'
                        print(
                            msg.format(cur_step, each_epoch_step_sum,
                                       train_score / config.train_print_step))
                        train_score = 0
                val_score = self.evaluate(sess, x_val, y_val)
                if val_score >= best_val_score:
                    best_val_score = val_score
                    saver.save(sess=sess,
                               save_path=model_config.model_save_path)
                    improved_str = '*'
                    last_improved_epoch = cur_epoch
                else:
                    improved_str = ''
                msg = 'the current epoch: {0}/{1}, val acc: {2:>6.2%}, cost: {3}s {4}'
                end_time = int(time.time())
                print(
                    msg.format(cur_epoch + 1, config.epochs_num, val_score,
                               end_time - start_time, improved_str))
                if cur_epoch - last_improved_epoch >= config.patience_epoch:
                    print("No optimization for a long time, auto stopping...")
                    break
Esempio n. 8
0
    def train_one_epoch(self):
        sum_loss = 0.0

        # train process
        batches_f2g = list(batch_iter(self.L, self.batch_size, self.neg_ratio\
                , self.lookup_f, self.lookup_g, 'f', 'g'))
        batches_g2f = list(batch_iter(self.L, self.batch_size, self.neg_ratio\
                , self.lookup_g, self.lookup_f, 'g', 'f'))
        n_batches = min(len(batches_f2g), len(batches_g2f))
        batch_id = 0
        for i in range(n_batches):
            # training the process from network f to network g
            pos_src_f2g, pos_obj_f2g, neg_src_f2g, neg_obj_f2g = batches_f2g[i]
            if not len(pos_src_f2g) == len(pos_obj_f2g) and not len(
                    neg_src_f2g) == len(neg_obj_f2g):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size_f2g = len(pos_src_f2g)
            feed_dict = {
                self.pos_src_inputs: self.F[pos_src_f2g, :],
                self.pos_obj_inputs: self.G[pos_obj_f2g, :],
                self.neg_src_inputs: self.F[neg_src_f2g, :],
                self.neg_obj_inputs: self.G[neg_obj_f2g, :],
                self.cur_batch_size: batch_size_f2g
            }
            _, cur_loss_f2g = self.sess.run([self.train_op_f2g, self.loss_f2g],
                                            feed_dict)

            sum_loss += cur_loss_f2g

            # training the process from network g to network f
            pos_src_g2f, pos_obj_g2f, neg_src_g2f, neg_obj_g2f = batches_g2f[i]
            if not len(pos_src_g2f) == len(pos_obj_g2f) and not len(
                    neg_src_g2f) == len(neg_obj_g2f):
                self.logger.info(
                    'The input label file goes wrong as the file format.')
                continue
            batch_size_g2f = len(pos_src_g2f)
            feed_dict = {
                self.pos_src_inputs: self.G[pos_src_g2f, :],
                self.pos_obj_inputs: self.F[pos_obj_g2f, :],
                self.neg_src_inputs: self.G[neg_src_g2f, :],
                self.neg_obj_inputs: self.F[neg_obj_g2f, :],
                self.cur_batch_size: batch_size_g2f
            }
            _, cur_loss_g2f = self.sess.run([self.train_op_g2f, self.loss_g2f],
                                            feed_dict)

            sum_loss += cur_loss_g2f

            batch_id += 1
            break

        # valid process
        valid_f, valid_g = valid_iter(self.L, self.valid_sample_size,
                                      self.lookup_f, self.lookup_g, 'f', 'g')
        # print valid_f,valid_g
        if not len(valid_f) == len(valid_g):
            self.logger.info(
                'The input label file goes wrong as the file format.')
            return
        valid_size = len(valid_f)
        feed_dict = {
            self.valid_f_inputs: self.F[valid_f, :],
            self.valid_g_inputs: self.G[valid_g, :],
        }
        # valid_dist = self.sess.run(self.dot_dist,feed_dict)
        valid_dist = self.sess.run(self.hamming_dist, feed_dict)
        mrr = .0
        for i in range(valid_size):
            fst_dist = valid_dist[i][0]
            pos = 1
            for k in range(1, len(valid_dist[i])):
                if fst_dist >= valid_dist[i][k]:
                    pos += 1
            # print pos
            # self.logger.info('dist:{},pos:{}'.format(fst_dist,pos))
            # print valid_dist[i]
            mrr += 1. / pos
        self.logger.info('Epoch={}, sum of loss={!s}, mrr={}'.format(
            self.cur_epoch, sum_loss / batch_id / 2, mrr / valid_size))
        # print 'mrr:',mrr/valid_size
        # self.logger.info('Epoch={}, sum of loss={!s}, valid_loss={}'
        #                     .format(self.cur_epoch, sum_loss/batch_id, valid_loss))
        self.cur_epoch += 1
Esempio n. 9
0
def train(train_x, train_y1, train_y2, test_x1, test_x2, test_y1, test_y2,
          word_dict, args):
    with tf.Session() as sess:
        model = Model(len(word_dict), num_class=2, args=args, vocab=word_dict)

        # Define training procedure
        global_step = tf.Variable(0, trainable=False)
        params = tf.trainable_variables()
        gradients = tf.gradients(model.Loss, params)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        optimizer = tf.train.AdamOptimizer(0.001)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, params),
                                             global_step=global_step)

        # Summary
        y1_loss_summary = tf.summary.scalar("y1_loss", model.subtaska_loss)
        y2_loss_summary = tf.summary.scalar("y2_loss", model.subtaskb_loss)
        Loss_summary = tf.summary.scalar("Loss", model.Loss)
        summary_op = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter("summary", sess.graph)

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        def train_step(batch_x, batch_y1, batch_y2):
            feed_dict = {
                model.X: batch_x,
                model.Y1: batch_y1,
                model.Y2: batch_y2,
                model.dropout: args.dropout
            }
            _, step, summaries, Loss, y1_loss, y2_loss = \
                sess.run([train_op, global_step, summary_op, model.Loss, model.subtaska_loss, model.subtaskb_loss],
                         feed_dict=feed_dict)
            summary_writer.add_summary(summaries, step)

            if step % 10 == 0:
                print("step {0}: loss={1} (y1_loss={2}, y2_loss={3})".format(
                    step, Loss, y1_loss, y2_loss))

        def evalA(test_x, test_y):

            test_batches = batch_iter_eval(test_x, test_y, args.batch_size)
            lossesA, accuraciesA, itersA, f1sA = 0, 0, 0, 0
            pred = []
            for batch_x, batch_y, in test_batches:
                feed_dict = {
                    model.X: batch_x,
                    model.Y1: batch_y,
                    model.Y2: batch_y,
                    model.dropout: args.dropout
                }
                y_lossA, accuracyA, f1A, preds = sess.run([
                    model.subtaska_loss, model.suba_accuracy, model.f1_suba,
                    model.subtaska_predictions
                ],
                                                          feed_dict=feed_dict)
                lossesA += y_lossA
                accuraciesA += accuracyA
                f1sA += f1A
                itersA += 1
                pred = np.concatenate((pred, preds))
            print("test perplexity = {0}".format(np.exp(lossesA / itersA)))
            print("Test Accuracy = {0}".format(accuraciesA / itersA))
            print("Test F1 = {0}\n".format(
                f1_score(test_y, pred, average='macro')))

        def evalB(test_x, test_y):
            test_batches = batch_iter_eval(test_x, test_y, args.batch_size)
            lossesA, accuraciesA, itersA, f1sA, preS, recA = 0, 0, 0, 0, 0, 0
            pred = []
            for batch_x, batch_y, in test_batches:
                feed_dict = {
                    model.X: batch_x,
                    model.Y1: batch_y,
                    model.Y2: batch_y,
                    model.dropout: args.dropout
                }
                y_lossA, accuracyA, f1A, prec, recall, preds = sess.run(
                    [
                        model.subtaskb_loss, model.subb_accuracy,
                        model.f1_subb, model.precision, model.recall,
                        model.subtaskb_predictions
                    ],
                    feed_dict=feed_dict)
                lossesA += y_lossA
                accuraciesA += accuracyA
                f1sA += f1A
                itersA += 1
                pred = np.concatenate((pred, preds))
            print("test perplexity = {0}".format(np.exp(lossesA / itersA)))
            print("Test Accuracy = {0}".format(accuraciesA / itersA))
            # print("Test F1 = {0}\n".format(f1sA/itersA))
            print("Test F1 = {0}\n".format(
                f1_score(test_y, pred, average='macro')))

            # import pdb;pdb.set_trace()

        batches = batch_iter(train_x, train_y1, train_y2, args.batch_size,
                             args.num_epochs)
        for batch_x, batch_y1, batch_y2 in batches:
            train_step(batch_x, batch_y1, batch_y2)
        print("\n-------------Training Ended--------------")
        print("Subtask A (Glove):")
        evalA(test_x1, test_y1)
        print("Subtask B (Glove):")
        evalB(test_x2, test_y2)