예제 #1
0
def evaluate_and_log_bleu(estimator, bleu_source, bleu_ref, vocab_file):
  """Calculate and record the BLEU score."""
  subtokenizer = tokenizer.Subtokenizer(vocab_file)
  tf.logging.info(subtokenizer.subtoken_list)
  uncased_score, cased_score = translate_and_compute_bleu(
      estimator, subtokenizer, bleu_source, bleu_ref)

  tf.logging.info("Bleu score (uncased): %f", uncased_score)
  tf.logging.info("Bleu score (cased): %f", cased_score)
  return uncased_score, cased_score
예제 #2
0
def main(unused_argv):
  from official.transformer import transformer_main_triblock_cls as transformer_main

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.text is None and FLAGS.file is None:
    tf.logging.warn("Nothing to translate. Make sure to call this script using "
                    "flags --text or --file.")
    return

  underscored_ids = FLAGS.underscored_ids.split(",")
  underscored_ids = [int(idx) for idx in underscored_ids]

  
  subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)
  #tf.logging.info(len(subtokenizer.subtoken_list[:]))
  #tf.logging.info(subtokenizer.subtoken_list[:])

  # Set up estimator and params
  params = transformer_main.PARAMS_MAP[FLAGS.param_set]
  params["underscored_ids"]=underscored_ids
  params["vocab_file"]=FLAGS.vocab_file
  params["beam_size"] = _BEAM_SIZE
  params["alpha"] = _ALPHA
  params["extra_decode_length"] = _EXTRA_DECODE_LENGTH
  params["batch_size"] = _DECODE_BATCH_SIZE
  estimator = tf.estimator.Estimator(
      model_fn=transformer_main.model_fn, model_dir=FLAGS.model_dir,
      params=params)

  if FLAGS.text is not None:
    tf.logging.info("Translating text: %s" % FLAGS.text)
    translate_text(estimator, subtokenizer, FLAGS.text)

  if FLAGS.file is not None:
    input_file = os.path.abspath(FLAGS.file)
    tf.logging.info("Translating file: %s" % input_file)
    if not tf.gfile.Exists(FLAGS.file):
      raise ValueError("File does not exist: %s" % input_file)

    output_file = None
    if FLAGS.file_out is not None:
      output_file = os.path.abspath(FLAGS.file_out)
      tf.logging.info("File output specified: %s" % output_file)

    translate_file(estimator, subtokenizer, input_file, output_file)
예제 #3
0
def main(unused_argv):

    subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)

    test_set = glob.glob(FLAGS.data_dir + "/" + "*.test.*.json")
    total_set = {"test": test_set}

    print(total_set)
    for mode, _set in total_set.items():
        writer = open(FLAGS.data_dir + "/" + "test_article_cls.txt", "w")
        writer2 = open(FLAGS.data_dir + "/" + "test_article_cls_answer.txt",
                       "w")
        mode_count = 0

        for file_name in _set:
            with open(file_name) as f:
                lines = f.readlines()

                for line in lines:
                    instances = json.loads(line)
                    for inst_index, instance in enumerate(instances):

                        for i, src in enumerate(instance['src']):

                            writer.write("CLS " + " ".join(src) + " ")
                        writer.write("\n")

                        for i, tgt in enumerate(instance['tgt']):

                            writer2.write(" " + " ".join(tgt) + " . ")
                        writer2.write("\n")

                        mode_count += 1
    print(mode + " : " + str(mode_count))
    writer.close()
    writer2.close()
