Ejemplo n.º 1
0
def main(_):
    hparams = train_mask_gan.create_hparams()
    log_dir = FLAGS.base_directory

    tf.gfile.MakeDirs(FLAGS.output_path)
    output_file = tf.gfile.GFile(os.path.join(FLAGS.output_path,
                                              'reviews.txt'),
                                 mode='w')

    # Load data set.
    if FLAGS.data_set == 'ptb':
        raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
        train_data, valid_data, _, _ = raw_data
    elif FLAGS.data_set == 'imdb':
        raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
        train_data, valid_data = raw_data
    else:
        raise NotImplementedError

    # Generating more data on train set.
    if FLAGS.sample_mode == SAMPLE_TRAIN:
        data_set = train_data
    elif FLAGS.sample_mode == SAMPLE_VALIDATION:
        data_set = valid_data
    else:
        raise NotImplementedError

    # Dictionary and reverse dictionry.
    if FLAGS.data_set == 'ptb':
        word_to_id = ptb_loader.build_vocab(
            os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
    elif FLAGS.data_set == 'imdb':
        word_to_id = imdb_loader.build_vocab(
            os.path.join(FLAGS.data_dir, 'vocab.txt'))
    id_to_word = {v: k for k, v in word_to_id.iteritems()}

    FLAGS.vocab_size = len(id_to_word)
    print('Vocab size: %d' % FLAGS.vocab_size)

    generate_samples(hparams, data_set, id_to_word, log_dir, output_file)
Ejemplo n.º 2
0
def main(_):
  hparams = train_mask_gan.create_hparams()
  log_dir = FLAGS.base_directory

  tf.gfile.MakeDirs(FLAGS.output_path)
  output_file = tf.gfile.GFile(
      os.path.join(FLAGS.output_path, 'reviews.txt'), mode='w')

  # Load data set.
  if FLAGS.data_set == 'ptb':
    raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
    train_data, valid_data, _, _ = raw_data
  elif FLAGS.data_set == 'imdb':
    raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
    train_data, valid_data = raw_data
  else:
    raise NotImplementedError

  # Generating more data on train set.
  if FLAGS.sample_mode == SAMPLE_TRAIN:
    data_set = train_data
  elif FLAGS.sample_mode == SAMPLE_VALIDATION:
    data_set = valid_data
  else:
    raise NotImplementedError

  # Dictionary and reverse dictionry.
  if FLAGS.data_set == 'ptb':
    word_to_id = ptb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
  elif FLAGS.data_set == 'imdb':
    word_to_id = imdb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'vocab.txt'))
  id_to_word = {v: k for k, v in word_to_id.iteritems()}

  FLAGS.vocab_size = len(id_to_word)
  print('Vocab size: %d' % FLAGS.vocab_size)

  generate_samples(hparams, data_set, id_to_word, log_dir, output_file)
def main(_):
  hparams = create_hparams()
  train_dir = FLAGS.base_directory + '/train'

  # Load data set.
  if FLAGS.data_set == 'ptb':
    raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
    train_data, valid_data, test_data, _ = raw_data
    valid_data_flat = valid_data
  elif FLAGS.data_set == 'imdb':
    raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
    # TODO(liamfedus): Get an IMDB test partition.
    train_data, valid_data = raw_data
    valid_data_flat = [word for review in valid_data for word in review]
  else:
    raise NotImplementedError

  if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_TRAIN_EVAL:
    data_set = train_data
  elif FLAGS.mode == MODE_VALIDATION:
    data_set = valid_data
  elif FLAGS.mode == MODE_TEST:
    data_set = test_data
  else:
    raise NotImplementedError

  # Dictionary and reverse dictionry.
  if FLAGS.data_set == 'ptb':
    word_to_id = ptb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
  elif FLAGS.data_set == 'imdb':
    word_to_id = imdb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'vocab.txt'))
  id_to_word = {v: k for k, v in word_to_id.items()}

  # Dictionary of Training Set n-gram counts.
  bigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=2)
  trigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=3)
  fourgram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=4)

  bigram_counts = n_gram.construct_ngrams_dict(bigram_tuples)
  trigram_counts = n_gram.construct_ngrams_dict(trigram_tuples)
  fourgram_counts = n_gram.construct_ngrams_dict(fourgram_tuples)
  print('Unique %d-grams: %d' % (2, len(bigram_counts)))
  print('Unique %d-grams: %d' % (3, len(trigram_counts)))
  print('Unique %d-grams: %d' % (4, len(fourgram_counts)))

  data_ngram_counts = {
      '2': bigram_counts,
      '3': trigram_counts,
      '4': fourgram_counts
  }

  # TODO(liamfedus):  This was necessary because there was a problem with our
  # originally trained IMDB models.  The EOS_INDEX was off by one, which means,
  # two words were mapping to index 86933.  The presence of '</s>' is going
  # to throw and out of vocabulary error.
  FLAGS.vocab_size = len(id_to_word)
  print('Vocab size: %d' % FLAGS.vocab_size)

  tf.gfile.MakeDirs(FLAGS.base_directory)

  if FLAGS.mode == MODE_TRAIN:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'train-log.txt'), mode='w')
  elif FLAGS.mode == MODE_VALIDATION:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'validation-log.txt'), mode='w')
  elif FLAGS.mode == MODE_TRAIN_EVAL:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'train_eval-log.txt'), mode='w')
  else:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'test-log.txt'), mode='w')

  if FLAGS.mode == MODE_TRAIN:
    train_model(hparams, data_set, train_dir, log, id_to_word,
                data_ngram_counts)

  elif FLAGS.mode == MODE_VALIDATION:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)
  elif FLAGS.mode == MODE_TRAIN_EVAL:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)

  elif FLAGS.mode == MODE_TEST:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)

  else:
    raise NotImplementedError
