def _recover_from_desc(self): # recover signature for sign, module_var in self.desc.sign2var.items(): inputs = [] outputs = [] feed_names = [] fetch_names = [] for var in module_var.feed_desc: variable = self.program.global_block().vars[var.var_name] inputs.append(variable) feed_names.append(var.alias) for var in module_var.fetch_desc: variable = self.program.global_block().vars[var.var_name] outputs.append(variable) fetch_names.append(var.alias) self.signatures[sign] = create_signature( sign, inputs=inputs, outputs=outputs, feed_names=feed_names, fetch_names=fetch_names) # recover default signature default_signature_name = utils.from_module_attr_to_pyobj( self.desc.attr.map.data['default_signature']) self.default_signature = self.signatures[ default_signature_name] if default_signature_name else None # recover module info module_info = self.desc.attr.map.data['module_info'] self.name = utils.from_module_attr_to_pyobj( module_info.map.data['name']) self.author = utils.from_module_attr_to_pyobj( module_info.map.data['author']) self.author_email = utils.from_module_attr_to_pyobj( module_info.map.data['author_email']) self.version = utils.from_module_attr_to_pyobj( module_info.map.data['version']) self.type = utils.from_module_attr_to_pyobj( module_info.map.data['type']) self.summary = utils.from_module_attr_to_pyobj( module_info.map.data['summary']) # recover extra info extra_info = self.desc.attr.map.data['extra_info'] self.extra_info = {} for key, value in extra_info.map.data.items(): self.extra_info[key] = utils.from_module_attr_to_pyobj(value) # recover name prefix self.name_prefix = utils.from_module_attr_to_pyobj( self.desc.attr.map.data["name_prefix"])
def context(self, sign_name=None, for_test=False, trainable=True, regularizer=None, max_seq_len=128, learning_rate=1e-3): """ Args: max_seq_len(int): maximum sequence length, this option is only available for BERT/ERNIE module """ if sign_name: if sign_name not in self.signatures: raise KeyError( "Module did not have a signature with name %s" % sign_name) signature = self.signatures[sign_name] else: inputs = [ input for signature in self.signatures.values() for input in signature.inputs ] outputs = [ output for signature in self.signatures.values() for output in signature.outputs ] feed_names = [ feed_name for signature in self.signatures.values() for feed_name in signature.feed_names ] fetch_names = [ fetch_name for signature in self.signatures.values() for fetch_name in signature.fetch_names ] signature = create_signature( name="hub_temp_signature", inputs=inputs, outputs=outputs, feed_names=feed_names, fetch_names=fetch_names, for_predict=False) program = self.program.clone(for_test=for_test) paddle_helper.remove_feed_fetch_op(program) if not for_test: paddle_helper.set_parameter_trainable(program, trainable) paddle_helper.set_parameter_learning_rate(program, learning_rate) paddle_helper.set_parameter_regularizer(program, regularizer) self._restore_parameter(program) self._recover_variable_info(program) paddle_helper.set_op_attr(program, is_test=for_test) #TODO(wuzewu): return feed_list and fetch_list directly feed_dict = {} fetch_dict = {} for index, var in enumerate(signature.inputs): feed_dict[index] = program.global_block().var(var.name) key = signature.feed_names[index] if key: feed_dict[key] = program.global_block().var(var.name) for index, var in enumerate(signature.outputs): fetch_dict[index] = program.global_block().var(var.name) key = signature.fetch_names[index] if key: fetch_dict[key] = program.global_block().var(var.name) # TODO(ZeyuChen) encapsulate into a funtion # update BERT/ERNIE's input tensor's sequence length to max_seq_len if self.name.startswith("bert") or self.name.startswith("ernie"): MAX_SEQ_LENGTH = 512 if max_seq_len > MAX_SEQ_LENGTH or max_seq_len <= 0: raise ValueError( "max_seq_len({}) should be in the range of [1, {}]".format( MAX_SEQ_LENGTH)) logger.info( "Set maximum sequence length of input tensor to {}".format( max_seq_len)) if self.name.startswith("ernie_v2"): feed_list = [ "input_ids", "position_ids", "segment_ids", "input_mask", "task_ids" ] logger.warning( "%s will exploite task_id, the arguement use_taskid of Reader class must be True." % self.name) else: feed_list = [ "input_ids", "position_ids", "segment_ids", "input_mask" ] logger.warning( "%s has no task_id, the arguement use_taskid of Reader class must be False." % self.name) for tensor_name in feed_list: seq_tensor_shape = [-1, max_seq_len, 1] logger.info("The shape of input tensor[{}] set to {}".format( tensor_name, seq_tensor_shape)) program.global_block().var( feed_dict[tensor_name].name).desc.set_shape( seq_tensor_shape) # record num parameters loaded by paddlehub num_param_loaded = 0 for param in program.global_block().iter_parameters(): num_param_loaded += 1 logger.info( "%d pretrained paramaters loaded by PaddleHub" % num_param_loaded) return feed_dict, fetch_dict, program