Exemplo n.º 1
0
def cl_build_ref_agent(self):
    ref_model_file = self.opt['ref_model_file']
    if ref_model_file is None or ref_model_file.lower() == "none":
        raise RuntimeError("CL training requires reference model!")
    else:
        from parlai.core.agents import create_agent_from_opt_file
        ref_agent = create_agent_from_opt_file(
            Opt({'model_file': ref_model_file}))
        eval_ref_agent = create_agent_from_opt_file(
            Opt({'model_file': ref_model_file}))
        if ref_agent is None:
            raise RuntimeError(
                "Build reference model failed! check your `ref_model_file`:{}!"
                .format(ref_model_file))
        if self.id == ref_agent.id and dict_same(self, ref_agent):
            self.use_external_ref_model = False
        else:
            self.use_external_ref_model = True
        # No need to do this
        # # check dict
        # if self.dict.tok2ind != ref_agent.dict.tok2ind or self.dict.ind2tok != ref_agent.dict.ind2tok:
        #     raise RuntimeError("Reference model is using different dict!")

    self.eval_ref_agent = eval_ref_agent
    self.ref_agent = ref_agent
Exemplo n.º 2
0
    def __init__(self, model: str, device: str, maxlen: int) -> None:
        model = self.check_agent(model)
        maxlen = maxlen if maxlen > 0 else self.default_maxlen()

        if "xxlarge" in model:
            size = "9B"
        elif "xlarge" in model:
            size = "3B"
        elif "large" in model:
            size = "1Bdistill"
        elif "medium" in model:
            size = "400Mdistill"
        elif "small" in model:
            size = "90M"
        else:
            raise Exception("wrong model")

        option = self.set_options(
            name=f"zoo:blender/blender_{size}/model",
            device=device,
        )

        super().__init__(
            name=model,
            suffix="\n",
            device=device,
            maxlen=maxlen,
            model=create_agent_from_opt_file(option),
        )
Exemplo n.º 3
0
    def test_init_from_from_checkpoint(self):
        with testing_utils.tempdir() as temp_dir:
            opt_from_file = {
                'datapath': 'dummy_path',
                'model': 'repeat_label',
                'init_model': os.path.join(temp_dir, 'something'),
                'model_file': os.path.join(temp_dir, 'something_else'),
            }
            opt = Opt({
                'datapath':
                'dummy_path',
                'model':
                'repeat_label',
                'init_model':
                os.path.join(temp_dir, 'something_else.checkpoint'),
                'model_file':
                os.path.join(temp_dir, 'something_else'),
                'load_from_checkpoint':
                True,
            })

            with open(os.path.join(temp_dir, 'something_else.opt'), 'w') as f:
                f.write(json.dumps(opt_from_file))

            agent = create_agent_from_opt_file(opt)
            init_model = agent.opt['init_model']
            # assert that the model was loaded with the correct checkpoitn
            assert '.checkpoint' in init_model
Exemplo n.º 4
0
    def __init__(self, model, device, maxlen):
        option = self.set_options(
            name="zoo:sensitive_topics_classifier/model",
            device=device,
        )

        super(SensitiveAgent, self).__init__(
            device=device,
            maxlen=maxlen,
            model=create_agent_from_opt_file(option),
            suffix="",
            name=model,
        )
Exemplo n.º 5
0
    def __init__(self, model, device, maxlen):
        model = self.check_agent(model)
        maxlen = maxlen if maxlen > 0 else self.default_maxlen()

        model = model + "_ft" if model != "all_tasks_mt" else model
        name = f"zoo:dodecadialogue/{model.split('.')[-1]}/model"
        option = self.set_options(name, device)

        super().__init__(
            name=model,
            suffix="\n",
            device=device,
            maxlen=maxlen,
            model=create_agent_from_opt_file(option),
        )

        if "wizard_of_wikipedia" in name:
            inherit(self, (WizardOfWikipediaAgent, Seq2SeqLM))
            self.build_wizard_of_wikipedia()

        elif "convai2" in name:
            inherit(self, (ConvAI2Agent, Seq2SeqLM))