コード例 #1
0
def save_state(
    filename,
    args,
    model_state_dict,
    criterion,
    optimizer,
    lr_scheduler,
    num_updates,
    optim_history=None,
    extra_state=None,
):
    from fairseq import utils

    if optim_history is None:
        optim_history = []
    if extra_state is None:
        extra_state = {}
    state_dict = {
        "args":
        args,
        "model":
        model_state_dict if model_state_dict else {},
        "optimizer_history":
        optim_history + [{
            "criterion_name": criterion.__class__.__name__,
            "optimizer_name": optimizer.__class__.__name__,
            "lr_scheduler_state": lr_scheduler.state_dict(),
            "num_updates": num_updates,
        }],
        "extra_state":
        extra_state,
    }
    if utils.has_parameters(criterion):
        state_dict["criterion"] = criterion.state_dict()
    if not args.no_save_optimizer_state:
        state_dict["last_optimizer_state"] = convert_state_dict_type(
            optimizer.state_dict())

    with PathManager.open(filename, "wb") as f:
        torch_persistent_save(state_dict, f)
コード例 #2
0
    def __init__(self, args):
        super().__init__(args)

        self.eos = DEFAULT_EOS

        self.gpu = getattr(args, "gpu", False)

        self.args = args

        self.load_model_vocab(args)

        if getattr(self.model.decoder.layers[0].encoder_attn,
                   'pre_decision_ratio', None) is not None:
            self.speech_segment_size *= (
                self.model.decoder.layers[0].encoder_attn.pre_decision_ratio)

        args.global_cmvn = None
        if args.config:
            with open(os.path.join(args.data_bin, args.config), "r") as f:
                config = yaml.load(f, Loader=yaml.BaseLoader)

            if "global_cmvn" in config:
                args.global_cmvn = np.load(
                    config["global_cmvn"]["stats_npz_path"])

        if args.global_stats:
            with PathManager.open(args.global_stats, "r") as f:
                global_cmvn = json.loads(f.read())
                self.global_cmvn = {
                    "mean": global_cmvn["mean"],
                    "std": global_cmvn["stddev"]
                }

        self.feature_extractor = OnlineFeatureExtractor(args)

        self.max_len = args.max_len

        self.force_finish = args.force_finish

        torch.set_grad_enabled(False)
コード例 #3
0
    def load(cls, f, f_non_lang_syms=None):
        """Loads the dictionary from a text file with the format:

        ```
        <symbol0> <count0>
        <symbol1> <count1>
        ...
        ```

        Identifies the space symbol if it exists, by obtaining its index
        (space_index=-1 if no space symbol)

        Loads non_lang_syms from another text file, if it exists, with one
        symbol per line
        """
        d = super().load(f)
        d.space_index = d.indices.get(d.space_word, -1)

        if f_non_lang_syms is not None:
            assert isinstance(f_non_lang_syms, str)
            try:
                with PathManager.open(f_non_lang_syms, "r", encoding="utf-8") as fd:
                    non_lang_syms = [x.rstrip() for x in fd.readlines()]
            except FileNotFoundError as fnfe:
                raise fnfe
            except UnicodeError:
                raise Exception(
                    "Incorrect encoding detected in {}, please "
                    "rebuild the dataset".format(f)
                )

            for sym in non_lang_syms:
                assert d.index(sym) != d.unk(), \
                "{} in {} is not in the dictionary".format(sym, f_non_lang_syms)
            d.non_lang_syms = non_lang_syms

        return d
