コード例 #1
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        paths = args.data.split(':')
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(
                paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly')

        # load dictionaries
        src_dict = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        tgt_dict = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang,
                                                   len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                   len(tgt_dict)))

        return cls(args, src_dict, tgt_dict)
コード例 #2
0
    def get_meter(self, name):
        """[deprecated] Get a specific meter by name."""
        from fairseq import meters

        if "get_meter" not in self._warn_once:
            self._warn_once.add("get_meter")
            utils.deprecation_warning(
                "Trainer.get_meter is deprecated. Please use fairseq.metrics instead."
            )

        train_meters = metrics.get_meters("train")
        if train_meters is None:
            train_meters = {}

        if name == "train_loss" and "loss" in train_meters:
            return train_meters["loss"]
        elif name == "train_nll_loss":
            # support for legacy train.py, which assumed this meter is
            # always initialized
            m = train_meters.get("nll_loss", None)
            return m or meters.AverageMeter()
        elif name == "wall":
            # support for legacy train.py, which assumed this meter is
            # always initialized
            m = metrics.get_meter("default", "wall")
            return m or meters.TimeMeter()
        elif name == "wps":
            m = metrics.get_meter("train", "wps")
            return m or meters.TimeMeter()
        elif name in {"valid_loss", "valid_nll_loss"}:
            # support for legacy train.py, which assumed these meters
            # are always initialized
            k = name[len("valid_"):]
            m = metrics.get_meter("valid", k)
            return m or meters.AverageMeter()
        elif name == "oom":
            return meters.AverageMeter()
        elif name in train_meters:
            return train_meters[name]
        return None
コード例 #3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        dictionary = None
        output_dictionary = None
        if args.data:
            paths = args.data.split(':')
            assert len(paths) > 0
            dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
            print('| dictionary: {} types'.format(len(dictionary)))
            output_dictionary = dictionary
            if args.output_dictionary_size >= 0:
                output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)

        # upgrade old checkpoints
        if hasattr(args, 'exclude_self_target'):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, 'self_target', False):
            targets.append('self')
        if getattr(args, 'future_target', False):
            targets.append('future')
        if getattr(args, 'past_target', False):
            targets.append('past')
        if len(targets) == 0:
            # standard language modeling
            targets = ['future']

        return cls(args, dictionary, output_dictionary, targets=targets)
コード例 #4
0
    def reduce_metrics(self, logging_outputs, criterion):
        """Aggregate logging outputs from data parallel training."""
        # backward compatibility for tasks that override aggregate_logging_outputs
        base_func = FairseqTask.aggregate_logging_outputs
        self_func = getattr(self, 'aggregate_logging_outputs').__func__
        if self_func is not base_func:
            utils.deprecation_warning(
                'Tasks should implement the reduce_metrics API. '
                'Falling back to deprecated aggregate_logging_outputs API.')
            agg_logging_outputs = self.aggregate_logging_outputs(
                logging_outputs, criterion)
            for k, v in agg_logging_outputs.items():
                metrics.log_scalar(k, v)
            return

        if not any('ntokens' in log for log in logging_outputs):
            warnings.warn(
                'ntokens not found in Criterion logging outputs, cannot log wpb or wps'
            )
        else:
            ntokens = utils.item(
                sum(log.get('ntokens', 0) for log in logging_outputs))
            metrics.log_scalar('wpb', ntokens, priority=180, round=1)
            metrics.log_speed('wps',
                              ntokens,
                              ignore_first=10,
                              priority=90,
                              round=1)

        if not any('nsentences' in log for log in logging_outputs):
            warnings.warn(
                'nsentences not found in Criterion logging outputs, cannot log bsz'
            )
        else:
            nsentences = utils.item(
                sum(log.get('nsentences', 0) for log in logging_outputs))
            metrics.log_scalar('bsz', nsentences, priority=190, round=1)

        criterion.__class__.reduce_metrics(logging_outputs)
コード例 #5
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).
        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        paths = args.data.split(':')
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])
        if args.source_lang is None or args.target_lang is None:
            raise Exception('Could not infer language pair, please provide it explicitly')

        # load dictionaries
        assert args.target_lang == 'actions', 'target extension must be "actions"'
        args.target_lang_nopos = 'actions_nopos'    # only build dictionary without pointer values
        args.target_lang_pos = 'actions_pos'
        args.target_lang_vocab_nodes = 'actions.vocab.nodes'
        args.target_lang_vocab_others = 'actions.vocab.others'
        src_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.source_lang)))
        # tgt_dict = cls.load_dictionary(os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang_nopos)))

        # NOTE rebuild the dictionary every time
        tgt_dict = cls.build_dictionary_bart_extend(
            node_freq_min=args.node_freq_min,
            node_file_path=os.path.join(paths[0], args.target_lang_vocab_nodes),
            others_file_path=os.path.join(paths[0], args.target_lang_vocab_others)
            )

        # TODO target dictionary 'actions_nopos' is hard coded now; change it later
        assert src_dict.pad() == tgt_dict.pad()
        assert src_dict.eos() == tgt_dict.eos()
        assert src_dict.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
        print('| [{}] dictionary: {} types'.format(args.target_lang_nopos, len(tgt_dict)))

        # ========== load the pretrained BART model ==========
        if getattr(args, 'arch', None):
            # training time: pretrained BART needs to be used for initialization
            if 'bart_base' in args.arch or 'bartsv_base' in args.arch:
                print('-' * 10 + ' loading pretrained bart.base model ' + '-' * 10)
                bart = torch.hub.load('pytorch/fairseq', 'bart.base')
            elif 'bart_large' in args.arch or 'bartsv_large' in args.arch:
                print('-' * 10 + 'loading pretrained bart.large model ' + '-' * 10)
                bart = torch.hub.load('pytorch/fairseq', 'bart.large')
            else:
                raise ValueError
        else:
            # inference time: pretrained BART is only used for dictionary related things; size does not matter
            # NOTE size does matter; update this later in model initialization if model is with "bart.large"
            print('-' * 10 + ' (for bpe vocab and embed size at inference time) loading pretrained bart.base model '
                  + '-' * 10)
            bart = torch.hub.load('pytorch/fairseq', 'bart.base')

        bart.eval()    # the pretrained BART model is only for assistance
        # ====================================================

        return cls(args, src_dict, tgt_dict, bart)
