def from_pretrained(model_name_or_path, checkpoint_file='model.pt', data_name_or_path='.', archive_map=None, **kwargs): from tools import checkpoint_utils, file_utils if archive_map is not None: if model_name_or_path in archive_map: model_name_or_path = archive_map[model_name_or_path] if data_name_or_path is not None and data_name_or_path in archive_map: data_name_or_path = archive_map[data_name_or_path] # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe) # for each model if isinstance(model_name_or_path, dict): for k, v in model_name_or_path.items(): if k == 'checkpoint_file': checkpoint_file = v elif (k != 'path' # only set kwargs that don't already have overrides and k not in kwargs): kwargs[k] = v model_name_or_path = model_name_or_path['path'] model_path = file_utils.load_archive_file(model_name_or_path) # convenience hack for loading data and BPE codes from model archive if data_name_or_path.startswith('.'): kwargs['data'] = os.path.abspath( os.path.join(model_path, data_name_or_path)) else: kwargs['data'] = file_utils.load_archive_file(data_name_or_path) for file, arg in { 'code': 'bpe_codes', 'bpecodes': 'bpe_codes', 'sentencepiece.bpe.model': 'sentencepiece_model', }.items(): path = os.path.join(model_path, file) if os.path.exists(path): kwargs[arg] = path if 'user_dir' in kwargs: utils.import_user_module( argparse.Namespace(user_dir=kwargs['user_dir'])) models, args, task = checkpoint_utils.load_model_ensemble_and_task( [ os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep) ], arg_overrides=kwargs, ) return { 'args': args, 'task': task, 'models': models, }
def main(args, override_args=None): utils.import_user_module(args) assert args.max_tokens is not None or args.max_sentences is not None, \ 'Must specify batch size either with --max-tokens or --max-sentences' use_fp16 = args.fp16 use_cuda = torch.cuda.is_available() and not args.cpu if use_cuda: torch.cuda.set_device(args.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, 'model_overrides', '{}'))) else: overrides = None # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [args.path], arg_overrides=overrides, suffix=getattr(args, "checkpoint_suffix", ""), ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(model_args) # Build criterion criterion = task.build_criterion(model_args) criterion.eval() for subset in args.valid_subset.split(','): try: task.load_dataset(subset, combine=False, epoch=1) dataset = task.dataset(subset) except KeyError: raise Exception('Cannot find dataset: ' + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, seed=args.seed, num_shards=args.distributed_world_size, shard_id=args.distributed_rank, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=('tqdm' if not args.no_progress_bar else 'simple'), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step(sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if args.distributed_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=getattr(args, 'all_gather_list_size', 16384), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def _main(args, output_file): logging.basicConfig( format='%(asctime)s | %(levelname)s | %(name)s | %(message)s', datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO, stream=output_file, ) logger = logging.getLogger('generate') utils.import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Load dataset splits task = tasks.setup_task(args) task.load_dataset(args.gen_subset) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(args.path), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models] ), ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x scorer = scoring.build_scorer(args, tgt_dict) num_sentences = 0 has_target = True wps_meter = TimeMeter() for sample in progress: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] constraints = None if "constraints" in sample: constraints = sample["constraints"] gen_timer.start() hypos = task.inference_step(generator, models, sample, prefix_tokens=prefix_tokens, constraints=constraints) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding if 'src_tokens' in sample['net_input']: src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) else: src_tokens = None target_tokens = None if has_target: target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string( target_tokens, args.remove_bpe, escape_unk=True, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) src_str = decode_fn(src_str) if has_target: target_str = decode_fn(target_str) if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str), file=output_file) if has_target: print('T-{}\t{}'.format(sample_id, target_str), file=output_file) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output(generator), ) detok_hypo_str = decode_fn(hypo_str) if not args.quiet: score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file) # detokenized hypothesis print('D-{}\t{}\t{}'.format(sample_id, score, detok_hypo_str), file=output_file) print('P-{}\t{}'.format( sample_id, ' '.join(map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2)).tolist(), )) ), file=output_file) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment]) ), file=output_file) if args.print_step: print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file) if getattr(args, 'retain_iter_history', False): for step, h in enumerate(hypo['history']): _, h_str, _ = utils.post_process_prediction( hypo_tokens=h['tokens'].int().cpu(), src_str=src_str, alignment=None, align_dict=None, tgt_dict=tgt_dict, remove_bpe=None, ) print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True) hypo_tokens = tgt_dict.encode_line(detok_hypo_str, add_if_not_exist=True) if hasattr(scorer, 'add_string'): scorer.add_string(target_str, detok_hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) progress.log({'wps': round(wps_meter.avg)}) num_sentences += sample["nsentences"] if "nsentences" in sample else sample['id'].numel() logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format( num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: if args.bpe and not args.sacrebleu: if args.remove_bpe: logger.warning("BLEU score is being computed by splitting detokenized string on spaces, this is probably not what you want. Use --sacrebleu for standard 13a BLEU tokenization") else: logger.warning("If you are using BPE on the target side, the BLEU score is computed on BPE tokens, not on proper words. Use --sacrebleu for standard 13a BLEU tokenization") # use print to be consistent with other main outputs: S-, H-, T-, D- and so on print( 'Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()), file=output_file) return scorer
def main(args): start_time = time.time() total_translate_time = 0 utils.import_user_module(args) if args.buffer_size < 1: args.buffer_size = 1 if args.max_tokens is None and args.max_sentences is None: args.max_sentences = 1 assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not args.max_sentences or args.max_sentences <= args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' logger.info(args) # Fix seed for stochastic decoding if args.seed is not None and not args.no_seed_provided: np.random.seed(args.seed) utils.set_torch_seed(args.seed) use_cuda = torch.cuda.is_available() and not args.cpu # Setup task, e.g., translation task = tasks.setup_task(args) # Load ensemble logger.info('loading model(s) from {}'.format(args.path)) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(os.pathsep), arg_overrides=eval(args.model_overrides), task=task, suffix=getattr(args, "checkpoint_suffix", ""), ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() # Initialize generator generator = task.build_generator(models, args) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(args) bpe = encoders.build_bpe(args) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(args.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]) if args.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) if args.buffer_size > 1: logger.info('Sentence buffer size: %s', args.buffer_size) logger.info('NOTE: hypothesis and token scores are output in base 2') logger.info('Type the input sentence and press return:') start_id = 0 for inputs in buffered_read(args.input, args.buffer_size): results = [] for batch in make_batches(inputs, args, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } translate_start_time = time.time() translations = task.inference_step(generator, models, sample, constraints=constraints) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if args.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append((start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations) })) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) print('S-{}\t{}'.format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print("C-{}\t{}".format( id_, tgt_dict.string(constraint, args.remove_bpe))) # Process top predictions for hypo in hypos[:min(len(hypos), args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=args.remove_bpe, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) detok_hypo_str = decode_fn(hypo_str) score = hypo['score'] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print('H-{}\t{}\t{}'.format(id_, score, hypo_str)) # detokenized hypothesis print('D-{}\t{}\t{}'.format(id_, score, detok_hypo_str)) print('P-{}\t{}'.format( id_, ' '.join( map( lambda x: '{:.4f}'.format(x), # convert from base e to base 2 hypo['positional_scores'].div_(math.log(2) ).tolist(), )))) if args.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment]) print('A-{}\t{}'.format(id_, alignment_str)) # update running id_ counter start_id += len(inputs) logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time))
def main(parsed_args, **unused_kwargs): assert parsed_args.path is not None, '--path required for evaluation!' if torch.cuda.is_available() and not parsed_args.cpu: torch.cuda.set_device(parsed_args.device_id) utils.import_user_module(parsed_args) logger.info(parsed_args) use_cuda = torch.cuda.is_available() and not parsed_args.cpu task = tasks.setup_task(parsed_args) # Load ensemble logger.info('loading model(s) from {}'.format(parsed_args.path)) models, args = checkpoint_utils.load_model_ensemble( parsed_args.path.split(os.pathsep), arg_overrides=eval(parsed_args.model_overrides), task=task, suffix=getattr(parsed_args, "checkpoint_suffix", ""), ) for arg in vars(parsed_args).keys(): if arg not in { 'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary', 'add_bos_token', }: setattr(args, arg, getattr(parsed_args, arg)) # reduce tokens per sample by the required context window size args.tokens_per_sample -= args.context_window task = tasks.setup_task(args) # Load dataset splits task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) if args.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=args.tokens_per_sample, context_window=args.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info('{} {} {} examples'.format(args.data, args.gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: model.prepare_for_inference_(args) if args.fp16: model.half() if use_cuda: model.cuda() assert len(models) > 0 logger.info('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters()))) itr = task.get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens or 36000, max_sentences=args.max_sentences, max_positions=utils.resolve_max_positions(*[ model.max_positions() for model in models ]), ignore_invalid_inputs=True, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=args.log_format, log_interval=args.log_interval, default_log_format=('tqdm' if not args.no_progress_bar else 'none'), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, args.softmax_batch) score_sum = 0. count = 0 if args.remove_bpe is not None: if args.remove_bpe == 'sentencepiece': raise NotImplementedError else: bpe_cont = args.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if 'net_input' not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample['ntokens']) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample['id'][i] tokens = hypo['tokens'] tgt_len = tokens.numel() pos_scores = hypo['positional_scores'].float() if getattr(args, 'add_bos_token', False): assert hypo['tokens'][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) if inf_scores.any(): logger.info( 'skipping tokens with inf scores:', task.target_dictionary.string(tokens[inf_scores.nonzero()]) ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if args.output_word_probs or args.output_word_stats: w = '' word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob) is_bpe = False w = '' if args.output_word_probs: logger.info( str(int(sample_id)) + " " + ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)) ) wps_meter.update(sample['ntokens']) progress.log({'wps': round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info('Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format( gen_timer.n, gen_timer.sum, 1. / gen_timer.avg )) logger.info('Loss (base 2): {:.4f}, Perplexity: {:.2f}'.format( avg_nll_loss, 2**avg_nll_loss )) if args.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def main(args): utils.import_user_module(args) os.makedirs(args.destdir, exist_ok=True) logger.addHandler(logging.FileHandler( filename=os.path.join(args.destdir, 'preprocess.log'), )) logger.info(args) task = tasks.get_task(args.task) def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, ) target = not args.only_source if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert not args.srcdict or not args.tgtdict, \ "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True ) tgt_dict = src_dict else: if args.srcdict: src_dict = task.load_dictionary(args.srcdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang)) def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): logger.info("[{}] Dictionary: {} types".format(lang, len(vocab))) n_seq_tok = [0, 0] replaced = Counter() def merge_result(worker_result): replaced.update(worker_result["replaced"]) n_seq_tok[0] += worker_result["nseq"] n_seq_tok[1] += worker_result["ntok"] input_file = "{}{}".format( input_prefix, ("." + lang) if lang is not None else "" ) offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize, ( args, input_file, vocab, prefix, lang, offsets[worker_id], offsets[worker_id + 1] ), callback=merge_result ) pool.close() ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl, vocab_size=len(vocab)) merge_result( Binarizer.binarize( input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1] ) ) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, lang) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) logger.info( "[{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( lang, input_file, n_seq_tok[0], n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1], vocab.unk_word, ) ) def make_binary_alignment_dataset(input_prefix, output_prefix, num_workers): nseq = [0] def merge_result(worker_result): nseq[0] += worker_result['nseq'] input_file = input_prefix offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize_alignments, ( args, input_file, utils.parse_alignment, prefix, offsets[worker_id], offsets[worker_id + 1] ), callback=merge_result ) pool.close() ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, None, "bin"), impl=args.dataset_impl) merge_result( Binarizer.binarize_alignments( input_file, utils.parse_alignment, lambda t: ds.add_item(t), offset=0, end=offsets[1] ) ) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, None) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, None, "idx")) logger.info( "[alignments] {}: parsed {} alignments".format( input_file, nseq[0] ) ) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): if args.dataset_impl == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) else: make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) def make_all(lang, vocab): if args.trainpref: make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers) def make_all_alignments(): if args.trainpref and os.path.exists(args.trainpref + "." + args.align_suffix): make_binary_alignment_dataset(args.trainpref + "." + args.align_suffix, "train.align", num_workers=args.workers) if args.validpref and os.path.exists(args.validpref + "." + args.align_suffix): make_binary_alignment_dataset(args.validpref + "." + args.align_suffix, "valid.align", num_workers=args.workers) if args.testpref and os.path.exists(args.testpref + "." + args.align_suffix): make_binary_alignment_dataset(args.testpref + "." + args.align_suffix, "test.align", num_workers=args.workers) make_all(args.source_lang, src_dict) if target: make_all(args.target_lang, tgt_dict) if args.align_suffix: make_all_alignments() logger.info("Wrote preprocessed data to {}".format(args.destdir)) if args.alignfile: assert args.trainpref, "--trainpref must be set if --alignfile is specified" src_file_name = train_path(args.source_lang) tgt_file_name = train_path(args.target_lang) freq_map = {} with open(args.alignfile, "r", encoding='utf-8') as align_file: with open(src_file_name, "r", encoding='utf-8') as src_file: with open(tgt_file_name, "r", encoding='utf-8') as tgt_file: for a, s, t in zip_longest(align_file, src_file, tgt_file): si = src_dict.encode_line(s, add_if_not_exist=False) ti = tgt_dict.encode_line(t, add_if_not_exist=False) ai = list(map(lambda x: tuple(x.split("-")), a.split())) for sai, tai in ai: srcidx = si[int(sai)] tgtidx = ti[int(tai)] if srcidx != src_dict.unk() and tgtidx != tgt_dict.unk(): assert srcidx != src_dict.pad() assert srcidx != src_dict.eos() assert tgtidx != tgt_dict.pad() assert tgtidx != tgt_dict.eos() if srcidx not in freq_map: freq_map[srcidx] = {} if tgtidx not in freq_map[srcidx]: freq_map[srcidx][tgtidx] = 1 else: freq_map[srcidx][tgtidx] += 1 align_dict = {} for srcidx in freq_map.keys(): align_dict[srcidx] = max(freq_map[srcidx], key=freq_map[srcidx].get) with open( os.path.join( args.destdir, "alignment.{}-{}.txt".format(args.source_lang, args.target_lang), ), "w", encoding='utf-8' ) as f: for k, v in align_dict.items(): print("{} {}".format(src_dict[k], tgt_dict[v]), file=f)
def main(args): utils.import_user_module(args) assert ( args.max_tokens is not None or args.max_sentences is not None ), "Must specify batch size either with --max-tokens or --max-sentences" metrics.reset() np.random.seed(args.seed) utils.set_torch_seed(args.seed) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info( "criterion: {} ({})".format(args.criterion, criterion.__class__.__name__) ) logger.info( "num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), ) ) # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: raise NotImplementedError('here') logger.info( "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) ) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( args.max_tokens, args.max_sentences ) ) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def get_parser(desc, default_task="translation"): # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser.add_argument("--user-dir", default=None) usr_args, _ = usr_parser.parse_known_args() utils.import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) # fmt: off parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument( '--log-interval', type=int, default=100, metavar='N', help='log progress every N batches (when progress bar is disabled)') parser.add_argument('--log-format', default=None, help='log format to use', choices=['json', 'none', 'simple', 'tqdm']) parser.add_argument( '--tensorboard-logdir', metavar='DIR', default='', help='path to save logs for tensorboard, should match --logdir ' 'of running tensorboard (default: no tensorboard logging)') parser.add_argument('--seed', default=None, type=int, metavar='N', help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument( '--memory-efficient-fp16', action='store_true', help='use a memory-efficient version of FP16 training; implies --fp16') parser.add_argument('--fp16-no-flatten-grads', action='store_true', help='don\'t flatten FP16 grads tensor') parser.add_argument('--fp16-init-scale', default=2**7, type=int, help='default FP16 loss scale') parser.add_argument('--fp16-scale-window', type=int, help='number of updates before increasing loss scale') parser.add_argument( '--fp16-scale-tolerance', default=0.0, type=float, help='pct of updates that can overflow before decreasing the loss scale' ) parser.add_argument( '--min-loss-scale', default=1e-4, type=float, metavar='D', help='minimum FP16 loss scale, after which training is stopped') parser.add_argument('--threshold-loss-scale', type=float, help='threshold FP16 loss scale from below') parser.add_argument( '--user-dir', default=None, help= 'path to a python module containing custom extensions (tasks and/or architectures)' ) parser.add_argument( '--empty-cache-freq', default=0, type=int, help='how often to clear the PyTorch CUDA cache (0 to disable)') parser.add_argument( '--all-gather-list-size', default=16384, type=int, help='number of bytes reserved for gathering stats from workers') parser.add_argument('--model-parallel-size', type=int, metavar='N', default=1, help='total number of GPUs to parallelize model over') parser.add_argument('--checkpoint-suffix', default='', help='suffix to add to the checkpoint file name') parser.add_argument('--quantization-config-path', default=None, help='path to quantization config file') parser.add_argument('--profile', action='store_true', help='enable autograd profiler emit_nvtx') from registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): parser.add_argument( '--' + registry_name.replace('_', '-'), default=REGISTRY['default'], choices=REGISTRY['registry'].keys(), ) # Task definitions can be found under fairseq/tasks/ from tasks import TASK_REGISTRY parser.add_argument('--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(), help='task') # fmt: on return parser
def parse_args_and_arch( parser: argparse.ArgumentParser, input_args: List[str] = None, parse_known: bool = False, suppress_defaults: bool = False, modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, ): """ Args: parser (ArgumentParser): the parser input_args (List[str]): strings to parse, defaults to sys.argv parse_known (bool): only parse known arguments, similar to `ArgumentParser.parse_known_args` suppress_defaults (bool): parse while ignoring all default values modify_parser (Optional[Callable[[ArgumentParser], None]]): function to modify the parser, e.g., to set default values """ if suppress_defaults: # Parse args without any default values. This requires us to parse # twice, once to identify all the necessary task/model args, and a second # time with all defaults set to None. args = parse_args_and_arch( parser, input_args=input_args, parse_known=parse_known, suppress_defaults=False, ) suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) suppressed_parser.set_defaults( **{k: None for k, v in vars(args).items()}) args = suppressed_parser.parse_args(input_args) return argparse.Namespace( **{k: v for k, v in vars(args).items() if v is not None}) from models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser.add_argument("--user-dir", default=None) usr_args, _ = usr_parser.parse_known_args(input_args) utils.import_user_module(usr_args) if modify_parser is not None: modify_parser(parser) # The parser doesn't know about model/criterion/optimizer-specific args, so # we parse twice. First we parse the model/criterion/optimizer, then we # parse a second time after adding the *-specific arguments. # If input_args is given, we will parse those args instead of sys.argv. args, _ = parser.parse_known_args(input_args) # Add model-specific args to parser. if hasattr(args, "arch"): model_specific_group = parser.add_argument_group( "Model-specific configuration", # Only include attributes which are explicitly given as command-line # arguments or which have default values. argument_default=argparse.SUPPRESS, ) ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) # Add *-specific args to parser. from registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): choice = getattr(args, registry_name, None) if choice is not None: cls = REGISTRY["registry"][choice] if hasattr(cls, "add_args"): cls.add_args(parser) if hasattr(args, "task"): from tasks import TASK_REGISTRY TASK_REGISTRY[args.task].add_args(parser) if getattr(args, "use_bmuf", False): # hack to support extra args for block distributed data parallelism from optim.bmuf import FairseqBMUF FairseqBMUF.add_args(parser) # Modify the parser a second time, since defaults may have been reset if modify_parser is not None: modify_parser(parser) # Parse a second time. if parse_known: args, extra = parser.parse_known_args(input_args) else: args = parser.parse_args(input_args) extra = None # Post-process args. if hasattr(args, "max_sentences_valid") and args.max_sentences_valid is None: args.max_sentences_valid = args.max_sentences if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: args.max_tokens_valid = args.max_tokens if getattr(args, "seed", None) is None: args.seed = 1 # default seed for training args.no_seed_provided = True else: args.no_seed_provided = False # Apply architecture configuration. if hasattr(args, "arch"): ARCH_CONFIG_REGISTRY[args.arch](args) if parse_known: return args, extra else: return args