def tokenize(path, output_path, buffer=1000): ''' Parse a file asynchronously ''' with tempfile.TemporaryDirectory() as tmpdir: paths = [] results = [] pool = Pool() word_counts = Counter() basename = os.path.basename(output_path) language = os.path.splitext(basename)[1][1:] file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer) for chunk in sorted(file_chunks): output_chunk = f'{chunk}{basename}' results.append(pool.apply_async(_tokenize, [language, chunk, output_chunk])) paths.append(output_chunk) pool.close() results = tqdm( results, unit='chunk', dynamic_ncols=True, desc=f'Tokenizing {basename}', file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): for result in results: word_counts += result.get() pool.join() file_utils.join(paths, output_path) return word_counts
def evaluate_epoch(self, epoch, experiment, stats_file, verbose=0): ''' Evaluate a single epoch ''' neg_log_likelihood = metrics.Metric('nll', metrics.format_float) def get_description(): mode_name = 'Test' if self.dataset.split == 'test' else 'Validate' description = f'{mode_name} #{epoch}' if verbose > 0: description += f' {neg_log_likelihood}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): for batch in batches: # run the data through the model batches.set_description_str(get_description()) nll, length, stats = self.evaluate(batch) self.update_stats(stats, self.stats, self.count) if length: neg_log_likelihood.update(nll / length) experiment.log_metric('nll', neg_log_likelihood.average) self.save_stats(stats_file) return neg_log_likelihood.average
def get_parse_vocab(path, segmenters, buffer=1000): ''' Parse a file asynchronously ''' with tempfile.TemporaryDirectory() as tmpdir: results = [] pool = Pool() vocab = set() basename = os.path.basename(path) file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer) for chunk in sorted(file_chunks): results.append( pool.apply_async(_get_parse_vocab, [chunk, segmenters])) pool.close() results = tqdm( results, unit='chunk', dynamic_ncols=True, desc=f'Extracting vocab: {basename}', file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): for result in results: vocab.update(result.get()) pool.join() return vocab
def apply_bpe(bpe_path, path, output_path, buffer=1000): ''' Parse a file asynchronously ''' with tempfile.TemporaryDirectory() as tmpdir: paths = [] results = [] pool = Pool() vocab = set() basename = os.path.basename(output_path) file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer) for chunk in sorted(file_chunks): output_chunk = f'{chunk}{basename}' results.append(pool.apply_async(_apply_bpe, [bpe_path, chunk, output_chunk])) paths.append(output_chunk) pool.close() results = tqdm( results, unit='chunk', dynamic_ncols=True, desc=f'BPE encoding {basename}', file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): for result in results: vocab.update(result.get()) pool.join() file_utils.join(paths, output_path) return vocab
def parse(path, output_path, buffer=1000): ''' Parse a file asynchronously ''' with tempfile.TemporaryDirectory() as tmpdir: paths = [] results = [] pool = Pool() basename = os.path.basename(path) file_chunks = file_utils.split(path, os.path.join(tmpdir, ''), buffer) for chunk in sorted(file_chunks): renamed_chunk = f'{chunk}{basename}' os.rename(chunk, renamed_chunk) results.append(pool.apply_async(_parse, [renamed_chunk])) paths.append(renamed_chunk.replace('bpe.32000.', '') + '.parse') pool.close() results = tqdm( results, unit='chunk', dynamic_ncols=True, desc=f'Parsing {basename}', file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): for result in results: result.get() pool.join() file_utils.join(paths, output_path)
def translate_all(self, output_file, epoch, experiment, verbose=0): ''' Generate all predictions from the dataset ''' def get_description(): description = f'Generate #{epoch}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): ordered_outputs = [] for batch in batches: # run the data through the model batches.set_description_str(get_description()) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() sequences = self.translator.translate( batch) # step to be profiled end_event.record() torch.cuda.synchronize() self.time_profile.append(start_event.elapsed_time(end_event)) if self.config.timed: continue target_sequences = next(iter(sequences.values())) for i, example_id in enumerate(batch['example_ids']): outputs = [] if verbose > 1: trim = verbose < 2 join = verbose < 3 for key in sequences.keys(): sequence = sequences[key][i] sequence = ' '.join( self.dataset.decode(sequence, join, trim)) outputs.append(f'{key}: {sequence}\n') outputs.append(f'+++++++++++++++++++++++++++++\n') else: sequence = target_sequences[i] decoded = ' '.join( self.dataset.decode(sequence, trim=not verbose)) outputs.append(f'{decoded}\n') if self.config.order_output: ordered_outputs.append((example_id, outputs)) else: output_file.writelines(outputs) for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate output_file.writelines(outputs)
async def __call__(self): """ Run the generation! """ entries = self.dataset.entries if self.args.data.max_entries: entries = entries[:self.args.data.max_entries] batch_iterator = tqdm( as_completed([self.scheduler.generate(entry) for entry in entries]), unit="entry", initial=1, dynamic_ncols=True, desc="Generating", total=len(entries), file=sys.stdout, # needed to make tqdm_wrap_stdout work ) sep = "*******\n" with tqdm_wrap_stdout(): example_id = 0 for result in batch_iterator: sample, batch, summary = await result summary_length = len(summary["tokens"]) context = self.generator.tokenizer.decode( summary["tokens"].tolist()) original = self.generator.tokenizer.decode( batch["tokens"][summary_length:].tolist()) logging.info( "#%d:\n%scontext\n%s%s\n%soriginal\n%s%s\n%ssample\n%s%s", example_id, sep, sep, context, sep, sep, original, sep, sep, sample, ) example_id += 1 batch_iterator.close()
def __call__(self) -> float: """ Run the evaluation! """ dataloader = get_dataloader(self.args.data, self.dataset, num_devices=len(self.model.device_ids)) def get_description(): return f"Eval {self.metric_store}" batch_iterator = tqdm( dataloader, unit="batch", initial=1, dynamic_ncols=True, desc=get_description(), file=sys.stdout, # needed to make tqdm_wrap_stdout work ) with ExitStack() as stack: # pylint:disable=no-member stack.enter_context(tqdm_wrap_stdout()) stack.enter_context(chunked_scattering()) # pylint:enable=no-member for batch in batch_iterator: try: self.eval_step(batch) except RuntimeError as rte: if "out of memory" in str(rte): self.metric_store["oom"].update(1) logging.warning(str(rte)) else: batch_iterator.close() raise rte batch_iterator.set_description_str(get_description()) batch_iterator.close() return self.metric_store["nll"].average
def train_epoch(self, epoch, experiment, verbose=0): ''' Run one training epoch ''' oom = self.metric_store['oom'] learning_rate = self.metric_store['lr'] num_tokens = self.metric_store['num_tok'] neg_log_likelihood = self.metric_store['nll'] def try_optimize(i, last=False): # optimize if: # 1) last and remainder # 2) not last and not remainder remainder = bool(i % self.config.accumulate_steps) if not last ^ remainder: next_lr = self.optimize() learning_rate.update(next_lr) experiment.log_metric('learning_rate', next_lr) return True return False def get_description(): description = f'Train #{epoch}' if verbose > 0: description += f' {self.metric_store}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): i = 1 nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 for i, batch in enumerate(batches, 1): try: nll, length = self.calculate_gradient(batch) did_optimize = try_optimize(i) # record the effective number of tokens num_tokens_per_update += int(sum(batch['input_lens'])) num_tokens_per_update += int(sum(batch['target_lens'])) if length: # record length and nll nll_per_update += nll length_per_update += length if did_optimize: # advance the experiment step experiment.set_step(experiment.curr_step + 1) num_tokens.update(num_tokens_per_update) neg_log_likelihood.update(nll_per_update / length_per_update) experiment.log_metric('num_tokens', num_tokens_per_update) experiment.log_metric('nll', neg_log_likelihood.last_value) # experiment.log_metric('max_memory_alloc', torch.cuda.max_memory_allocated()//1024//1024) # experiment.log_metric('max_memory_cache', torch.cuda.max_memory_cached()//1024//1024) nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 except RuntimeError as rte: if 'out of memory' in str(rte): torch.cuda.empty_cache() oom.update(1) experiment.log_metric('oom', oom.total) #exit(-1) else: batches.close() raise rte if self.should_checkpoint(): new_best = False if self.config.early_stopping: with tqdm_unwrap_stdout(): new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best) batches.set_description_str(get_description()) if self.is_done(experiment, epoch): batches.close() break try_optimize(i, last=True)
def translate_all(self, output_file, stats_file, example_file, epoch, experiment, verbose=0): ''' Generate all predictions from the dataset ''' def get_description(): description = f'Generate #{epoch}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): ordered_outputs = [] with torch.no_grad(): self.model.eval() count = 0 for batch in batches: # run the data through the model batches.set_description_str(get_description()) result = self.model(batch) # # stats encoder_stats = probe( result['encoder_attn_weights_tensor']) decoder_stats = probe( result['decoder_attn_weights_tensor']) enc_dec_stats = probe( result['enc_dec_attn_weights_tensor']) train_stats = { 'encoder_stats': { stats_type: encoder_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES }, 'decoder_stats': { stats_type: decoder_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES }, 'enc_dec_stats': { stats_type: enc_dec_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES } } self.update_stats2(train_stats, self.train_stats, self.train_count) sequences, test_stats = self.translator.translate(batch) # self.update_stats(test_stats, self.test_stats, self.test_count) if self.config.timed: continue target_sequences = next(iter(sequences.values())) new_targets = [] for i, example_id in enumerate(batch['example_ids']): # print("example_id", example_id) if example_id == 430: print("saved") train_tensors = { 'encoder': result['encoder_attn_weights_tensor'].cpu( ).numpy().tolist(), 'decoder': result['decoder_attn_weights_tensor'].cpu( ).numpy().tolist(), 'enc_dec': result['enc_dec_attn_weights_tensor'].cpu(). numpy().tolist() } json.dump(train_tensors, example_file) outputs = [] if verbose > 1: trim = verbose < 2 join = verbose < 3 for key in sequences.keys(): sequence = sequences[key][i] sequence = ' '.join( self.dataset.decode(sequence, join, trim)) outputs.append(f'{key}: {sequence}\n') outputs.append(f'+++++++++++++++++++++++++++++\n') else: sequence = target_sequences[i] new_targets.append(torch.LongTensor(sequence)) decoded = ' '.join( self.dataset.decode(sequence, trim=not verbose)) outputs.append(f'{decoded}\n') if self.config.order_output: ordered_outputs.append((example_id, outputs)) else: output_file.writelines(outputs) self.dataset.collate_field(batch, 'target', new_targets) result = self.model(batch) # stats encoder_stats = probe( result['encoder_attn_weights_tensor']) decoder_stats = probe( result['decoder_attn_weights_tensor']) enc_dec_stats = probe( result['enc_dec_attn_weights_tensor']) test_stats = { 'encoder_stats': { stats_type: encoder_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES }, 'decoder_stats': { stats_type: decoder_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES }, 'enc_dec_stats': { stats_type: enc_dec_stats[stats_type].view( self.num_layers, self.num_heads, -1).cpu().numpy() for stats_type in STATS_TYPES } } self.update_stats2(test_stats, self.test_stats, self.test_count) count += 1 if count == 50: break for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate output_file.writelines(outputs) self.save_stats2(stats_file)
def translate_all(self, output_file, epoch, experiment, verbose=0): ''' Generate all predictions from the dataset ''' def get_description(): description = f'Generate #{epoch}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): self.model.eval() ordered_outputs = [] for batch in batches: # print("in probe new translate", flush=True) # run the data through the model batches.set_description_str(get_description()) sequences, attn_weights_tensors_dict = self.translator.translate(batch) if self.config.timed: continue target_sequences = next(iter(sequences.values())) encoder_attn_weights_tensor = attn_weights_tensors_dict['encoder_attn_weights_tensor'] new_targets = [] output_sentences = [] source_sentences = [] for i, example_id in enumerate(batch['example_ids']): outputs = [] if verbose > 1: trim = verbose < 2 join = verbose < 3 for key in sequences.keys(): sequence = sequences[key][i] sequence = ' '.join(self.dataset.decode(sequence, join, trim)) outputs.append(f'{key}: {sequence}\n') outputs.append(f'+++++++++++++++++++++++++++++\n') else: sequence = target_sequences[i] new_targets.append(torch.LongTensor(sequence)) decoded = ' '.join(self.dataset.decode(sequence, trim=not verbose)) outputs.append(f'{decoded}\n') output_sentence = ' '.join(self.dataset.decode(sequence, join=False, trim=not verbose)) output_sentences.append(output_sentence) source_sentence = ' '.join(self.dataset.decode(batch['inputs'][i], join=False, trim=not verbose)) source_sentences.append(source_sentence) # Encoder heatmap # print("saving encoder heatmap") for j in range(encoder_attn_weights_tensor.shape[0]): for k in range(encoder_attn_weights_tensor.shape[1]): attn_filename = f'encoder_attn_weights{example_id}_l{j}_h{k}.png' attn_path = os.path.join(self.config.output_directory, attn_filename) save_attention(source_sentence, source_sentence, encoder_attn_weights_tensor[j][k].cpu().numpy(), attn_path) if self.config.order_output: ordered_outputs.append((example_id, outputs)) else: output_file.writelines(outputs) self.dataset.collate_field(batch, 'target', new_targets) result = self.model(batch) # Decoder heatmap # print("saving decoder heatmap") for i, example_id in enumerate(batch['example_ids']): for j in range(result['decoder_attn_weights_tensor'].shape[0]): for k in range(result['decoder_attn_weights_tensor'].shape[1]): attn_filename = f'decoder_attn_weights{example_id}_l{j}_h{k}.png' attn_path = os.path.join(self.config.output_directory, attn_filename) save_attention(output_sentences[i], output_sentences[i], result['decoder_attn_weights_tensor'][j][k].cpu().numpy(), attn_path) attn_filename = f'enc_dec_attn_weights{example_id}_l{j}_h{k}.png' attn_path = os.path.join(self.config.output_directory, attn_filename) save_attention(source_sentences[i], '<SOS> ' + output_sentences[i], result['enc_dec_attn_weights_tensor'][j][k].cpu().numpy(), attn_path) for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate output_file.writelines(outputs)
def translate_all(self, output_file, enc_off_diagonal_output_file, dec_off_diagonal_output_file, epoch, experiment, verbose=0): ''' Generate all predictions from the dataset ''' def get_description(): description = f'Generate #{epoch}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): self.model.eval() ordered_outputs = [] for batch in batches: # print("in probe new translate", flush=True) # run the data through the model batches.set_description_str(get_description()) sequences = self.translator.translate(batch) if self.config.timed: continue target_sequences = next(iter(sequences.values())) new_targets = [] output_sentences = [] source_sentences = [] for i, example_id in enumerate(batch['example_ids']): outputs = [] if verbose > 1: trim = verbose < 2 join = verbose < 3 for key in sequences.keys(): sequence = sequences[key][i] sequence = ' '.join(self.dataset.decode(sequence, join, trim)) outputs.append(f'{key}: {sequence}\n') outputs.append(f'+++++++++++++++++++++++++++++\n') else: sequence = target_sequences[i] new_targets.append(torch.LongTensor(sequence)) decoded = ' '.join(self.dataset.decode(sequence, trim=not verbose)) outputs.append(f'{decoded}\n') output_sentence = ' '.join(self.dataset.decode(sequence, join=False, trim=not verbose)) output_sentences.append(output_sentence) source_sentence = ' '.join(self.dataset.decode(batch['inputs'][i], join=False, trim=not verbose)) source_sentences.append(source_sentence) # Encoder heatmap # print("saving encoder heatmap") # for j in range(encoder_attn_weights_tensor.shape[0]): # for k in range(encoder_attn_weights_tensor.shape[1]): # attn_filename = f'encoder_attn_weights{example_id}_l{j}_h{k}.png' # attn_path = os.path.join(self.config.output_directory, attn_filename) # save_attention(source_sentence, source_sentence, # encoder_attn_weights_tensor[j][k].cpu().numpy(), attn_path) if self.config.order_output: ordered_outputs.append((example_id, outputs)) else: output_file.writelines(outputs) self.dataset.collate_field(batch, 'target', new_targets) result = self.model(batch) # Decoder heatmap # print("saving decoder heatmap") for i, example_id in enumerate(batch['example_ids']): # print("result.keys()", result.keys()) # print("result['encoder_attn_weights_tensor']", result['encoder_attn_weights_tensor'].shape) for coder in ['encoder', 'decoder']: attn_weights_shape = result[coder + '_attn_weights_tensor'].shape attn_weights = result[coder + '_attn_weights_tensor'].view(-1, attn_weights_shape[2], attn_weights_shape[3]) indices_q = torch.round(torch.arange(attn_weights_shape[2], dtype=torch.float32, device=attn_weights.get_device()).view(1, -1) * self.dataset.word_count_ratio) argmax_weights = torch.argmax(attn_weights, dim=2) # print("argmax_weights", argmax_weights) max_weights = torch.max(attn_weights, dim=2)[0] #attn_weights[argmax_weights] # print("max_weights", max_weights.shape) distance = torch.abs(argmax_weights.type_as(indices_q) - indices_q) # print("attn_weights", attn_weights.shape) # print("distance", distance.shape) # print(distance) # print("distance >= threshold", (distance >= self.config.off_diagonal_distance_threshold).shape) # print(distance >= self.config.off_diagonal_distance_threshold, torch.sum(distance >= self.config.off_diagonal_distance_threshold)) # print("max_weights[distance >= threshold]", max_weights[distance >= self.config.off_diagonal_distance_threshold].shape, max_weights[distance >= 1]) max_prob = torch.max(max_weights[distance >= self.config.off_diagonal_distance_threshold]) argmax_offset = torch.max(distance) number = torch.sum(distance >= self.config.off_diagonal_distance_threshold) #self.config.off_diagonal_threshold_param if self.config.off_diagonal_threshold_type == "number": # print("number") idx = int(torch.round(number.to(torch.float32) / float(attn_weights.shape[0] * attn_weights.shape[1]) * self.config.off_diagonal_bins).cpu().item()) self.number_frac_dict[coder][idx] += 1 self.number_frac_list_dict[coder][idx].append(example_id) # elif self.config.off_diagonal_threshold_type == "offset": # print("offset") # if argmax_offset >= self.config.off_diagonal_threshold_param: # self.off_diagonal.append(example_id) # else: # self.non_off_diagonal.append(example_id) # else: # prob # print("prob") # if max_prob >= self.config.off_diagonal_threshold_param: # self.off_diagonal.append(example_id) # else: # self.non_off_diagonal.append(example_id) #self.dataset.word_count_ratio # for j in range(result['decoder_attn_weights_tensor'].shape[0]): # for k in range(result['decoder_attn_weights_tensor'].shape[1]): # attn_filename = f'decoder_attn_weights{example_id}_l{j}_h{k}.png' # attn_path = os.path.join(self.config.output_directory, attn_filename) # save_attention(output_sentences[i], output_sentences[i], # result['decoder_attn_weights_tensor'][j][k].cpu().numpy(), attn_path) # attn_filename = f'enc_dec_attn_weights{example_id}_l{j}_h{k}.png' # attn_path = os.path.join(self.config.output_directory, attn_filename) # save_attention(source_sentences[i], '<PAD>' + output_sentences[i], # result['enc_dec_attn_weights_tensor'][j][k].cpu().numpy(), attn_path) # print("num off diagonal", len(self.off_diagonal)) # print("num non off diagonal", len(self.non_off_diagonal)) pp = pprint.PrettyPrinter() print("---------Encoder---------") print("number_dict") pp.pprint([(k, self.number_dict['encoder'][k]) for k in sorted(self.number_dict['encoder'].keys())]) print("number_frac_dict") pp.pprint([(k, self.number_frac_dict['encoder'][k]) for k in sorted(self.number_frac_dict['encoder'].keys())]) print("---------Decoder---------") print("number_dict") pp.pprint([(k, self.number_dict['decoder'][k]) for k in sorted(self.number_dict['decoder'].keys())]) print("number_frac_dict") pp.pprint( [(k, self.number_frac_dict['decoder'][k]) for k in sorted(self.number_frac_dict['decoder'].keys())]) for k in sorted(self.number_frac_dict['encoder'].keys()): enc_off_diagonal_output_file.write(str(k) + "\t" + " ".join(str(x) for x in self.number_frac_list_dict['encoder'][k]) + "\n") for k in sorted(self.number_frac_dict['decoder'].keys()): dec_off_diagonal_output_file.write( str(k) + "\t" + " ".join(str(x) for x in self.number_frac_list_dict['decoder'][k]) + "\n") # off_diagonal_output_file.write(str(len(self.off_diagonal)) + "\t" + " ".join([str(x) for x in self.off_diagonal]) + "\n") # off_diagonal_output_file.write(str(len(self.non_off_diagonal)) + "\t" + " ".join([str(x) for x in self.non_off_diagonal]) + "\n") for _, outputs in sorted(ordered_outputs, key=lambda x: x[0]): # pylint:disable=consider-using-enumerate output_file.writelines(outputs)
def train_epoch(self, epoch, experiment, verbose=0): ''' Run one training epoch ''' oom = self.metric_store['oom'] learning_rate = self.metric_store['lr'] num_tokens = self.metric_store['num_tok'] neg_log_likelihood = self.metric_store['nll'] perplexity = self.metric_store['ppl'] def try_optimize(i, curr_step, last=False): # optimize if: # 1) last and remainder # 2) not last and not remainder remainder = bool(i % self.config.accumulate_steps) if not last ^ remainder: next_lr = self.optimize(curr_step) learning_rate.update(next_lr) experiment.log_metric('learning_rate', next_lr) return True return False def get_description(): description = f'Train #{epoch}' if verbose > 0: description += f' {self.metric_store}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout ) with tqdm_wrap_stdout(): i = 1 nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 cnter = 0 for i, batch in enumerate(batches, 1): if type(batch) is not torch.Tensor: # concatenated dataset batch = ([b.squeeze(0) for b in batch]) else: batch = (batch,) try: # if batch.shape[0] == self.config.batch_length: if True: nll, length = self.calculate_gradient(batch) did_optimize = try_optimize(i, experiment.curr_step ) # record the effective number of tokens num_tokens_per_update += length if length: # record length and nll nll_per_update += nll length_per_update += length if did_optimize: # advance the experiment step experiment.set_step(experiment.curr_step + 1) num_tokens.update(num_tokens_per_update) nll = nll_per_update / length_per_update neg_log_likelihood.update(nll) perplexity.update(np.exp(nll)) experiment.log_metric('num_tokens', num_tokens_per_update) experiment.log_metric('nll', neg_log_likelihood.last_value) experiment.log_metric('ppl', perplexity.last_value) nll_per_update = 0. length_per_update = 0 num_tokens_per_update = 0 except RuntimeError as rte: if 'out of memory' in str(rte): torch.cuda.empty_cache() oom.update(1) experiment.log_metric('oom', oom.total) #exit(-1) raise rte else: batches.close() raise rte if self.should_checkpoint(): new_best = False if self.config.early_stopping: with tqdm_unwrap_stdout(): new_best = self.evaluate(experiment, epoch, verbose) self.checkpoint(epoch, experiment.curr_step, new_best) batches.set_description_str(get_description()) if self.is_done(experiment, epoch): batches.close() break try_optimize(i, experiment.curr_step, last=True)
def translate_all(self, output_file, stats_file, epoch, experiment, verbose=0): ''' Generate all predictions from the dataset ''' def get_description(): description = f'Generate #{epoch}' if verbose > 1: description += f' [{profile.mem_stat_string(["allocated"])}]' return description batches = tqdm( self.dataloader, unit='batch', dynamic_ncols=True, desc=get_description(), file=sys.stdout # needed to make tqdm_wrap_stdout work ) with tqdm_wrap_stdout(): ordered_outputs = [] with torch.no_grad(): self.model.eval() for batch in batches: # run the data through the model batches.set_description_str(get_description()) result = self.model(batch) enc_dec_stats = probe( result['enc_dec_attn_weights_tensor']) train_stats = { 'enc_dec_stats': { stats_type: enc_dec_stats[stats_type].view( self.num_layers, self.num_heads, -1) for stats_type in STATS_TYPES } } self.update_stats(train_stats, self.train_stats, self.train_count) # translate sequences, attn_weights_tensors = self.translator.translate( batch) target_sequences = next(iter(sequences.values())) new_targets = [ torch.LongTensor(sequence) for sequence in target_sequences ] # new data with new targets self.dataset.collate_field(batch, 'target', new_targets) # model(new-batch) result = self.model(batch) enc_dec_stats = probe( result['enc_dec_attn_weights_tensor']) test_stats = { 'enc_dec_stats': { stats_type: enc_dec_stats[stats_type].view( self.num_layers, self.num_heads, -1) for stats_type in STATS_TYPES } } self.update_stats(test_stats, self.test_stats, self.test_count) if self.config.timed: continue self.save_stats(stats_file)
def __call__(self): """ Run the training! """ # Must be called first self.try_init_amp() model = self.modules["model"] optimizer = self.modules["optimizer"] scheduler = self.modules["scheduler"] if self.args.optim.use_gradient_checkpointing: model.enable_gradient_checkpointing() model = nn.DataParallel(model) dataloader = get_dataloader( self.args.data, self.dataset, num_devices=len(model.device_ids), shuffle=True, ) def get_description(): return f"Train {self.metric_store}" max_steps = self.args.optim.max_steps accumulation_steps = self.args.optim.gradient_accumulation_steps progress = tqdm( unit="step", initial=self.step, dynamic_ncols=True, desc=get_description(), total=max_steps, file=sys.stdout, # needed to make tqdm_wrap_stdout work ) with ExitStack() as stack: # pylint:disable=no-member stack.enter_context(tqdm_wrap_stdout()) stack.enter_context(chunked_scattering()) stack.enter_context(self.experiment.train()) # pylint:enable=no-member if self.args.optim.early_stopping: # If using early stopping, must evaluate regularly to determine # if training should stop early, so setup an Evaluator eval_args = copy.deepcopy(self.args) eval_args.data.batch_size = self.args.optim.eval_batch_size evaluator = Evaluator(eval_args) evaluator.model = model evaluator.load_dataset("validation") evaluator.initialize_experiment(experiment=self.experiment) # Make sure we are tracking validation nll self.metric_store.add( metrics.Metric("vnll", "format_float", "g(m)")) # And store a local variable for easy access vnll_metric = self.metric_store["vnll"] loss = 0 num_tokens = 0 for step, batch in enumerate(cycle(dataloader), 1): try: step_loss = self.compute_gradients_and_loss( batch, model, optimizer) run_optimizer = (step % accumulation_steps) == 0 if run_optimizer: # Run an optimization step optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() # Update loss and num tokens after running an optimization # step, in case it results in an out of memory error loss += step_loss num_tokens += batch["num_tokens"] if run_optimizer: # Since we ran the optimizer, increment current step self.step += 1 self.experiment.set_step(self.step) progress.update() # update our metrics as well self.update_metrics( loss / accumulation_steps, num_tokens, scheduler.get_lr()[0], ) num_tokens = 0 loss = 0 # and finally check if we should save if (self.args.save_steps > 0 and self.step % self.args.save_steps == 0): # First save the current checkpoint self.save() # Then if we are implementing early stopping, see # if we achieved a new best if self.args.optim.early_stopping: evaluator.reset_metrics() with ExitStack() as eval_stack: # pylint:disable=no-member eval_stack.enter_context( tqdm_unwrap_stdout()) eval_stack.enter_context( release_cuda_memory( collect_tensors(optimizer.state))) # pylint:enable=no-member vnll = evaluator() vnll_metric.update(vnll) # Save the updated metrics self.save_metrics() if vnll == vnll_metric.min: self.on_new_best() # Try to combat OOM errors caused by doing evaluation # in the same loop with training. This manifests in out # of memory errors after the first or second evaluation # run. refresh_cuda_memory() if not self.prune_checkpoints(): logging.info("Stopping early") break if self.step >= max_steps: logging.info("Finished training") break except RuntimeError as rte: if "out of memory" in str(rte): self.metric_store["oom"].update(1) logging.warning(str(rte)) else: progress.close() raise rte progress.set_description_str(get_description()) progress.close()