예제 #1
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  vocab_freqs = defaultdict(int)
  doc_counts = defaultdict(int)

  # Fill vocabulary frequencies map and document counts map
  for doc in document_generators.documents(
      dataset='train',
      include_unlabeled=FLAGS.use_unlabeled,
      include_validation=FLAGS.include_validation):
    fill_vocab_from_doc(doc, vocab_freqs, doc_counts)

  # Filter out low-occurring terms
  vocab_freqs = dict((term, freq) for term, freq in vocab_freqs.iteritems()
                     if doc_counts[term] > FLAGS.doc_count_threshold)

  # Sort by frequency
  ordered_vocab_freqs = data_utils.sort_vocab_by_frequency(vocab_freqs)

  # Limit vocab size
  ordered_vocab_freqs = ordered_vocab_freqs[:MAX_VOCAB_SIZE]

  # Add EOS token
  ordered_vocab_freqs.append((data_utils.EOS_TOKEN, 1))

  # Write
  tf.gfile.MakeDirs(FLAGS.output_dir)
  data_utils.write_vocab_and_frequency(ordered_vocab_freqs, FLAGS.output_dir)
예제 #2
0
파일: gen_vocab.py 프로젝트: rder96/models
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)
  vocab_freqs = defaultdict(int)
  doc_counts = defaultdict(int)

  # Fill vocabulary frequencies map and document counts map
  for doc in document_generators.documents(
      dataset='train',
      include_unlabeled=FLAGS.use_unlabeled,
      include_validation=FLAGS.include_validation):
    fill_vocab_from_doc(doc, vocab_freqs, doc_counts)

  # Filter out low-occurring terms
  vocab_freqs = dict((term, freq) for term, freq in iteritems(vocab_freqs)
                     if doc_counts[term] > FLAGS.doc_count_threshold)

  # Sort by frequency
  ordered_vocab_freqs = data_utils.sort_vocab_by_frequency(vocab_freqs)

  # Limit vocab size
  ordered_vocab_freqs = ordered_vocab_freqs[:MAX_VOCAB_SIZE]

  # Add EOS token
  ordered_vocab_freqs.append((data_utils.EOS_TOKEN, 1))

  # Write
  tf.gfile.MakeDirs(FLAGS.output_dir)
  data_utils.write_vocab_and_frequency(ordered_vocab_freqs, FLAGS.output_dir)
예제 #3
0
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
    """Generates training data."""

    # Construct training data writers
    writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM)
    writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA)
    writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS)
    writer_valid_class = build_tf_record_writer(data.VALID_CLASS)
    writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM)
    writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS)
    writer_bd_valid_class = build_shuffling_tf_record_writer(
        data.VALID_BD_CLASS)

    for doc in document_generators.documents(dataset='train',
                                             include_unlabeled=True,
                                             include_validation=True):
        input_seq = build_input_sequence(doc, vocab_ids)
        if len(input_seq) < 2:
            continue
        rev_seq = data.build_reverse_sequence(input_seq)
        lm_seq = data.build_lm_sequence(input_seq)
        rev_lm_seq = data.build_lm_sequence(rev_seq)
        seq_ae_seq = data.build_seq_ae_sequence(input_seq)
        if doc.label is not None:
            # Used for sentiment classification.
            label_seq = data.build_labeled_sequence(
                input_seq,
                doc.label,
                label_gain=(FLAGS.label_gain and not doc.is_validation))
            bd_label_seq = data.build_labeled_sequence(
                data.build_bidirectional_seq(input_seq, rev_seq),
                doc.label,
                label_gain=(FLAGS.label_gain and not doc.is_validation))
            class_writer = writer_valid_class if doc.is_validation else writer_class
            bd_class_writer = (writer_bd_valid_class
                               if doc.is_validation else writer_bd_class)
            class_writer.write(label_seq.seq.SerializeToString())
            bd_class_writer.write(bd_label_seq.seq.SerializeToString())

        # Write
        lm_seq_ser = lm_seq.seq.SerializeToString()
        seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
        writer_lm_all.write(lm_seq_ser)
        writer_seq_ae_all.write(seq_ae_seq_ser)
        if not doc.is_validation:
            writer_lm.write(lm_seq_ser)
            writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
            writer_seq_ae.write(seq_ae_seq_ser)

    # Close writers
    writer_lm.close()
    writer_seq_ae.close()
    writer_class.close()
    writer_valid_class.close()
    writer_rev_lm.close()
    writer_bd_class.close()
    writer_bd_valid_class.close()
