def main(_):
    domain_str = aFLAGS.domains
    global num_domains, layer_indexes, domain_weight, do_sent_pair
    num_domains = len(domain_str.split(","))
    layer_indexes_str = aFLAGS.layer_indexes
    layer_indexes = [int(x) for x in layer_indexes_str.split(",")]
    domain_weight = aFLAGS.domain_weight
    do_sent_pair = aFLAGS.do_sent_pair

    app = Application()

    if "PAI" in tf.__version__:
        train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                       is_training=True,
                                       input_schema=app.input_schema,
                                       batch_size=app.train_batch_size)
        eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                      is_training=False,
                                      input_schema=app.input_schema,
                                      batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)
    else:
        train_reader = CSVReader(input_glob=app.train_input_fp,
                                 is_training=True,
                                 input_schema=app.input_schema,
                                 batch_size=app.train_batch_size)
        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=False,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)
Exemple #2
0
def main(_):
    app = Application()
    if FLAGS.mode == "train_and_evaluate_on_the_fly":
        if "PAI" in tf.__version__:
            train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                           is_training=True,
                                           input_schema=app.input_schema,
                                           batch_size=app.train_batch_size)
            eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                          is_training=False,
                                          input_schema=app.input_schema,
                                          batch_size=app.eval_batch_size)
        else:
            train_reader = CSVReader(input_glob=app.train_input_fp,
                                     is_training=True,
                                     input_schema=app.input_schema,
                                     batch_size=app.train_batch_size)
            eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)

    if FLAGS.mode == "predict_on_the_fly":
        if "PAI" in tf.__version__:
            pred_reader = OdpsTableReader(input_glob=app.predict_input_fp,
                                          input_schema=app.input_schema,
                                          is_training=False,
                                          batch_size=app.predict_batch_size)
            pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp,
                                          output_schema=app.output_schema,
                                          slice_id=0,
                                          input_queue=None)
        else:
            pred_reader = CSVReader(input_glob=app.predict_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.predict_batch_size)
            pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                                    output_schema=app.output_schema)
        app.run_predict(reader=pred_reader,
                        writer=pred_writer,
                        checkpoint_path=app.predict_checkpoint_path)
Exemple #3
0
def main(_):
    app = Application()
    if FLAGS.usePAI:
        train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                       is_training=True,
                                       input_schema=app.input_schema,
                                       batch_size=app.train_batch_size)
        eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                      is_training=False,
                                      input_schema=app.input_schema,
                                      batch_size=app.eval_batch_size)
    else:
        train_reader = CSVReader(input_glob=app.train_input_fp,
                                 is_training=True,
                                 input_schema=app.input_schema,
                                 batch_size=app.train_batch_size)
        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=False,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)

    app.run_train_and_evaluate(train_reader=train_reader,
                               eval_reader=eval_reader)