def main(unused_argv):

    subtokenizer = tokenizer.Subtokenizer(FLAGS.vocab_file)
    print(subtokenizer.subtoken_list)

    train_set = glob.glob(FLAGS.data_dir+"/"+"*.train.*.json")
    valid_set = glob.glob(FLAGS.data_dir+"/"+"*.valid.*.json")
    test_set = glob.glob(FLAGS.data_dir+"/"+"*.test.*.json")
    total_set = {"train":train_set,"valid":valid_set,"test":test_set}
    
    #print(total_set)
    print_switch=True
    for mode,_set in total_set.items():
        writer = tf.python_io.TFRecordWriter(FLAGS.data_dir+"/"+mode+".tfrecord")
        mode_count = 0
        
        for file_name in _set:
            with open(file_name) as f:
                lines=f.readlines()
                for line in lines:
                    instances= json.loads(line)
                    for inst_index,instance in enumerate(instances):
                        

                        src_list=[]
                        src_sep_list=[]
                        src_cls_mask=[]
                        
                        
                        for i,src in enumerate(instance['src']):
                            src_line=" ".join(src)
                            src_line= src_line.replace("-RRB-",")")
                            src_line= src_line.replace("-LRB-","(")
                            src_line= src_line.replace("-RSB-","]")
                            src_line= src_line.replace("-LSB-","[")
                            src_line= src_line.replace("-RCB-","}")
                            src_line= src_line.replace("-LCB-","{")
                            src_ids = _encode_and_add_eos(src_line,subtokenizer)

                            src_list.extend(src_ids)
                            src_sep_list.extend([i]*len(src_ids))
                            src_cls_mask.extend([1]+[0]*(len(src_ids)-1))


                        src_list+= [tokenizer.EOS_ID]
                        src_sep_list+=[i+1]
                        src_cls_mask+=[0]

                        tgt_list = []
                        tgt_sep_list = []
                        tgt_cls_mask = []
                        

                        for i,tgt in enumerate(instance['tgt']):
                            tgt+=["."] #It doesn't have punctuation 
                            tgt_line=" ".join(tgt)
                            tgt_line= tgt_line.replace("-RRB-",")")
                            tgt_line= tgt_line.replace("-LRB-","(")
                            tgt_line= tgt_line.replace("-RSB-","]")
                            tgt_line= tgt_line.replace("-LSB-","[")
                            tgt_line= tgt_line.replace("-RCB-","}")
                            tgt_line= tgt_line.replace("-LCB-","{")

                            tgt_ids = _encode_and_add_eos(tgt_line,subtokenizer)

                            tgt_list.extend(tgt_ids)       
                            tgt_sep_list.extend([i]*len(tgt_ids))
                            tgt_cls_mask.extend([1]+[0]*(len(tgt_ids)-1))

                            
                        tgt_list+= [tokenizer.EOS_ID]
                        tgt_sep_list+=[i+1]
                        tgt_cls_mask += [0]


                        features = collections.OrderedDict()
                        features["inputs"] = create_int_feature(src_list)
                        features["targets"] = create_int_feature(tgt_list)
                        features["input_seps"] = create_int_feature(src_sep_list)
                        features["target_seps"] = create_int_feature(tgt_sep_list)
                        features["input_cls_mask"] = create_int_feature(src_cls_mask)
                        features["target_cls_mask"] = create_int_feature(tgt_cls_mask)

                        tf_example = tf.train.Example(features=tf.train.Features(feature=features))

                        writer.write(tf_example.SerializeToString())

                        #print
                        
                        if inst_index < 10 and print_switch:
                            tf.logging.info("*** Example ***")
                            tf.logging.info("*** INPUT ***")
                            tf.logging.info(_trim_and_decode(src_list,subtokenizer))
                            tf.logging.info("*** TARGET ***")
                            tf.logging.info(_trim_and_decode(tgt_list,subtokenizer))
                            for feature_name in features.keys():
                                feature = features[feature_name]
                                values = []
                                if feature.int64_list.value:
                                    values = feature.int64_list.value
                                elif feature.float_list.value:
                                    values = feature.float_list.value
                                tf.logging.info("%s: %s" % (feature_name, " ".join([str(x) for x in values])))

                        mode_count+=1
                    print_switch=False

        print(mode+" : "+str(mode_count))
        writer.close()
    def predict(self, encoder_outputs, encoder_decoder_attention_bias):
        """Return predicted sequence."""
        batch_size = tf.shape(encoder_outputs)[0]
        input_length = tf.shape(encoder_outputs)[1]

        max_decode_length = self.params["max_output_length"]

        symbols_to_logits_fn = self._get_symbols_to_logits_fn(
            max_decode_length)

        # Create initial set of IDs that will be passed into symbols_to_logits_fn.
        initial_ids = tf.zeros([batch_size], dtype=tf.int32)

        # Create cache storing decoder attention values for each layer.
        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
                "v": tf.zeros([batch_size, 0, self.params["hidden_size"]]),
            }
            for layer in range(self.params["num_hidden_layers"])
        }

        # Add encoder output and attention bias to the cache.
        cache["encoder_outputs"] = encoder_outputs
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        ####domyoung 2019.12.9#####
        nontrigrams = nontrigram_generator(max_decode_length,
                                           self.params["underscored_ids"])
        nontrigrams = tf.constant(nontrigrams, dtype=tf.int32)
        tile_dims = [1] * nontrigrams.shape.ndims
        tile_dims[-1] = batch_size * self.params["beam_size"]
        nontrigrams = tf.tile(nontrigrams, tile_dims)
        nontrigrams = tf.reshape(nontrigrams, [-1, max_decode_length])
        subtokenizer = tokenizer.Subtokenizer(self.params["vocab_file"])
        key = tf.range(self.params["vocab_size"], dtype=tf.int32)
        print(key)
        #replace the first token '<pad>_' into '<pad>'
        subtoken_list = subtokenizer.subtoken_list[:]
        subtoken_list[2] = 'CLS'
        #subtoken_list[3] = 'END'
        #subtoken_list[8439]=' .'
        tf.logging.info(len(subtoken_list))
        tf.logging.info(subtoken_list)
        value = tf.constant(subtoken_list, dtype=tf.string)
        default_value = tf.constant("", dtype=tf.string)
        hashTable = tf.contrib.lookup.HashTable(
            tf.contrib.lookup.KeyValueTensorInitializer(key, value),
            default_value)
        # Use beam search to find the top beam_size sequences and scores.

        decoded_ids, scores = beam_search.sequence_beam_search(
            symbols_to_logits_fn=symbols_to_logits_fn,
            initial_ids=initial_ids,
            initial_cache=cache,
            vocab_size=self.params["vocab_size"],
            hashTable=hashTable,
            nontrigrams=nontrigrams,
            beam_size=self.params["beam_size"],
            batch_size=self.params["batch_size"],
            alpha=self.params["alpha"],
            max_decode_length=max_decode_length,
            use_trigram=True,
            eos_id=EOS_ID)

        #########################
        # Get the top sequence for each batch element
        top_decoded_ids = decoded_ids[:, 0, 1:]
        top_scores = scores[:, 0]

        return {"outputs": top_decoded_ids, "scores": top_scores}