예제 #4
0
def generate_training_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
  """Generates training data."""

  # Construct training data writers
  writer_lm = build_shuffling_tf_record_writer(data.TRAIN_LM)
  writer_seq_ae = build_shuffling_tf_record_writer(data.TRAIN_SA)
  writer_class = build_shuffling_tf_record_writer(data.TRAIN_CLASS)
  writer_valid_class = build_tf_record_writer(data.VALID_CLASS)
  writer_rev_lm = build_shuffling_tf_record_writer(data.TRAIN_REV_LM)
  writer_bd_class = build_shuffling_tf_record_writer(data.TRAIN_BD_CLASS)
  writer_bd_valid_class = build_shuffling_tf_record_writer(data.VALID_BD_CLASS)

  for doc in document_generators.documents(
      dataset='train', include_unlabeled=True, include_validation=True):
    input_seq = build_input_sequence(doc, vocab_ids)
    if len(input_seq) < 2:
      continue
    rev_seq = data.build_reverse_sequence(input_seq)
    lm_seq = data.build_lm_sequence(input_seq)
    rev_lm_seq = data.build_lm_sequence(rev_seq)
    seq_ae_seq = data.build_seq_ae_sequence(input_seq)
    if doc.label is not None:
      # Used for sentiment classification.
      label_seq = data.build_labeled_sequence(
          input_seq,
          doc.label,
          label_gain=(FLAGS.label_gain and not doc.is_validation))
      bd_label_seq = data.build_labeled_sequence(
          data.build_bidirectional_seq(input_seq, rev_seq),
          doc.label,
          label_gain=(FLAGS.label_gain and not doc.is_validation))
      class_writer = writer_valid_class if doc.is_validation else writer_class
      bd_class_writer = (writer_bd_valid_class
                         if doc.is_validation else writer_bd_class)
      class_writer.write(label_seq.seq.SerializeToString())
      bd_class_writer.write(bd_label_seq.seq.SerializeToString())

    # Write
    lm_seq_ser = lm_seq.seq.SerializeToString()
    seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
    writer_lm_all.write(lm_seq_ser)
    writer_seq_ae_all.write(seq_ae_seq_ser)
    if not doc.is_validation:
      writer_lm.write(lm_seq_ser)
      writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
      writer_seq_ae.write(seq_ae_seq_ser)

  # Close writers
  writer_lm.close()
  writer_seq_ae.close()
  writer_class.close()
  writer_valid_class.close()
  writer_rev_lm.close()
  writer_bd_class.close()
  writer_bd_valid_class.close()
예제 #5
0
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
    """Generates test data."""
    # Construct test data writers
    writer_lm = build_shuffling_tf_record_writer(data.TEST_LM)
    writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM)
    writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA)
    writer_class = build_tf_record_writer(data.TEST_CLASS)
    writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS)

    for doc in document_generators.documents(dataset='test',
                                             include_unlabeled=False,
                                             include_validation=True):
        input_seq = build_input_sequence(doc, vocab_ids)
        if len(input_seq) < 2:
            continue
        rev_seq = data.build_reverse_sequence(input_seq)
        lm_seq = data.build_lm_sequence(input_seq)
        rev_lm_seq = data.build_lm_sequence(rev_seq)
        seq_ae_seq = data.build_seq_ae_sequence(input_seq)
        label_seq = data.build_labeled_sequence(input_seq, doc.label)
        bd_label_seq = data.build_labeled_sequence(
            data.build_bidirectional_seq(input_seq, rev_seq), doc.label)

        # Write
        writer_class.write(label_seq.seq.SerializeToString())
        writer_bd_class.write(bd_label_seq.seq.SerializeToString())
        lm_seq_ser = lm_seq.seq.SerializeToString()
        seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
        writer_lm.write(lm_seq_ser)
        writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
        writer_seq_ae.write(seq_ae_seq_ser)
        writer_lm_all.write(lm_seq_ser)
        writer_seq_ae_all.write(seq_ae_seq_ser)

    # Close test writers
    writer_lm.close()
    writer_rev_lm.close()
    writer_seq_ae.close()
    writer_class.close()
    writer_bd_class.close()
예제 #6
0
def generate_test_data(vocab_ids, writer_lm_all, writer_seq_ae_all):
  """Generates test data."""
  # Construct test data writers
  writer_lm = build_shuffling_tf_record_writer(data.TEST_LM)
  writer_rev_lm = build_shuffling_tf_record_writer(data.TEST_REV_LM)
  writer_seq_ae = build_shuffling_tf_record_writer(data.TEST_SA)
  writer_class = build_tf_record_writer(data.TEST_CLASS)
  writer_bd_class = build_shuffling_tf_record_writer(data.TEST_BD_CLASS)

  for doc in document_generators.documents(
      dataset='test', include_unlabeled=False, include_validation=True):
    input_seq = build_input_sequence(doc, vocab_ids)
    if len(input_seq) < 2:
      continue
    rev_seq = data.build_reverse_sequence(input_seq)
    lm_seq = data.build_lm_sequence(input_seq)
    rev_lm_seq = data.build_lm_sequence(rev_seq)
    seq_ae_seq = data.build_seq_ae_sequence(input_seq)
    label_seq = data.build_labeled_sequence(input_seq, doc.label)
    bd_label_seq = data.build_labeled_sequence(
        data.build_bidirectional_seq(input_seq, rev_seq), doc.label)

    # Write
    writer_class.write(label_seq.seq.SerializeToString())
    writer_bd_class.write(bd_label_seq.seq.SerializeToString())
    lm_seq_ser = lm_seq.seq.SerializeToString()
    seq_ae_seq_ser = seq_ae_seq.seq.SerializeToString()
    writer_lm.write(lm_seq_ser)
    writer_rev_lm.write(rev_lm_seq.seq.SerializeToString())
    writer_seq_ae.write(seq_ae_seq_ser)
    writer_lm_all.write(lm_seq_ser)
    writer_seq_ae_all.write(seq_ae_seq_ser)

  # Close test writers
  writer_lm.close()
  writer_rev_lm.close()
  writer_seq_ae.close()
  writer_class.close()
  writer_bd_class.close()