def get_parser(desc, default_task="translation"): # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser.add_argument("--user-dir", default=None) usr_args, _ = usr_parser.parse_known_args() import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) gen_parser_from_dataclass(parser, CommonParams()) from fairseq.registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): parser.add_argument( "--" + registry_name.replace("_", "-"), default=REGISTRY["default"], choices=REGISTRY["registry"].keys(), ) # Task definitions can be found under fairseq/tasks/ from fairseq.tasks import TASK_REGISTRY parser.add_argument( "--task", metavar="TASK", default=default_task, choices=TASK_REGISTRY.keys(), help="task", ) # fmt: on return parser
def main(args): # Load dataset import_user_module(args) task = tasks.setup_task(args) task.load_dataset(args.gen_subset) dataset = task.dataset(args.gen_subset) # Get iterator over batches batch_index_iterator = get_batch_iterator( dataset=dataset, max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=None, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, large_sent_first=False) # collate batch of sentences into single tensor for all data for batch_ids in tqdm(batch_index_iterator): samples = [dataset[i] for i in batch_ids] dataset.collater(samples)
def main(args): assert args.path is not None, '--path required for generation!' assert not args.sampling or args.nbest == args.beam, \ '--sampling requires --nbest to be equal to --beam' assert args.replace_unk is None or args.raw_text, \ '--replace-unk requires a raw text dataset (--raw-text)' import_user_module(args) if args.max_tokens is None and args.max_sentences is None: args.max_tokens = 12000 print(args) use_cuda = torch.cuda.is_available() and not args.cpu # ========== for bartsv task, rebuild dictionary after model args are loaded ========== # assert not hasattr(args, 'node_freq_min'), 'node_freq_min should be read from model args' # args.node_freq_min = 5 # temporarily set before model loading, as this is needed in tasks.setup_task(args) # ===================================================================================== # Load dataset splits task = tasks.setup_task(args) # Note: states are not needed since they will be provided by the state # machine task.load_dataset(args.gen_subset, state_machine=False) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # Load ensemble print('| loading model(s) from {}'.format(args.path)) try: models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=task, ) except: # NOTE this is for "bartsv" models when default "args.node_freq_min" (5) is not equal to the model # when loading model with the above task there will be an error when building the model with the task's # target vocabulary, which would be of different size # TODO better handle these cases (without sacrificing compatibility with other model archs) models, _model_args = checkpoint_utils.load_model_ensemble( args.path.split(':'), arg_overrides=eval(args.model_overrides), task=None, ) # ========== for bartsv task, rebuild the dictionary based on model args ========== if 'bartsv' in _model_args.arch and args.node_freq_min != _model_args.node_freq_min: args.node_freq_min = _model_args.node_freq_min # Load dataset splits task = tasks.setup_task(args) # Note: states are not needed since they will be provided by the state machine task.load_dataset(args.gen_subset, state_machine=False) # Set dictionaries try: src_dict = getattr(task, 'source_dictionary', None) except NotImplementedError: src_dict = None tgt_dict = task.target_dictionary # ================================================================================== # import pdb; pdb.set_trace() # print(_model_args) # ========== for previous model trained when new arguments were not there ========== if not hasattr(_model_args, 'shift_pointer_value'): _model_args.shift_pointer_value = 1 # ================================================================================== # Optimize ensemble for generation for model in models: model.make_generation_fast_( beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, need_attn=args.print_alignment, ) if args.fp16: model.half() if use_cuda: model.cuda() # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align # dictionary) align_dict = utils.load_align_dict(args.replace_unk) # Load dataset (possibly sharded) itr = task.get_batch_iterator( dataset=task.dataset(args.gen_subset), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=None, ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, required_batch_size_multiple=args.required_batch_size_multiple, num_shards=args.num_shards, shard_id=args.shard_id, num_workers=args.num_workers, # large_sent_first=False # not in fairseq ).next_epoch_itr(shuffle=False) # Initialize generator gen_timer = StopwatchMeter() generator = task.build_generator(args, _model_args) # Generate and compute BLEU score if args.sacrebleu: scorer = bleu.SacrebleuScorer() else: scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) num_sentences = 0 has_target = True examples = Examples(args.path, args.results_path, args.gen_subset, args.nbest) error_stats = {'num_sub_start': 0} with progress_bar.build_progress_bar(args, itr) as t: wps_meter = TimeMeter() for sample in t: sample = utils.move_to_cuda(sample) if use_cuda else sample if 'net_input' not in sample: raise Exception("Did not expect empty sample") continue prefix_tokens = None if args.prefix_size > 0: prefix_tokens = sample['target'][:, :args.prefix_size] # breakpoint() gen_timer.start() hypos = task.inference_step(generator, models, sample, args, prefix_tokens) num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos) gen_timer.stop(num_generated_tokens) # breakpoint() for i, sample_id in enumerate(sample['id'].tolist()): has_target = sample['target'] is not None # Remove padding src_tokens = utils.strip_pad( sample['net_input']['src_tokens'][i, :], tgt_dict.pad()) target_tokens = None if has_target: target_tokens = utils.strip_pad( sample['target'][i, :], tgt_dict.pad()).int().cpu() # Either retrieve the original sentences or regenerate them from tokens. if align_dict is not None: src_str = task.dataset( args.gen_subset).src.get_original_text(sample_id) target_str = task.dataset( args.gen_subset).tgt.get_original_text(sample_id) else: if src_dict is not None: src_str = src_dict.string(src_tokens, args.remove_bpe) else: src_str = "" if has_target: target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True) # debug: '<<unk>>' is added to the dictionary # if 'unk' in target_str: # breakpoint() # ==========> NOTE we do not really have the ground truth target (with the same alignments) # target_str might have <unk> as the target dictionary is only built on training data # but it doesn't matter. It should not affect the target dictionary! if not args.quiet: if src_dict is not None: print('S-{}\t{}'.format(sample_id, src_str)) if has_target: print('T-{}\t{}'.format(sample_id, target_str)) # Process top predictions for j, hypo in enumerate(hypos[i][:args.nbest]): # hypo_tokens, hypo_str, alignment = utils.post_process_prediction( # hypo_tokens=hypo['tokens'].int().cpu(), # src_str=src_str, # alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, # align_dict=align_dict, # tgt_dict=tgt_dict, # remove_bpe=args.remove_bpe, # # FIXME: AMR specific # split_token="\t", # line_tokenizer=task.tokenize, # ) if 'bartsv' in _model_args.arch: if not tgt_dict[hypo['tokens'][0]].startswith( tgt_dict.bpe.INIT): error_stats['num_sub_start'] += 1 try: actions_nopos, actions_pos, actions = post_process_action_pointer_prediction_bartsv( hypo, tgt_dict) except: breakpoint() else: actions_nopos, actions_pos, actions = post_process_action_pointer_prediction( hypo, tgt_dict) # breakpoint() if args.clean_arcs: actions_nopos, actions_pos, actions, invalid_idx = clean_pointer_arcs( actions_nopos, actions_pos, actions) # TODO these are just dummy for the reference below to run hypo_tokens = hypo['tokens'].int().cpu() hypo_str = '/t'.join(actions) alignment = None # update the list of examples examples.append({ 'actions_nopos': actions_nopos, 'actions_pos': actions_pos, 'actions': actions, 'reference': target_str, 'src_str': src_str, 'sample_id': sample_id }) if not args.quiet: print('H-{}\t{}\t{}'.format(sample_id, hypo_str, hypo['score'])) print('P-{}\t{}'.format( sample_id, ' '.join( map( lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist(), )))) if args.print_alignment: print('A-{}\t{}'.format( sample_id, ' '.join( map(lambda x: str(utils.item(x)), alignment)))) # Score only the top hypothesis if has_target and j == 0: if align_dict is not None or args.remove_bpe is not None: # Convert back to tokens for evaluation with unk replacement and/or without BPE target_tokens = tgt_dict.encode_line( target_str, add_if_not_exist=False) # NOTE do not modify the tgt dictionary with 'add_if_not_exist=True'! if hasattr(scorer, 'add_string'): scorer.add_string(target_str, hypo_str) else: scorer.add(target_tokens, hypo_tokens) wps_meter.update(num_generated_tokens) t.log({'wps': round(wps_meter.avg)}) num_sentences += sample['nsentences'] # Save examples to files examples.save() print('| Error case (handled by manual fix) statistics:') print(error_stats) print( '| Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)' .format(num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) if has_target: print('| Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string())) return scorer
def parse_args_and_arch( parser: argparse.ArgumentParser, input_args: List[str] = None, parse_known: bool = False, suppress_defaults: bool = False, modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, ): """ Args: parser (ArgumentParser): the parser input_args (List[str]): strings to parse, defaults to sys.argv parse_known (bool): only parse known arguments, similar to `ArgumentParser.parse_known_args` suppress_defaults (bool): parse while ignoring all default values modify_parser (Optional[Callable[[ArgumentParser], None]]): function to modify the parser, e.g., to set default values """ if suppress_defaults: # Parse args without any default values. This requires us to parse # twice, once to identify all the necessary task/model args, and a second # time with all defaults set to None. args = parse_args_and_arch( parser, input_args=input_args, parse_known=parse_known, suppress_defaults=False, ) suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) args = suppressed_parser.parse_args(input_args) return argparse.Namespace( **{k: v for k, v in vars(args).items() if v is not None} ) from fairseq.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser.add_argument("--user-dir", default=None) usr_args, _ = usr_parser.parse_known_args(input_args) import_user_module(usr_args) if modify_parser is not None: modify_parser(parser) # The parser doesn't know about model/criterion/optimizer-specific args, so # we parse twice. First we parse the model/criterion/optimizer, then we # parse a second time after adding the *-specific arguments. # If input_args is given, we will parse those args instead of sys.argv. args, _ = parser.parse_known_args(input_args) # Add model-specific args to parser. if hasattr(args, "arch"): model_specific_group = parser.add_argument_group( "Model-specific configuration", # Only include attributes which are explicitly given as command-line # arguments or which have default values. argument_default=argparse.SUPPRESS, ) if args.arch in ARCH_MODEL_REGISTRY: ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) elif args.arch in MODEL_REGISTRY: MODEL_REGISTRY[args.arch].add_args(model_specific_group) else: raise RuntimeError() # Add *-specific args to parser. from fairseq.registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): choice = getattr(args, registry_name, None) if choice is not None: cls = REGISTRY["registry"][choice] if hasattr(cls, "add_args"): cls.add_args(parser) if hasattr(args, "task"): from fairseq.tasks import TASK_REGISTRY TASK_REGISTRY[args.task].add_args(parser) if getattr(args, "use_bmuf", False): # hack to support extra args for block distributed data parallelism from fairseq.optim.bmuf import FairseqBMUF FairseqBMUF.add_args(parser) # Modify the parser a second time, since defaults may have been reset if modify_parser is not None: modify_parser(parser) # Parse a second time. if parse_known: args, extra = parser.parse_known_args(input_args) else: args = parser.parse_args(input_args) extra = None # Post-process args. if ( hasattr(args, "batch_size_valid") and args.batch_size_valid is None ) or not hasattr(args, "batch_size_valid"): args.batch_size_valid = args.batch_size if hasattr(args, "max_tokens_valid") and args.max_tokens_valid is None: args.max_tokens_valid = args.max_tokens if getattr(args, "memory_efficient_fp16", False): args.fp16 = True if getattr(args, "memory_efficient_bf16", False): args.bf16 = True args.tpu = getattr(args, "tpu", False) args.bf16 = getattr(args, "bf16", False) if args.bf16: args.tpu = True if args.tpu and args.fp16: raise ValueError("Cannot combine --fp16 and --tpu, use --bf16 on TPUs") if getattr(args, "seed", None) is None: args.seed = 1 # default seed for training args.no_seed_provided = True else: args.no_seed_provided = False # Apply architecture configuration. if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: ARCH_CONFIG_REGISTRY[args.arch](args) if parse_known: return args, extra else: return args
def main(args): import_user_module(args) assert ( args.max_tokens is not None or args.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(args.seed) utils.set_torch_seed(args.seed) if distributed_utils.is_master(args): checkpoint_utils.verify_checkpoint_directory(args.save_dir) # Print args logger.info(args) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(args) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in args.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) # Build model and criterion model = task.build_model(args) criterion = task.build_criterion(args) logger.info(model) logger.info("task: {} ({})".format(args.task, task.__class__.__name__)) logger.info("model: {} ({})".format(args.arch, model.__class__.__name__)) logger.info("criterion: {} ({})".format(args.criterion, criterion.__class__.__name__)) logger.info("num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # breakpoint() # ========== initialize the model with pretrained BART parameters ========== # for shared embeddings and subtoken split for amr nodes if 'bartsv' in args.arch: if args.initialize_with_bart: logger.info( '-' * 10 + ' initializing model parameters with pretrained BART model ' + '-' * 10) new_state_dict = copy.deepcopy(task.bart.model.state_dict()) # treat the embedding initialization separately later, as the size different logger.info( '-' * 10 + ' delay encoder embeddings, decoder input and output embeddings initialization ' + '-' * 10) ignore_keys = set([ 'encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'decoder.output_projection.weight' ]) for k in ignore_keys: del new_state_dict[k] if not args.initialize_with_bart_enc: logger.info( '-' * 10 + ' do not initialize with BART encoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('encoder'): del new_state_dict[k] if not args.initialize_with_bart_dec: logger.info( '-' * 10 + ' do not initialize with BART decoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('decoder'): del new_state_dict[k] model.load_state_dict(new_state_dict, strict=False, args=args) # initialize the Bart part embeddings bart_vocab_size = task.target_dictionary.bart_vocab_size # NOTE we need to prune the pretrained BART embeddings, especially for bart.base bart_embed_weight = task.bart.model.encoder.embed_tokens.weight.data[: bart_vocab_size] assert len(bart_embed_weight) == bart_vocab_size with torch.no_grad(): model.encoder.embed_tokens.weight[:bart_vocab_size].copy_( bart_embed_weight) model.decoder.embed_tokens.weight[:bart_vocab_size].copy_( bart_embed_weight) model.decoder.output_projection.weight[:bart_vocab_size].copy_( bart_embed_weight) if args.bart_emb_init_composition: logger.info( '-' * 10 + ' initialize extended target embeddings with compositional embeddings ' 'from BART vocabulary ' + '-' * 10) # breakpoint() symbols = [ task.target_dictionary[idx] for idx in range(bart_vocab_size, len(task.target_dictionary)) ] mapper = MapAvgEmbeddingBART(task.bart, task.bart.model.decoder.embed_tokens) comp_embed_weight, map_all = mapper.map_avg_embeddings( symbols, transform=transform_action_symbol, add_noise=False) assert len(comp_embed_weight) == len(symbols) with torch.no_grad(): model.encoder.embed_tokens.weight[bart_vocab_size:].copy_( comp_embed_weight) model.decoder.embed_tokens.weight[bart_vocab_size:].copy_( comp_embed_weight) model.decoder.output_projection.weight[bart_vocab_size:].copy_( comp_embed_weight) elif 'bart' in args.arch: if args.initialize_with_bart: logger.info( '-' * 10 + ' initializing model parameters with pretrained BART model ' + '-' * 10) new_state_dict = copy.deepcopy(task.bart.model.state_dict()) if not args.bart_emb_decoder: logger.info('-' * 10 + ' build a separate decoder dictionary embedding ' + '-' * 10) if not args.bart_emb_decoder_input: ignore_keys = set([ 'decoder.embed_tokens.weight', 'decoder.output_projection.weight' ]) else: logger.info( '-' * 10 + ' use BART dictionary embedding for target input ' + '-' * 10) ignore_keys = set(['decoder.output_projection.weight']) for k in ignore_keys: del new_state_dict[k] if not args.initialize_with_bart_enc: logger.info( '-' * 10 + ' do not initialize with BART encoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('encoder'): del new_state_dict[k] if not args.initialize_with_bart_dec: logger.info( '-' * 10 + ' do not initialize with BART decoder parameters ' + '-' * 10) for k in list(new_state_dict.keys()): if k.startswith('decoder'): del new_state_dict[k] model.load_state_dict(new_state_dict, strict=False, args=args) # initialize the target embeddings with average of subtoken embeddings in BART vocabulary if args.bart_emb_init_composition: assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of BART vocabulary here' logger.info( '-' * 10 + ' initialize target embeddings with compositional embeddings from BART vocabulary ' + '-' * 10) composite_embed = CompositeEmbeddingBART( task.bart, task.bart.model.decoder.embed_tokens, task.target_dictionary) if args.bart_emb_decoder_input: # only initialize the decoder output embeddings with torch.no_grad(): model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: # initialize both the decoder input and output embeddings with torch.no_grad(): model.decoder.embed_tokens.weight.copy_( composite_embed.embedding_weight) model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) elif 'roberta' in args.arch: # initialize the target embeddings with average of subtoken embeddings in BART vocabulary if args.bart_emb_init_composition: assert not args.bart_emb_decoder, 'should not use the compositional embeddings on top of RoBERTa vocabulary here' logger.info( '-' * 10 + ' initialize target embeddings with compositional embeddings from RoBERTa vocabulary ' + '-' * 10) composite_embed = CompositeEmbeddingBART( task.bart, # NOTE here "bart" means roberta task.bart.model.encoder.sentence_encoder.embed_tokens, task.target_dictionary) if args.bart_emb_decoder_input: # only initialize the decoder output embeddings with torch.no_grad(): model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: # initialize both the decoder input and output embeddings with torch.no_grad(): model.decoder.embed_tokens.weight.copy_( composite_embed.embedding_weight) model.decoder.output_projection.weight.copy_( composite_embed.embedding_weight) else: raise ValueError # ========================================================================== # breakpoint() # (optionally) Configure quantization if args.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=args.quantization_config_path, max_epoch=args.max_epoch, max_update=args.max_update, ) else: quantizer = None # Build trainer if args.model_parallel_size == 1: trainer = Trainer(args, task, model, criterion, quantizer) else: trainer = MegatronTrainer(args, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( args.distributed_world_size)) logger.info( "max tokens per GPU = {} and max sentences per GPU = {}".format( args.max_tokens, args.batch_size)) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( args, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > args.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(args, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def main(args): import_user_module(args) print(args) # to control what preprocessing needs to be run (as they take both time and storage so we avoid running repeatedly) run_basic = True # this includes: # src: build src dictionary, copy the raw data to dir; build src binary data (need to refactor later if unneeded) # tgt: split target non-pointer actions and pointer values into separate files; build tgt dictionary run_act_states = True # this includes: # run the state machine reformer to get # a) training data: input and output, pointer values; # b) states information to facilitate modeling; # takes about 1 hour and 13G space on CCC run_roberta_emb = True # this includes: # for src sentences, use pre-trained RoBERTa model to extract contextual embeddings for each word; # takes about 10min for RoBERTa base and 30 mins for RoBERTa large and 2-3G space; # this needs GPU and only needs to run once for the English sentences, which does not change for different oracles; # thus the embeddings are stored separately from the oracles. if os.path.exists(args.destdir): print(f'binarized actions and states directory {args.destdir} already exists; not rerunning.') run_basic = False run_act_states = False if os.path.exists(args.embdir): print(f'pre-trained embedding directory {args.embdir} already exists; not rerunning.') run_roberta_emb = False os.makedirs(args.destdir, exist_ok=True) os.makedirs(args.embdir, exist_ok=True) target = not args.only_source task = tasks.get_task(args.task) # preprocess target actions files, to split '.actions' to '.actions_nopos' and '.actions_pos' # when building dictionary on the target actions sequences # split the action file into two files, one without arc pointer and one with only arc pointer values # and the dictionary is only built on the no pointer actions if run_basic: assert args.target_lang == 'actions', 'target extension must be "actions"' actions_files = [f'{pref}.{args.target_lang}' for pref in (args.trainpref, args.validpref, args.testpref)] task.split_actions_pointer_files(actions_files) args.target_lang_nopos = 'actions_nopos' # only build dictionary without pointer values args.target_lang_pos = 'actions_pos' # set tokenizer tokenize = task.tokenize if hasattr(task, 'tokenize') else tokenize_line def train_path(lang): return "{}{}".format(args.trainpref, ("." + lang) if lang else "") def file_name(prefix, lang): fname = prefix if lang is not None: fname += ".{lang}".format(lang=lang) return fname def dest_path(prefix, lang): return os.path.join(args.destdir, file_name(prefix, lang)) def dict_path(lang): return dest_path("dict", lang) + ".txt" def build_dictionary(filenames, src=False, tgt=False): assert src ^ tgt return task.build_dictionary( filenames, workers=args.workers, threshold=args.thresholdsrc if src else args.thresholdtgt, nwords=args.nwordssrc if src else args.nwordstgt, padding_factor=args.padding_factor, # tokenize separator is taken care inside task ) # build dictionary and save if run_basic: if not args.srcdict and os.path.exists(dict_path(args.source_lang)): raise FileExistsError(dict_path(args.source_lang)) if target and not args.tgtdict and os.path.exists(dict_path(args.target_lang)): raise FileExistsError(dict_path(args.target_lang)) if args.joined_dictionary: assert not args.srcdict or not args.tgtdict, \ "cannot use both --srcdict and --tgtdict with --joined-dictionary" if args.srcdict: src_dict = task.load_dictionary(args.srcdict) elif args.tgtdict: src_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary( {train_path(lang) for lang in [args.source_lang, args.target_lang]}, src=True ) tgt_dict = src_dict else: if args.srcdict: src_dict = task.load_dictionary(args.srcdict) else: assert args.trainpref, "--trainpref must be set if --srcdict is not specified" src_dict = build_dictionary([train_path(args.source_lang)], src=True) if target: if args.tgtdict: tgt_dict = task.load_dictionary(args.tgtdict) else: assert args.trainpref, "--trainpref must be set if --tgtdict is not specified" tgt_dict = build_dictionary([train_path(args.target_lang_nopos)], tgt=True) else: tgt_dict = None src_dict.save(dict_path(args.source_lang)) if target and tgt_dict is not None: tgt_dict.save(dict_path(args.target_lang_nopos)) # save binarized preprocessed files def make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers): print("| [{}] Dictionary: {} types".format(lang, len(vocab) - 1)) n_seq_tok = [0, 0] replaced = Counter() def merge_result(worker_result): replaced.update(worker_result["replaced"]) n_seq_tok[0] += worker_result["nseq"] n_seq_tok[1] += worker_result["ntok"] input_file = "{}{}".format( input_prefix, ("." + lang) if lang is not None else "" ) offsets = Binarizer.find_offsets(input_file, num_workers) pool = None if num_workers > 1: pool = Pool(processes=num_workers - 1) for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) pool.apply_async( binarize, ( args, input_file, vocab, prefix, lang, offsets[worker_id], offsets[worker_id + 1], False, # note here we shut off append eos tokenize ), callback=merge_result ) pool.close() ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl, vocab_size=len(vocab), dtype=np.int64) merge_result( Binarizer.binarize( input_file, vocab, lambda t: ds.add_item(t), offset=0, end=offsets[1], append_eos=False, tokenize=tokenize ) ) if num_workers > 1: pool.join() for worker_id in range(1, num_workers): prefix = "{}{}".format(output_prefix, worker_id) temp_file_path = dataset_dest_prefix(args, prefix, lang) ds.merge_file_(temp_file_path) os.remove(indexed_dataset.data_file_path(temp_file_path)) os.remove(indexed_dataset.index_file_path(temp_file_path)) ds.finalize(dataset_dest_file(args, output_prefix, lang, "idx")) print( "| [{}] {}: {} sents, {} tokens, {:.3}% replaced by {}".format( lang, input_file, n_seq_tok[0], n_seq_tok[1], 100 * sum(replaced.values()) / n_seq_tok[1], vocab.unk_word, ) ) def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1, dataset_impl=args.dataset_impl): if dataset_impl == "raw": # Copy original text file to destination folder output_text_file = dest_path( output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), lang, ) shutil.copyfile(file_name(input_prefix, lang), output_text_file) else: make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers) def make_all(lang, vocab, dataset_impl=args.dataset_impl): if args.trainpref: make_dataset(vocab, args.trainpref, "train", lang, num_workers=args.workers, dataset_impl=dataset_impl) if args.validpref: for k, validpref in enumerate(args.validpref.split(",")): outprefix = "valid{}".format(k) if k > 0 else "valid" make_dataset(vocab, validpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl) if args.testpref: for k, testpref in enumerate(args.testpref.split(",")): outprefix = "test{}".format(k) if k > 0 else "test" make_dataset(vocab, testpref, outprefix, lang, num_workers=args.workers, dataset_impl=dataset_impl) # NOTE we do not encode the source sentences with dictionary, as the source embeddings are directly provided # from RoBERTa, thus the source dictionary here is of no use if run_basic: make_all(args.source_lang, src_dict, dataset_impl='raw') make_all(args.source_lang, src_dict, dataset_impl='mmap') # above: just leave for the sake of model to run without too much change # NOTE there are <unk> in valid and test set for target actions # if target: # make_all(args.target_lang_nopos, tgt_dict) # NOTE targets (input, output, pointer values) are now all included in the state generation process # binarize pointer values and save to file # TODO make naming convention clearer # assume one training file, one validation file, and one test file # for pos_file, split in [(f'{pref}.actions_pos', split) for pref, split in # [(args.trainpref, 'train'), (args.validpref, 'valid'), (args.testpref, 'test')]]: # out_pref = os.path.join(args.destdir, split) # task.binarize_actions_pointer_file(pos_file, out_pref) # save action states information to assist training with auxiliary info # assume one training file, one validation file, and one test file if run_act_states: task_obj = task(args, tgt_dict=tgt_dict) for prefix, split in zip([args.trainpref, args.validpref, args.testpref], ['train', 'valid', 'test']): en_file = prefix + '.en' actions_file = prefix + '.actions' out_file_pref = os.path.join(args.destdir, split) task_obj.build_actions_states_info(en_file, actions_file, out_file_pref, num_workers=args.workers) # save RoBERTa embeddings # TODO refactor this code if run_roberta_emb: make_roberta_embeddings(args, tokenize=tokenize) print("| Wrote preprocessed oracle data to {}".format(args.destdir)) print("| Wrote preprocessed embedding data to {}".format(args.embdir))
def get_parser(desc, default_task='translation'): # Before creating the true parser, we need to import optional user module # in order to eagerly import custom tasks, optimizers, architectures, etc. usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) usr_parser.add_argument('--user-dir', default=None) usr_args, _ = usr_parser.parse_known_args() import_user_module(usr_args) parser = argparse.ArgumentParser(allow_abbrev=False) # fmt: off parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') parser.add_argument( '--log-interval', type=int, default=1000, metavar='N', help='log progress every N batches (when progress bar is disabled)') parser.add_argument('--log-format', default=None, help='log format to use', choices=['json', 'none', 'simple', 'tqdm']) parser.add_argument( '--tensorboard-logdir', metavar='DIR', default='', help='path to save logs for tensorboard, should match --logdir ' 'of running tensorboard (default: no tensorboard logging)') parser.add_argument("--tbmf-wrapper", action="store_true", help="[FB only] ") parser.add_argument('--seed', default=1, type=int, metavar='N', help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument( '--memory-efficient-fp16', action='store_true', help='use a memory-efficient version of FP16 training; implies --fp16') parser.add_argument('--fp16-init-scale', default=2**7, type=int, help='default FP16 loss scale') parser.add_argument('--fp16-scale-window', type=int, help='number of updates before increasing loss scale') parser.add_argument( '--fp16-scale-tolerance', default=0.0, type=float, help='pct of updates that can overflow before decreasing the loss scale' ) parser.add_argument( '--min-loss-scale', default=1e-4, type=float, metavar='D', help='minimum FP16 loss scale, after which training is stopped') parser.add_argument('--threshold-loss-scale', type=float, help='threshold FP16 loss scale from below') parser.add_argument( '--user-dir', default=None, help= 'path to a python module containing custom extensions (tasks and/or architectures)' ) parser.add_argument('--profile', default=False, action='store_true', help='enable autograd profiler emit_nvtx') parser.add_argument('--model-parallel-size', default=1, type=int, help='total number of GPUs to parallelize model over') parser.add_argument('--quantization-config-path', default=None, help='path to quantization config file') parser.add_argument('--bf16', default=False, action='store_true', help='use bfloat16; implies --tpu') from fairseq.registry import REGISTRIES for registry_name, REGISTRY in REGISTRIES.items(): parser.add_argument( '--' + registry_name.replace('_', '-'), default=REGISTRY['default'], choices=REGISTRY['registry'].keys(), ) # Task definitions can be found under fairseq/tasks/ from fairseq.tasks import TASK_REGISTRY parser.add_argument('--task', metavar='TASK', default=default_task, choices=TASK_REGISTRY.keys(), help='task') # fmt: on return parser