def get_evaluation(self, sess, data_list, data_type, global_step=None):
        _logger.add()
        _logger.add('getting evaluation result')

        logits_list, loss_list, accu_list = [], [], []
        for sample_batch, _, _, _ in Dataset.generate_batch_sample_iter(
                data_list):
            feed_dict = self.model.get_feed_dict(sample_batch, 'dev')
            logits, loss, accu = sess.run(
                [self.model.logits, self.model.loss, self.model.accuracy],
                feed_dict)
            logits_list.append(np.argmax(logits, -1))
            loss_list.append(loss)
            accu_list.append(accu)

        logits_array = np.concatenate(logits_list, 0)
        loss_value = np.mean(loss_list)
        accu_array = np.concatenate(accu_list, 0)
        accu_value = np.mean(accu_array)

        if global_step is not None:
            if data_type == 'train':
                summary_feed_dict = {
                    self.train_loss: loss_value,
                    self.train_accuracy: accu_value,
                }
                summary = sess.run(self.train_summaries, summary_feed_dict)
                self.writer.add_summary(summary, global_step)
            elif data_type == 'dev':
                summary_feed_dict = {
                    self.dev_loss: loss_value,
                    self.dev_accuracy: accu_value,
                }
                summary = sess.run(self.dev_summaries, summary_feed_dict)
                self.writer.add_summary(summary, global_step)
            else:
                summary_feed_dict = {
                    self.test_loss: loss_value,
                    self.test_accuracy: accu_value,
                }
                summary = sess.run(self.test_summaries, summary_feed_dict)
                self.writer.add_summary(summary, global_step)
        return loss_value, accu_value
Ejemplo n.º 2
0
def train():
    output_model_params()
    loadFile = True
    ifLoad, data = False, None
    if loadFile:
        ifLoad, data = load_file(cfg.processed_path, 'data', 'pickle')
    if not ifLoad or not loadFile:
        data_object = Dataset(cfg.train_data_path, cfg.dev_data_path)
        data_object.save_dict(cfg.dict_path)
        save_file({'data_obj': data_object}, cfg.processed_path)
    else:
        data_object = data['data_obj']

    emb_mat_token, emb_mat_glove = data_object.emb_mat_token, data_object.emb_mat_glove

    with tf.variable_scope(network_type) as scope:
        if network_type in model_set:
            model = Model(emb_mat_token, emb_mat_glove,
                          len(data_object.dicts['token']),
                          len(data_object.dicts['char']),
                          data_object.max_lens['token'], scope.name)

    graphHandler = GraphHandler(model)
    evaluator = Evaluator(model)
    performRecoder = PerformRecoder(5)

    if cfg.gpu_mem < 1.:
        gpu_options = tf.GPUOptions(
            per_process_gpu_memory_fraction=cfg.gpu_mem, allow_growth=True)
    else:
        gpu_options = tf.GPUOptions()
    graph_config = tf.ConfigProto(gpu_options=gpu_options,
                                  allow_soft_placement=True)
    sess = tf.Session(config=graph_config)
    graphHandler.initialize(sess)

    # begin training
    steps_per_epoch = int(
        math.ceil(1.0 * len(data_object.digitized_train_data_list) /
                  cfg.train_batch_size))
    num_steps = steps_per_epoch * cfg.max_epoch or cfg.num_steps

    global_step = 0
    # debug or not
    if cfg.debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    for sample_batch, batch_num, data_round, idx_b in Dataset.generate_batch_sample_iter(
            data_object.digitized_train_data_list, num_steps):
        global_step = sess.run(model.global_step) + 1
        if_get_summary = global_step % (cfg.log_period or steps_per_epoch) == 0
        loss, summary, train_op = model.step(sess,
                                             sample_batch,
                                             get_summary=if_get_summary)

        if global_step % 10 == 0:
            _logger.add('data round: %d: %d/%d, global step:%d -- loss: %.4f' %
                        (data_round, idx_b, batch_num, global_step, loss))

        if if_get_summary:
            graphHandler.add_summary(summary, global_step)

        # Occasional evaluation
        if global_step % (cfg.eval_period or steps_per_epoch) == 0:
            # ---- dev ----
            dev_loss, dev_accu = evaluator.get_evaluation(
                sess, data_object.digitized_dev_data_list, 'dev', global_step)
            _logger.add('==> for dev, loss: %.4f, accuracy: %.4f' %
                        (dev_loss, dev_accu))
            # ---- test ----
            if cfg.test_data_name != None:
                test_loss, test_accu = evaluator.get_evaluation(
                    sess, data_object.digitized_test_data_list, 'test',
                    global_step)
                _logger.add('~~> for test, loss: %.4f, accuracy: %.4f' %
                            (test_loss, test_accu))

            is_in_top, deleted_step = performRecoder.update_top_list(
                global_step, dev_accu, sess)
        this_epoch_time, mean_epoch_time = cfg.time_counter.update_data_round(
            data_round)
        if this_epoch_time is not None and mean_epoch_time is not None:
            _logger.add('##> this epoch time: %f, mean epoch time: %f' %
                        (this_epoch_time, mean_epoch_time))