Esempio n. 1
0
def prepare_datasets(dataset_name=gin.REQUIRED,
                     shuffle_input_sentences=False,
                     num_eval_examples=2000,
                     batch_size=32):
    """Create batched, properly-formatted datasets from the TFDS datasets.

  Args:
    dataset_name: Name of TFDS dataset.
    shuffle_input_sentences: Not used during evaluation, but arg still needed
      for gin compatibility.
    num_eval_examples: Number of examples to use during evaluation. For the
      nolabel evaluation, this is also the number of distractors we choose
      between.
    batch_size: Batch size.

  Returns:
    A dictionary mapping from the dataset split to a Dataset object.
  """
    del shuffle_input_sentences

    splits_to_load = {
        'valid_nolabel': 'train[:2%]',
        'train_nolabel': 'train[2%:4%]',
        'valid2018': rocstories_sentence_embeddings.VALIDATION_2018,
        'valid2016': rocstories_sentence_embeddings.VALIDATION_2016
    }

    datasets = tfds.load(dataset_name,
                         data_dir=FLAGS.data_dir,
                         split=splits_to_load,
                         download=False)

    emb_matrices = {}

    valid_nolabel_ds = utils.build_train_style_dataset(
        datasets['valid_nolabel'],
        batch_size,
        False,
        num_examples=num_eval_examples,
        is_training=False)
    datasets['valid_nolabel'], emb_matrices['valid_nolabel'] = valid_nolabel_ds

    train_nolabel_ds = utils.build_train_style_dataset(
        datasets['train_nolabel'],
        batch_size,
        False,
        num_examples=num_eval_examples,
        is_training=False)
    datasets['train_nolabel'], emb_matrices['train_nolabel'] = train_nolabel_ds

    # Convert official evaluation datasets to validation data format. There are no
    # embedding matrices involved here since the task has only two possible next
    # sentences to pick between for each example. Ignore num_eval_examples and use
    # the full datasets for these.
    datasets['valid2018'] = utils.build_validation_dataset(
        datasets['valid2018'])
    datasets['valid2016'] = utils.build_validation_dataset(
        datasets['valid2016'])

    return datasets, emb_matrices
def prepare_dataset(dataset_name=gin.REQUIRED,
                    shuffle_input_sentences=False,
                    num_eval_examples=2000,
                    batch_size=32):
    """Create batched, properly-formatted datasets from the TFDS datasets.

  Args:
    dataset_name: Name of TFDS dataset.
    shuffle_input_sentences: Not used during evaluation, but arg still needed
      for gin compatibility.
    num_eval_examples: Number of examples to use during evaluation. For the
      nolabel evaluation, this is also the number of distractors we choose
      between.
    batch_size: Batch size.

  Returns:
    The validation dataset, the story identifiers for each story in the
      embedding matrix, and the embedding matrix.
  """

    del num_eval_examples
    del shuffle_input_sentences

    splits_to_load = [
        tfds.Split.TRAIN,
        rocstories_sentence_embeddings.VALIDATION_2018,
    ]
    tfds_train, tfds_valid = tfds.load(dataset_name,
                                       data_dir=FLAGS.data_dir,
                                       split=splits_to_load)

    _, train_embs, train_story_ids = utils.build_train_style_dataset(
        tfds_train,
        batch_size,
        shuffle_input_sentences=False,
        return_ids=True,
        is_training=False)
    out = build_all_distractor_valid_dataset(tfds_valid, batch_size=batch_size)
    valid_dataset, valid_embs, valid_story_ids = out

    all_story_ids = valid_story_ids + train_story_ids
    all_emb_matrix = tf.concat([valid_embs, train_embs], axis=0)

    return valid_dataset, all_story_ids, all_emb_matrix
Esempio n. 3
0
def prepare_datasets(dataset_name=gin.REQUIRED,
                     shuffle_input_sentences=False,
                     num_eval_examples=2000,
                     batch_size=32):
  """Create batched, properly-formatted datasets from the TFDS datasets.

  Args:
    dataset_name: Name of TFDS dataset.
    shuffle_input_sentences: If True, the order of the input sentences is
      randomized.
    num_eval_examples: Number of examples to use during evaluation. For the
      nolabel evaluation, this is also the number of distractors we choose
      between.
    batch_size: Batch size.

  Returns:
    A dictionary mapping from the dataset split to a Dataset object.
  """
  splits_to_load = {
      'valid_nolabel': 'train[:2%]',
      'train': 'train[2%:]',
      'train_nolabel': 'train[2%:4%]',
      'valid2018': rocstories_sentence_embeddings.VALIDATION_2018,
      'valid2016': rocstories_sentence_embeddings.VALIDATION_2016}

  datasets = tfds.load(
      dataset_name,
      data_dir=FLAGS.data_dir,
      split=splits_to_load,
      download=False)

  emb_matrices = {}
  # Convert datasets to expected training data format, and build of the
  # embedding matrices.
  train_ds = utils.build_train_style_dataset(
      datasets['train'], batch_size, shuffle_input_sentences)
  datasets['train'], emb_matrices['train'] = train_ds

  valid_nolabel_ds = utils.build_train_style_dataset(
      datasets['valid_nolabel'], batch_size, False,
      num_examples=num_eval_examples)
  datasets['valid_nolabel'], emb_matrices['valid_nolabel'] = valid_nolabel_ds

  train_nolabel_ds = utils.build_train_style_dataset(
      datasets['train_nolabel'], batch_size, False,
      num_examples=num_eval_examples)
  datasets['train_nolabel'], emb_matrices['train_nolabel'] = train_nolabel_ds

  # Convert official evaluation datasets to validation data format. There are no
  # embedding matrices involved here since the task has only two possible next
  # sentences to pick between for each example.
  datasets['valid2018'] = utils.build_validation_dataset(
      datasets['valid2018']).take(num_eval_examples)
  datasets['valid2016'] = utils.build_validation_dataset(
      datasets['valid2016']).take(num_eval_examples)

  logging.info('EMBEDDING MATRICES CREATED:')
  for key in emb_matrices:
    logging.info('%s: %s', key, emb_matrices[key].shape)

  return datasets, emb_matrices