def get_data_iterator(self, path: str, batch_size: int, is_train: bool, shuffle=True, shuffle_seed: int = 0, offset: int = 0) -> ShardedDataIterator: data_files = glob.glob(path) logger.info("Data files: %s", data_files) if not data_files: raise RuntimeError('No Data files found') preprocessed_data_files = self._get_preprocessed_files( data_files, is_train) data = read_serialized_data_from_files(preprocessed_data_files) iterator = ShardedDataIterator(data, shard_id=self.shard_id, num_shards=self.distributed_factor, batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset) # apply deserialization hook iterator.apply(lambda sample: sample.on_deserialize()) return iterator
def get_data_iterator( self, path: str, batch_size: int, is_train: bool, shuffle=True, shuffle_seed: int = 0, offset: int = 0, ) -> ShardedDataIterator: run_preprocessing = (True if self.distributed_factor == 1 or self.cfg.local_rank in [-1, 0] else False) # Original, raw gold passages gold_passages_src = self.cfg.gold_passages_src if gold_passages_src: if not is_train: gold_passages_src = self.cfg.gold_passages_src_dev assert os.path.exists( gold_passages_src ), "Please specify valid gold_passages_src/gold_passages_src_dev" # Processed, 100-word split gold passages gold_passages_processed = (self.cfg.gold_passages_processed if is_train else self.cfg.gold_passages_processed_dev) if self.wiki_data is None: self.wiki_data = TokenizedWikipediaPassages( data_file=self.cfg.wiki_psgs_tokenized) bm25_retrieval_results = self.cfg.bm25_retrieval_results if is_train else None dataset = ExtractiveReaderGeneralDataset( path, bm25_retrieval_results, self.wiki_data, is_train, gold_passages_src, gold_passages_processed, self.tensorizer, run_preprocessing, self.cfg.num_workers, debugging=self.debugging, ) dataset.load_data() iterator = ShardedDataIterator( dataset, shard_id=self.shard_id, num_shards=self.distributed_factor, batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset, ) # apply deserialization hook iterator.apply(lambda sample: sample.on_deserialize()) return iterator
def get_data_iterator( self, path: str, batch_size: int, is_train: bool, shuffle=True, shuffle_seed: int = 0, offset: int = 0, ) -> ShardedDataIterator: run_preprocessing = ( True if self.distributed_factor == 1 or self.cfg.local_rank in [-1, 0] else False ) gold_passages_src = self.cfg.gold_passages_src if gold_passages_src: if not is_train: gold_passages_src = self.cfg.gold_passages_src_dev assert os.path.exists( gold_passages_src ), "Please specify valid gold_passages_src/gold_passages_src_dev" dataset = ExtractiveReaderDataset( path, is_train, gold_passages_src, self.tensorizer, run_preprocessing, self.cfg.num_workers, ) dataset.load_data() iterator = ShardedDataIterator( dataset, shard_id=self.shard_id, num_shards=self.distributed_factor, batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset, ) # apply deserialization hook iterator.apply(lambda sample: sample.on_deserialize()) return iterator
def get_data_iterator( self, path: str, batch_size: int, shuffle=True, shuffle_seed: int = 0, offset: int = 0, upsample_rates: list = None, ) -> ShardedDataIterator: data_files = glob.glob(path) data = read_data_from_json_files(data_files, upsample_rates) # filter those without positive ctx data = [r for r in data if len(r["positive_ctxs"]) > 0] logger.info("Total cleaned data size: {}".format(len(data))) return ShardedDataIterator( data, shard_id=self.shard_id, num_shards=self.distributed_factor, batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset, strict_batch_size=True, # this is not really necessary, one can probably disable it )
def get_data_iterator( self, batch_size: int, is_train_set: bool, shuffle=True, shuffle_seed: int = 0, offset: int = 0, rank: int = 0, ): hydra_datasets = ( self.ds_cfg.train_datasets if is_train_set else self.ds_cfg.dev_datasets ) sampling_rates = self.ds_cfg.sampling_rates logger.info( "Initializing task/set data %s", self.ds_cfg.train_datasets_names if is_train_set else self.ds_cfg.dev_datasets_names, ) # randomized data loading to avoid file system congestion datasets_list = [ds for ds in hydra_datasets] rnd = random.Random(rank) rnd.shuffle(datasets_list) [ds.load_data() for ds in datasets_list] sharded_iterators = [ ShardedDataIterator( ds, shard_id=self.shard_id, num_shards=self.distributed_factor, batch_size=batch_size, shuffle=shuffle, shuffle_seed=shuffle_seed, offset=offset, ) for ds in hydra_datasets ] return MultiSetDataIterator( sharded_iterators, shuffle_seed, shuffle, sampling_rates=sampling_rates if is_train_set else [1], rank=rank, )
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, global_step: int, ): cfg = self.cfg rolling_train_loss = 0.0 epoch_loss = 0 log_result_step = cfg.train.log_batch_step rolling_loss_step = cfg.train.train_rolling_loss_step self.reader.train() epoch_batches = train_data_iterator.max_iterations for i, samples_batch in enumerate( train_data_iterator.iterate_ds_data(epoch=epoch)): data_iteration = train_data_iterator.get_iteration() # enables to resume to exactly same train state if cfg.fully_resumable: np.random.seed(cfg.seed + global_step) torch.manual_seed(cfg.seed + global_step) if cfg.n_gpu > 0: torch.cuda.manual_seed_all(cfg.seed + global_step) input = create_reader_input( self.tensorizer.get_pad_id(), samples_batch, cfg.passages_per_question, cfg.encoder.sequence_length, cfg.max_n_answers, is_train=True, shuffle=True, ) loss = self._calc_loss(input) epoch_loss += loss.item() rolling_train_loss += loss.item() max_grad_norm = cfg.train.max_grad_norm if cfg.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), max_grad_norm) else: loss.backward() if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.reader.parameters(), max_grad_norm) if (i + 1) % cfg.train.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.reader.zero_grad() global_step += 1 if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, global_step=%d, lr=%f", epoch, data_iteration, epoch_batches, global_step, lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if global_step % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save(epoch, train_data_iterator.get_iteration(), scheduler) self.reader.train() epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) return global_step
def _train_epoch( self, scheduler, epoch: int, eval_step: int, train_data_iterator: ShardedDataIterator, ): args = self.args rolling_train_loss = 0.0 epoch_loss = 0 epoch_correct_predictions = 0 log_result_step = args.log_batch_step rolling_loss_step = args.train_rolling_loss_step num_hard_negatives = args.hard_negatives num_other_negatives = args.other_negatives seed = args.seed self.biencoder.train() epoch_batches = train_data_iterator.max_iterations data_iteration = 0 for i, samples_batch in enumerate( train_data_iterator.iterate_data(epoch=epoch) ): # to be able to resume shuffled ctx- pools data_iteration = train_data_iterator.get_iteration() random.seed(seed + epoch + data_iteration) biencoder_batch = BiEncoder.create_biencoder_input( samples_batch, self.tensorizer, True, num_hard_negatives, num_other_negatives, shuffle=True, shuffle_positives=args.shuffle_positive_ctx, ) loss, correct_cnt = _do_biencoder_fwd_pass( self.biencoder, biencoder_batch, self.tensorizer, args ) epoch_correct_predictions += correct_cnt epoch_loss += loss.item() rolling_train_loss += loss.item() if args.fp16: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( amp.master_params(self.optimizer), args.max_grad_norm ) else: loss.backward() if args.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.biencoder.parameters(), args.max_grad_norm ) if (i + 1) % args.gradient_accumulation_steps == 0: self.optimizer.step() scheduler.step() self.biencoder.zero_grad() if i % log_result_step == 0: lr = self.optimizer.param_groups[0]["lr"] logger.info( "Epoch: %d: Step: %d/%d, loss=%f, lr=%f", epoch, data_iteration, epoch_batches, loss.item(), lr, ) if (i + 1) % rolling_loss_step == 0: logger.info("Train batch %d", data_iteration) latest_rolling_train_av_loss = rolling_train_loss / rolling_loss_step logger.info( "Avg. loss per last %d batches: %f", rolling_loss_step, latest_rolling_train_av_loss, ) rolling_train_loss = 0.0 if data_iteration % eval_step == 0: logger.info( "Validation: Epoch: %d Step: %d/%d", epoch, data_iteration, epoch_batches, ) self.validate_and_save( epoch, train_data_iterator.get_iteration(), scheduler ) self.biencoder.train() self.validate_and_save(epoch, data_iteration, scheduler) epoch_loss = (epoch_loss / epoch_batches) if epoch_batches > 0 else 0 logger.info("Av Loss per epoch=%f", epoch_loss) logger.info("epoch total correct predictions=%d", epoch_correct_predictions)