def main(_):

    # load domain and class dict
    domains = aFLAGS.domains.split(",")
    classes = aFLAGS.classes.split(",")

    if aFLAGS.do_sent_pair:
        app = SentPairClsApplication()
    else:
        app = SentClsApplication()

    # create empty centroids
    domain_class_embeddings = dict()
    for domain_name in domains:
        for class_name in classes:
            key_name = domain_name + "\t" + class_name
            domain_class_embeddings[key_name] = list()

    # for training data
    if FLAGS.usePAI:
        predict_reader = OdpsTableReader(input_glob=app.predict_input_fp,
                                         is_training=False,
                                         input_schema=app.input_schema,
                                         batch_size=app.predict_batch_size)
    else:
        predict_reader = CSVReader(input_glob=app.predict_input_fp,
                                   is_training=False,
                                   input_schema=app.input_schema,
                                   batch_size=app.predict_batch_size)

    # do inference for training data
    temp_output_data = list()

    for output in app.run_predict(reader=predict_reader,
                                  checkpoint_path=app.predict_checkpoint_path):
        current_size = len(output["pool_output"])
        for i in range(current_size):
            if FLAGS.do_sent_pair:
                pool_output = output["pool_output"][i]
                text1 = output["text1"][i]
                text2 = output["text2"][i]
                domain = output["domain"][i]
                label = output["label"][i]
                key_name = domain.decode('utf-8') + "\t" + label.decode(
                    'utf-8')
                domain_class_embeddings[key_name].append(pool_output)
                temp_output_data.append(
                    (text1, text2, domain, label, pool_output))
            else:
                pool_output = output["pool_output"][i]
                text = output["text"][i]
                domain = output["domain"][i]
                label = output["label"][i]
                key_name = domain.decode('utf-8') + "\t" + label.decode(
                    'utf-8')
                domain_class_embeddings[key_name].append(pool_output)
                temp_output_data.append((text, domain, label, pool_output))

    # compute centroids
    centroid_embeddings = dict()
    for key_name in domain_class_embeddings:
        domain_class_data_embeddings = np.array(
            domain_class_embeddings[key_name])
        centroid_embeddings[key_name] = np.mean(domain_class_data_embeddings,
                                                axis=0)

    # output files for meta fine-tune
    if FLAGS.usePAI:
        #write odps tables
        records = []
        if aFLAGS.do_sent_pair:
            for text1, text2, domain, label, embeddings in temp_output_data:
                weight = compute_weight(domain, label, embeddings,
                                        centroid_embeddings)
                tup = (text1, text2, str(domains.index(domain)), label,
                       np.around(weight, decimals=5))
                records.append(tup)
        else:
            for text, domain, label, embeddings in temp_output_data:
                weight = compute_weight(domain, label, embeddings,
                                        centroid_embeddings)
                tup = (text, str(domains.index(domain)), label,
                       np.around(weight, decimals=5))
                records.append(tup)

        with tf.python_io.TableWriter(FLAGS.outputs) as writer:
            if aFLAGS.do_sent_pair:
                indices = list(x for x in range(0, 5))
            else:
                indices = list(x for x in range(0, 4))
            writer.write(records, indices)
    else:
        #write to local file
        with open(FLAGS.outputs, 'w+') as f:
            if aFLAGS.do_sent_pair:
                for text1, text2, domain, label, embeddings in temp_output_data:
                    weight = compute_weight(domain, label, embeddings,
                                            centroid_embeddings)
                    f.write(text1 + '\t' + text2 + '\t' +
                            str(domains.index(domain)) + '\t' + label + '\t' +
                            np.around(weight, decimals=5).astype('str') + '\n')
            else:
                for text, domain, label, embeddings in temp_output_data:
                    weight = compute_weight(domain, label, embeddings,
                                            centroid_embeddings)
                    f.write(text + '\t' + str(domains.index(domain)) + '\t' +
                            label + '\t' +
                            np.around(weight, decimals=5).astype('str') + '\n')
