def load_state_dict( self, state_dict, strict=True, model_cfg: Optional[DictConfig] = None, args: Optional[Namespace] = None, ): """Copies parameters and buffers from *state_dict* into this module and its descendants. Overrides the method in :class:`nn.Module`. Compared with that method this additionally "upgrades" *state_dicts* from old checkpoints. """ if model_cfg is None and args is not None: logger.warn("using 'args' is deprecated, please update your code to use dataclass config") model_cfg = convert_namespace_to_omegaconf(args).model self.upgrade_state_dict(state_dict) new_state_dict = prune_state_dict(state_dict, model_cfg) return super().load_state_dict(new_state_dict, strict)
def load_model_ensemble_and_task( filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 ): from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] for filename in filenames: orig_filename = filename for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") else: filename = orig_filename[:-3] + f"_part{shard_idx}.pt" if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) # build model for ensemble model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) ensemble.append(model) return ensemble, cfg, task
def setUp(self, cfg): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) self.task = tasks.setup_task(cfg.task) self.tgt_dict = self.task.target_dictionary # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _ = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides={}, task=self.task, suffix=cfg.checkpoint.checkpoint_suffix, strict=False, num_shards=cfg.checkpoint.checkpoint_shard_count, ) if len(models) > 1: raise Exception( "Currently loading multiple models is not supported") self.model = models[0] # Optimize model for generation if cfg.common.fp16: self.model.half() if self.use_cuda: self.model.cuda() self.model.prepare_for_inference_(cfg) self.generator = self.task.build_generator( [self.model], cfg.generation, extra_gen_cls_kwargs={}, ) # Handle tokenization and BPE self.tokenizer = self.task.build_tokenizer(cfg.tokenizer) self.bpe = self.task.build_bpe(cfg.bpe) self.remove_bpe = cfg.common_eval.post_process
def __init__(self, n_class=30, encoder_hidden_dim=768, w2v_sd=None): super(KWS, self).__init__() self.n_class = n_class cfg = convert_namespace_to_omegaconf(w2v_sd['args']) task = tasks.setup_task(cfg.task) state_dict = w2v_sd['model'] assert not cfg is None assert not state_dict is None self.w2v_encoder = task.build_model(cfg.model) self.w2v_encoder.load_state_dict(state_dict) out_channels = 112 self.decoder = nn.Sequential( nn.Conv1d(encoder_hidden_dim, out_channels, 25, dilation=2), nn.BatchNorm1d(out_channels), nn.ReLU(), nn.Conv1d(out_channels, out_channels, 1), nn.BatchNorm1d(out_channels), nn.ReLU(), nn.Conv1d(out_channels, self.n_class, 1)) self.softmax = nn.Softmax(dim=-1)
def main(cfg: DictConfig): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) assert cfg.common_eval.path is not None, "--path required for generation!" assert (not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert (cfg.generation.replace_unk is None or cfg.dataset.dataset_impl == "raw" ), "--replace-unk requires a raw text dataset (--dataset-impl=raw)" if cfg.common_eval.results_path is not None: os.makedirs(cfg.common_eval.results_path, exist_ok=True) output_path = os.path.join( cfg.common_eval.results_path, "generate-{}.txt".format(cfg.dataset.gen_subset), ) with open(output_path, "w", buffering=1, encoding="utf-8") as h: return _main(cfg, h) else: return _main(cfg, sys.stdout)
def distributed_init(cfg: FairseqConfig): if isinstance(cfg, Namespace): from fairseq.dataclass.utils import convert_namespace_to_omegaconf cfg = convert_namespace_to_omegaconf(cfg) if not cfg.common.tpu: if torch.distributed.is_available( ) and torch.distributed.is_initialized(): warnings.warn( "Distributed is already initialized, cannot initialize twice!") else: logger.info("distributed init (rank {}): {}".format( cfg.distributed_training.distributed_rank, cfg.distributed_training.distributed_init_method, )) dist.init_process_group( backend=cfg.distributed_training.distributed_backend, init_method=cfg.distributed_training.distributed_init_method, world_size=cfg.distributed_training.distributed_world_size, rank=cfg.distributed_training.distributed_rank, ) logger.info("initialized host {} as rank {}".format( socket.gethostname(), cfg.distributed_training.distributed_rank, )) # perform a dummy all-reduce to initialize the NCCL communicator if torch.cuda.is_available(): dist.all_reduce(torch.zeros(1).cuda()) cfg.distributed_training.distributed_rank = torch.distributed.get_rank( ) else: assert xm.xrt_world_size( ) == cfg.distributed_training.distributed_world_size global _USE_XLA _USE_XLA = True cfg.distributed_training.device_id = xm.get_local_ordinal() cfg.distributed_training.distributed_rank = xm.get_ordinal() xm.rendezvous("distributed_init") # wait for all workers xm.mark_step() if is_master(cfg.distributed_training): logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) if cfg.common.model_parallel_size > 1: try: from fairseq.model_parallel.megatron.mpu import ( initialize_model_parallel, model_parallel_cuda_manual_seed, ) except ImportError: raise ImportError("\n\nPlease install the megatron submodule:" "\n\n git submodule update --init " "fairseq/model_parallel/megatron") global _USE_MEGATRON _USE_MEGATRON = True initialize_model_parallel(cfg.common.model_parallel_size) model_parallel_cuda_manual_seed(cfg.common.seed) model_part_number = get_model_parallel_rank() cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format( model_part_number) return cfg.distributed_training.distributed_rank
def main(cfg: DictConfig, override_args=None): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if cfg.distributed_training.distributed_world_size > 1: data_parallel_world_size = distributed_utils.get_data_parallel_world_size( ) data_parallel_rank = distributed_utils.get_data_parallel_rank() else: data_parallel_world_size = 1 data_parallel_rank = 0 if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, ) model = models[0] # Move models to GPU for model in models: if use_fp16: model.half() if use_cuda: model.cuda() # Print args logger.info(saved_cfg) # Build criterion criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) # Initialize data iterator itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( task.max_positions(), *[m.max_positions() for m in models], ), ignore_invalid_inputs=cfg.dataset. skip_invalid_size_inputs_valid_test, required_batch_size_multiple=cfg.dataset. required_batch_size_multiple, seed=cfg.common.seed, num_shards=data_parallel_world_size, shard_id=data_parallel_rank, num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, prefix=f"valid on '{subset}' subset", default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) log_outputs = [] for i, sample in enumerate(progress): sample = utils.move_to_cuda(sample) if use_cuda else sample _loss, _sample_size, log_output = task.valid_step( sample, model, criterion) progress.log(log_output, step=i) log_outputs.append(log_output) if data_parallel_world_size > 1: log_outputs = distributed_utils.all_gather_list( log_outputs, max_size=cfg.common.all_gather_list_size, group=distributed_utils.get_data_parallel_group(), ) log_outputs = list(chain.from_iterable(log_outputs)) with metrics.aggregate() as agg: task.reduce_metrics(log_outputs, criterion) log_output = agg.get_smoothed_values() progress.print(log_output, tag=subset, step=i)
def main(cfg: DictConfig, override_args=None, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) if override_args is not None: overrides = vars(override_args) overrides.update(eval(getattr(override_args, "model_overrides", "{}"))) else: overrides = None logger.info(cfg) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Load dataset splits gen_subset = cfg.dataset.gen_subset task.load_dataset(gen_subset) dataset = task.dataset(gen_subset) if cfg.eval_lm.context_window > 0: dataset = LMContextWindowDataset( dataset=dataset, tokens_per_sample=cfg.task.tokens_per_sample, context_window=cfg.eval_lm.context_window, pad_idx=task.source_dictionary.pad(), ) logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info( "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) ) itr = task.get_batch_iterator( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, max_sentences=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), ignore_invalid_inputs=True, num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, ).next_epoch_itr(shuffle=False) progress = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) gen_timer = StopwatchMeter() scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) score_sum = 0.0 count = 0 if cfg.common_eval.remove_bpe is not None: if cfg.common_eval.remove_bpe == "sentencepiece": raise NotImplementedError else: bpe_cont = cfg.common_eval.remove_bpe.rstrip() bpe_toks = { i for i in range(len(task.source_dictionary)) if task.source_dictionary[i].endswith(bpe_cont) } bpe_len = len(bpe_cont) else: bpe_toks = None bpe_len = 0 word_stats = dict() wps_meter = TimeMeter() for sample in progress: if "net_input" not in sample: continue sample = utils.move_to_cuda(sample) if use_cuda else sample gen_timer.start() hypos = scorer.generate(models, sample) gen_timer.stop(sample["ntokens"]) for i, hypos_i in enumerate(hypos): hypo = hypos_i[0] sample_id = sample["id"][i] tokens = hypo["tokens"] tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() if cfg.task.add_bos_token: assert hypo["tokens"][0].item() == task.target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] skipped_toks = 0 if bpe_toks is not None: for i in range(tgt_len - 1): if tokens[i].item() in bpe_toks: skipped_toks += 1 pos_scores[i + 1] += pos_scores[i] pos_scores[i] = 0 inf_scores = pos_scores.eq(float("inf")) | pos_scores.eq(float("-inf")) if inf_scores.any(): logger.info( "skipping tokens with inf scores:", task.target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() w += task.source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True else: word_prob.append((w, pos_scores[i].item())) next_prob = None ind = i + 1 while ind < len(tokens): if pos_scores[ind].item() != 0: next_prob = pos_scores[ind] break ind += 1 word_stats.setdefault(w, WordStat(w, is_bpe)).add( pos_scores[i].item(), next_prob ) is_bpe = False w = "" if cfg.eval_lm.output_word_probs: logger.info( str(int(sample_id)) + " " + ( "\t".join( "{} [{:2f}]".format(x[0], x[1]) for x in word_prob ) ) ) wps_meter.update(sample["ntokens"]) progress.log({"wps": round(wps_meter.avg)}) avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg ) ) logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( avg_nll_loss, 2 ** avg_nll_loss ) ) if cfg.eval_lm.output_word_stats: for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): logger.info(ws)
def test_multilingual_denoising(self): with TemporaryDirectory() as dirname: # prep input file lang_dir = os.path.join(dirname, "en") os.mkdir(lang_dir) raw_file = os.path.join(lang_dir, "raw") data = make_data(out_file=raw_file) vocab = build_vocab(data) # binarize binarizer = VocabularyDatasetBinarizer(vocab, append_eos=False) split = "train" bin_file = os.path.join(lang_dir, split) dataset_impl = "mmap" FileBinarizer.multiprocess_dataset( input_file=raw_file, binarizer=binarizer, dataset_impl=dataset_impl, vocab_size=len(vocab), output_prefix=bin_file, ) # setup task train_args = options.parse_args_and_arch( options.get_training_parser(), [ "--task", "multilingual_denoising", "--arch", "bart_base", "--seed", "42", "--mask-length", "word", "--permute-sentences", "1", "--rotate", "0", "--replace-length", "-1", "--mask", "0.2", dirname, ], ) cfg = convert_namespace_to_omegaconf(train_args) task = MultilingualDenoisingTask(cfg.task, binarizer.dict) # load datasets original_dataset = task._load_dataset_split(bin_file, 1, False) task.load_dataset(split) masked_dataset = task.dataset(split) iterator = task.get_batch_iterator( dataset=masked_dataset, max_tokens=65_536, max_positions=4_096, ).next_epoch_itr(shuffle=False) mask_index = task.source_dictionary.index("<mask>") for batch in iterator: for sample in range(len(batch)): net_input = batch["net_input"] masked_src_tokens = net_input["src_tokens"][sample] masked_src_length = net_input["src_lengths"][sample] masked_tgt_tokens = batch["target"][sample] sample_id = batch["id"][sample] original_tokens = original_dataset[sample_id] original_tokens = original_tokens.masked_select( masked_src_tokens[:masked_src_length] == mask_index) masked_tokens = masked_tgt_tokens.masked_select( masked_src_tokens == mask_index) assert masked_tokens.equal(original_tokens)
def __init__(self, cfg: DecoderConfig, tgt_dict: Dictionary) -> None: super().__init__(cfg, tgt_dict) self.lexicon = load_words(cfg.lexicon) if cfg.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(cfg.lmpath, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(cfg.lmpath) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unitlm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [ tgt_dict.index(token) for token in spelling ] assert tgt_dict.unk() not in spelling_idxs, \ f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, word_score=cfg.wordscore, unk_score=cfg.unkweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) if self.asgtransitions is None: self.asgtransitions = [] self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, self.asgtransitions, self.unitlm, ) else: assert self.unitlm, "Lexicon-free decoding requires unit LM" d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(cfg.lmpath, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=cfg.beam, beam_size_token=cfg.beamsizetoken or len(tgt_dict), beam_threshold=cfg.beamthreshold, lm_weight=cfg.lmweight, sil_score=cfg.silweight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder(self.decoder_opts, self.lm, self.silence, self.blank, [])
def main(cfg: FairseqConfig): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) start_time = time.time() total_translate_time = 0 utils.import_user_module(cfg.common) if cfg.interactive.buffer_size < 1: cfg.interactive.buffer_size = 1 if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.batch_size = 1 assert (not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert (not cfg.dataset.batch_size or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" logger.info(cfg) # Fix seed for stochastic decoding if cfg.common.seed is not None and not cfg.generation.no_seed_provided: np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # Load ensemble overrides = ast.literal_eval(cfg.common_eval.model_overrides) logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Initialize generator generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) bpe = encoders.build_bpe(cfg.bpe) def encode_fn(x): if tokenizer is not None: x = tokenizer.encode(x) if bpe is not None: x = bpe.encode(x) return x def decode_fn(x): if bpe is not None: x = bpe.decode(x) if tokenizer is not None: x = tokenizer.decode(x) return x # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]) if cfg.generation.constraints: logger.warning( "NOTE: Constrained decoding currently assumes a shared subword vocabulary." ) if cfg.interactive.buffer_size > 1: logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) logger.info("NOTE: hypothesis and token scores are output in base 2") logger.info("Type the input sentence and press return:") start_id = 0 for inputs in buffered_read(cfg.interactive.input, cfg.interactive.buffer_size): results = [] for batch in make_batches(inputs, cfg, task, max_positions, encode_fn): bsz = batch.src_tokens.size(0) src_tokens = batch.src_tokens src_lengths = batch.src_lengths constraints = batch.constraints if use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() if constraints is not None: constraints = constraints.cuda() sample = { "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, }, } translate_start_time = time.time() translations = task.inference_step(generator, models, sample, constraints=constraints) translate_time = time.time() - translate_start_time total_translate_time += translate_time list_constraints = [[] for _ in range(bsz)] if cfg.generation.constraints: list_constraints = [unpack_constraints(c) for c in constraints] for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) constraints = list_constraints[i] results.append(( start_id + id, src_tokens_i, hypos, { "constraints": constraints, "time": translate_time / len(translations), }, )) # sort output to match input order for id_, src_tokens, hypos, info in sorted(results, key=lambda x: x[0]): if src_dict is not None: src_str = src_dict.string(src_tokens, cfg.common_eval.post_process) print("S-{}\t{}".format(id_, src_str)) print("W-{}\t{:.3f}\tseconds".format(id_, info["time"])) for constraint in info["constraints"]: print("C-{}\t{}".format( id_, tgt_dict.string(constraint, cfg.common_eval.post_process))) # Process top predictions for hypo in hypos[:min(len(hypos), cfg.generation.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo["tokens"].int().cpu(), src_str=src_str, alignment=hypo["alignment"], align_dict=align_dict, tgt_dict=tgt_dict, remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=get_symbols_to_strip_from_output( generator), ) detok_hypo_str = decode_fn(hypo_str) score = hypo["score"] / math.log(2) # convert to base 2 # original hypothesis (after tokenization and BPE) print("H-{}\t{}\t{}".format(id_, score, hypo_str)) # detokenized hypothesis print("D-{}\t{}\t{}".format(id_, score, detok_hypo_str)) print("P-{}\t{}".format( id_, " ".join( map( lambda x: "{:.4f}".format(x), # convert from base e to base 2 hypo["positional_scores"].div_(math.log(2) ).tolist(), )), )) if cfg.generation.print_alignment: alignment_str = " ".join( ["{}-{}".format(src, tgt) for src, tgt in alignment]) print("A-{}\t{}".format(id_, alignment_str)) # update running id_ counter start_id += len(inputs) logger.info("Total time: {:.3f} seconds; translation time: {:.3f}".format( time.time() - start_time, total_translate_time))
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates( state["optimizer_history"][-1]["num_updates"] ) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint # support old external tasks argspec = inspect.getfullargspec(task.build_model) if "from_checkpoint" in argspec.args: model = task.build_model(cfg.model, from_checkpoint=True) else: model = task.build_model(cfg.model) if ( "optimizer_history" in state and len(state["optimizer_history"]) > 0 and "num_updates" in state["optimizer_history"][-1] ): model.set_num_updates(state["optimizer_history"][-1]["num_updates"]) model.load_state_dict( state["model"], strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task
def main(cfg: DictConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {})".format(criterion.__class__.__name__)) logger.info("num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum))
def __init__(self, cfg: HubertAsrConfig, tgt_dict=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for " "both pre-training and here") w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) if state is not None and "task_state" in state: # This will load the stored "dictionaries" object task.load_state_dict(state["task_state"]) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: # set strict=False because we omit some modules model.load_state_dict(state["model"], strict=False) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None
def __init__(self, args, tgt_dict): super().__init__(args, tgt_dict) self.unit_lm = getattr(args, "unit_lm", False) self.lexicon = load_words(args.lexicon) if args.lexicon else None self.idx_to_wrd = {} checkpoint = torch.load(args.kenlm_model, map_location="cpu") if "cfg" in checkpoint and checkpoint["cfg"] is not None: lm_args = checkpoint["cfg"] else: lm_args = convert_namespace_to_omegaconf(checkpoint["args"]) with open_dict(lm_args.task): lm_args.task.data = osp.dirname(args.kenlm_model) task = tasks.setup_task(lm_args.task) model = task.build_model(lm_args.model) model.load_state_dict(checkpoint["model"], strict=False) self.trie = Trie(self.vocab_size, self.silence) self.word_dict = task.dictionary self.unk_word = self.word_dict.unk() self.lm = FairseqLM(self.word_dict, model) if self.lexicon: start_state = self.lm.start(False) for i, (word, spellings) in enumerate(self.lexicon.items()): if self.unit_lm: word_idx = i self.idx_to_wrd[i] = word score = 0 else: word_idx = self.word_dict.index(word) _, score = self.lm.score(start_state, word_idx, no_cache=True) for spelling in spellings: spelling_idxs = [tgt_dict.index(token) for token in spelling] assert ( tgt_dict.unk() not in spelling_idxs ), f"{spelling} {spelling_idxs}" self.trie.insert(spelling_idxs, word_idx, score) self.trie.smear(SmearingMode.MAX) self.decoder_opts = LexiconDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, word_score=args.word_score, unk_score=args.unk_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconDecoder( self.decoder_opts, self.trie, self.lm, self.silence, self.blank, self.unk_word, [], self.unit_lm, ) else: assert args.unit_lm, "lexicon free decoding can only be done with a unit language model" from flashlight.lib.text.decoder import LexiconFreeDecoder, LexiconFreeDecoderOptions d = {w: [[w]] for w in tgt_dict.symbols} self.word_dict = create_word_dict(d) self.lm = KenLM(args.kenlm_model, self.word_dict) self.decoder_opts = LexiconFreeDecoderOptions( beam_size=args.beam, beam_size_token=int(getattr(args, "beam_size_token", len(tgt_dict))), beam_threshold=args.beam_threshold, lm_weight=args.lm_weight, sil_score=args.sil_weight, log_add=False, criterion_type=self.criterion_type, ) self.decoder = LexiconFreeDecoder( self.decoder_opts, self.lm, self.silence, self.blank, [] )
def __init__(self): parser = options.get_interactive_generation_parser() args = options.parse_args_and_arch(parser) cfg = convert_namespace_to_omegaconf(args) utils.import_user_module(cfg.common) if cfg.interactive.buffer_size < 1: cfg.interactive.buffer_size = 1 if cfg.dataset.max_tokens is None and cfg.dataset.batch_size is None: cfg.dataset.batch_size = 1 assert (not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam ), "--sampling requires --nbest to be equal to --beam" assert (not cfg.dataset.batch_size or cfg.dataset.batch_size <= cfg.interactive.buffer_size ), "--batch-size cannot be larger than --buffer-size" use_cuda = torch.cuda.is_available() and not cfg.common.cpu # Setup task, e.g., translation task = tasks.setup_task(cfg.task) # Load ensemble models, _model_args = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), task=task, suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, ) # Set dictionaries src_dict = task.source_dictionary tgt_dict = task.target_dictionary # Optimize ensemble for generation for model in models: if model is None: continue if cfg.common.fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) # Initialize generator generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) bpe = encoders.build_bpe(cfg.bpe) # Load alignment dictionary for unknown word replacement # (None if no unknown word replacement, empty if no path to align dictionary) align_dict = utils.load_align_dict(cfg.generation.replace_unk) max_positions = utils.resolve_max_positions( task.max_positions(), *[model.max_positions() for model in models]) if cfg.interactive.buffer_size > 1: logger.info("Sentence buffer size: %s", cfg.interactive.buffer_size) self.context = { 'bpe': bpe, 'tokenizer': tokenizer, 'cfg': cfg, 'task': task, 'max_positions': max_positions, 'use_cuda': use_cuda, 'generator': generator, 'models': models, 'src_dict': src_dict, 'tgt_dict': tgt_dict, 'align_dict': align_dict, }
def cli_main(): parser = options.get_eval_lm_parser() args = options.parse_args_and_arch(parser) distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
def main(cfg: DictConfig, **unused_kwargs): if isinstance(cfg, Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) logger.info(cfg) if cfg.eval_lm.context_window > 0: # reduce tokens per sample by the required context window size cfg.task.tokens_per_sample -= cfg.eval_lm.context_window # Initialize the task using the current *cfg* task = tasks.setup_task(cfg.task) # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=eval(cfg.common_eval.model_overrides), suffix=cfg.checkpoint.checkpoint_suffix, strict=(cfg.checkpoint.checkpoint_shard_count == 1), num_shards=cfg.checkpoint.checkpoint_shard_count, task=task, ) use_fp16 = cfg.common.fp16 use_cuda = torch.cuda.is_available() and not cfg.common.cpu if use_cuda: torch.cuda.set_device(cfg.distributed_training.device_id) # Optimize ensemble for generation and set the source and dest dicts on the model # (required by scorer) for model in models: if use_fp16: model.half() if use_cuda and not cfg.distributed_training.pipeline_model_parallel: model.cuda() model.prepare_for_inference_(cfg) assert len(models) > 0 logger.info( "num. model params: {:,}".format(sum(p.numel() for p in models[0].parameters())) ) # Load dataset splits task.load_dataset(cfg.dataset.gen_subset) dataset = task.dataset(cfg.dataset.gen_subset) logger.info( "{} {} {:,} examples".format( cfg.task.data, cfg.dataset.gen_subset, len(dataset) ) ) itr = task.eval_lm_dataloader( dataset=dataset, max_tokens=cfg.dataset.max_tokens or 36000, batch_size=cfg.dataset.batch_size, max_positions=utils.resolve_max_positions( *[model.max_positions() for model in models] ), num_shards=max( cfg.dataset.num_shards, cfg.distributed_training.distributed_world_size, ), shard_id=max( cfg.dataset.shard_id, cfg.distributed_training.distributed_rank, ), num_workers=cfg.dataset.num_workers, data_buffer_size=cfg.dataset.data_buffer_size, context_window=cfg.eval_lm.context_window, ) itr = progress_bar.progress_bar( itr, log_format=cfg.common.log_format, log_interval=cfg.common.log_interval, default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), ) results = eval_lm( models=models, source_dictionary=task.source_dictionary, batch_iterator=itr, post_process=cfg.common_eval.post_process, output_word_probs=cfg.eval_lm.output_word_probs, output_word_stats=cfg.eval_lm.output_word_stats, target_dictionary=task.target_dictionary, softmax_batch=cfg.eval_lm.softmax_batch, remove_bos_token=getattr(cfg.task, "add_bos_token", False), ) logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( results["loss"], results["perplexity"] ) ) return results
def main(rank, world_size, args): start = time.time() if world_size > 1: torch.distributed.init_process_group(backend="gloo", init_method="env://", world_size=world_size, rank=rank) torch.cuda.set_device(rank % torch.cuda.device_count()) raw_args = args args = convert_namespace_to_omegaconf(args) if args.common.seed is not None: np.random.seed(args.common.seed) utils.set_torch_seed(args.common.seed) models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( [raw_args.path], arg_overrides={"data": args.task.data}) tgt_dict = task.target_dictionary for model in models: model.prepare_for_inference_(args) model.cuda().eval() if raw_args.fp16: model = model.half() model = models[0] config = ExpressiveCodeDataConfig(args.task.data) dataset = CodeDataset( manifest=config.manifests[raw_args.eval_subset], dictionary=task.source_dictionary, dur_dictionary=task.source_duration_dictionary, f0_dictionary=task.source_f0_dictionary, config=config, discrete_dur=task.cfg.discrete_duration, discrete_f0=task.cfg.discrete_f0, log_f0=task.cfg.log_f0, normalize_f0_mean=task.cfg.normalize_f0_mean, normalize_f0_std=task.cfg.normalize_f0_std, interpolate_f0=task.cfg.interpolate_f0, shifts=task.cfg.stream_shifts, return_filename=True, strip_filename=False, return_continuous_f0=raw_args.dequantize_prosody, ) if raw_args.filter_names: dataset = FilterNamesDataset(dataset, raw_args.filter_names) criterion = task.build_criterion(model_args.criterion) name2metric = { "continuation": continuation, "teacher_force_everything": teacher_force_everything, "correlation": correlation, } name2keys = { "continuation": ( "Token BLEU3", "Duration NLL", "Duration MAE", "F0 NLL", "F0 MAE", "F0 sum", "F0 sum_sq", "Dur sum", "Dur sum_sq", ), "teacher_force_everything": ("token_loss", "duration_loss", "f0_loss"), "correlation": ("Duration corr", "F0 corr"), } metric_name = raw_args.metric metric = name2metric[metric_name] results = metric(raw_args, dataset, model, criterion, tgt_dict, rank, world_size) values = None if metric_name not in [ "correlation", ]: values, normalizers = results values = maybe_aggregate_normalize(values, normalizers, world_size) elif metric_name == "correlation": values = maybe_aggregate_correlations(results, world_size) else: assert False assert values is not None summary = dict(zip(name2keys[raw_args.metric], values.tolist())) if metric_name == "continuation": summary["F0 Std"] = np.sqrt(-summary["F0 sum"]**2 + summary["F0 sum_sq"]) summary["Dur Std"] = np.sqrt(-summary["Dur sum"]**2 + summary["Dur sum_sq"]) del summary["F0 sum"] del summary["F0 sum_sq"] del summary["Dur sum"] del summary["Dur sum_sq"] summary["metric"] = metric_name if rank == 0: print(summary) if raw_args.wandb: wandb_results(summary, raw_args) print("# finished in ", time.time() - start, "seconds")
def main(cfg: FairseqConfig) -> None: if isinstance(cfg, argparse.Namespace): cfg = convert_namespace_to_omegaconf(cfg) utils.import_user_module(cfg.common) if is_master(cfg.distributed_training) and "job_logging_cfg" in cfg: # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126) logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg)) assert ( cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None ), "Must specify batch size either with --max-tokens or --batch-size" metrics.reset() np.random.seed(cfg.common.seed) utils.set_torch_seed(cfg.common.seed) if distributed_utils.is_master(cfg.distributed_training): checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir) # Print args logger.info(cfg) if cfg.checkpoint.write_checkpoints_asynchronously: try: import iopath # noqa: F401 except ImportError: logging.exception( "Asynchronous checkpoint writing is specified but iopath is " "not installed: `pip install iopath`") return # Setup task, e.g., translation, language modeling, etc. task = tasks.setup_task(cfg.task) # Load valid dataset (we load training data below, based on the latest checkpoint) for valid_sub_split in cfg.dataset.valid_subset.split(","): task.load_dataset(valid_sub_split, combine=False, epoch=1) assert cfg.criterion, "Please specify criterion to train a model" # Build model and criterion model = task.build_model(cfg.model) criterion = task.build_criterion(cfg.criterion) logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info("num. model params: {:,} (num. trained: {:,})".format( sum(p.numel() for p in model.parameters()), sum(p.numel() for p in model.parameters() if p.requires_grad), )) # (optionally) Configure quantization if cfg.common.quantization_config_path is not None: quantizer = quantization_utils.Quantizer( config_path=cfg.common.quantization_config_path, max_epoch=cfg.optimization.max_epoch, max_update=cfg.optimization.max_update, ) else: quantizer = None # Build trainer if cfg.common.model_parallel_size == 1: trainer = Trainer(cfg, task, model, criterion, quantizer) else: trainer = MegatronTrainer(cfg, task, model, criterion) logger.info("training on {} devices (GPUs/TPUs)".format( cfg.distributed_training.distributed_world_size)) logger.info("max tokens per GPU = {} and batch size per GPU = {}".format( cfg.dataset.max_tokens, cfg.dataset.batch_size, )) # Load the latest checkpoint if one is available and restore the # corresponding train iterator extra_state, epoch_itr = checkpoint_utils.load_checkpoint( cfg.checkpoint, trainer, # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) max_epoch = cfg.optimization.max_epoch or math.inf lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() while epoch_itr.next_epoch_idx <= max_epoch: if lr <= cfg.optimization.stop_min_lr: logger.info( f"stopping training because current learning rate ({lr}) is smaller " "than or equal to minimum learning rate " f"(--stop-min-lr={cfg.optimization.stop_min_lr})") break # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: break # only use first validation loss to update the learning rate lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0]) epoch_itr = trainer.get_train_iterator( epoch_itr.next_epoch_idx, # sharded data: get train iterator for next epoch load_dataset=task.has_sharded_data("train"), # don't cache epoch iterators for sharded datasets disable_iterator_cache=task.has_sharded_data("train"), ) train_meter.stop() logger.info("done training in {:.1f} seconds".format(train_meter.sum)) # ioPath implementation to wait for all asynchronous file writes to complete. if cfg.checkpoint.write_checkpoints_asynchronously: logger.info( "ioPath PathManager waiting for all asynchronous checkpoint " "writes to finish.") PathManager.async_close() logger.info("ioPath PathManager finished waiting.")
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if isinstance(cfg, Namespace): logger.warning( "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" ) cfg = convert_namespace_to_omegaconf(cfg) self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) self.tpu = cfg.common.tpu self.cuda = torch.cuda.is_available( ) and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") # copy model and criterion to current device/dtype self._criterion = criterion self._model = model if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) if not cfg.distributed_training.pipeline_model_parallel: self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: self.last_device = torch.device( cfg.distributed_training.pipeline_devices[-1]) # check that shared parameters are preserved after device transfer for shared_param in shared_params: ref = _get_module_by_path(self._model, shared_param[0]) for path in shared_param[1:]: logger.info("detected shared parameter: {} <- {}".format( shared_param[0], path)) _set_module_by_path(self._model, path, ref) self._dummy_batch = None # indicates we don't have a dummy batch at first self._lr_scheduler = None self._num_updates = 0 self._num_xla_compiles = 0 # for TPUs self._optim_history = None self._optimizer = None self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: self._grad_norm_buf = torch.cuda.DoubleTensor( self.data_parallel_world_size) else: self._grad_norm_buf = None self.quantizer = quantizer if self.quantizer is not None: self.quantizer.set_trainer(self) # get detailed cuda environment if self.cuda: self.cuda_env = utils.CudaEnvironment() if self.data_parallel_world_size > 1: self.cuda_env_arr = distributed_utils.all_gather_list( self.cuda_env, group=distributed_utils.get_global_group()) else: self.cuda_env_arr = [self.cuda_env] if self.data_parallel_rank == 0: utils.CudaEnvironment.pretty_print_cuda_env_list( self.cuda_env_arr) else: self.cuda_env = None self.cuda_env_arr = None metrics.log_start_time("wall", priority=790, round=0) self._start_time = time.time() self._previous_training_time = 0 self._cumulative_training_time = None
def train_language_model( data_dir, arch, extra_flags=None, run_validation=False, extra_valid_flags=None, task="language_modeling", world_size=1, ): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", task, data_dir, "--arch", arch, "--optimizer", "adam", "--lr", "0.0001", "--max-tokens", "500", "--tokens-per-sample", "500", "--save-dir", data_dir, "--max-epoch", "1", "--no-progress-bar", "--distributed-world-size", str(world_size), "--ddp-backend", "no_c10d", "--num-workers", "0", ] + (extra_flags or []), ) cfg = convert_namespace_to_omegaconf(train_args) distributed_utils.call_main(cfg, train.main) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch( validate_parser, [ "--task", task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), "--valid-subset", "valid", "--max-tokens", "500", "--no-progress-bar", "--num-workers", "0", ] + (extra_valid_flags or []), ) validate.main(validate_args)
def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" # add optimizer_history if "optimizer_history" not in state: state["optimizer_history"] = [ {"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]} ] state["last_optimizer_state"] = state["optimizer"] del state["optimizer"] del state["best_loss"] # move extra_state into sub-dictionary if "epoch" in state and "extra_state" not in state: state["extra_state"] = { "epoch": state["epoch"], "batch_offset": state["batch_offset"], "val_loss": state["val_loss"], } del state["epoch"] del state["batch_offset"] del state["val_loss"] # reduce optimizer history's memory usage (only keep the last state) if "optimizer" in state["optimizer_history"][-1]: state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"] for optim_hist in state["optimizer_history"]: del optim_hist["optimizer"] # record the optimizer class name if "optimizer_name" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" # move best_loss into lr_scheduler_state if "lr_scheduler_state" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["lr_scheduler_state"] = { "best": state["optimizer_history"][-1]["best_loss"] } del state["optimizer_history"][-1]["best_loss"] # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"].get("epoch", 0), "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # old model checkpoints may not have separate source/target positions if hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions" ): state["args"].max_source_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions # default to translation task if not hasattr(state["args"], "task"): state["args"].task = "translation" # --raw-text and --lazy-load are deprecated if getattr(state["args"], "raw_text", False): state["args"].dataset_impl = "raw" elif getattr(state["args"], "lazy_load", False): state["args"].dataset_impl = "lazy" # epochs start at 1 if state["extra_state"]["train_iterator"] is not None: state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1 ) # --remove-bpe ==> --postprocess if hasattr(state["args"], "remove_bpe"): state["args"].post_process = state["args"].remove_bpe # --min-lr ==> --stop-min-lr if hasattr(state["args"], "min_lr"): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr # binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion if hasattr(state["args"], "criterion") and state["args"].criterion in [ "binary_cross_entropy", "kd_binary_cross_entropy", ]: state["args"].criterion = "wav2vec" # remove log_keys if it's None (criteria will supply a default value of []) if hasattr(state["args"], "log_keys") and state["args"].log_keys is None: delattr(state["args"], "log_keys") # speech_pretraining => audio pretraining if ( hasattr(state["args"], "task") and state["args"].task == "speech_pretraining" ): state["args"].task = "audio_pretraining" # audio_cpc => wav2vec if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": state["args"].arch = "wav2vec" # convert legacy float learning rate to List[float] if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] # convert task data arg to a string instead of List[string] if ( hasattr(state["args"], "data") and isinstance(state["args"].data, list) and len(state["args"].data) > 0 ): state["args"].data = state["args"].data[0] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] with open_dict(cfg): # any upgrades for Hydra-based configs if ( "task" in cfg and "eval_wer_config" in cfg.task and isinstance(cfg.task.eval_wer_config.print_alignment, bool) ): cfg.task.eval_wer_config.print_alignment = "hard" if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool): cfg.generation.print_alignment = ( "hard" if cfg.generation.print_alignment else None ) if ( "model" in cfg and "w2v_args" in cfg.model and cfg.model.w2v_args is not None and ( hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args ) and hasattr(cfg.model.w2v_args.task, "eval_wer_config") and cfg.model.w2v_args.task.eval_wer_config is not None and isinstance( cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool ) ): cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard" return state
def load_model_ensemble_and_task( filenames, arg_overrides: Optional[Dict[str, Any]] = None, task=None, strict=True, suffix="", num_shards=1, state=None, ): assert state is None or len(filenames) == 1 from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] cfg = None for filename in filenames: orig_filename = filename model_shard_state = {"shard_weights": [], "shard_metadata": []} assert num_shards > 0 st = time.time() for shard_idx in range(num_shards): filename = get_maybe_sharded_checkpoint_filename( orig_filename, suffix, shard_idx, num_shards ) if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) if state is None: state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: cfg = state["cfg"] else: raise RuntimeError( f"Neither args nor cfg exist in state keys = {state.keys()}" ) if task is None: task = tasks.setup_task(cfg.task) if "task_state" in state: task.load_state_dict(state["task_state"]) if "fsdp_metadata" in state and num_shards > 1: model_shard_state["shard_weights"].append(state["model"]) model_shard_state["shard_metadata"].append(state["fsdp_metadata"]) # check FSDP import before the code goes too far if not has_FSDP: raise ImportError( "Cannot find FullyShardedDataParallel. " "Please install fairscale with: pip install fairscale" ) if shard_idx == num_shards - 1: consolidated_model_state = FSDP.consolidate_shard_weights( shard_weights=model_shard_state["shard_weights"], shard_metadata=model_shard_state["shard_metadata"], ) model = task.build_model(cfg.model) model.load_state_dict( consolidated_model_state, strict=strict, model_cfg=cfg.model ) else: # model parallel checkpoint or unsharded checkpoint model = task.build_model(cfg.model) new_state_model = state["model"] '''=====The following if-else statement is a work-around ===== # the current metadata loading/saving of pytorch. # In Pytorch, if state["model"]["_metadata"] exists as dictionary, then model.load_state_dict(strict=True) # will throw an error for unexpected "_metadata" key. To avoid this error, we need the state_dict to be # in orderedDict format, which has new_state_model._metadata attribute but not as key. # TODO yuansg@ This issue should be fixed in pytorch ideally. ''' if new_state_model.get("_metadata", None) is not None: new_metadata = new_state_model.get("_metadata", None) del state["model"]["_metadata"] else: new_metadata = None # Construct state dict content. contents = OrderedDict(new_state_model) # We explicitly set _metadata for the state models. The _metadata is implicitly stored for pytorch models. # calling state["model"] in fairseq will not invoke metadata storage. if new_metadata is None: logger.warning("===Jit: state[\"model\"] does not contain key \"_metadata\"=====") logger.warning("===Jit: we will be filling in with current model's meta-data instead.") # For models trained before this diff, we do the following to be backward compatible. contents.__setattr__("_metadata", model.state_dict()._metadata) else: contents.__setattr__("_metadata", new_metadata) '''====End of work-around logic=====''' model.load_state_dict( contents, strict=strict, model_cfg=cfg.model ) # reset state so it gets loaded for the next model in ensemble state = None if shard_idx % 10 == 0 and shard_idx > 0: elapsed = time.time() - st logger.info( f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard" ) # build model for ensemble ensemble.append(model) return ensemble, cfg, task
def cli_main(): parser = options.get_interactive_generation_parser() args = options.parse_args_and_arch(parser) distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks # add optimizer_history if "optimizer_history" not in state: state["optimizer_history"] = [{ "criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"] }] state["last_optimizer_state"] = state["optimizer"] del state["optimizer"] del state["best_loss"] # move extra_state into sub-dictionary if "epoch" in state and "extra_state" not in state: state["extra_state"] = { "epoch": state["epoch"], "batch_offset": state["batch_offset"], "val_loss": state["val_loss"], } del state["epoch"] del state["batch_offset"] del state["val_loss"] # reduce optimizer history's memory usage (only keep the last state) if "optimizer" in state["optimizer_history"][-1]: state["last_optimizer_state"] = state["optimizer_history"][-1][ "optimizer"] for optim_hist in state["optimizer_history"]: del optim_hist["optimizer"] # record the optimizer class name if "optimizer_name" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" # move best_loss into lr_scheduler_state if "lr_scheduler_state" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["lr_scheduler_state"] = { "best": state["optimizer_history"][-1]["best_loss"] } del state["optimizer_history"][-1]["best_loss"] # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } # old model checkpoints may not have separate source/target positions # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # default to translation task if not hasattr(state["args"], "task"): state["args"].task = "translation" # --raw-text and --lazy-load are deprecated if getattr(state["args"], "raw_text", False): state["args"].dataset_impl = "raw" elif getattr(state["args"], "lazy_load", False): state["args"].dataset_impl = "lazy" # epochs start at 1 if state["extra_state"]["train_iterator"] is not None: state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1) if hasattr(state["args"], "remove_bpe"): state["args"].post_process = state["args"].remove_bpe state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: with open_dict(state["cfg"]): if state["cfg"].task is not None: if hasattr(state["cfg"].task, "max_positions") and not hasattr( state["cfg"].task, "max_source_positions"): state["cfg"].task.max_source_positions = state[ "cfg"].task.max_positions state["cfg"].task.max_target_positions = state[ "cfg"].task.max_positions return state
def train_translation_model( data_dir, arch, extra_flags=None, task="translation", run_validation=False, lang_flags=None, extra_valid_flags=None, world_size=1, ): if lang_flags is None: lang_flags = [ "--source-lang", "in", "--target-lang", "out", ] train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", task, data_dir, "--save-dir", data_dir, "--arch", arch, "--optimizer", "nag", "--lr", "0.05", "--max-tokens", "500", "--max-epoch", "1", "--no-progress-bar", "--distributed-world-size", str(world_size), "--num-workers", "0", ] + lang_flags + (extra_flags or []), ) cfg = convert_namespace_to_omegaconf(train_args) distributed_utils.call_main(cfg, train.main) if run_validation: # test validation validate_parser = options.get_validation_parser() validate_args = options.parse_args_and_arch( validate_parser, [ "--task", task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), "--valid-subset", "valid", "--max-tokens", "500", "--no-progress-bar", "--num-workers", "0", ] + lang_flags + (extra_valid_flags or []), ) validate.main(validate_args)
def __init__(self, cfg: Wav2BartPoolConfig, tgt_dict=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: if os.path.isfile(os.path.join(cfg.w2v_path)): print('load wav2vec from cfg path') state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) else: print('load wav2vec from relative path') state = checkpoint_utils.load_checkpoint_to_cpu( 'models/wav2vec_small.pt', arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 self.pooling = nn.AvgPool1d(10, stride=10) if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) elif getattr(cfg, "decoder_embed_dim", d) != d: self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None
def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None): self.apply_mask = cfg.apply_mask arg_overrides = { "dropout": cfg.dropout, "activation_dropout": cfg.activation_dropout, "dropout_input": cfg.dropout_input, "attention_dropout": cfg.attention_dropout, "mask_length": cfg.mask_length, "mask_prob": cfg.mask_prob, "mask_selection": cfg.mask_selection, "mask_other": cfg.mask_other, "no_mask_overlap": cfg.no_mask_overlap, "mask_channel_length": cfg.mask_channel_length, "mask_channel_prob": cfg.mask_channel_prob, "mask_channel_before": cfg.mask_channel_before, "mask_channel_selection": cfg.mask_channel_selection, "mask_channel_other": cfg.mask_channel_other, "no_mask_channel_overlap": cfg.no_mask_channel_overlap, "encoder_layerdrop": cfg.layerdrop, "feature_grad_mult": cfg.feature_grad_mult, } if cfg.w2v_args is None: state = checkpoint_utils.load_checkpoint_to_cpu( cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) w2v_args.criterion = None w2v_args.lr_scheduler = None cfg.w2v_args = w2v_args else: state = None w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( w2v_args) assert cfg.normalize == w2v_args.task.normalize, ( "Fine-tuning works best when data normalization is the same. " "Please check that --normalize is set or unset for both pre-training and here" ) w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() super().__init__(task.source_dictionary) d = w2v_args.model.encoder_embed_dim self.w2v_model = model self.final_dropout = nn.Dropout(cfg.final_dropout) self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 targ_d = None self.proj = None if output_size is not None: targ_d = output_size elif getattr(cfg, "decoder_embed_dim", d) != d: targ_d = cfg.decoder_embed_dim if targ_d is not None: self.proj = Linear(d, targ_d)
def _upgrade_state_dict(state): """Helper for upgrading old model checkpoints.""" from fairseq import models, registry, tasks # add optimizer_history if "optimizer_history" not in state: state["optimizer_history"] = [{ "criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"] }] state["last_optimizer_state"] = state["optimizer"] del state["optimizer"] del state["best_loss"] # move extra_state into sub-dictionary if "epoch" in state and "extra_state" not in state: state["extra_state"] = { "epoch": state["epoch"], "batch_offset": state["batch_offset"], "val_loss": state["val_loss"], } del state["epoch"] del state["batch_offset"] del state["val_loss"] # reduce optimizer history's memory usage (only keep the last state) if "optimizer" in state["optimizer_history"][-1]: state["last_optimizer_state"] = state["optimizer_history"][-1][ "optimizer"] for optim_hist in state["optimizer_history"]: del optim_hist["optimizer"] # record the optimizer class name if "optimizer_name" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG" # move best_loss into lr_scheduler_state if "lr_scheduler_state" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["lr_scheduler_state"] = { "best": state["optimizer_history"][-1]["best_loss"] } del state["optimizer_history"][-1]["best_loss"] # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 # old model checkpoints may not have separate source/target positions if hasattr(state["args"], "max_positions") and not hasattr( state["args"], "max_source_positions"): state["args"].max_source_positions = state["args"].max_positions state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { "epoch": state["extra_state"]["epoch"], "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # default to translation task if not hasattr(state["args"], "task"): state["args"].task = "translation" # --raw-text and --lazy-load are deprecated if getattr(state["args"], "raw_text", False): state["args"].dataset_impl = "raw" elif getattr(state["args"], "lazy_load", False): state["args"].dataset_impl = "lazy" # epochs start at 1 if state["extra_state"]["train_iterator"] is not None: state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1) # --remove-bpe ==> --postprocess if hasattr(state["args"], "remove_bpe"): state["args"].post_process = state["args"].remove_bpe # --min-lr ==> --stop-min-lr if hasattr(state["args"], "min_lr"): state["args"].stop_min_lr = state["args"].min_lr del state["args"].min_lr # binary_cross_entropy => wav2vec criterion if hasattr(state["args"], "criterion" ) and state["args"].criterion == "binary_cross_entropy": state["args"].criterion = "wav2vec" # speech_pretraining => audio pretraining if hasattr(state["args"], "task") and state["args"].task == "speech_pretraining": state["args"].task = "audio_pretraining" # audio_cpc => wav2vec if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": state["args"].arch = "wav2vec" # convert legacy float learning rate to List[float] if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): state["args"].lr = [state["args"].lr] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: with open_dict(state["cfg"]): # any upgrades for Hydra-based configs pass return state