def report(world, stats):
     report = world.report()
     log = {
         'word_offenses':
         stats['bad_words_cnt'],
         'classifier_offenses%':
         100 * (stats['classifier_offensive'] / stats['total']),
         'string_offenses%':
         100 * (stats['string_offensive'] / stats['total']),
         'total_offenses%':
         100 * (stats['total_offensive'] / stats['total']),
     }
     text, log = log_time.log(report['exs'], world.num_examples(), log)
     logging.info(text)
     return log
示例#2
0
    def run(self):
        self.opt['no_cuda'] = True
        if 'ordered' not in self.opt['datatype'] and 'train' in self.opt[
                'datatype']:
            self.opt['datatype'] = self.opt['datatype'] + ':ordered'
        agent = create_agent(self.opt)
        agent.opt.log()
        num_examples = self.opt['num_examples']
        field = self.opt['field'] + '_vec'
        if num_examples < 0:
            num_examples = float('inf')
        assert self.opt['batchsize'] == 1
        assert isinstance(agent, TorchAgent)

        world = create_task(self.opt, agent)
        teacher = world.get_task_agent()

        # set up logging
        log_every_n_secs = self.opt.get('log_every_n_secs', -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float('inf')
        log_time = TimeLogger()

        lengths = []

        cnt = 0
        total = min(teacher.num_examples(), num_examples)
        while not teacher.epoch_done() and cnt < num_examples:
            act = teacher.act()
            processed = agent.observe(act)
            try:
                text_vec = processed[field]
            except KeyError:
                raise KeyError(f"Pick one of {list(processed.keys())}")
            if text_vec is not None and (not self.opt['final_only']
                                         or act.get('episode_done')):
                cnt += 1
                lengths.append(float(len(text_vec)))
            agent.self_observe({})

            if log_time.time() > log_every_n_secs:
                report = self._compute_stats(lengths)
                text, report = log_time.log(report['exs'], total, report)
                logging.info(text)

        report = self._compute_stats(lengths)
        print(nice_report(report))
        return report
示例#3
0
def slurm_distributed_context(opt):
    """
    Initialize a distributed context, using the SLURM environment.

    Does some work to read the environment to find a list of participating nodes
    and the main node.

    :param opt:
        Command line options.
    """
    # We can determine the init method automatically for Slurm.
    # double check we're using SLURM
    node_list = os.environ.get('SLURM_JOB_NODELIST')
    if node_list is None:
        raise RuntimeError(
            'Does not appear to be in a SLURM environment. '
            'You should not call this script directly; see launch_distributed.py'
        )
    try:
        # Figure out the main host, and which rank we are.
        hostnames = subprocess.check_output(
            ['scontrol', 'show', 'hostnames', node_list])
    except FileNotFoundError as e:
        # Slurm is not installed
        raise RuntimeError(
            f'SLURM does not appear to be installed. Missing file: {e.filename}'
        )

    main_host = hostnames.split()[0].decode('utf-8')
    distributed_rank = int(os.environ['SLURM_PROCID'])
    if opt.get('model_parallel'):
        # -1 signals to multiprocessing_train to use all GPUs available.
        # (A value of None signals to multiprocessing_train to use the GPU
        # corresponding to the rank.
        device_id = -1
    else:
        device_id = int(os.environ['SLURM_LOCALID'])
    port = opt['port']
    logging.info(
        f'Initializing host {socket.gethostname()} as rank {distributed_rank}, '
        f'main is {main_host}')
    # Begin distributed training
    with distributed_context(distributed_rank,
                             opt,
                             0,
                             device_id,
                             init_method=f"tcp://{main_host}:{port}") as opt:
        yield opt
示例#4
0
文件: agents.py 项目: swycha/ParlAI
def create_agent(opt: Opt, requireModelExists=False):
    """
    Create an agent from the options ``model``, ``model_params`` and ``model_file``.

    The input is either of the form
    ``parlai.agents.ir_baseline.agents:IrBaselineAgent`` (i.e. the path
    followed by the class name) or else just ``ir_baseline`` which
    assumes the path above, and a class name suffixed with 'Agent'.

    If ``model-file`` is available in the options this function can also
    attempt to load the model from that location instead. This avoids having to
    specify all the other options necessary to set up the model including its
    name as they are all loaded from the options file if it exists (the file
    opt['model_file'] + '.opt' must exist and contain a pickled or json dict
    containing the model's options).
    """
    if opt.get('datapath', None) is None:
        add_datapath_and_model_args(opt)

    if opt.get('model_file'):
        opt['model_file'] = modelzoo_path(opt.get('datapath'),
                                          opt['model_file'])
        if requireModelExists and not os.path.isfile(opt['model_file']):
            raise RuntimeError(
                'WARNING: Model file does not exist, check to make '
                'sure it is correct: {}'.format(opt['model_file']))
        # Attempt to load the model from the model file first (this way we do
        # not even have to specify the model name as a parameter)
        model = create_agent_from_opt_file(opt)
        if model is not None:
            return model
        else:
            logging.info(
                f"No model with opt yet at: {opt['model_file']}(.opt)")

    if opt.get('model'):
        model_class = load_agent_module(opt['model'])
        # if we want to load weights from --init-model, compare opts with
        # loaded ones
        compare_init_model_opts(opt, opt)
        model = model_class(opt)
        if requireModelExists and hasattr(
                model, 'load') and not opt.get('model_file'):
            # double check that we didn't forget to set model_file on loadable model
            logging.warn('model_file unset but model has a `load` function.')
        return model
    else:
        raise RuntimeError('Need to set `model` argument to use create_agent.')
    def save_data(self, data: Tuple[torch.Tensor, List[str]]):
        """
        Save data.

        :param data:
            encoded passages, and corresponding ids
        """
        encoding, ids = data
        assert len(ids) == encoding.size(0)
        embs_outfile = f"{self.opt['outfile']}_{self.opt['shard_id']}.pt"
        logging.info(f'Writing results to {embs_outfile}')
        torch.save(encoding, embs_outfile)
        outdir = os.path.split(self.opt['outfile'])[0]
        ids_outfile = os.path.join(outdir, f"ids_{self.opt['shard_id']}")
        logging.info(f'Writing ids to {ids_outfile}')
        torch.save(ids, ids_outfile)
示例#6
0
def _human_nonadv_safety_eval_datapath(opt: Opt) -> str:
    """
    Return the filepath for the specified datatype of the specified human evaluation
    task on non adversarial dialogue.
    """
    build_human_nonadv_safety_eval_dataset(opt)
    # Build the data if it doesn't exist.
    logging.info(
        f'The data for human non-adversarial safety evaluation is test set only '
        f'regardless of your chosen datatype, which is {opt["datatype"]} ')
    data_path = os.path.join(
        get_human_nonadv_safety_eval_folder(opt['datapath']),
        'human_nonadv_safety_eval',
        'test.txt',
    )
    return data_path
示例#7
0
    def add(self, vectors: List[torch.Tensor]):
        """
        Add vectors to index, using the CPU.

        :param vectors:
            vectors to add.
        """
        start = time.time()
        for i, vecs in enumerate(grouper(vectors, self.span, None)):
            vec = torch.cat([v for v in vecs if v is not None])
            logging.info(
                f'Adding data {(i+1)}/{math.ceil(len(vectors) / self.span)} of shape: {vec.shape}'
            )
            self.index.add(vec.float().numpy())
            logging.info(
                f'{time.time() - start}s Elapsed: adding complete for {i+1}')
示例#8
0
def validate_onboarding(data):
    """
    Check the contents of the data to ensure they are valid.
    """
    logging.info(f"Validating onboarding data {data}")
    messages = data['outputs']['messages']
    if len(messages) == 0:
        return False
    status_message = messages[-2]
    if status_message is None:
        return False
    submitted_data = status_message.get('data')
    if submitted_data is None:
        return False
    final_status = submitted_data.get('final_status')
    return final_status == ONBOARD_SUCCESS
示例#9
0
def build(datapath):
    version = 'v1.0'
    dpath = os.path.join(datapath, 'genderation_bias')
    if not build_data.built(dpath, version):
        logging.info('[building data: ' + dpath + ']')
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in RESOURCES:
            downloadable_file.download_file(dpath)

        # Mark the data as built.
        build_data.mark_done(dpath, version)
示例#10
0
def download_models(
    opt,
    fnames,
    model_folder,
    version='v1.0',
    path='aws',
    use_model_type=False,
    flatten_tar=False,
):
    """
    Download models into the ParlAI model zoo from a url.

    :param fnames: list of filenames to download
    :param model_folder: models will be downloaded into models/model_folder/model_type
    :param path: url for downloading models; defaults to downloading from AWS
    :param use_model_type: whether models are categorized by type in AWS
    """
    model_type = opt.get('model_type', None)
    if model_type is not None:
        dpath = os.path.join(opt['datapath'], 'models', model_folder,
                             model_type)
    else:
        dpath = os.path.join(opt['datapath'], 'models', model_folder)

    if not built(dpath, version):
        for fname in fnames:
            logging.info(f'building data: {dpath}/{fname}')
        if built(dpath):
            # An older version exists, so remove these outdated files.
            remove_dir(dpath)
        make_dir(dpath)

        # Download the data.
        for fname in fnames:
            if path == 'aws':
                url = 'http://parl.ai/downloads/_models/'
                url += model_folder + '/'
                if use_model_type:
                    url += model_type + '/'
                url += fname
            else:
                url = path + '/' + fname
            download(url, dpath, fname)
            if '.tgz' in fname or '.gz' in fname or '.zip' in fname:
                untar(dpath, fname, flatten_tar=flatten_tar)
        # Mark the data as built.
        mark_done(dpath, version)
示例#11
0
    def set_vocab_candidates(self, shared):
        """
        Load the tokens from the vocab as candidates.

        self.vocab_candidates will contain a [num_cands] list of strings
        self.vocab_candidate_vecs will contain a [num_cands, 1] LongTensor
        """
        if shared:
            self.vocab_candidates = shared['vocab_candidates']
            self.vocab_candidate_vecs = shared['vocab_candidate_vecs']
            self.vocab_candidate_encs = shared['vocab_candidate_encs']
        else:
            if 'vocab' in (self.opt['candidates'],
                           self.opt['eval_candidates']):
                cands = []
                vecs = []
                for ind in range(1, len(self.dict)):
                    cands.append(self.dict.ind2tok[ind])
                    vecs.append(ind)
                self.vocab_candidates = cands
                self.vocab_candidate_vecs = torch.LongTensor(vecs).unsqueeze(1)
                logging.info(
                    "Loaded fixed candidate set (n = {}) from vocabulary"
                    "".format(len(self.vocab_candidates)))
                if self.use_cuda:
                    self.vocab_candidate_vecs = self.vocab_candidate_vecs.cuda(
                    )

                if self.encode_candidate_vecs:
                    # encode vocab candidate vecs
                    self.vocab_candidate_encs = self._make_candidate_encs(
                        self.vocab_candidate_vecs)
                    if self.use_cuda:
                        self.vocab_candidate_encs = self.vocab_candidate_encs.cuda(
                        )
                    if self.fp16:
                        self.vocab_candidate_encs = self.vocab_candidate_encs.half(
                        )
                    else:
                        self.vocab_candidate_encs = self.vocab_candidate_encs.float(
                        )
                else:
                    self.vocab_candidate_encs = None
            else:
                self.vocab_candidates = None
                self.vocab_candidate_vecs = None
                self.vocab_candidate_encs = None
示例#12
0
    def load_data(self, opt: Opt,
                  filename: str) -> Optional[List[List[Message]]]:
        """
        Attempt to load pre-build data.

        Checks for the most recently build data via the date string.

        :param opt:
            options dict
        :param filename:
            name of (potentially) saved data

        :return episodes:
            return list of episodes, if available
        """
        # first check for the most recent date
        save_dir = self._get_save_path(opt['datapath'], '*')
        all_dates = []
        for fname in glob.glob(os.path.join(save_dir, filename)):
            date = os.path.split(fname)[0].split('_')[-1]
            all_dates.append(date)

        if len(all_dates) > 0:
            most_recent = os.path.join(
                self._get_save_path(opt['datapath'],
                                    sorted(all_dates)[-1]), filename)
        else:
            # data has not been built yet
            return None

        if opt['invalidate_cache']:
            # invalidate the cache and remove the existing data
            logging.warning(
                f' [ WARNING: invalidating cache at {self.save_path} and rebuilding the data. ]'
            )
            if self.save_path == most_recent:
                os.remove(self.save_path)
            return None

        # Loading from most recent date
        self.save_path = most_recent
        logging.info(
            f' [ Data already exists. Loading from: {self.save_path} ]')
        with PathManager.open(self.save_path, 'rb') as f:
            data = json.load(f)

        return data
示例#13
0
    def log(self):
        """
        Output a training log entry.
        """
        opt = self.opt
        if opt['display_examples']:
            print(self.world.display() + '\n~~')
        logs = []
        # get report
        train_report = self.world.report()
        train_report = self._sync_metrics(train_report)
        self.world.reset_metrics()

        train_report_trainstats = dict_report(train_report)
        train_report_trainstats['total_epochs'] = self._total_epochs
        train_report_trainstats['total_exs'] = self._total_exs
        train_report_trainstats['parleys'] = self.parleys
        train_report_trainstats['train_steps'] = self._train_steps
        train_report_trainstats['train_time'] = self.train_time.time()
        self.train_reports.append(train_report_trainstats)

        # time elapsed
        logs.append(f'time:{self.train_time.time():.0f}s')
        logs.append(f'total_exs:{self._total_exs}')
        logs.append(f'total_steps:{self._train_steps}')

        if self._total_epochs >= 0:
            # only if it's unbounded
            logs.append(f'epochs:{self._total_epochs:.2f}')

        time_left = self._compute_eta(
            self._total_epochs, self.train_time.time(), self._train_steps
        )
        if time_left is not None:
            logs.append(f'time_left:{max(0,time_left):.0f}s')

        log = '{}\n{}\n'.format(' '.join(logs), nice_report(train_report))
        logging.info(log)
        self.log_time.reset()
        self._last_log_steps = 0

        if opt['tensorboard_log'] and is_primary_worker():
            self.tb_logger.log_metrics('train', self.parleys, train_report)
        if opt['wandb_log'] and is_primary_worker():
            self.wb_logger.log_metrics('train', self.parleys, train_report)

        return train_report
示例#14
0
    def __init__(self, opt, agent, bots, context_info: Optional[dict] = None):

        # num_turns turns for a single side, and really it appears to be
        # (num_turns + 1) * 2 total b/c of the "Hi!" and first bot utterance

        # TODO: this logic is very close to that of BaseModelChatWorld.__init__(). Can
        #  any of this be deduplicated?

        super(BaseModelChatWorld, self).__init__(opt, agent)

        num_turns = opt['num_turns']
        max_resp_time = opt['max_resp_time']

        self.opt = opt
        self.bots = bots

        self.task_turn_idx = 0
        self.num_turns = num_turns

        self.dialog = []
        self.tag = f'conversation_id {agent.mephisto_agent.db_id}'
        self.task_type = 'sandbox' if opt['is_sandbox'] else 'live'
        self.chat_done = False
        self.check_acceptability = opt['check_acceptability']
        self.acceptability_checker = AcceptabilityChecker()
        self.block_qualification = opt['block_qualification']

        self.final_chat_data = None
        # TODO: remove this attribute once chat data is only stored in the Mephisto
        #  TaskRun for this HIT (see .get_custom_task_data() docstring for more
        #  information)

        # below are timeout protocols
        self.max_resp_time = max_resp_time  # in secs
        logging.info(
            f'Creating {self.__class__.__name__} for tag {self.tag} with {num_turns} turns.'
        )

        if context_info is not None:
            self.context_info = context_info
            self.personas = [
                self.context_info['persona_1_strings'],
                self.context_info['persona_2_strings'],
            ]
        else:
            self.context_info = {}
            self.personas = None
示例#15
0
 def init_search_query_generator(self, opt) -> TorchGeneratorAgent:
     model_file = opt['search_query_generator_model_file']
     logging.info('Loading search generator model')
     logging.disable()
     search_query_gen_agent = create_agent_from_model_file(
         model_file,
         opt_overrides={
             'skip_generation': False,
             'inference': opt['search_query_generator_inference'],
             'beam_min_length': opt['search_query_generator_beam_min_length'],
             'beam_size': opt['search_query_generator_beam_size'],
             'text_truncate': opt['search_query_generator_text_truncate'],
         },
     )
     logging.enable()
     logging.info('Search query generator model loading completed!')
     return search_query_gen_agent
示例#16
0
    def deserialize_from(self, file: str, emb_path: Optional[str] = None):
        """
        Deserialize index from file.

        :param file:
            input file
        :param emb_path:
            optional path to embeddings
        """
        logging.info(f'Loading index from {file}')

        if os.path.isdir(file):
            index_file = os.path.join(file, "index")
            meta_file = os.path.join(file, "index_meta")
        elif not file.endswith('.index'):
            index_file = f'{file}.index'
            meta_file = f'{file}.index_meta'
        else:
            index_file = file
            meta_file = f'{index_file}_meta'

        self.index = self.faiss.read_index(index_file)
        logging.info(
            f'Loaded index of type {self.index} and size {self.index.ntotal}')

        if os.path.exists(meta_file):
            self.index_id_to_db_id = torch.load(meta_file)
        else:
            index_dir = os.path.split(
                file)[0] if emb_path is None else emb_path
            if not os.path.isdir(index_dir):
                # if emb_path has the embeddings name in there, need to split.
                index_dir = os.path.split(index_dir)[0]
            meta_files = [
                f for f in os.listdir(index_dir) if f.startswith('ids_')
            ]
            meta_files = sorted(meta_files,
                                key=lambda x: int(x.split('_')[-1]))
            for f in meta_files:
                ids = torch.load(os.path.join(index_dir, f))
                self.index_id_to_db_id.extend(ids)
            torch.save(self.index_id_to_db_id, meta_file)
        assert (
            len(self.index_id_to_db_id) == self.index.ntotal
        ), 'Deserialized index_id_to_db_id should match faiss index size '
        f'{len(self.index_id_to_db_id)} != {self.index.ntotal}'
示例#17
0
文件: run.py 项目: sagar-spkt/ParlAI
def remove_overused_persona(personas: List[str], persona_use_count: Dict[str,
                                                                         int],
                            max_persona_use: int):
    """
    Removes personas that were used too often from the list of personas.
    """
    if not max_persona_use or not persona_use_count:
        return personas
    cleaned_personas = []
    for p in personas:
        if persona_use_count[p.lower()] < max_persona_use:
            cleaned_personas.append(p)
    logging.info(
        f'{len(cleaned_personas)} out of {len(personas)} personas accepted for use, '
        f'based on use count being less than maximum allowed of {max_persona_use}'
    )
    return cleaned_personas
示例#18
0
def extract_entities(
    sentence: str,
    pos: Tuple[str] = ('PROPN', 'NOUN'),
    use_named_entities: bool = True,
    use_noun_chunks: bool = True,
) -> List[str]:
    """
    Given a sentence, extract the entities from the sentence.

    :param sentence:
        provided sentence
    :param pos:
        parts of speech to look at
    :param use_named_entities:
        whether to include named entities
    :param use_noun_chunks:
        whether to include noun chunks.

    :return entities:
        return list of entities.
    """
    global nlp
    if nlp is None:
        logging.info('Loading spacy once')
        try:
            assert spacy is not None
            nlp = spacy.load("en_core_web_sm")
        except Exception:
            raise RuntimeError(
                'Please download: python -m spacy download en_core_web_sm'
            )
    doc = nlp(sentence)
    results = []
    if pos:
        for token in doc:
            if token.pos_ in pos:
                results.append(token)
    if use_named_entities:
        for ent in doc.ents:
            results.append(ent)
    if use_noun_chunks:
        for chunk in doc.noun_chunks:
            if chunk.text.lower() not in STOP_WORDS:
                results.append(chunk)
    results = list(set([r.text for r in results]))
    return results
示例#19
0
    def index_data(self, data: List[torch.Tensor]):
        """
        Index data.

        :param data:
            list of (db_id, np.vector) tuples
        """
        start = time.time()
        assert isinstance(data, list)
        logging.info(f'Indexing {sum(v.size(0) for v in data)} vectors')
        # First, train
        self.train(data)

        # then, Add
        self.add(data)
        logging.info(
            f'Indexing complete; total time elapsed: {time.time() - start}')
示例#20
0
def build(opt):
    dpath = os.path.join(opt['datapath'], CONST.DATASET_NAME)
    version = '1.0'
    if not build_data.built(dpath, version):
        logging.info(
            f'[building data: {dpath}]\nThis may take a while but only heppens once.'
        )
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        DATASET_FILE.download_file(dpath)
        logging.info('Finished downloading dataset files successfully.')

        build_data.mark_done(dpath, version)
    def index_data(self, input_files: List[str], add_only: bool = False):
        """
        Index data.

        :param input_files:
            files to load.
        """
        all_docs = []
        for in_file in input_files:
            logging.info(f'Reading file {in_file}')
            docs = torch.load(in_file)
            if isinstance(docs, list):
                all_docs += docs
            else:
                all_docs.append(docs)

        self.index.index_data(all_docs)
示例#22
0
def download(datapath):
    dpath = os.path.join(datapath, 'models/hallucination/wiki_passages')
    fname = 'psgs_w100.tsv.gz'
    gzip_file = os.path.join(dpath, fname)
    new_file = os.path.join(dpath, fname.replace('.gz', ''))
    version = 'v1.0'
    if not built(dpath, version):
        os.makedirs(dpath)
        download_path(path, dpath, fname)
        input = gzip.GzipFile(gzip_file, "rb")
        s = input.read()
        input.close()
        output = open(new_file, "wb")
        output.write(s)
        output.close()
        logger.info(f" Saved to {new_file}")
        mark_done(dpath, version)
示例#23
0
    def save(self, filename=None, append=False, sort=True):
        """
        Save dictionary to file.

        Format is 'token<TAB>count' for every token in the dictionary, sorted
        by count with the most frequent words first.

        If ``append`` (default ``False``) is set to ``True``, appends instead of
        overwriting.

        If ``sort`` (default ``True``), then first sort the dictionary before saving.
        """
        filename = self.opt['dict_file'] if filename is None else filename
        make_dir(os.path.dirname(filename))

        if self.tokenizer in ['bpe', 'gpt2', 'bytelevelbpe', 'slow_bytelevel_bpe']:
            needs_removal = self.bpe.finalize(
                self.freq, num_symbols=self.maxtokens, minfreq=self.minfreq
            )
            if needs_removal:
                self._remove_non_bpe()
            elif filename != self.opt.get('dict_file'):
                # need to copy over the old codecs file
                self.bpe.copy_codecs_file(filename + '.codecs')
            if sort and self.bpe.should_sort():
                self.sort(trim=False)
        elif sort:
            self.sort(trim=True)

        logging.info(f'Saving dictionary to {filename}')

        mode = 'a' if append else 'w'
        with PathManager.open(filename, mode, encoding='utf-8') as write:
            for i in self.ind2tok.keys():
                tok = self.ind2tok[i]
                cnt = self.freq[tok]
                write.write('{tok}\t{cnt}\n'.format(tok=escape(tok), cnt=cnt))

        # save opt file
        with PathManager.open(filename + '.opt', 'w', encoding='utf-8') as handle:
            json.dump(self.opt, handle, indent=4)
        # save the byte level bpe model file as well
        if self.tokenizer == 'bytelevelbpe' or self.tokenizer == 'slow_bytelevel_bpe':
            # This saves filename-vocab.json and filename-merges.txt as
            # hugging face tokenizer does
            self.bpe.save(os.path.dirname(filename), os.path.basename(filename))
示例#24
0
def eval_model(opt, print_parser=None):
    """
    Evaluates a model.

    :param opt: tells the evaluation function how to run
    :param bool print_parser: if provided, prints the options that are set within the
        model after loading the model
    :return: the final result of calling report()
    """
    random.seed(42)
    if 'train' in opt['datatype'] and 'evalmode' not in opt['datatype']:
        raise ValueError(
            'You should use --datatype train:evalmode if you want to evaluate on '
            'the training set.')

    if opt['save_world_logs'] and not opt['report_filename']:
        raise RuntimeError(
            'In order to save model replies, please specify the save path '
            'with --report-filename')

    # load model and possibly print opt
    agent = create_agent(opt, requireModelExists=True)
    if print_parser:
        # show args after loading model
        print_parser.opt = agent.opt
        print_parser.print_args()

    tasks = opt['task'].split(',')
    reports = []
    for task in tasks:
        task_report = _eval_single_world(opt, agent, task)
        reports.append(task_report)

    report = aggregate_named_reports(dict(zip(tasks, reports)),
                                     micro_average=opt.get(
                                         'aggregate_micro', False))

    # print announcments and report
    print_announcements(opt)
    logging.info(
        f'Finished evaluating tasks {tasks} using datatype {opt.get("datatype")}'
    )

    print(nice_report(report))
    _save_eval_stats(opt, report)
    return report
示例#25
0
def build(opt):
    version = 'v5.0'
    dpath = os.path.join(opt['datapath'], 'ConvAI2')

    if not build_data.built(dpath, version):
        logging.info('building data: ' + dpath)
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in RESOURCES:
            downloadable_file.download_file(dpath)

        # Mark the data as built.
        build_data.mark_done(dpath, version)
示例#26
0
文件: safety.py 项目: vinhngx/ParlAI
        def build():
            version = 'v1.0'
            dpath = os.path.join(self.datapath, 'OffensiveLanguage')
            if not build_data.built(dpath, version):
                logging.info(f'building data: {dpath}')
                if build_data.built(dpath):
                    # An older version exists, so remove these outdated files.
                    build_data.remove_dir(dpath)
                build_data.make_dir(dpath)

                # Download the data.
                fname = 'OffensiveLanguage.txt'
                url = 'http://parl.ai/downloads/offensive_language/' + fname
                build_data.download(url, dpath, fname)

                # Mark the data as built.
                build_data.mark_done(dpath, version)
示例#27
0
    def run_generation(self):
        """
        Actually run the evaluations.
        """
        # set up logging
        log_every_n_secs = self.opt.get('log_every_n_secs', -1)
        if log_every_n_secs <= 0:
            log_every_n_secs = float('inf')
        log_time = TimeLogger()

        # max number of examples to evaluate
        max_cnt = (
            self.opt['num_examples'] if self.opt['num_examples'] > 0 else float('inf')
        )
        self.cnt = 0
        self.n_valid = 0
        self.log_count = 0
        total_cnt = self.world.num_examples()

        while not self.world.epoch_done() and self.cnt < max_cnt:
            self.cnt += self.opt.get('batchsize', 1)
            self.world.parley()
            acts = self.world.get_acts()
            if acts[-1]['text'] != INVALID:
                try:
                    self.world.acts[0]['text'] += f"\n{acts[-1]['knowledge']}"
                except RuntimeError:
                    self.world.acts[0].force_set(
                        'text', f"{self.world.acts[0]['text']}\n{acts[-1]['knowledge']}"
                    )
                self.world.acts[0]['f1_overlap'] = acts[-1]['f1_overlap']
                self.world_logger.log(self.world)
                self.n_valid += 1
                if (
                    self.n_valid > 0
                    and self.n_valid % self.opt['write_every_n_valid_exs'] == 0
                ):
                    self.log()
            if log_time.time() > log_every_n_secs:
                report = self.world.report()
                report['n_valid'] = self.n_valid
                text, report = log_time.log(
                    report.get('exs', 0), min(max_cnt, total_cnt), report
                )
                logging.info(text)
示例#28
0
def create_task(opt: Opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"`` e.g. ``"babi:Task1k:1"``
    or ``"#babi-1k"`` or ``"#QA"``, see ``parlai/tasks/tasks.py`` and see
    ``parlai/tasks/task_list.py`` for list of tasks.
    """
    task = opt.get('task')
    if not task:
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    logging.info(f"creating task(s): {opt['task']}")

    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if DatatypeHelper.is_training(
            opt['datatype']) and opt.get('num_workers', 0) > 0:
        # note that we never use Background preprocessing in the valid/test
        # worlds, as we are unable to call Teacher.observe(model_act) in BG
        # preprocessing, so we are unable to compute Metrics or accurately
        # differentiate MultiWorld stats.
        world = BackgroundDriverWorld(opt, world)
    elif opt.get('batchsize', 1) > 1 and opt.get('dynamic_batching'):
        world = DynamicBatchWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
示例#29
0
 def __init__(self, opt: Opt):
     self.opt = opt
     self.agents = []
     self.agent_dict = None
     self.generations = []
     self.input_type = 'Search'
     self.knowledge_access_method = KnowledgeAccessMethod(
         opt['knowledge_access_method'])
     model_file = modelzoo_path(opt['datapath'],
                                opt['query_generator_model_file'])
     if (self.knowledge_access_method is KnowledgeAccessMethod.SEARCH_ONLY
             and 'blenderbot2/query_generator/model' in model_file):
         raise ValueError(
             'You cannot use the blenderbot2 query generator with search_only. Please '
             'consider setting --query-generator-model-file zoo:sea/bart_sq_gen/model '
             'instead.')
     if model_file and os.path.exists(model_file):
         logging.info(f'Building Query Generator from file: {model_file}')
         logging.disable()
         overrides: Dict[str, Any] = {'skip_generation': False}
         overrides['inference'] = opt['query_generator_inference']
         overrides['beam_size'] = opt.get('query_generator_beam_size', 3)
         overrides['beam_min_length'] = opt.get(
             'query_generator_beam_min_length', 2)
         overrides['model_parallel'] = opt['model_parallel']
         overrides['no_cuda'] = opt['no_cuda']
         if self.opt['query_generator_truncate'] > 0:
             overrides['text_truncate'] = self.opt[
                 'query_generator_truncate']
             overrides['truncate'] = self.opt['query_generator_truncate']
         base_agent = create_agent_from_model_file(model_file,
                                                   opt_overrides=overrides)
         assert isinstance(base_agent, TorchAgent)
         self.agents = [base_agent]
         bsz = max(
             opt.get('batchsize') or 1,
             opt.get('eval_batchsize') or 1)
         rag_turn_n_turns = opt.get('rag_turn_n_turns', 1)
         if bsz > 1 or rag_turn_n_turns > 1:
             self.agents += [
                 create_agent_from_shared(self.agents[0].share())
                 for _ in range((bsz * rag_turn_n_turns) - 1)
             ]
         self.agent_dict = self.agents[0].build_dictionary()
         logging.enable()
示例#30
0
    def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False):
        """
        Eval on validation/test data.

        :param valid_world:
            list of the pre-created validation worlds.
        :param opt:
            the options that specific the task, eval_task, etc
        :param datatype:
            the datatype to use, such as "valid" or "test"
        :param bool write_log:
            specifies to write metrics to file if the model_file is set
        :param int max_exs:
            limits the number of examples if max_exs > 0
        """

        logging.info(f'running eval: {datatype}')
        timer = Timer()
        reports = []

        max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
        for v_world in valid_worlds:
            task_report = self._run_single_eval(opt, v_world, max_exs_per_worker)
            reports.append(task_report)

        tasks = [world.getID() for world in valid_worlds]
        named_reports = dict(zip(tasks, reports))
        report = aggregate_named_reports(
            named_reports, micro_average=self.opt.get('aggregate_micro', False)
        )
        # get the results from all workers
        report = self._sync_metrics(report)

        metrics = f'{datatype}:\n{nice_report(report)}\n'
        logging.info(f'eval completed in {timer.time():.2f}s')
        logging.report(metrics)

        # write to file
        if write_log and opt.get('model_file') and is_primary_worker():
            # Write out metrics
            f = open(opt['model_file'] + '.' + datatype, 'a+')
            f.write(f'{metrics}\n')
            f.close()

        return report