def validate(args, trainer, task, epoch_itr, subsets): """Evaluate the model on the validation set(s) and return the losses.""" if args.fixed_validation_seed is not None: # set fixed seed for every validation utils.set_torch_seed(args.fixed_validation_seed) valid_losses = [] for subset in subsets: logger.info('begin validation on "{}" subset'.format(subset)) # Initialize data iterator itr = trainer.get_valid_iterator(subset).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, prefix=f"valid on '{subset}' subset", tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) # create a new root metrics aggregator so validation metrics # don't pollute other aggregators (e.g., train meters) with metrics.aggregate(new_root=True) as agg: for sample in progress: trainer.valid_step(sample) # log validation stats stats = get_valid_stats(args, trainer, agg.get_smoothed_values()) progress.print(stats, tag=subset, step=trainer.get_num_updates()) valid_losses.append(stats[args.best_checkpoint_metric]) return valid_losses
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(args, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding if 'src_tokens' in sample['net_input']: src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) else: src_tokens = None target_tokens = None if has_target: target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, args.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) # detokenized hypothesis print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file) print('P-{}\t{}'.format( sample_id, ' '.join(map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2)).tolist(), )) ), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) ), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line(detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), file=output_file) return scorer
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string(tokens[inf_scores.nonzero()]) ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) ) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg )) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss )) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def train(args, trainer, task, epoch_itr): """Train the model for one epoch and return validation losses.""" # Initialize data iterator itr = epoch_itr.next_epoch_itr( fix_batches_to_gpus=args.fix_batches_to_gpus, shuffle=(epoch_itr.next_epoch_idx > args.curriculum), ) update_freq = ( args.update_freq[epoch_itr.epoch - 1] if epoch_itr.epoch <= len(args.update_freq) else args.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, epoch=epoch_itr.epoch, tensorboard_logdir=( args.tensorboard_logdir if distributed_utils.is_master(args) else None ), default_log_format=("tqdm" if not args.no_progress_bar else "simple"), ) trainer.begin_epoch(epoch_itr.epoch) valid_subsets = args.valid_subset.split(",") should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): log_output = trainer.train_step(samples) if log_output is not None: # not OOM, overflow, ... # log mid-epoch stats num_updates = trainer.get_num_updates() if num_updates % args.log_interval == 0: stats = get_training_stats(metrics.get_smoothed_values("train_inner")) progress.log(stats, tag="train_inner", step=num_updates) # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( args, trainer, task, epoch_itr, valid_subsets, end_of_epoch ) if should_stop: break # log end-of-epoch stats logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch)) stats = get_training_stats(metrics.get_smoothed_values("train")) progress.print(stats, tag="train", step=num_updates) # reset epoch-level meters metrics.reset_meters("train") return valid_losses, should_stop