def get_train_dataloader_from_cache( train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int, sample_weights_path=None, fix_seed_for_weighted_sampler=False, ): # TODO: Expose buffer_size parameter (issue #1183) if sample_weights_path is not None: dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=False) dataset = _ListDataset([elem for elem in dataset]) _sample_weights = pd.read_csv(sample_weights_path, sep='\t', header=None)[0] sampler = WeightedDatasetSampler( dataset, _sample_weights, fix_seed=fix_seed_for_weighted_sampler) else: dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) sampler = None train_dataloader = torch_utils.DataLoaderWithLength( dataset=dataset, batch_size=train_batch_size, collate_fn=task.collate_fn, sampler=sampler) return train_dataloader
def get_train_dataloader_from_cache( train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int ): # TODO: Expose buffer_size parameter (issue #1183) dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) train_dataloader = torch_utils.DataLoaderWithLength( dataset=dataset, batch_size=train_batch_size, collate_fn=task.collate_fn, ) return train_dataloader
def get_train_dataloader_from_cache( train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int, batch_method: str, min_batch_size: int, total_batches: int, matchlist_pickle_path: int, ): # TODO: Expose buffer_size parameter (Issue #50) if batch_method == 'default': dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) train_dataloader = torch_utils.DataLoaderWithLength( dataset=dataset, batch_size=train_batch_size, collate_fn=task.collate_fn, ) elif batch_method == 'clustered': dataset = train_cache.get_uniterable_dataset(buffer_size=10000) assert ( total_batches > 0 ), f"Must define total number of batches to generate. Given: {total_batches}." assert ( train_batch_size > 0 ), f"Max batch size must be greater than zero. Given: {train_batch_size}." # Currently only supports pickled matchlist. Could potentially incorporate matching, but may take long # depending on size of data. assert os.path.exists( matchlist_pickle_path ), f"Must first create pickled match list or path given does not exist. Given: {matchlist_pickle_path}" match_list = pickle.load(open(matchlist_pickle_path, 'rb')) matched_random_batch_sampler = torch_utils.MatchedRandomBatchSampler( min_batch_size=min_batch_size, max_batch_size=train_batch_size, drop_last=True, match_list=match_list, total_batches=total_batches, ) train_dataloader = torch_utils.DataLoader( dataset=dataset, collate_fn=task.collate_fn, batch_sampler=matched_random_batch_sampler, ) else: raise KeyError(f"Batching method not supported: {batch_method}") return train_dataloader
def get_eval_dataloader_from_cache( eval_cache: caching.ChunkedFilesDataCache, task, eval_batch_size: int, subset_num=None, explicit_subset=None, ): dataset = eval_cache.get_iterable_dataset( buffer_size=10000, shuffle=False, subset_num=subset_num, explicit_subset=explicit_subset, ) eval_dataloader = torch_utils.DataLoaderWithLength( dataset=dataset, batch_size=eval_batch_size, collate_fn=task.collate_fn, ) return eval_dataloader
def get_loss_weights_dict(self, start_position: int = None): if start_position is not None: raise Exception() loss_weights_dict = {} for task_name in self.jiant_task_container.task_run_config.train_task_list: task_specific_config = self.jiant_task_container.task_specific_configs[ task_name] logger.info('task="%s": loading loss weights from "%s"', task_name, task_specific_config.train_loss_weights) train_batch_size = task_specific_config.train_batch_size if task_specific_config.train_loss_weights is not None: dataset = pd.read_csv(task_specific_config.train_loss_weights, sep='\t', header=None)[0].values dataset = torch.Tensor(dataset).to(self.device) loss_weights_dict[task_name] = InfiniteYield( torch_utils.DataLoaderWithLength( dataset=dataset, batch_size=train_batch_size)) else: loss_weights_dict[task_name] = None return loss_weights_dict
def main(args: RunConfiguration): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # === Shared model components setup === # model_type = "roberta-base" model_arch = ModelArchitectures.from_model_type(model_type=model_type) transformers_class_spec = model_setup.TRANSFORMERS_CLASS_SPEC_DICT[ model_arch] ancestor_model = model_setup.get_ancestor_model( transformers_class_spec=transformers_class_spec, model_config_path=args.model_config_path, ) encoder = model_setup.get_encoder( model_arch=model_arch, ancestor_model=ancestor_model, ) tokenizer = shared_model_setup.get_tokenizer( model_type=model_type, tokenizer_path=args.model_tokenizer_path, ) # === Taskmodels setup === # task_dict = { "mnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "mnli.json", )), "qnli": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )), "rte": tasks.create_task_from_config_path( os.path.join( args.task_config_base_path, "qnli.json", )) } taskmodels_dict = { "nli": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["mnli"].LABELS), ), ), "rte": taskmodels.ClassificationModel( encoder=encoder, classification_head=heads.ClassificationHead( hidden_size=encoder.config.hidden_size, hidden_dropout_prob=encoder.config.hidden_dropout_prob, num_labels=len(task_dict["rte"].LABELS), ), ), } task_to_taskmodel_map = { "mnli": "nli", "qnli": "nli", "rte": "rte", } # === Final === # jiant_model = JiantModel( task_dict=task_dict, encoder=encoder, taskmodels_dict=taskmodels_dict, task_to_taskmodel_map=task_to_taskmodel_map, tokenizer=tokenizer, ) jiant_model = jiant_model.to(device) # === Run === # task_dataloader_dict = {} for task_name, task in task_dict.items(): train_cache = caching.ChunkedFilesDataCache( cache_fol_path=os.path.join(args.task_cache_base_path, task_name, "train"), ) train_dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) train_dataloader = torch_utils.DataLoaderWithLength( dataset=train_dataset, batch_size=4, collate_fn=task.collate_fn, ) task_dataloader_dict[task_name] = train_dataloader for task_name, task in task_dict.items(): batch, batch_metadata = next(iter(task_dataloader_dict[task_name])) batch = batch.to(device) with torch.no_grad(): model_output = wrap_jiant_forward( jiant_model=jiant_model, batch=batch, task=task, compute_loss=True, ) print(task_name) print(model_output) print()