def prepare_datasets(dataset_name=gin.REQUIRED, shuffle_input_sentences=False, num_eval_examples=2000, batch_size=32): """Create batched, properly-formatted datasets from the TFDS datasets. Args: dataset_name: Name of TFDS dataset. shuffle_input_sentences: Not used during evaluation, but arg still needed for gin compatibility. num_eval_examples: Number of examples to use during evaluation. For the nolabel evaluation, this is also the number of distractors we choose between. batch_size: Batch size. Returns: A dictionary mapping from the dataset split to a Dataset object. """ del shuffle_input_sentences splits_to_load = { 'valid_nolabel': 'train[:2%]', 'train_nolabel': 'train[2%:4%]', 'valid2018': rocstories_sentence_embeddings.VALIDATION_2018, 'valid2016': rocstories_sentence_embeddings.VALIDATION_2016 } datasets = tfds.load(dataset_name, data_dir=FLAGS.data_dir, split=splits_to_load, download=False) emb_matrices = {} valid_nolabel_ds = utils.build_train_style_dataset( datasets['valid_nolabel'], batch_size, False, num_examples=num_eval_examples, is_training=False) datasets['valid_nolabel'], emb_matrices['valid_nolabel'] = valid_nolabel_ds train_nolabel_ds = utils.build_train_style_dataset( datasets['train_nolabel'], batch_size, False, num_examples=num_eval_examples, is_training=False) datasets['train_nolabel'], emb_matrices['train_nolabel'] = train_nolabel_ds # Convert official evaluation datasets to validation data format. There are no # embedding matrices involved here since the task has only two possible next # sentences to pick between for each example. Ignore num_eval_examples and use # the full datasets for these. datasets['valid2018'] = utils.build_validation_dataset( datasets['valid2018']) datasets['valid2016'] = utils.build_validation_dataset( datasets['valid2016']) return datasets, emb_matrices
def prepare_dataset(dataset_name=gin.REQUIRED, shuffle_input_sentences=False, num_eval_examples=2000, batch_size=32): """Create batched, properly-formatted datasets from the TFDS datasets. Args: dataset_name: Name of TFDS dataset. shuffle_input_sentences: Not used during evaluation, but arg still needed for gin compatibility. num_eval_examples: Number of examples to use during evaluation. For the nolabel evaluation, this is also the number of distractors we choose between. batch_size: Batch size. Returns: The validation dataset, the story identifiers for each story in the embedding matrix, and the embedding matrix. """ del num_eval_examples del shuffle_input_sentences splits_to_load = [ tfds.Split.TRAIN, rocstories_sentence_embeddings.VALIDATION_2018, ] tfds_train, tfds_valid = tfds.load(dataset_name, data_dir=FLAGS.data_dir, split=splits_to_load) _, train_embs, train_story_ids = utils.build_train_style_dataset( tfds_train, batch_size, shuffle_input_sentences=False, return_ids=True, is_training=False) out = build_all_distractor_valid_dataset(tfds_valid, batch_size=batch_size) valid_dataset, valid_embs, valid_story_ids = out all_story_ids = valid_story_ids + train_story_ids all_emb_matrix = tf.concat([valid_embs, train_embs], axis=0) return valid_dataset, all_story_ids, all_emb_matrix
def prepare_datasets(dataset_name=gin.REQUIRED, shuffle_input_sentences=False, num_eval_examples=2000, batch_size=32): """Create batched, properly-formatted datasets from the TFDS datasets. Args: dataset_name: Name of TFDS dataset. shuffle_input_sentences: If True, the order of the input sentences is randomized. num_eval_examples: Number of examples to use during evaluation. For the nolabel evaluation, this is also the number of distractors we choose between. batch_size: Batch size. Returns: A dictionary mapping from the dataset split to a Dataset object. """ splits_to_load = { 'valid_nolabel': 'train[:2%]', 'train': 'train[2%:]', 'train_nolabel': 'train[2%:4%]', 'valid2018': rocstories_sentence_embeddings.VALIDATION_2018, 'valid2016': rocstories_sentence_embeddings.VALIDATION_2016} datasets = tfds.load( dataset_name, data_dir=FLAGS.data_dir, split=splits_to_load, download=False) emb_matrices = {} # Convert datasets to expected training data format, and build of the # embedding matrices. train_ds = utils.build_train_style_dataset( datasets['train'], batch_size, shuffle_input_sentences) datasets['train'], emb_matrices['train'] = train_ds valid_nolabel_ds = utils.build_train_style_dataset( datasets['valid_nolabel'], batch_size, False, num_examples=num_eval_examples) datasets['valid_nolabel'], emb_matrices['valid_nolabel'] = valid_nolabel_ds train_nolabel_ds = utils.build_train_style_dataset( datasets['train_nolabel'], batch_size, False, num_examples=num_eval_examples) datasets['train_nolabel'], emb_matrices['train_nolabel'] = train_nolabel_ds # Convert official evaluation datasets to validation data format. There are no # embedding matrices involved here since the task has only two possible next # sentences to pick between for each example. datasets['valid2018'] = utils.build_validation_dataset( datasets['valid2018']).take(num_eval_examples) datasets['valid2016'] = utils.build_validation_dataset( datasets['valid2016']).take(num_eval_examples) logging.info('EMBEDDING MATRICES CREATED:') for key in emb_matrices: logging.info('%s: %s', key, emb_matrices[key].shape) return datasets, emb_matrices