コード例 #4
0
def save_state(
    filename,
    cfg: FairseqConfig,
    model_state_dict,
    criterion,
    optimizer,
    lr_scheduler,
    num_updates,
    optim_history=None,
    extra_state=None,
    **kwargs,
):
    from fairseq import utils

    if optim_history is None:
        optim_history = []
    if extra_state is None:
        extra_state = {}
    state_dict = {
        "cfg":
        cfg,
        "args":
        kwargs.get("args", None),
        "model":
        model_state_dict or {},
        "optimizer_history":
        optim_history + [{
            "criterion_name": criterion.__class__.__name__,
            "optimizer_name": optimizer.__class__.__name__,
            "lr_scheduler_state": lr_scheduler.state_dict(),
            "num_updates": num_updates,
        }],
        "extra_state":
        extra_state,
    }
    if utils.has_parameters(criterion):
        state_dict["criterion"] = criterion.state_dict()

    if cfg is None:
        cfg = state_dict["args"]
        assert cfg is not None, "must provide cfg or args"

    if isinstance(cfg, DictConfig):
        no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
    else:
        no_save_optimizer_state = cfg.no_save_optimizer_state
    if not no_save_optimizer_state:
        state_dict["last_optimizer_state"] = optimizer.state_dict()

    # keep everything on CPU
    state_dict = utils.move_to_cpu(state_dict)

    if PathManager.supports_rename(filename):
        # do atomic save
        with PathManager.open(filename + ".tmp", "wb") as f:
            torch_persistent_save(state_dict, f)
        PathManager.rename(filename + ".tmp", filename)
    else:
        # fallback to non-atomic save
        with PathManager.open(filename, "wb") as f:
            torch_persistent_save(state_dict, f)
コード例 #5
0
def main():
    parser = argparse.ArgumentParser(
        description="Tool to average the params of input checkpoints to "
        "produce a new checkpoint", )
    # fmt: off
    parser.add_argument('--inputs',
                        required=True,
                        nargs='+',
                        help='Input checkpoint file paths.')
    parser.add_argument(
        '--output',
        required=True,
        metavar='FILE',
        help=
        'Write the new checkpoint containing the averaged weights to this path.'
    )
    num_group = parser.add_mutually_exclusive_group()
    num_group.add_argument(
        '--num-epoch-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
        'and average last this many of them.')
    num_group.add_argument(
        '--num-update-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint.jj_ee_xx.pt in the path specified by input, '
        'and average last this many of them.')
    parser.add_argument(
        '--checkpoint-upper-bound',
        type=int,
        help=
        'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
        'when using --num-update-checkpoints, this will set an upper bound on which update to use'
        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.'
        'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500'
    )
    # fmt: on
    args = parser.parse_args()
    print(args)

    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

    assert args.checkpoint_upper_bound is None or (
        args.num_epoch_checkpoints is not None
        or args.num_update_checkpoints is not None
    ), "--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints"
    assert (
        args.num_epoch_checkpoints is None
        or args.num_update_checkpoints is None
    ), "Cannot combine --num-epoch-checkpoints and --num-update-checkpoints"

    if num is not None:
        args.inputs = last_n_checkpoints(
            args.inputs,
            num,
            is_update_based,
            upper_bound=args.checkpoint_upper_bound,
        )
        print("averaging checkpoints: ", args.inputs)

    new_state = average_checkpoints(args.inputs)
    with PathManager.open(args.output, "wb") as f:
        torch.save(new_state, f)
    print("Finished writing averaged checkpoint to {}".format(args.output))
コード例 #6
0
 def test_file_io(self):
     from fairseq.file_io import PathManager
     with PathManager.open(os.path.join(self._tmpdir, "test.txt"),
                           "r") as f:
         s = f.read()
     self.assertEqual(s, self._tmpfile_contents)