コード例 #6
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """

        print("These are the arguments", args)
        print("\n")

        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        dictionary = None
        output_dictionary = None

        def read_config(path):
            with open(path) as config:
                import yaml
                contents = config.read()
                data = yaml.load(contents)
                return data

        # if args.data:
        #     paths = args.data.split(':')
        #     assert len(paths) > 0
        #     dictionary = Dictionary.load(os.path.join(paths[0], 'dict.txt'))
        #     print('| dictionary: {} types'.format(len(dictionary)))
        #     output_dictionary = dictionary
        #     if args.output_dictionary_size >= 0:
        #         output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)

        ######################### CHANGE THIS #################################################
        args.data = read_config(
            "/content/drive/My Drive/IIIT-H RA/ICON/fairseq-working/config.yaml"
        )
        ########################################################################################
        if (args.data):
            path = args.data['dictionary']['src']
            dictionary = Dictionary.load(path)
            print('| dictionary: {} types'.format(len(dictionary)))
            output_dictionary = dictionary
            if (args.output_dictionary_size >= 0):
                output_dictionary = TruncatedDictionary(
                    dictionary, args.output_dictionary_size)

        # upgrade old checkpoints
        if hasattr(args, 'exclude_self_target'):
            args.self_target = not args.exclude_self_target

        targets = []
        if getattr(args, 'self_target', False):
            targets.append('self')
        if getattr(args, 'future_target', False):
            targets.append('future')
        if getattr(args, 'past_target', False):
            targets.append('past')
        if len(targets) == 0:
            # standard language modeling
            targets = ['future']

        return cls(args,
                   args.data,
                   dictionary,
                   output_dictionary,
                   targets=targets)
コード例 #7
0
ファイル: model.py プロジェクト: dangss/Align-SCA
 def __init__(self, *args, **kwargs):
     super().__init__(*args, **kwargs)
     utils.deprecation_warning('hello', stacklevel=4)
コード例 #8
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)
        if getattr(args, 'raw_text', False):
            utils.deprecation_warning(
                '--raw-text is deprecated, please use --dataset-impl=raw')
            args.dataset_impl = 'raw'
        elif getattr(args, 'lazy_load', False):
            utils.deprecation_warning(
                '--lazy-load is deprecated, please use --dataset-impl=lazy')
            args.dataset_impl = 'lazy'

        paths = args.data.split(':')
        assert len(paths) > 0
        # find language pair automatically
        if args.source_lang is None or args.target_lang is None:
            raise Exception(
                'Could not infer language pair, please provide it explicitly'
            )  # See comment below?
            # TRY ADDING: --source-lang true_w,true_p --target-lang reco_w
            # OR ENABLING BELOW HARDCODED DEFAULTS.
            # args.source_lang = "true_w,true_p"
            # args.target_lang = "reco_w"
        #     args.source_lang, args.target_lang = data_utils.infer_language_pair(paths[0])

        assert type(args.source_lang) == str
        if ',' not in args.source_lang:
            raise Exception(
                "source-lang is " + args.source_lang +
                " source-lang needs to contain two comma separated strings")
        # load dictionaries
        src_lang1, src_lang2 = args.source_lang.split(',')
        src_dict1 = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(src_lang1)))
        src_dict2 = cls.load_dictionary(
            os.path.join(paths[0], 'dict.{}.txt'.format(src_lang2)))
        if ',' in args.target_lang:
            assert args.criterion == 'cross_entropy_dual'
            target_lang, target_lang_extra = args.target_lang.split(',')
            tgt_dict = cls.load_dictionary(
                os.path.join(paths[0], 'dict.{}.txt'.format(target_lang)))
            tgt_dict_extra = cls.load_dictionary(
                os.path.join(paths[0],
                             'dict.{}.txt'.format(target_lang_extra)))
            dual_decoder = True
        else:
            tgt_dict = cls.load_dictionary(
                os.path.join(paths[0], 'dict.{}.txt'.format(args.target_lang)))
            dual_decoder = False
            tgt_dict_extra = None

        assert src_dict1.pad() == tgt_dict.pad()
        assert src_dict1.eos() == tgt_dict.eos()
        assert src_dict1.unk() == tgt_dict.unk()
        assert src_dict2.pad() == tgt_dict.pad()
        assert src_dict2.eos() == tgt_dict.eos()
        assert src_dict2.unk() == tgt_dict.unk()
        print('| [{}] dictionary: {} types'.format(src_lang1, len(src_dict1)))
        print('| [{}] dictionary: {} types'.format(src_lang2, len(src_dict2)))

        if dual_decoder:
            print('| [{}] dictionary: {} types'.format(target_lang,
                                                       len(tgt_dict)))
            print('| [{}] dictionary: {} types'.format(target_lang_extra,
                                                       len(tgt_dict_extra)))
        else:
            print('| [{}] dictionary: {} types'.format(args.target_lang,
                                                       len(tgt_dict)))
        return cls(args, src_dict1, src_dict2, tgt_dict, tgt_dict_extra)