Ejemplo n.º 4
0
def main(_):
  hparams = create_hparams()
  train_dir = FLAGS.base_directory + '/train'

  # Load data set.
  if FLAGS.data_set == 'ptb':
    raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
    train_data, valid_data, test_data, _ = raw_data
    valid_data_flat = valid_data
  elif FLAGS.data_set == 'imdb':
    raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
    # TODO(liamfedus): Get an IMDB test partition.
    train_data, valid_data = raw_data
    valid_data_flat = [word for review in valid_data for word in review]
  else:
    raise NotImplementedError

  if FLAGS.mode == MODE_TRAIN or FLAGS.mode == MODE_TRAIN_EVAL:
    data_set = train_data
  elif FLAGS.mode == MODE_VALIDATION:
    data_set = valid_data
  elif FLAGS.mode == MODE_TEST:
    data_set = test_data
  else:
    raise NotImplementedError

  # Dictionary and reverse dictionry.
  if FLAGS.data_set == 'ptb':
    word_to_id = ptb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
  elif FLAGS.data_set == 'imdb':
    word_to_id = imdb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'vocab.txt'))
  id_to_word = {v: k for k, v in word_to_id.iteritems()}

  # Dictionary of Training Set n-gram counts.
  bigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=2)
  trigram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=3)
  fourgram_tuples = n_gram.find_all_ngrams(valid_data_flat, n=4)

  bigram_counts = n_gram.construct_ngrams_dict(bigram_tuples)
  trigram_counts = n_gram.construct_ngrams_dict(trigram_tuples)
  fourgram_counts = n_gram.construct_ngrams_dict(fourgram_tuples)
  print('Unique %d-grams: %d' % (2, len(bigram_counts)))
  print('Unique %d-grams: %d' % (3, len(trigram_counts)))
  print('Unique %d-grams: %d' % (4, len(fourgram_counts)))

  data_ngram_counts = {
      '2': bigram_counts,
      '3': trigram_counts,
      '4': fourgram_counts
  }

  # TODO(liamfedus):  This was necessary because there was a problem with our
  # originally trained IMDB models.  The EOS_INDEX was off by one, which means,
  # two words were mapping to index 86933.  The presence of '</s>' is going
  # to throw and out of vocabulary error.
  FLAGS.vocab_size = len(id_to_word)
  print('Vocab size: %d' % FLAGS.vocab_size)

  tf.gfile.MakeDirs(FLAGS.base_directory)

  if FLAGS.mode == MODE_TRAIN:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'train-log.txt'), mode='w')
  elif FLAGS.mode == MODE_VALIDATION:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'validation-log.txt'), mode='w')
  elif FLAGS.mode == MODE_TRAIN_EVAL:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'train_eval-log.txt'), mode='w')
  else:
    log = tf.gfile.GFile(
        os.path.join(FLAGS.base_directory, 'test-log.txt'), mode='w')

  if FLAGS.mode == MODE_TRAIN:
    train_model(hparams, data_set, train_dir, log, id_to_word,
                data_ngram_counts)

  elif FLAGS.mode == MODE_VALIDATION:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)
  elif FLAGS.mode == MODE_TRAIN_EVAL:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)

  elif FLAGS.mode == MODE_TEST:
    evaluate_model(hparams, data_set, train_dir, log, id_to_word,
                   data_ngram_counts)

  else:
    raise NotImplementedError