コード例 #7
0
    def add_from_file(self,
                      f,
                      tgt_first=False,
                      bos_id_tgt=None,
                      pad_id_tgt=None,
                      eos_id_tgt=None,
                      unk_id_tgt=None):
        """
        Loads a pre-existing dictionary from a text file and adds its symbols
        to this instance.
        """
        if isinstance(f, str):
            try:
                with PathManager.open(f, "r", encoding="utf-8") as fd:
                    self.add_from_file(fd,
                                       tgt_first=tgt_first,
                                       bos_id_tgt=bos_id_tgt,
                                       pad_id_tgt=pad_id_tgt,
                                       eos_id_tgt=eos_id_tgt,
                                       unk_id_tgt=unk_id_tgt)
            except FileNotFoundError as fnfe:
                raise fnfe
            except UnicodeError:
                raise Exception("Incorrect encoding detected in {}, please "
                                "rebuild the dataset".format(f))
            return

        lines = f.readlines()
        indices_start_line = self._load_meta(lines)

        if tgt_first:
            i = 0
            for line in lines[indices_start_line:]:
                while i in [bos_id_tgt, pad_id_tgt, eos_id_tgt, unk_id_tgt]:
                    if i == bos_id_tgt:
                        self.add_symbol(self.bos_word)
                        i += 1
                    elif i == pad_id_tgt:
                        self.add_symbol(self.pad_word)
                        i += 1
                    elif i == eos_id_tgt:
                        self.add_symbol(self.eos_word)
                        i += 1
                    elif i == unk_id_tgt:
                        self.add_symbol(self.unk_word)
                        i += 1
                try:
                    line, field = line.rstrip().rsplit(" ", 1)
                    if field == "#fairseq:overwrite":
                        overwrite = True
                        line, field = line.rsplit(" ", 1)
                    else:
                        overwrite = False
                    count = int(field)
                    word = line
                    # if word in self and not overwrite:
                    #     raise RuntimeError(
                    #         "Duplicate word found when loading Dictionary: '{}'. "
                    #         "Duplicate words can overwrite earlier ones by adding the "
                    #         "#fairseq:overwrite flag at the end of the corresponding row "
                    #         "in the dictionary file. If using the Camembert model, please "
                    #         "download an updated copy of the model file."
                    #         .format(word)
                    #     )
                    self.add_symbol(word, n=count, overwrite=overwrite)
                    i += 1
                except ValueError:
                    raise ValueError(
                        "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
                    )
        else:
            for line in lines[indices_start_line:]:
                try:
                    line, field = line.rstrip().rsplit(" ", 1)
                    if field == "#fairseq:overwrite":
                        overwrite = True
                        line, field = line.rsplit(" ", 1)
                    else:
                        overwrite = False
                    count = int(field)
                    word = line
                    if word in self and not overwrite:
                        raise RuntimeError(
                            "Duplicate word found when loading Dictionary: '{}'. "
                            "Duplicate words can overwrite earlier ones by adding the "
                            "#fairseq:overwrite flag at the end of the corresponding row "
                            "in the dictionary file. If using the Camembert model, please "
                            "download an updated copy of the model file.".
                            format(word))
                    self.add_symbol(word, n=count, overwrite=overwrite)
                except ValueError:
                    raise ValueError(
                        "Incorrect dictionary format, expected '<token> <cnt> [flags]'"
                    )
コード例 #8
0
def main():
    parser = argparse.ArgumentParser(
        description='Tool to average the params of input checkpoints to '
        'produce a new checkpoint')
    # fmt: off
    parser.add_argument('--input-directory',
                        type=str,
                        required=True,
                        help='Input directory containing model checkpoints.')
    parser.add_argument(
        '--output',
        default=None,
        type=str,
        help=
        'Write the new checkpoint containing the averaged weights to this path.'
    )
    num_group = parser.add_mutually_exclusive_group()
    num_group.add_argument(
        '--num-epoch-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint_xx.pt in the path specified by input, '
        'and average last this many of them.')
    num_group.add_argument(
        '--num-update-checkpoints',
        type=int,
        help=
        'if set, will try to find checkpoints with names checkpoint_ee_xx.pt in the path specified by input, '
        'and average last this many of them.')
    parser.add_argument(
        '--checkpoint-upper-bound',
        type=int,
        help=
        'when using --num-epoch-checkpoints, this will set an upper bound on which epoch to use, '
        'when using --num-update-checkpoints, this will set an upper bound on which update to use'
        'e.g., with --num-epoch-checkpoints=10 --checkpoint-upper-bound=50, checkpoints 41-50 would be averaged.'
        'e.g., with --num-update-checkpoints=10 --checkpoint-upper-bound=50000, checkpoints 40500-50000 would be averaged assuming --save-interval-updates 500'
    )
    # fmt: on
    args = parser.parse_args()
    print(args)

    num = None
    is_update_based = False
    if args.num_update_checkpoints is not None:
        num = args.num_update_checkpoints
        is_update_based = True
    elif args.num_epoch_checkpoints is not None:
        num = args.num_epoch_checkpoints

    assert args.checkpoint_upper_bound is None or (args.num_epoch_checkpoints is not None or args.num_update_checkpoints is not None), \
        '--checkpoint-upper-bound requires --num-epoch-checkpoints or --num-update-checkpoints'
    assert args.num_epoch_checkpoints is None or args.num_update_checkpoints is None, \
        'Cannot combine --num-epoch-checkpoints and --num-update-checkpoints'

    if num is not None:
        args.inputs = last_n_checkpoints(
            args.input_directory,
            num,
            is_update_based,
            upper_bound=args.checkpoint_upper_bound,
        )
        print('averaging checkpoints: ', args.inputs)

    if args.output is None and args.num_epoch_checkpoints is not None:
        range_checkpoints = sorted([
            int(re.search(r".*checkpoint(.*)\.pt$", input_file).groups()[0])
            for input_file in args.inputs
        ])
        start = range_checkpoints[0]
        end = range_checkpoints[-1]
        args.output = os.path.join(
            args.input_directory, "checkpoint_average_%s_%s.pt" % (start, end))

    assert args.output is not None, "No output file specified, nor could it be deduced"
    new_state = average_checkpoints(args.inputs)
    with PathManager.open(args.output, 'wb') as f:
        torch.save(new_state, f)
    print('Finished writing averaged checkpoint to {}'.format(args.output))
