Ejemplo n.º 1
0
def main(_):
    hparams = train_mask_gan.create_hparams()
    log_dir = FLAGS.base_directory

    tf.gfile.MakeDirs(FLAGS.output_path)
    output_file = tf.gfile.GFile(os.path.join(FLAGS.output_path, 'output.txt'),
                                 mode='w')

    # Load data set.
    raw_data = spec_loader.spec_raw_data(FLAGS.data_dir)
    train_data, valid_data = raw_data

    # Generating more data on train set.
    if FLAGS.sample_mode == SAMPLE_TRAIN:
        data_set = train_data
    elif FLAGS.sample_mode == SAMPLE_VALIDATION:
        data_set = valid_data
    else:
        raise NotImplementedError

    # Dictionary and reverse dictionry.
    word_to_id = spec_loader.build_vocab()
    id_to_word = {v: k for k, v in word_to_id.items()}

    FLAGS.vocab_size = len(id_to_word)
    print('Vocab size: %d' % FLAGS.vocab_size)

    generate_samples(hparams, data_set, id_to_word, log_dir, output_file)
Ejemplo n.º 2
0
def main(_):
  hparams = train_mask_gan.create_hparams()
  log_dir = FLAGS.base_directory

  tf.gfile.MakeDirs(FLAGS.output_path)
  output_file = tf.gfile.GFile(
      os.path.join(FLAGS.output_path, 'reviews.txt'), mode='w')

  # Load data set.
  if FLAGS.data_set == 'ptb':
    raw_data = ptb_loader.ptb_raw_data(FLAGS.data_dir)
    train_data, valid_data, _, _ = raw_data
  elif FLAGS.data_set == 'imdb':
    raw_data = imdb_loader.imdb_raw_data(FLAGS.data_dir)
    train_data, valid_data = raw_data
  else:
    raise NotImplementedError

  # Generating more data on train set.
  if FLAGS.sample_mode == SAMPLE_TRAIN:
    data_set = train_data
  elif FLAGS.sample_mode == SAMPLE_VALIDATION:
    data_set = valid_data
  else:
    raise NotImplementedError

  # Dictionary and reverse dictionry.
  if FLAGS.data_set == 'ptb':
    word_to_id = ptb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'ptb.train.txt'))
  elif FLAGS.data_set == 'imdb':
    word_to_id = imdb_loader.build_vocab(
        os.path.join(FLAGS.data_dir, 'vocab.txt'))
  id_to_word = {v: k for k, v in word_to_id.iteritems()}

  FLAGS.vocab_size = len(id_to_word)
  print('Vocab size: %d' % FLAGS.vocab_size)

  generate_samples(hparams, data_set, id_to_word, log_dir, output_file)