예제 #1
0
    def _load_known_opts(self, optfile, parsed):
        """
        Pull in CLI args for proper models/tasks/etc.

        Called before args are parsed; ``_load_opts`` is used for actually overriding
        opts after they are parsed.
        """
        new_opt = Opt.load(optfile)
        for key, value in new_opt.items():
            # existing command line parameters take priority.
            if key not in parsed or parsed[key] is None:
                parsed[key] = value
예제 #2
0
 def test_nodatafile(self):
     for dt in [
             'train:ordered',
             'train:stream:ordered',
             'valid',
             'test',
             'valid:stream',
             'test:stream',
     ]:
         opt = Opt({'datatype': dt, 'datapath': '/tmp', 'task': 'test'})
         with self.assertRaises(KeyError):
             NoDatafileTeacher(opt)
예제 #3
0
 def _load_opts(self, opt):
     optfile = opt.get('init_opt')
     new_opt = Opt.load(optfile)
     for key, value in new_opt.items():
         # existing command line parameters take priority.
         if key not in opt:
             raise RuntimeError(
                 'Trying to set opt from file that does not exist: ' +
                 str(key))
         if key not in opt['override']:
             opt[key] = value
             opt['override'][key] = value
예제 #4
0
    def _initialize_bart(self, opt: Opt) -> Opt:
        """
        Download and convert BART pre-trained models.

        Additionally, convert `init-fairseq-model` if necessary.

        :param opt:
            ParlAI-parsed options

        :return opt:
            return opt with BART-specific args.
        """
        if not opt.get('converting'):
            download(opt['datapath'])
            opt['init_model'] = os.path.join(opt['datapath'],
                                             'models/bart/bart_large/model')
        if opt.get('init_fairseq_model'):
            opt = self._convert_model(opt)
        opt.update(BART_ARGS)
        compare_init_model_opts(opt, opt)
        return opt
예제 #5
0
def get_dialogue_task_mutators(opt: Opt) -> str:
    """
    Set the mutators appropriately for the dialogue tasks.
    """
    mutators = '+'.join([
        'flatten', 'extract_entity_for_response_model',
        'skip_retrieval_mutator'
    ])
    if opt.get('mutators'):
        mutators = '+'.join([mutators, opt['mutators']])
    logging.warning(f'overriding mutators to {mutators}')
    return mutators
예제 #6
0
 def __init__(self, opt: Opt, shared=None):
     """
     Initializes reranker.
     """
     self.predictor_model_file = modelzoo_path(opt['datapath'],
                                               opt['predictor_model_file'])
     self.reranker_strategy = opt['reranker_strategy']
     self.normalize_candidates = opt['normalize_candidates']
     self.delimiter = opt.get('delimiter', '\n')
     self.include_context = True
     self.include_label_cand_only = False
     self.init_predictor(opt, shared)
예제 #7
0
    def test_multitask(self):
        """
        Test that model correctly handles multiple inputs.

        Random chance is 10%, so this should be able to get much better than that very
        quickly.
        """
        args = Opt({**self.base_args, **self.multitask_args})
        valid, test = testing_utils.train_model(args)
        assert (
            valid['accuracy'] > 0.2
        ), f'ImagePolyencoderAgent val-set accuracy on a simple task was {valid["accuracy"].value():0.2f}.'