Exemple #5
0
def run(mode):
    if FLAGS.config is None:
        config_json = {
            "model_type": _APP_FLAGS.model_type,
            "vocab_size": _APP_FLAGS.vocab_size,
            "hidden_size": _APP_FLAGS.hidden_size,
            "intermediate_size": _APP_FLAGS.intermediate_size,
            "num_hidden_layers": _APP_FLAGS.num_hidden_layers,
            "max_position_embeddings": 512,
            "num_attention_heads": _APP_FLAGS.num_attention_heads,
            "type_vocab_size": 2
        }

        if not tf.gfile.Exists(_APP_FLAGS.model_dir):
            tf.gfile.MkDir(_APP_FLAGS.model_dir)

        # Pretrain from scratch
        if _APP_FLAGS.pretrain_model_name_or_path is None:
            if not tf.gfile.Exists(_APP_FLAGS.model_dir + "/config.json"):
                with tf.gfile.GFile(_APP_FLAGS.model_dir + "/config.json",
                                    mode='w') as f:
                    json.dump(config_json, f)

            shutil.copy2(_APP_FLAGS.vocab_fp, _APP_FLAGS.model_dir)
            if _APP_FLAGS.spm_model_fp is not None:
                shutil.copy2(_APP_FLAGS.spm_model_fp, _APP_FLAGS.model_dir)

        config = PretrainConfig()
        if _APP_FLAGS.do_multitaks_pretrain:
            app = PretrainMultitask(user_defined_config=config)
        else:
            app = Pretrain(user_defined_config=config)
    else:
        if _APP_FLAGS.do_multitaks_pretrain:
            app = PretrainMultitask()
        else:
            app = Pretrain()

    if "train" in mode:
        if _APP_FLAGS.data_reader == 'tfrecord':
            train_reader = BundleTFRecordReader(
                input_glob=app.train_input_fp,
                is_training=True,
                shuffle_buffer_size=1024,
                worker_hosts=FLAGS.worker_hosts,
                task_index=FLAGS.task_index,
                input_schema=app.input_schema,
                batch_size=app.train_batch_size)
        elif _APP_FLAGS.data_reader == 'odps':
            tf.logging.info("***********Reading Odps Table *************")
            worker_id = FLAGS.task_index
            num_workers = len(FLAGS.worker_hosts.split(","))

            train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                           is_training=True,
                                           shuffle_buffer_size=1024,
                                           input_schema=app.input_schema,
                                           slice_id=worker_id,
                                           slice_count=num_workers,
                                           batch_size=app.train_batch_size)

    if mode == "train_and_evaluate":
        if _APP_FLAGS.data_reader == 'tfrecord':
            eval_reader = BundleTFRecordReader(input_glob=app.eval_input_fp,
                                               is_training=False,
                                               shuffle_buffer_size=1024,
                                               worker_hosts=FLAGS.worker_hosts,
                                               task_index=FLAGS.task_index,
                                               input_schema=app.input_schema,
                                               batch_size=app.eval_batch_size)

        elif _APP_FLAGS.data_reader == 'odps':
            eval_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                          is_training=False,
                                          shuffle_buffer_size=1024,
                                          input_schema=app.input_schema,
                                          slice_id=worker_id,
                                          slice_count=num_workers,
                                          batch_size=app.train_batch_size)

        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)

    elif mode == "train":
        app.run_train(reader=train_reader)

    elif mode == "evaluate":
        if _APP_FLAGS.data_reader == 'tfrecord':
            eval_reader = BundleTFRecordReader(input_glob=app.eval_input_fp,
                                               is_training=False,
                                               shuffle_buffer_size=1024,
                                               worker_hosts=FLAGS.worker_hosts,
                                               task_index=FLAGS.task_index,
                                               input_schema=app.input_schema,
                                               batch_size=app.eval_batch_size)

        elif _APP_FLAGS.data_reader == 'odps':
            eval_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                          is_training=False,
                                          shuffle_buffer_size=1024,
                                          input_schema=app.input_schema,
                                          slice_id=worker_id,
                                          slice_count=num_workers,
                                          batch_size=app.train_batch_size)

        ckpts = set()
        with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"),
                            mode='r') as reader:
            for line in reader:
                line = line.strip()
                line = line.replace("oss://", "")
                ckpts.add(
                    int(
                        line.split(":")[1].strip().replace(
                            "\"",
                            "").split("/")[-1].replace("model.ckpt-", "")))

        ckpts.remove(0)
        writer = tf.summary.FileWriter(
            os.path.join(app.config.model_dir, "eval_output"))
        for ckpt in sorted(ckpts):
            checkpoint_path = os.path.join(app.config.model_dir,
                                           "model.ckpt-" + str(ckpt))
            tf.logging.info("checkpoint_path is {}".format(checkpoint_path))
            ret_metrics = app.run_evaluate(reader=eval_reader,
                                           checkpoint_path=checkpoint_path)

            global_step = ret_metrics['global_step']

            eval_masked_lm_accuracy = tf.Summary()
            eval_masked_lm_accuracy.value.add(
                tag='masked_lm_valid_accuracy',
                simple_value=ret_metrics['eval_masked_lm_accuracy'])

            eval_masked_lm_loss = tf.Summary()
            eval_masked_lm_loss.value.add(
                tag='masked_lm_valid_loss',
                simple_value=ret_metrics['eval_masked_lm_loss'])

            writer.add_summary(eval_masked_lm_accuracy, global_step)
            writer.add_summary(eval_masked_lm_loss, global_step)

        writer.close()