コード例 #9
0
ファイル: fb_dictionary.py プロジェクト: stas00/deep-shallow
def add_file_to_dictionary(filename, dict, tokenize):
    with PathManager.open(filename, "r", encoding="utf-8") as f:
        for line in f:
            for word in tokenize(line):
                dict.add_symbol(word)
            dict.add_symbol(dict.eos_word)
コード例 #10
0
def main(args):
    state = checkpoint_utils.load_checkpoint_to_cpu(args.checkpoint)
    ns = state["args"]
    model = state["model"]
    ns.arch = "transformer_modular"

    if (args.encoder_attention_heads_active is None
            and args.decoder_attention_heads_active is None):
        raise ValueError(
            'Either --encoder-attention-heads-active or '
            '--decoder-attention-heads-active option must be set.')
    if args.encoder_attention_heads_active is None:
        args.encoder_attention_heads_active = args.decoder_attention_heads_active

    if args.encoder_modular_layer_indices is not None:
        ns.encoder_modular_layer_indices = "({})".format(
            args.encoder_modular_layer_indices)
        model = convert_model(model, ns, coder="encoder", att_type="self_attn")
    if args.decoder_modular_layer_indices is not None:
        ns.decoder_modular_layer_indices = "({})".format(
            args.decoder_modular_layer_indices)
        model = convert_model(model, ns, coder="decoder", att_type="self_attn")
        model = convert_model(model,
                              ns,
                              coder="decoder",
                              att_type="encoder_attn")

    ctrl_enc = ModularCtrl(ns.encoder_embed_dim,
                           ns.encoder_attention_heads,
                           args.encoder_attention_heads_active,
                           hidden_depth=args.ctrl_hidden_depth,
                           hidden_dim=args.ctrl_hidden_dim,
                           ctrl_type=args.ctrl_type)
    ns.module_ctrl_hidden_depth = args.ctrl_hidden_depth
    ns.module_ctrl_hidden_dim = args.ctrl_hidden_dim
    ns.module_ctrl_type = args.ctrl_type

    for k, v in ctrl_enc.state_dict().items():
        model["encoder.module_ctrl.{}".format(k)] = v

    if not args.share_encoder_ctrl:
        if args.decoder_attention_heads_active is None:
            raise ValueError("Missing ``decoder-attention-heads-active'' "
                             "when ``share-encoder-ctrl'' is disabled.")
        ns.share_encoder_ctrl = False
        ctrl_dec = ModularCtrl(ns.decoder_embed_dim,
                               ns.decoder_attention_heads,
                               args.decoder_attention_heads_active,
                               hidden_depth=args.ctrl_hidden_depth,
                               hidden_dim=args.ctrl_hidden_dim,
                               ctrl_type=args.ctrl_type)
        for k, v in ctrl_dec.state_dict().items():
            model["decoder.module_ctrl.{}".format(k)] = v
    else:
        ns.share_encoder_ctrl = True

    ns.arch = "transformer_modular"
    ns.criterion = "label_smoothed_cross_entropy_modular"
    ns.task = "translation_modular"
    ns.encoder_attention_heads_active = args.encoder_attention_heads_active

    state["args"] = ns
    state["model"] = model

    for i, _ in enumerate(state["optimizer_history"]):
        state["optimizer_history"][i][
            "criterion_name"] = 'LabelSmoothedCrossEntropyModularCriterion'

    state = utils.move_to_cpu(state)

    with PathManager.open(args.save_as, "wb") as f:
        checkpoint_utils.torch_persistent_save(state, f)