コード例 #1
0
ファイル: sp_task.py プロジェクト: jxhe/sparse-text-prototype
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        paths = args.data.split(os.pathsep)
        assert len(paths) > 0
        # find language pair automatically
        # if args.source_lang is None or args.target_lang is None:
        #     args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
        # if args.source_lang is None or args.target_lang is None:
        #     raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.txt'))

        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))

        if args.inv_editor == 'levenshtein':
            edit_dict = RetrievePrototypeDataset.get_edit_dict()
        else:
            edit_dict = None

        if edit_dict is not None:
            print('| [edit] dictionary: {} types'.format(len(edit_dict)))

        return cls(args, src_dict, edit_dict)
コード例 #2
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = BertDictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
        src_eos_idx = src_dict.add_special_token('[END_OF_SENT]')
        print('src_dict:[END_OF_SENT] id = {}, token = {}'.format(src_eos_idx, src_dict[src_eos_idx]))

        tgt_dict = BertDictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        tgt_eos_idx = tgt_dict.add_special_token('[END_OF_SENT]')
        print('tgt_dict:[END_OF_SENT] id = {}, token = {}'.format(tgt_eos_idx, tgt_dict[tgt_eos_idx]))

        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        assert src_dict.sep() == tgt_dict.sep()

        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        assert not args.left_pad_source, f'args.left_pad_source must be False'

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        args.no_strip_node_label = getattr(args, 'no_strip_node_label', False)
        src_dict = DPTreeWrapperDictionary.load(
            os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)),
            no_strip_node_label=args.no_strip_node_label)
        tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] DPtree-dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #4
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        # left_pad_source=True, left_pad_target=False
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # 例如source_lang=cn  target_lang=en
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        #加载字典文件
        src_dict = cls.load_dictionary(
            os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(
            os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} 个字符'.format(args.source_lang,
                                                 len(src_dict)))
        print('| [{}] dictionary: {} 个字符'.format(args.target_lang,
                                                 len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #5
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        paths = utils.split_paths(args.data)
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.src.txt'))
        tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.tgt.txt'))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        logger.info('[{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        logger.info('[{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #6
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        # if args.source_lang is None or args.target_lang is None:
        #     args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
        # if args.source_lang is None or args.target_lang is None:
        #     raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        if args.share_all_embeddings:
            src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.txt'))
            tgt_dict = Dictionary.load(os.path.join(args.data[0], 'dict.txt'))
        else:
            src_dict = Dictionary.load(os.path.join(args.data[0], 'dict.txt'))
            tgt_dict = Dictionary.load(
                os.path.join(args.data[0], 'dict.tgt.txt'))

        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format('src', len(src_dict)))
        print('| [{}] dictionary: {} types'.format('tgt', len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #7
0
ファイル: translation.py プロジェクト: fyabc/fairseq
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #8
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        if args.source_lang is not None or args.target_lang is not None:
            if args.lang_pairs is not None:
                raise ValueError(
                    '--source-lang/--target-lang implies generation, which is '
                    'incompatible with --lang-pairs')
            training = False
            args.lang_pairs = [
                '{}-{}'.format(args.source_lang, args.target_lang)
            ]
        else:
            training = True
            args.lang_pairs = args.lang_pairs.split(',')
            args.source_lang, args.target_lang = args.lang_pairs[0].split('-')

        langs = list(
            {x
             for lang_pair in args.lang_pairs for x in lang_pair.split('-')})

        # load dictionaries
        dicts = OrderedDict()
        for lang in langs:
            dicts[lang] = Dictionary.load(
                os.path.join(args.data, 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[langs[0]].pad()
                assert dicts[lang].eos() == dicts[langs[0]].eos()
                assert dicts[lang].unk() == dicts[langs[0]].unk()
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))

        return cls(args, dicts, training)
コード例 #9
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
            args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
        else:
            training = True
            args.source_lang, args.target_lang = args.lang_pairs[0].split('-')

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            dicts[lang] = Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
コード例 #10
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
            args.lang_pairs = ['{}-{}'.format(args.source_lang, args.target_lang)]
        else:
            training = True
            args.source_lang, args.target_lang = args.lang_pairs[0].split('-')

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = args.data.split(':')
            assert len(paths) > 0
            dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
コード例 #11
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        args.trigram_block = options.eval_bool(args.trigram_block)
        args.init_from_pretrained_doc_model = options.eval_bool(args.init_from_pretrained_doc_model)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data)
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        if args.pretrained_bert_model.startswith('roberta'):
            src_dict = GPT2Dictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
        else:
            src_dict = BertDictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
 
        if args.init_from_pretrained_doc_model:
            print('adding the [SENT_MASK] token? change it within Bert Special Tokens')
            pass
            # adding the [SENT_MASK] token?

        tgt_dict = FlexibleDictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #12
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        if args.lang_pairs is None:
            raise ValueError('--lang-pairs is required. List all the language pairs in the training objective.')
        if isinstance(args.lang_pairs, str):
            args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = utils.split_paths(args.data)
            assert len(paths) > 0
            dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            logger.info('[{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
コード例 #13
0
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument("--dropout", type=float, metavar="D",
                            help="dropout probability")
        parser.add_argument("--encoder-conv-channels", type=str, metavar="EXPR",
                            help="list of encoder convolution's out channels")
        parser.add_argument("--encoder-conv-kernel-sizes", type=str, metavar="EXPR",
                            help="list of encoder convolution's kernel sizes")
        parser.add_argument("--encoder-conv-strides", type=str, metavar="EXPR",
                            help="list of encoder convolution's strides")
        parser.add_argument("--encoder-rnn-hidden-size", type=int, metavar="N",
                            help="encoder rnn's hidden size")
        parser.add_argument("--encoder-rnn-layers", type=int, metavar="N",
                            help="number of rnn encoder layers")
        parser.add_argument("--encoder-rnn-bidirectional",
                            type=lambda x: options.eval_bool(x),
                            help="make all rnn layers of encoder bidirectional")
        parser.add_argument("--encoder-rnn-residual",
                            type=lambda x: options.eval_bool(x),
                            help="create residual connections for rnn encoder "
                            "layers (starting from the 2nd layer), i.e., the actual "
                            "output of such layer is the sum of its input and output")

        # Granular dropout settings (if not specified these default to --dropout)
        parser.add_argument("--encoder-rnn-dropout-in", type=float, metavar="D",
                            help="dropout probability for encoder rnn's input")
        parser.add_argument("--encoder-rnn-dropout-out", type=float, metavar="D",
                            help="dropout probability for encoder rnn's output")
コード例 #14
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data)
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = Dictionary.load(
            os.path.join(args.data, 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = Dictionary.load(
            os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #15
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        src_dict = BertBasedDictionary(args.bert_name)
        tgt_dict = cls.load_dictionary(
            os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad(), "%d != %d" % (src_dict.pad(),
                                                               tgt_dict.pad())
        assert src_dict.eos() == tgt_dict.eos(), "%d != %d" % (src_dict.eos(),
                                                               tgt_dict.eos())
        assert src_dict.unk() == tgt_dict.unk(), "%d != %d" % (src_dict.unk(),
                                                               tgt_dict.unk())
        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #16
0
    def setup_task(cls, args, **kwargs):
        """Setup GEC task, including dictionary & model building."""

        """
        Similar to the translation task, but also load labels dictionaries
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(os.path.join(args.data[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #17
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        paths = args.data.split(':')
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
        char_dict = cls.load_dictionary(os.path.join(paths[0], 'dict_char.txt'))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang, len(tgt_dict)))

        return cls(args, src_dict, tgt_dict, char_dict)
コード例 #18
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).
        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        paths = utils.split_paths(args.data[0])
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(args.data[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        logger.info("loading dicts from.{}".format(args.dicts))
        dictionary = cls.load_dictionary(
            os.path.join(paths[0], args.dicts)
        )

        logger.info("args.add_lang_token: {} ".format(args.add_lang_token))
        if args.add_lang_token:
            languages = args.langs.split(",")
            for lang_pair in languages:
                print("{} was add to dictionary".format(lang_pair))
                lang = lang_pair.split("-")
                dictionary.add_symbol("[{}]".format(lang[0]))
                dictionary.add_symbol("[{}]".format(lang[1]))
        return cls(args, dictionary,dictionary)
コード例 #19
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        # paths = args.data.split(':')
        # assert len(paths) > 0
        # # find language pair automatically
        # if args.source_lang is None or args.target_lang is None:
        #     args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
        # if args.source_lang is None or args.target_lang is None:
        #     raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        # print(args.data)
        def read_config(path):
            with open(path) as config:
                import yaml
                contents = config.read()
                data = yaml.load(contents)
                return data

        path = "/content/drive/My Drive/IIIT-H RA/ICON/fairseq-working/config.yaml"
        data = read_config(path)
        # self.pairs = pairs_select(data['corpo
        # ra'])

        # src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        # tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
        src_dict = cls.load_dictionary(data['dictionary']['src'])
        tgt_dict = cls.load_dictionary(data['dictionary']['tgt'])
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        assert src_dict == tgt_dict
        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict, data)
コード例 #20
0
ファイル: translation.py プロジェクト: jind11/TitleStylist
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        if args.lang_pairs is not None:
            args.lang_pairs = args.lang_pairs.split(',')
            sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
            for lang in sorted_langs:
                paths = args.data.split(':')
                assert len(paths) > 0
                # dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
                dicts[lang] = BertDictionary.load_from_file(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
                if len(dicts) > 0:
                    assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                    assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                    assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
                if args.encoder_langtok is not None or args.decoder_langtok:
                    for lang_to_add in sorted_langs:
                        dicts[lang].add_symbol(_lang_token(lang_to_add))
                print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        else:
            assert args.dae_styles

        if len(args.data.split(':')) > 1:
            sorted_langs = sorted(list(args.dae_styles.split(',')))
            for idx, lang in enumerate(args.dae_styles.split(',')):
                paths = args.data.split(':')[1:]
                # dicts[lang] = Dictionary.load(os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
                dicts[lang] = BertDictionary.load_from_file(os.path.join(paths[idx], 'dict.txt'))
                if len(dicts) > 0:
                    assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                    assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                    assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
                if args.encoder_langtok is not None or args.decoder_langtok:
                    for lang_to_add in sorted_langs:
                        dicts[lang].add_symbol(_lang_token(lang_to_add))
                print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return cls(args, dicts, training)
コード例 #21
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

		Args:
			args (argparse.Namespace): parsed command-line arguments
		"""
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        paths = args.data.split(os.pathsep)
        assert len(paths) > 0

        # load dictionaries
        src_dict = cls.load_dictionary(
            os.path.join(paths[0], 'train.dict.src.txt'))
        sql_dict = cls.load_dictionary(
            os.path.join(paths[0], 'train.dict.sql.txt'))

        def store_random_embeddings(num_embeddings, embedding_dim, padding_idx,
                                    fname):
            m = nn.Embedding(num_embeddings,
                             embedding_dim,
                             padding_idx=padding_idx)
            fname = open(os.path.join(args.save_dir, 'rnd_embed.' + fname),
                         'w')
            for i in m.weight.tolist():
                for j in i:
                    fname.write(str(j) + ' ')
                fname.write('\n')
            fname.close()

        if 'path' in args:
            save_path = os.path.dirname(os.path.realpath(args.path))
        else:
            save_path = args.save_dir
        if not os.path.exists(os.path.join(save_path, 'rnd_embed.src')):
            store_random_embeddings(len(src_dict), args.word_encoder_embed_dim,
                                    src_dict.pad(), 'src')
            store_random_embeddings(len(sql_dict), args.decoder_embed_dim,
                                    sql_dict.pad(), 'sql')
        src_random_embedding_path = os.path.join(save_path, 'rnd_embed.src')
        sql_random_embedding_path = os.path.join(save_path, 'rnd_embed.sql')

        assert src_dict.pad() == sql_dict.pad()
        assert src_dict.eos() == sql_dict.eos()
        assert src_dict.unk() == sql_dict.unk()
        logger.info('["src"] dictionary: {} types'.format(len(src_dict)))
        logger.info('["sql"] dictionary: {} types'.format(len(sql_dict)))

        return cls(args, src_dict, sql_dict, src_random_embedding_path,
                   sql_random_embedding_path)
コード例 #22
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data[0])

        # load dictionary
        subword_dict = SubwordDictionary.load(
            os.path.join(args.data[0], 'model.vcb'))

        return cls(args, subword_dict)
コード例 #23
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        args.trigram_block = options.eval_bool(args.trigram_block)
        args.init_from_pretrained_doc_model = options.eval_bool(
            args.init_from_pretrained_doc_model)

        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                args.data)
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        if args.roberta_model.startswith('roberta'):
            src_dict = GPT2Dictionary.load(
                os.path.join(args.data,
                             'dict.{}.txt'.format(args.source_lang)))
        else:
            src_dict = BertDictionary.load(
                os.path.join(args.data,
                             'dict.{}.txt'.format(args.source_lang)))
        idx = src_dict.add_special_token('<sent_mask>')
        print('<sent_mask> id = {}, token = {}'.format(idx, src_dict[idx]))
        print('<mask> id is', src_dict.index('<mask>'))
        print('<sent_mask> id is', src_dict.index('<sent_mask>'))

        # tgt_dict = FlexibleDictionary.load(os.path.join(args.data, 'dict.{}.txt'.format(args.target_lang)))
        # generate the tgt_dict
        tgt_dict = PointerFlexibleDictionary(args.max_doc_length,
                                             specialTokens=[('EOS', '</s>'),
                                                            ('PAD', '<pad>'),
                                                            ('UNK', '<unk>'),
                                                            ('BOS', '<s>')])

        assert tgt_dict.index('0') == 0
        print('| WARNING: idx should should match the context in the tgt dict')
        # if args.predict_arch == 'pointer_net':
        #     assert tgt_dict.eos() == args.max_doc_length

        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #24
0
    def prepare(cls, args, **kargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        if args.lang_pairs is None:
            raise ValueError(
                '--lang-pairs is required. List all the language pairs in the training objective.'
            )
        args.lang_pairs = args.lang_pairs.split(',')
        sorted_langs = sorted(
            list({
                x
                for lang_pair in args.lang_pairs for x in lang_pair.split('-')
            }))
        if args.source_lang is not None or args.target_lang is not None:
            training = False
        else:
            training = True

        # load dictionaries
        dicts = OrderedDict()
        for lang in sorted_langs:
            paths = args.data.split(':')
            assert len(paths) > 0
            dicts[lang] = Dictionary.load(
                os.path.join(paths[0], 'dict.{}.txt'.format(lang)))
            # add lang_token to dict
            for lan_name in sorted_langs:
                lan_token = _lang_token(lan_name)
                if dicts[lang].index(lan_token) == dicts[lang].unk_index:
                    dicts[lang].add_symbol(lan_token)
            if len(dicts) > 0:
                assert dicts[lang].pad() == dicts[sorted_langs[0]].pad()
                assert dicts[lang].eos() == dicts[sorted_langs[0]].eos()
                assert dicts[lang].unk() == dicts[sorted_langs[0]].unk()
            if args.encoder_langtok is not None or args.decoder_langtok:
                for lang_to_add in sorted_langs:
                    dicts[lang].add_symbol(_lang_token(lang_to_add))
            print('| [{}] dictionary: {} types'.format(lang, len(dicts[lang])))
        return dicts, training
コード例 #25
0
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument("--dropout", type=float, metavar="D",
                            help="dropout probability")
        parser.add_argument("--hidden-sizes", type=str, metavar="EXPR",
                            help="list of hidden sizes for all Tdnn layers")
        parser.add_argument("--kernel-sizes", type=str, metavar="EXPR",
                            help="list of all Tdnn layer\'s kernel sizes")
        parser.add_argument("--strides", type=str, metavar="EXPR",
                            help="list of all Tdnn layer\'s strides")
        parser.add_argument("--dilations", type=str, metavar="EXPR",
                            help="list of all Tdnn layer\'s dilations")
        parser.add_argument("--num-layers", type=int, metavar="N",
                            help="number of Tdnn layers")
        parser.add_argument("--residual", type=lambda x: options.eval_bool(x),
                            help="create residual connections for rnn encoder "
                            "layers (starting from the 2nd layer), i.e., the actual "
                            "output of such layer is the sum of its input and output")

        # Granular dropout settings (if not specified these default to --dropout)
        parser.add_argument("--dropout-in", type=float, metavar="D",
                            help="dropout probability for encoder\'s input")
        parser.add_argument("--dropout-out", type=float, metavar="D",
                            help="dropout probability for Tdnn layers\' output")
コード例 #26
0
ファイル: tasks.py プロジェクト: avsaditya/translate
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)

        assert not pytorch_translate_data.is_multilingual(
            args
        ), "Must set `--task pytorch_translate_multilingual` for multilingual training"

        # Load dictionaries
        source_dict = pytorch_translate_dictionary.Dictionary.load(
            args.source_vocab_file)
        target_dict = pytorch_translate_dictionary.Dictionary.load(
            args.target_vocab_file)

        source_lang = args.source_lang or "src"
        target_lang = args.target_lang or "tgt"

        print(f"| [{source_lang}] dictionary: {len(source_dict)} types")
        print(f"| [{target_lang}] dictionary: {len(target_dict)} types")

        use_char_source = (args.char_source_vocab_file != "") or (getattr(
            args, "arch", "") == "char_source")
        if use_char_source:
            char_source_dict = pytorch_translate_dictionary.Dictionary.load(
                args.char_source_vocab_file)
            # this attribute is used for CharSourceModel construction
            args.char_source_dict_size = len(char_source_dict)
        else:
            char_source_dict = None

        return cls(args, source_dict, target_dict, char_source_dict)
コード例 #27
0
    def setup_task(cls, args, **kwargs):
        # Here we can perform any setup required for the task. This may include
        # loading Dictionaries, initializing shared Embedding layers, etc.
        # In this case we'll just load the Dictionaries.

        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        # load dictionaries
        vocab = cls.load_dictionary(os.path.join(args.data, 'dict.txt'))
        logger.info('[{}] dictionary: {} types'.format('Src + tgt',
                                                       len(vocab)))
        vocab.model = spm.SentencePieceProcessor(
            model_file=os.path.join(args.data, 'spm.model'))

        return cls(args, vocab)
コード例 #28
0
    def setup_task(cls, args, **kwargs):
        args.left_pad_source = options.eval_bool(args.left_pad_source)

        # Load dictionaries
        source_dict = MaskedLMDictionary.load(args.source_vocab_file)
        target_dict = MaskedLMDictionary.load(args.target_vocab_file)

        source_lang = args.source_lang or "src"
        target_lang = args.target_lang or "tgt"

        print(f"| [{source_lang}] dictionary: {len(source_dict)} types")
        print(f"| [{target_lang}] dictionary: {len(target_dict)} types")

        use_char_source = (
            (args.char_source_vocab_file != "")
            or (getattr(args, "arch", "") == "char_source")
            or (getattr(args, "arch", "") == "char_source_transformer")
            or getattr(args, "arch", "") == "char_source_hybrid")
        if use_char_source:
            char_source_dict = MaskedLMDictionary.load(
                args.char_source_vocab_file)
            # this attribute is used for CharSourceModel construction
            args.char_source_dict_size = len(char_source_dict)
        else:
            char_source_dict = None

        return cls(args, source_dict, target_dict, char_source_dict)
コード例 #29
0
    def add_args(parser):
        """Add model-specific arguments to the parser."""
        # fmt: off
        parser.add_argument('--dropout', type=float, metavar='D',
                            help='dropout probability')
        parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
                            help='decoder embedding dimension')
        parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
                            help='path to pre-trained decoder embedding')
        parser.add_argument('--decoder-freeze-embed', action='store_true',
                            help='freeze decoder embeddings')
        parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
                            help='decoder hidden size')
        parser.add_argument('--decoder-layers', type=int, metavar='N',
                            help='number of decoder layers')
        parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
                            help='decoder output embedding dimension')
        parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
                            help='comma separated list of adaptive softmax cutoff points. '
                                 'Must be used with adaptive_loss criterion')
        parser.add_argument('--share-embed',
                            type=lambda x: options.eval_bool(x),
                            help='share input and output embeddings')
        parser.add_argument('--is-wordlm', action='store_true',
                            help='whether it is word LM or subword LM. Only '
                            'relevant for ASR decoding with LM, and it determines '
                            'how the underlying decoder instance gets the dictionary'
                            'from the task instance when calling cls.build_model()')

        # Granular dropout settings (if not specified these default to --dropout)
        parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
                            help='dropout probability for decoder input embedding')
        parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
                            help='dropout probability for decoder output')
