예제 #1
0
 def _deal_with_tensor(value):
     """Deal with tensor from tensor proto."""
     tensor_proto = value.get('tensor_proto')
     tensor_proto.ClearField('tensor_content')
     step = value.get('step', 0)
     if tensor_proto.iter and step > 0:
         log.debug("Received previous tensor.")
         step -= 1
     tensor_content = b''.join(value.get('tensor_contents'))
     tensor = OpTensor(tensor_proto, tensor_content, step)
     if value.get('oversize'):
         tensor.clean_tensor_value(oversize=True)
     return tensor
예제 #2
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.
                - tensor_proto (TensorProto): The tensor proto.
                - tensor_contents (list[byte]): The list of tensor content values.

        Returns:
            bool, the tensor has updated successfully.
        """
        tensor_proto = value.get('tensor_proto')
        tensor_proto.ClearField('tensor_content')
        step = value.get('step', 0)
        if tensor_proto.iter and step > 0:
            log.debug("Received previous tensor.")
            step -= 1
        tensor_content = b''.join(value.get('tensor_contents'))
        tensor = OpTensor(tensor_proto, tensor_content, step)
        flag = self._put_tensor_into_cache(tensor, step)
        log.info("Put tensor %s of step: %d, into cache. Flag: %s", tensor.name, step, flag)
        return flag
예제 #3
0
    def put_const_vals(self, const_vals):
        """
        Put const value into tensor cache.

        Args:
            const_vals (list[NamedValueProto]): List of const values.
        """
        for const_val in const_vals:
            if not (const_val.value and const_val.key):
                continue
            if DataType.Name(const_val.value.dtype) == "DT_TENSOR":
                tensor_proto = const_val.value.tensor_val
                tensor_proto.node_name = const_val.key
                tensor_proto.slot = '0'
                const_tensor = OpTensor(tensor_proto)
            else:
                const_tensor = ConstTensor(const_val)
            self._const_vals[const_tensor.name] = const_tensor
예제 #4
0
    def put(self, value):
        """
        Put value into tensor cache. Called by grpc server.

        Args:
            value (dict): The Tensor proto message.

                - step (int): The current step of tensor.

                - tensor_protos (list[TensorProto]): The tensor proto.
        """
        tensor_protos = value.get('tensor_protos')
        merged_tensor = self._get_merged_tensor(tensor_protos)
        step = value.get('step', 0)
        if merged_tensor.iter and step > 0:
            log.debug("Received previous tensor.")
            step -= 1
        tensor = OpTensor(merged_tensor, step)
        self._put_tensor_into_cache(tensor, step)
        log.info("Put tensor %s of step: %d, into cache", tensor.name, step)