def generate_semcor_wsd_episodes(wsd_dataset, n_episodes, n_support_examples, n_query_examples, task): word_splits = {k: v for (k, v) in wsd_dataset.word_splits.items() if len(v['sentences']) > (n_support_examples + n_query_examples)} if n_episodes > len(word_splits): raise Exception('Not enough data available to generate {} episodes'.format(n_episodes)) episodes = [] for word in word_splits.keys(): if len(episodes) == n_episodes: break indices = list(range(len(word_splits[word]['sentences']))) random.shuffle(indices) start_index = 0 train_subset = WordWSDDataset(sentences=[word_splits[word]['sentences'][i] for i in indices[start_index: start_index + n_support_examples]], labels=[word_splits[word]['labels'][i] for i in indices[start_index: start_index + n_support_examples]], n_classes=len(wsd_dataset.sense_inventory[word])) support_loader = data.DataLoader(train_subset, batch_size=n_support_examples, collate_fn=prepare_batch) start_index += n_support_examples test_subset = WordWSDDataset(sentences=[word_splits[word]['sentences'][i] for i in indices[start_index: start_index + n_query_examples]], labels=[word_splits[word]['labels'][i] for i in indices[start_index: start_index + n_query_examples]], n_classes=len(wsd_dataset.sense_inventory[word])) query_loader = data.DataLoader(test_subset, batch_size=n_query_examples, collate_fn=prepare_batch) episode = Episode(support_loader=support_loader, query_loader=query_loader, base_task=task, task_id=task + '-' + word, n_classes=train_subset.n_classes) episodes.append(episode) return episodes
def generate_wsd_episodes(dir, n_episodes, n_support_examples, n_query_examples, task, meta_train=True): episodes = [] for file_name in glob.glob(os.path.join(dir, '*.json')): if len(episodes) == n_episodes: break word = file_name.split(os.sep)[-1].split('.')[0] word_wsd_dataset = MetaWSDDataset(file_name) train_subset = Subset(word_wsd_dataset, range(0, n_support_examples)) support_loader = data.DataLoader(train_subset, batch_size=n_support_examples, collate_fn=prepare_batch) if meta_train: test_subset = Subset( word_wsd_dataset, range(n_support_examples, n_support_examples + n_query_examples)) else: test_subset = Subset( word_wsd_dataset, range(n_support_examples, len(word_wsd_dataset))) query_loader = data.DataLoader(test_subset, batch_size=n_query_examples, collate_fn=prepare_batch) episode = Episode(support_loader=support_loader, query_loader=query_loader, base_task=task, task_id=task + '-' + word, n_classes=word_wsd_dataset.n_classes) episodes.append(episode) return episodes
def generate_ner_episodes(dir, labels_file, n_episodes, n_support_examples, n_query_examples, task, meta_train=False, vectors='bert'): episodes = [] labels = get_labels(labels_file) examples, label_map = read_examples_from_file(dir, labels) print('label_map', label_map) if meta_train == True: ner_dataset = NERSampler(examples, labels, label_map, 6, n_support_examples, n_query_examples, n_episodes) else: ner_dataset = SequentialSampler(examples, labels, label_map, 6, n_support_examples, n_query_examples, n_episodes) for index, ner_data in enumerate(ner_dataset): tags, sup_sents, query_sents = ner_data if vectors == 'bert': support_loader = data.DataLoader( sup_sents, batch_size=6 * n_support_examples, collate_fn=lambda pb: prepare_bert_batch(pb)) query_loader = data.DataLoader( query_sents, batch_size=6 * n_query_examples, collate_fn=lambda pb: prepare_bert_batch(pb)) else: support_loader = data.DataLoader(sup_sents, batch_size=6 * n_support_examples, collate_fn=prepare_batch) query_loader = data.DataLoader(query_sents, batch_size=6 * n_query_examples, collate_fn=prepare_batch) episode = Episode(support_loader=support_loader, query_loader=query_loader, base_task=task, task_id=task + '-' + str(index), n_classes=len(tags), tags=tags) episodes.append(episode) return episodes, label_map