コード例 #30
0
    def setup_task(cls, args, **kwargs):
        assert pytorch_translate_data.is_multilingual(
            args
        ), "Must set `--task pytorch_translate_multilingual` for multilingual training"
        args.left_pad_source = options.eval_bool(args.left_pad_source)

        def load_dicts(langs, paths):
            dicts = OrderedDict()
            for lang, dict_path in zip(langs, paths):
                d = pytorch_translate_dictionary.Dictionary.load(dict_path)
                dicts[lang] = d
                print(f"| [{lang}] dictionary: {len(d)} types")
            return dicts

        if not hasattr(args, "multiling_source_vocab_file"):
            args.multiling_encoder_lang = args.multiling_source_lang
            args.multiling_source_vocab_file = [args.source_vocab_file]
        if not hasattr(args, "multiling_target_vocab_file"):
            args.multiling_decoder_lang = args.multiling_target_lang
            args.multiling_target_vocab_file = [args.target_vocab_file]

        # Load dictionaries
        src_dicts = load_dicts(
            args.multiling_encoder_lang, args.multiling_source_vocab_file
        )
        tgt_dicts = load_dicts(
            args.multiling_decoder_lang, args.multiling_target_vocab_file
        )

        return cls(args, src_dicts, tgt_dicts)
コード例 #31
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        print(f'| args.data = {args.data}')
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        # assert args.left_pad_source, f'Need left_pad_source True as use EOS as classifcation token'
        assert not args.left_pad_source, f'args.left_pad_source must be False as it the root for classification'

        assert args.source_lang is not None
        if args.source_lang is None:
            args.source_lang = task_utils.infer_language_mono(args.data)

        dict_path = os.path.join(args.data, 'dict.txt')
        if not os.path.exists(dict_path):
            dict_path = os.path.join(args.data, f'dict.{args.source_lang}.txt')

        dictionary = None
        output_dictionary = None
        if args.data:
            dictionary = Dictionary.load(dict_path)
            print('| dictionary: {} types'.format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)

        if args.source_lang is None:
            args.source_lang = task_utils.infer_language_mono(args.data)

        # dict_path = os.path.join(args.data, 'dict.txt')
        # src_dict = Dictionary.load(dict_path)
        # print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        return cls(args, dictionary, output_dictionary)