Beispiel #1
0
    def _recover_from_desc(self):
        # recover signature
        for sign, module_var in self.desc.sign2var.items():
            inputs = []
            outputs = []
            feed_names = []
            fetch_names = []
            for var in module_var.feed_desc:
                variable = self.program.global_block().vars[var.var_name]
                inputs.append(variable)
                feed_names.append(var.alias)

            for var in module_var.fetch_desc:
                variable = self.program.global_block().vars[var.var_name]
                outputs.append(variable)
                fetch_names.append(var.alias)

            self.signatures[sign] = create_signature(
                sign,
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names)

        # recover default signature
        default_signature_name = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data['default_signature'])
        self.default_signature = self.signatures[
            default_signature_name] if default_signature_name else None

        # recover module info
        module_info = self.desc.attr.map.data['module_info']
        self.name = utils.from_module_attr_to_pyobj(
            module_info.map.data['name'])
        self.author = utils.from_module_attr_to_pyobj(
            module_info.map.data['author'])
        self.author_email = utils.from_module_attr_to_pyobj(
            module_info.map.data['author_email'])
        self.version = utils.from_module_attr_to_pyobj(
            module_info.map.data['version'])
        self.type = utils.from_module_attr_to_pyobj(
            module_info.map.data['type'])
        self.summary = utils.from_module_attr_to_pyobj(
            module_info.map.data['summary'])

        # recover extra info
        extra_info = self.desc.attr.map.data['extra_info']
        self.extra_info = {}
        for key, value in extra_info.map.data.items():
            self.extra_info[key] = utils.from_module_attr_to_pyobj(value)

        # recover name prefix
        self.name_prefix = utils.from_module_attr_to_pyobj(
            self.desc.attr.map.data["name_prefix"])
Beispiel #2
0
    def context(self,
                sign_name=None,
                for_test=False,
                trainable=True,
                regularizer=None,
                max_seq_len=128,
                learning_rate=1e-3):
        """
        Args:
            max_seq_len(int): maximum sequence length, this option is only
            available for BERT/ERNIE module
        """

        if sign_name:
            if sign_name not in self.signatures:
                raise KeyError(
                    "Module did not have a signature with name %s" % sign_name)
            signature = self.signatures[sign_name]
        else:
            inputs = [
                input for signature in self.signatures.values()
                for input in signature.inputs
            ]
            outputs = [
                output for signature in self.signatures.values()
                for output in signature.outputs
            ]
            feed_names = [
                feed_name for signature in self.signatures.values()
                for feed_name in signature.feed_names
            ]
            fetch_names = [
                fetch_name for signature in self.signatures.values()
                for fetch_name in signature.fetch_names
            ]
            signature = create_signature(
                name="hub_temp_signature",
                inputs=inputs,
                outputs=outputs,
                feed_names=feed_names,
                fetch_names=fetch_names,
                for_predict=False)

        program = self.program.clone(for_test=for_test)
        paddle_helper.remove_feed_fetch_op(program)

        if not for_test:
            paddle_helper.set_parameter_trainable(program, trainable)

            paddle_helper.set_parameter_learning_rate(program, learning_rate)

            paddle_helper.set_parameter_regularizer(program, regularizer)

            self._restore_parameter(program)

        self._recover_variable_info(program)

        paddle_helper.set_op_attr(program, is_test=for_test)
        #TODO(wuzewu): return feed_list and fetch_list directly
        feed_dict = {}
        fetch_dict = {}
        for index, var in enumerate(signature.inputs):
            feed_dict[index] = program.global_block().var(var.name)
            key = signature.feed_names[index]
            if key:
                feed_dict[key] = program.global_block().var(var.name)

        for index, var in enumerate(signature.outputs):
            fetch_dict[index] = program.global_block().var(var.name)
            key = signature.fetch_names[index]
            if key:
                fetch_dict[key] = program.global_block().var(var.name)

        # TODO(ZeyuChen) encapsulate into a funtion
        # update BERT/ERNIE's input tensor's sequence length to max_seq_len
        if self.name.startswith("bert") or self.name.startswith("ernie"):
            MAX_SEQ_LENGTH = 512
            if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0:
                raise ValueError(
                    "max_seq_len({}) should be in the range of [1, {}]".format(
                        MAX_SEQ_LENGTH))
            logger.info(
                "Set maximum sequence length of input tensor to {}".format(
                    max_seq_len))
            if self.name.startswith("ernie_v2"):
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask",
                    "task_ids"
                ]
                logger.warning(
                    "%s will exploite task_id, the arguement use_taskid of Reader class must be True."
                    % self.name)
            else:
                feed_list = [
                    "input_ids", "position_ids", "segment_ids", "input_mask"
                ]
                logger.warning(
                    "%s has no task_id, the arguement use_taskid of Reader class must be False."
                    % self.name)
            for tensor_name in feed_list:
                seq_tensor_shape = [-1, max_seq_len, 1]
                logger.info("The shape of input tensor[{}] set to {}".format(
                    tensor_name, seq_tensor_shape))
                program.global_block().var(
                    feed_dict[tensor_name].name).desc.set_shape(
                        seq_tensor_shape)

        # record num parameters loaded by paddlehub
        num_param_loaded = 0
        for param in program.global_block().iter_parameters():
            num_param_loaded += 1
        logger.info(
            "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded)

        return feed_dict, fetch_dict, program