예제 #8
0
파일: worlds.py 프로젝트: rewicks/ParlAI
def create_task(opt: Opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"`` e.g. ``"babi:Task1k:1"``
    or ``"#babi-1k"`` or ``"#QA"``, see ``parlai/tasks/tasks.py`` and see
    ``parlai/tasks/task_list.py`` for list of tasks.
    """
    task = opt.get('task')
    if not task:
        raise RuntimeError(
            'No task specified. Please select a task with ' + '--task {task_name}.'
        )
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    print('[creating task(s): ' + opt['task'] + ']')

    # check if single or multithreaded, and single-example or batched examples
    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt, user_agents, default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if opt.get('numthreads', 1) > 1:
        # use hogwild world if more than one thread requested
        # hogwild world will create sub batch worlds as well if bsz > 1
        world = HogwildWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
예제 #9
0
 def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
     super().__init__(opt)
     self.worlds: List[World] = []
     for index, k in enumerate(opt['task'].split(',')):
         k = k.strip()
         if k:
             if shared:
                 # Create worlds based on shared data.
                 s = shared['worlds'][index]
                 self.worlds.append(s['world_class'](s['opt'], None, s))
             else:
                 # Agents are already specified.
                 opt_singletask = copy.deepcopy(opt)
                 opt_singletask['task'] = k
                 self.worlds.append(
                     create_task_world(
                         opt_singletask, agents, default_world=default_world
                     )
                 )
     self.world_idx = -1
     self.new_world = True
     self.parleys = -1
     # Check to see if we are training
     self.is_training = DatatypeHelper.is_training(opt.get('datatype'))
     # Make multi-task task probabilities.
     self.cum_task_weights = [1] * len(self.worlds)
     self.task_choices = range(len(self.worlds))
     weights = self.opt.get('multitask_weights', [1])
     if weights == 'stochastic':
         weights = [w.num_episodes() for w in self.worlds]
     sum = 0
     for i in self.task_choices:
         if len(weights) > i:
             weight = weights[i]
         else:
             weight = 1
         self.cum_task_weights[i] = weight + sum
         sum += weight
     task_ids: Dict[str, Teacher] = {}
     # Having overlap in teacher ids will cause issues for metrics aggregation.
     for each_world in self.worlds:
         world_id = each_world.getID()
         if world_id in task_ids:
             raise AssertionError(
                 '{} and {} teachers have overlap in id {}.'.format(
                     task_ids[world_id],
                     each_world.get_agents()[0].__class__,
                     world_id,
                 )
             )
         else:
             task_ids[world_id] = each_world.get_task_agent()
예제 #10
0
def bpe_factory(opt: Opt, shared: TShared) -> 'BPEHelper':
    """
    BPE Helper Factory.

    Returns the appropriate BPE helper given the opt
    as well as available libraries.

    :param opt:
        options
    :param shared:
        shared dict

    :return BPEHelper:
        returns the appropriate BPEHelper object
    """
    from parlai.core.dict import DictionaryAgent

    tokenizer = opt.get('dict_tokenizer', DictionaryAgent.default_tok)

    bpe_helper: Optional[BPEHelper] = None

    if tokenizer == 'bytelevelbpe':
        # Attempt to instantiate HF tokenizer
        try:
            bpe_helper = HuggingFaceBpeHelper(opt, shared)
        except ImportError:
            if opt['dict_loaded']:
                warn_once(
                    ''
                    '\n\n--------------------------------------------------\n\n'
                    'WARNING: You have chosen to use Huggingface\'s tokenizer.\n'
                    'Please install HuggingFace tokenizer with: pip install tokenizers.\n'
                    'For now, defaulting to the GPT2Tokenizer.'
                    '\n\n--------------------------------------------------\n\n'
                )
                tokenizer = 'slow_bytelevel_bpe'
            else:
                raise ImportError(
                    'Please install HuggingFace tokenizer with: pip install tokenizers.\n'
                )
    if tokenizer == 'slow_bytelevel_bpe':
        bpe_helper = SlowBytelevelBPE(opt, shared)
    if tokenizer == 'gpt2':
        bpe_helper = Gpt2BpeHelper(opt, shared)
    if tokenizer == 'bpe':
        bpe_helper = SubwordBPEHelper(opt, shared)

    assert (
        bpe_helper is not None
    ), f"bpe_factory called with invalid tokenizer: {tokenizer}"

    return bpe_helper
 def test_gpt2_bpe_tokenize(self):
     opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': './data'})
     agent = DictionaryAgent(opt)
     self.assertEqual(
         # grinning face emoji
         agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
         GPT2_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt(agent.tok2ind[w] for w in GPT2_BPE_RESULT),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
예제 #12
0
def get_dialogue_task_mutators(opt: Opt) -> str:
    """
    Set the mutators appropriately for the dialogue tasks.
    """
    mutators = '+'.join([
        'flatten',
        'skip_retrieval_mutator',
        'bst_tasks_maybe_generate_search_query_mutator',
    ])
    if opt.get('mutators'):
        mutators = '+'.join([mutators, opt['mutators']])
    logging.warning(f'overriding mutators to {mutators}')
    return mutators
예제 #13
0
 def __init__(self, opt: Opt, shared=None):
     """
     Setup reranker.
     """
     super().__init__(opt, shared)
     reranker_class = self.get_reranker_class()
     self.inference_strategies = (opt['inference_strategies']
                                  or opt['inference']).split(',')
     self.debug_mode = opt.get('debug_mode', False)
     if not shared:
         self.reranker = reranker_class(opt, shared=None)
     else:
         self.reranker = reranker_class(opt, shared=shared['reranker'])
예제 #14
0
    def test_safe_personas(self):

        base_kwargs = Opt({'datatype': 'train', 'task': 'blended_skill_talk'})
        safe_personas_only_to_count = {False: 4819, True: 3890}
        for safe_personas_only, count in safe_personas_only_to_count.items():
            full_kwargs = {
                **base_kwargs, 'safe_personas_only': safe_personas_only
            }
            parser = setup_args()
            parser.set_defaults(**full_kwargs)
            opt = parser.parse_args([])
            personas = _load_personas(opt)
            self.assertEqual(len(personas), count)
예제 #15
0
    def test_opt(self):
        opt = {'x': 0}
        opt = Opt(opt)
        opt['x'] += 1
        opt['x'] = 10
        history = opt.history['x']
        self.assertEqual(history[0][1], 1, 'History not set properly')
        self.assertEqual(history[1][1], 10, 'History not set properly')

        opt_copy = deepcopy(opt)
        history = opt_copy.history['x']
        self.assertEqual(history[0][1], 1, 'Deepcopy history not set properly')
        self.assertEqual(history[1][1], 10, 'Deepcopy history not set properly')
예제 #16
0
def _all_split_datafiles(opt: Opt) -> List[str]:
    datafiles = []
    split_type = SplitType(opt.get("cmu_dog_split_type"))
    if split_type in {SplitType.SEEN, SplitType.UNSEEN}:
        # For seen/unseen split, the full set of dialogs is split
        # across train, valid, test seen, and test unseen
        for split in ['train', 'valid', 'test']:
            datafiles.append(_datafile(split, SplitType.SEEN))
        datafiles.append(_datafile('test', SplitType.UNSEEN))
    else:
        for split in ['train', 'valid', 'test']:
            datafiles.append(_datafile(split, split_type))
    return datafiles
예제 #17
0
def create_task(opt: Opt, user_agents, default_world=None):
    """
    Create a world + task_agents (aka a task).

    Assuming ``opt['task']="task_dir:teacher_class:options"`` e.g. ``"babi:Task1k:1"``
    or ``"#babi-1k"`` or ``"#QA"``, see ``parlai/tasks/tasks.py`` and see
    ``parlai/tasks/task_list.py`` for list of tasks.
    """
    task = opt.get('task')
    if not task:
        raise RuntimeError('No task specified. Please select a task with ' +
                           '--task {task_name}.')
    if type(user_agents) != list:
        user_agents = [user_agents]

    # Convert any hashtag task labels to task directory path names.
    # (e.g. "#QA" to the list of tasks that are QA tasks).
    opt = copy.deepcopy(opt)
    opt['task'] = ids_to_tasks(opt['task'])
    logging.info(f"creating task(s): {opt['task']}")

    if ',' not in opt['task']:
        # Single task
        world = create_task_world(opt,
                                  user_agents,
                                  default_world=default_world)
    else:
        # Multitask teacher/agent
        # TODO: remove and replace with multiteachers only?
        world = MultiWorld(opt, user_agents, default_world=default_world)

    if opt.get('batchsize', 1) > 1 and opt.get('dynamic_batching'):
        world = DynamicBatchWorld(opt, world)
    elif opt.get('batchsize', 1) > 1:
        # otherwise check if should use batchworld
        world = BatchWorld(opt, world)

    return world
예제 #18
0
    def __init__(self, opt: Opt, dictionary: DictionaryAgent):
        if opt.get('n_positions'):
            # if the number of positions is explicitly provided, use that
            n_positions = opt['n_positions']
        else:
            # else, use the worst case from truncate
            n_positions = max(
                opt.get('truncate') or 0,
                opt.get('text_truncate') or 0,
                opt.get('label_truncate') or 0,
            )
            if n_positions == 0:
                # default to 1024
                n_positions = 1024

        super().__init__(opt, dictionary)
        self.encoder = ContextWithImageEncoder(
            n_heads=opt['n_heads'],
            n_layers=opt['n_layers'],
            embedding_size=opt['embedding_size'],
            ffn_size=opt['ffn_size'],
            vocabulary_size=len(dictionary),
            embedding=self.embeddings,
            dropout=opt['dropout'],
            attention_dropout=opt['attention_dropout'],
            relu_dropout=opt['relu_dropout'],
            padding_idx=self.pad_idx,
            learn_positional_embeddings=opt['learn_positional_embeddings'],
            embeddings_scale=opt['embeddings_scale'],
            n_positions=n_positions,
            n_segments=opt.get('n_segments', 0),
            activation=opt['activation'],
            variant=opt['variant'],
            output_scaling=opt['output_scaling'],
            image_encoder_num_layers=opt['image_encoder_num_layers'],
            image_features_dim=opt['image_features_dim'],
            fusion=opt['image_fusion_type'],
        )
예제 #19
0
    def test_asymmetry(self):
        opt = Opt({'model': 'transformer/generator', 'n_layers': 1})
        agent = create_agent(opt)
        self.assertEqual(agent.model.encoder.n_layers, 1)
        self.assertEqual(agent.model.decoder.n_layers, 1)

        opt = Opt({
            'model': 'transformer/generator',
            'n_layers': 1,
            'n_encoder_layers': 2
        })
        agent = create_agent(opt)
        self.assertEqual(agent.model.encoder.n_layers, 2)
        self.assertEqual(agent.model.decoder.n_layers, 1)

        opt = Opt({
            'model': 'transformer/generator',
            'n_layers': 1,
            'n_encoder_layers': 2,
            'n_decoder_layers': 4,
        })
        agent = create_agent(opt)
        self.assertEqual(agent.model.encoder.n_layers, 2)
        self.assertEqual(agent.model.decoder.n_layers, 4)

        opt = Opt({
            'model': 'transformer/generator',
            'n_layers': 1,
            'n_decoder_layers': 4
        })
        agent = create_agent(opt)
        self.assertEqual(agent.model.encoder.n_layers, 1)
        self.assertEqual(agent.model.decoder.n_layers, 4)

        opt = Opt({'model': 'transformer/generator'})
        agent = create_agent(opt)
        self.assertEqual(agent.model.encoder.n_layers, 2)
        self.assertEqual(agent.model.decoder.n_layers, 2)
예제 #20
0
    def get_encoder(
        self,
        opt: Opt,
        dict_: DictionaryAgent,
        embeddings: torch.nn.Embedding,
        module_klass: Type[TransformerEncoder],
        null_idx: int,
        reduction_type: Optional[str] = None,
    ):
        """
        Return encoder, given options.

        Ensures that multiobjective options are copied correctly.

        :param opt:
            opt dict
        :param dict:
            dictionary agent
        :param null_idx:
            null/pad index into dict
        :param reduction_type:
            reduction type for the encoder
        :return:
            a TransformerEncoder, initialized correctly
        """
        opt = copy.deepcopy(opt)
        opt['n_heads'] = opt.get('n_multiobjective_heads', 4)
        opt['n_layers'] = opt.get('n_multiobjective_layers', 2)
        opt['n_encoder_layers'] = opt.get('n_multiobjective_layers', 2)
        opt['n_decoder_layers'] = opt.get('n_multiobjective_layers', 2)
        return module_klass(
            opt=opt,
            vocabulary_size=len(dict_),
            embedding=embeddings,
            padding_idx=null_idx,
            reduction_type=reduction_type,
            n_segments=opt.get('n_segments', 2),
        )
예제 #21
0
def _create_task_agents(opt: Opt):
    """
    Create task agent(s) for the given task name.

    It does this by calling the create_agent function in agents.py of the given task. If
    create_agents function does not exist, it just looks for the teacher (agent) class
    defined by the task name directly.  (This saves the task creator bothering to define
    the create_agents function when it is not needed.)
    """
    if opt.get('interactive_task', False) or opt.get('selfchat_task', False):
        # do not need task agents in interactive or self chat settings
        return []

    try:
        # Tries to call the create_agent function in agents.py
        my_module = load_task_module(opt['task'])
        task_agents = my_module.create_agents(opt)  # type: ignore
    except (ModuleNotFoundError, AttributeError):
        # Create_agent not found, so try to create the teacher directly.
        return create_task_agent_from_taskname(opt)
    if type(task_agents) != list:
        task_agents = [task_agents]
    return task_agents
예제 #22
0
 def test_gpt2_bpe_tokenize(self):
     datapath = ParlaiParser().parse_args([], print_args=False)['datapath']
     opt = Opt({'dict_tokenizer': 'gpt2', 'datapath': datapath})
     agent = DictionaryAgent(opt)
     self.assertEqual(
         # grinning face emoji
         agent.gpt2_tokenize(u'Hello, ParlAI! \U0001f600'),
         GPT2_BPE_RESULT,
     )
     self.assertEqual(
         agent.vec2txt(agent.tok2ind[w] for w in GPT2_BPE_RESULT),
         # grinning face emoji
         u'Hello, ParlAI! \U0001f600',
     )
예제 #23
0
 def __init__(self, opt: Opt, dictionary: DictionaryAgent, retriever_shared=None):
     super().__init__(opt, dictionary, retriever_shared)
     if opt.get('filter_docs_with_label'):
         assert (
             RetrieverType(opt['rag_retriever_type']) == RetrieverType.SEARCH_ENGINE
         )
         self.retriever = FilterDocsForLabelSearchEngineRetrieverCombo(
             opt, dictionary, shared=retriever_shared
         )  # type: ignore
     else:
         self.retriever = combo_fid_retriever_factory(
             opt, dictionary, shared=retriever_shared
         )
     self.top_docs: List[List[Document]] = []
예제 #24
0
    def __init__(self, opt: Opt, shared=None):
        if not shared:
            self.hf_tokenizer = self.get_tokenizer(opt)
            self.tok2ind = self.hf_tokenizer.get_vocab()
            self.ind2tok = {v: k for k, v in self.tok2ind.items()}
        else:
            self.hf_tokenizer = shared['hf_tokenizer']
            self.tok2ind = shared['tok2ind']
            self.ind2tok = shared['ind2tok']

        self.freq = defaultdict(int)
        for tok in self.tok2ind:
            self.freq[tok] = 1
        self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)

        self._unk_token_idx = self.hf_tokenizer.unk_token_id
        self.override_special_tokens(opt)

        self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
        self.tokenizer = 'hf'
        self.opt = opt
        self.max_length = (self.opt.get('text_truncate')
                           or self.hf_tokenizer.model_max_length)
예제 #25
0
    def __init__(self, opt: Opt, shared: TShared = None):
        if not shared:
            self.idx_to_ep = {}
        else:
            self.idx_to_ep = shared['idx_to_ep']
        self.prepend_personality = opt.get('prepend_personality', True)
        self.include_dialogue_history = opt.get('include_dialogue_history',
                                                True)
        self.category_frac = opt.get('category_frac', 0.0)
        super().__init__(opt, shared)
        self.num_eps = len(self.data) + len(
            [d for d in self.data if len(d['dialog']) > 1])

        # Replace personalities with polarity categories ("positive/neutral" or
        # "negative"), with probability self.category_frac
        if not shared:
            category_map = get_category_map(self.personalities)
            for i, d in enumerate(self.data):
                use_category_rand = random.random()
                if use_category_rand < self.category_frac:
                    self.data[i]['dialog'] = [[
                        category_map[personality], label
                    ] for personality, label in d['dialog']]
예제 #26
0
 def init_predictor(self, opt: Opt, shared=None):
     """
     Initializes Predictor Module
     """
     if not shared:
         if not opt.get("predictor_model_file"):
             logging.warn(
                 'Reranker MUST specify predictor_model_file unless subclass __init__() sets up the model in its own way (unusual). Skipping predictor setup!'
             )
         else:
             self.predictor = create_agent_from_model_file(
                 self.predictor_model_file)
     else:
         self.predictor = shared['predictor']
예제 #27
0
파일: agents.py 프로젝트: simplecoka/cortx
    def __init__(self, opt: Opt, shared: TShared = None):
        assert opt['flatten_delimiter'] == opt.get(
            'delimiter', '\n'
        ), '--flatten-delimiter and --delimiter are set differently, please inspect and set to the same to avoid unexpected results'
        self.opt = opt

        if shared and 'data' in shared:
            self.data = shared['data']
        else:
            self.word_lists = self.build_wordlists(opt)
            self.data = self._setup_data(opt)

        super().__init__(opt, shared)
        self.reset()
예제 #28
0
def combo_fid_retriever_factory(opt: Opt,
                                dictionary: DictionaryAgent,
                                shared=None) -> Optional[RagRetriever]:
    """
    Bypass call to standard retriever factory to possibly build our own retriever.
    """
    if opt.get('converting'):
        return None
    retriever = RetrieverType(opt['rag_retriever_type'])
    if retriever is RetrieverType.SEARCH_ENGINE:
        return ComboFidSearchQuerySearchEngineRetriever(
            opt, dictionary, shared=shared)  # type: ignore
    else:
        return retriever_factory(opt, dictionary, shared)
예제 #29
0
 def init_predictor(self, opt: Opt, shared=None):
     if not shared:
         override = {
             'return_cand_scores': True,
             'datatype': 'valid',
             'interactive_mode': opt.get('interactive_mode', True),
             'ignore_bad_candidates': True,
             'encode_candidate_vecs': True,
             'interactive_candidates': 'inline',
         }  # to not init optim
         self.predictor = create_agent_from_model_file(
             self.predictor_model_file, opt_overrides=override)
     else:
         self.predictor = shared['predictor']
예제 #30
0
    def update_state_dict(opt: Opt, state_dict: Dict[str, torch.Tensor],
                          model: torch.nn.Module):
        """
        Update the given state dict to be RAG-ified.

        :param opt:
            options
        :param state_dict:
            weights to load
        :param model:
            underlying model that will load the state_dict

        :return updated_state_dict:
            return state_dict with appropriate keys/values
        """
        # 1. Substitute all "encoder" and "decoder" keys with "seq2seq_encoder" and "seq2seq_decoder"
        if not [k for k in state_dict if k.startswith('seq2seq')]:
            for k in list(state_dict.keys()):
                if k.startswith('encoder') or k.startswith('decoder'):
                    weights = state_dict.pop(k)
                    state_dict[f'seq2seq_{k}'] = weights
        # 2. Retriever state
        if not [k for k in state_dict if 'retriever' in k]:
            retriever_state = {
                f"retriever.{k}": v
                for k, v in
                model.retriever.state_dict().items()  # type: ignore
            }
            state_dict.update(retriever_state)
        # 3. Handle n_positional difference
        if opt.get('n_extra_positions', 0) > 0:
            key = 'seq2seq_encoder.position_embeddings.weight'
            init_weight = (
                model.seq2seq_encoder.position_embeddings.
                weight  # type: ignore
            )
            if state_dict[key].size(
                    0) < opt['n_positions'] + opt['n_extra_positions']:
                # Make sure we're not adding more positions to a model trained
                # with extra positions
                state_dict[key] = torch.cat(
                    [
                        state_dict[key].to(init_weight),  # type: ignore
                        init_weight[
                            -opt['n_extra_positions']:, :],  # type: ignore
                    ],
                    dim=0,
                )
        return state_dict