Exemplo n.º 1
0
        def _FromParamValue(param_pb: hyperparams_pb2.HyperparamValue) -> Any:
            """Deserializes HyperparamValue proto."""

            which_oneof = param_pb.WhichOneof('kind')
            if not which_oneof:
                return None
            elif which_oneof == 'param_val':
                return _FromParam(param_pb.param_val)
            elif which_oneof == 'list_val':
                return [
                    _FromParamValue(val) for val in param_pb.list_val.items
                ]
            elif which_oneof == 'named_tuple_val':
                named_tuple_cls = _LoadClass(param_pb.named_tuple_val.type)
                if not dataclasses.is_dataclass(
                        named_tuple_cls) and not issubclass(
                            named_tuple_cls, tuple):
                    return None
                return named_tuple_cls(*[
                    _FromParamValue(val)
                    for val in param_pb.named_tuple_val.items
                ])
            elif which_oneof == 'tuple_val':
                return tuple(
                    [_FromParamValue(val) for val in param_pb.tuple_val.items])
            elif which_oneof == 'dict_val':
                dict_val = dict()
                for k in param_pb.dict_val.items:
                    dict_val[k] = _FromParamValue(param_pb.dict_val.items[k])
                return dict_val
            elif which_oneof == 'type_val':
                tokens = param_pb.type_val.split('/')
                assert len(tokens) == 2
                return getattr(importlib.import_module(tokens[0]), tokens[1])
            elif which_oneof == 'dtype_val':
                return tf.as_dtype(param_pb.dtype_val)
            elif which_oneof == 'enum_val':
                enum_cls = _LoadClass(param_pb.enum_val.type)
                if not issubclass(enum_cls, enum.Enum):
                    return None
                return enum_cls[param_pb.enum_val.name]
            elif which_oneof == 'proto_val':
                proto_cls = _LoadClass(param_pb.proto_val.type)
                if not issubclass(proto_cls, message.Message):
                    return None
                proto_msg = proto_cls()
                proto_msg.ParseFromString(param_pb.proto_val.val)
                return proto_msg
            elif which_oneof == 'symbolic_val':
                sym = pickle.loads(param_pb.symbolic_val)
                if not symbolic.IsExpr(sym):
                    raise TypeError(
                        'Unexpected result when deserializing symbolic expr.')
                return sym
            elif which_oneof == 'string_repr_val':
                raise TypeError(
                    'Cannot deserialize string_repr_val instance: %s' %
                    param_pb.string_repr_val)
            else:
                return getattr(param_pb, which_oneof)
Exemplo n.º 2
0
    def _FromParamValue(param_pb):
      """Deserializes HyperparamValue proto."""

      which_oneof = param_pb.WhichOneof('kind')
      if which_oneof == 'param_val':
        return _FromParam(param_pb.param_val)
      elif which_oneof == 'list_val':
        return [_FromParamValue(val) for val in param_pb.list_val.items]
      elif which_oneof == 'tuple_val':
        return tuple([_FromParamValue(val) for val in param_pb.tuple_val.items])
      elif which_oneof == 'dict_val':
        dict_val = dict()
        for k in param_pb.dict_val.items:
          dict_val[k] = _FromParamValue(param_pb.dict_val.items[k])
        return dict_val
      elif which_oneof == 'type_val':
        tokens = param_pb.type_val.split('/')
        assert len(tokens) == 2
        return getattr(importlib.import_module(tokens[0]), tokens[1])
      elif which_oneof == 'dtype_val':
        return tf.as_dtype(param_pb.dtype_val)
      elif which_oneof == 'string_val':
        return param_pb.string_val
      elif which_oneof == 'int_val':
        return param_pb.int_val
      elif which_oneof == 'float_val':
        return param_pb.float_val
      elif which_oneof == 'bool_val':
        return param_pb.bool_val
      else:
        # If nothing is set, it's the None type.
        return None
Exemplo n.º 3
0
        def _FromParamValue(param_pb):
            """Deserializes HyperparamValue proto."""

            which_oneof = param_pb.WhichOneof('kind')
            if which_oneof == 'param_val':
                return _FromParam(param_pb.param_val)
            elif which_oneof == 'list_val':
                return [
                    _FromParamValue(val) for val in param_pb.list_val.items
                ]
            elif which_oneof == 'named_tuple_val':
                named_tuple_cls = _LoadClass(param_pb.named_tuple_val.type)
                if not dataclasses.is_dataclass(
                        named_tuple_cls) and not issubclass(
                            named_tuple_cls, tuple):
                    return None
                return named_tuple_cls(*[
                    _FromParamValue(val)
                    for val in param_pb.named_tuple_val.items
                ])
            elif which_oneof == 'tuple_val':
                return tuple(
                    [_FromParamValue(val) for val in param_pb.tuple_val.items])
            elif which_oneof == 'dict_val':
                dict_val = dict()
                for k in param_pb.dict_val.items:
                    dict_val[k] = _FromParamValue(param_pb.dict_val.items[k])
                return dict_val
            elif which_oneof == 'type_val':
                tokens = param_pb.type_val.split('/')
                assert len(tokens) == 2
                return getattr(importlib.import_module(tokens[0]), tokens[1])
            elif which_oneof == 'dtype_val':
                return tf.as_dtype(param_pb.dtype_val)
            elif which_oneof == 'string_val':
                return param_pb.string_val
            elif which_oneof == 'int_val':
                return param_pb.int_val
            elif which_oneof == 'float_val':
                return param_pb.float_val
            elif which_oneof == 'bool_val':
                return param_pb.bool_val
            elif which_oneof == 'enum_val':
                enum_cls = _LoadClass(param_pb.enum_val.type)
                if not issubclass(enum_cls, enum.Enum):
                    return None
                return enum_cls[param_pb.enum_val.name]
            elif which_oneof == 'proto_val':
                proto_cls = _LoadClass(param_pb.proto_val.type)
                if not issubclass(proto_cls, message.Message):
                    return None
                proto_msg = proto_cls()
                proto_msg.ParseFromString(param_pb.proto_val.val)
                return proto_msg
            else:
                # If nothing is set, it's the None type.
                return None
Exemplo n.º 4
0
    def _config_infeed(self,
                       num_partitions,
                       device_assignment,
                       batch_size,
                       key_size=2,
                       return_tgt_mask=False,
                       use_partitioned_infeed_queue=False):
        """Config the infeed ops and args."""
        zero_batch = get_zero_batch(batch_size=batch_size,
                                    max_len=self._prefix_max_len,
                                    key_size=key_size,
                                    return_tgt_mask=return_tgt_mask)

        host_device = device_assignment.host_device(replica=0, job=self._tpu)
        host_id = int(host_device.split('/task:')[1].split('/device:')[0])
        input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1)
                                for x in zero_batch]

        if use_partitioned_infeed_queue:
            infeed = tpu_feed._PartitionedInfeedQueue(  # pylint: disable=protected-access
                number_of_tuple_elements=len(zero_batch),
                host_id=host_id,
                input_partition_dims=input_partition_dims,
                device_assignment=device_assignment)
        else:
            infeed = tpu_feed.InfeedQueue(
                number_of_tuple_elements=len(zero_batch))

        self.infeed_args = []
        for x in zero_batch:
            p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape)
            self.infeed_args += [p]
        if use_partitioned_infeed_queue:
            self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args])
        else:
            self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops(
                self.infeed_args, device_assignment=device_assignment)
        return infeed
Exemplo n.º 5
0
    def FromText(self, text, type_overrides=None):
        """Merges params specified in 'text' into 'params'.

    'text' follows the simple text format as produced by
    ParamsToSimpleText.  For a param specified in both 'params' and
    'text', overwrites the value in 'params' according to 'text'.
    Params specified in 'text' but not in 'params' are ignored.

    Args:
      text: A text representation of params.
      type_overrides: Overrides for the types of the params.
    Raises:
      AttributeError: text contains invalid parameter key
      ValueError: text contains invalid parameter value
    """
        if self._immutable:
            raise TypeError('This Params instance is immutable.')
        kv = {}
        type_overrides = type_overrides or {}
        string_continue = None  # None or (key, quote, value)
        for line in text.split('\n'):
            # Continuing a multi-line string.
            if string_continue:
                value_stripped = line.rstrip()
                if not _EndsWithTerminalQuote(value_stripped,
                                              string_continue[1]):
                    # String continues
                    string_continue = (string_continue[0], string_continue[1],
                                       string_continue[2] + '\n' + line)
                    continue
                # String terminates.
                kv[string_continue[
                    0]] = string_continue[2] + '\n' + value_stripped
                string_continue = None
                continue

            # Regular line.
            line = line.strip()
            if not line or line[0] == '#':
                # empty line or comment
                continue
            pair = line.split(':', 1)
            if len(pair) == 2:
                key = pair[0].strip()
                value = pair[1].lstrip()
                value_stripped = value.rstrip()
                # Detect single vs multi-line string start.
                if value and value[0] in ['"', '\'']:
                    quote_char = value[0]
                    if not _EndsWithTerminalQuote(value[1:], quote_char):
                        # Multi-line string.
                        string_continue = (key, quote_char, value)
                        continue
                kv[key] = value_stripped
        for key, val in six.iteritems(kv):
            old_val = self.Get(key)
            val_type = type(old_val).__name__
            if isinstance(old_val, (six.string_types, six.text_type)):
                val_type = 'str'
            if key in type_overrides:
                val_type = type_overrides[key]
            # Converts val (a string) to a best-guessed typed value.
            if val_type == 'bool':
                val = (val and (val != 'False') and (val != 'false'))
            elif val_type == 'int':
                val = int(val)
            elif val_type == 'float':
                val = float(val)
            elif val_type == 'DType':
                val = tf.as_dtype(val)
            elif val_type in ['list', 'tuple']:
                val = ast.literal_eval(val)
            elif val_type == 'dict':
                val = ast.literal_eval(val) if val != 'dict' else {}
            elif val_type == 'str':
                val = _UnquoteString(val)
                if val.startswith('[') and val.endswith(']'):
                    # We may have stored a list as a string, try converting to a list.
                    # In case of ValueError - use the string as is.
                    try:
                        val = ast.literal_eval(val)
                    except ValueError:
                        pass
            elif isinstance(old_val, enum.Enum):
                cls, _, name = val.rpartition('.')
                if val_type != cls:
                    raise ValueError('Expected enum of class %s but got %s' %
                                     (val_type, cls))
                val = type(old_val)[name]
            elif (isinstance(old_val, type)
                  or isinstance(old_val, message.Message) or old_val is None):
                if val == 'NoneType':
                    val = None
                elif old_val is None and val in ('False', 'false'):
                    val = False
                elif old_val is None and val in ('True', 'true'):
                    val = True
                else:
                    try:
                        val_type, pkg, cls = val.split('/', 2)
                        if val_type == 'type':
                            val = getattr(sys.modules[pkg], cls)
                        elif val_type == 'proto':
                            cls, proto_str = cls.split('/', 1)
                            proto_cls = getattr(sys.modules[pkg], cls)
                            if not issubclass(proto_cls, message.Message):
                                raise ValueError('%s is not a proto class.' %
                                                 proto_cls)
                            val = text_format.Parse(proto_str, proto_cls())
                    except ValueError as e:
                        raise ValueError('Error processing %r : %r with %r' %
                                         (key, val, e))
            else:
                raise ValueError('Failed to read a parameter: %r : %r' %
                                 (key, val))
            self.Set(**{key: val})
Exemplo n.º 6
0
    def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices):
        """Process all per-example tensor outfeed data for a TPU sess.run.

    Args:
      per_example_tensors: dict of key -> tensor as generated by TpuTrainStep.
      num_loops: number of times that TpuTrainStep will be executed by TpuTrain.
      num_devices: number of TPU cores assigned to this process.

    Returns:
      A dict of per-example tensors from the latest TpuTrainStep.
    """
        if not per_example_tensors:
            return tf.no_op()

        tensor_shapes = [
            py_utils.GetShape(per_example_tensors[key])
            for key in sorted(per_example_tensors)
        ]
        tensor_types = [
            tf.as_dtype(per_example_tensors[key].dtype)
            for key in sorted(per_example_tensors)
        ]

        def LoopBody(i, *input_arrays):
            """Process outfeed data for a single TpuTrainStep.

      Args:
        i: current loop index.
        *input_arrays: One tf.TensorArray per outfeed tensor.

      Returns:
        i+1 (new index) plus post-write tf.TensorArray handles.
      """
            # Outfeed ops execute on each JF node, so they must be located on the
            # nodes.
            outfeed_devices = []
            device_assignment = py_utils.GetTpuDeviceAssignment()
            assert device_assignment
            for replica in range(device_assignment.num_replicas):
                for core in range(device_assignment.num_cores_per_replica):
                    with tf.device(device_assignment.host_device(
                            replica, core)):
                        outfeed_devices.append(
                            tpu_ops.outfeed_dequeue_tuple(
                                tensor_types,
                                tensor_shapes,
                                device_ordinal=device_assignment.tpu_ordinal(
                                    replica, core)))
            offset = i * num_devices
            output_arrays = list(input_arrays)
            # Each output_array holds a different per-example tensor. We get results
            # for each tensor from each TPU for each TpuTrainStep call.
            for j in range(len(output_arrays)):
                for k in range(len(outfeed_devices)):
                    output_arrays[j] = output_arrays[j].write(
                        offset + k, outfeed_devices[k][j])

            return tuple([i + 1] + output_arrays)

        def LoopCond(i, *output_arrays):
            del output_arrays
            return i < num_loops

        output_arrays = []
        for i in range(len(tensor_shapes)):
            output_arrays.append(
                tf.TensorArray(tensor_types[i],
                               size=num_loops * num_devices,
                               element_shape=tensor_shapes[i]))
        # Loop once for each time that TpuTrainStep runs.
        output_arrays = tf.while_loop(LoopCond,
                                      LoopBody, [0] + output_arrays,
                                      parallel_iterations=1)[1:]
        concatenated_arrays = [array.concat() for array in output_arrays]
        return dict(zip(sorted(per_example_tensors), concatenated_arrays))
Exemplo n.º 7
0
 def _ValueFromText(key, old_val, val):
   """Returns the new param value from its text representation."""
   val_type = type(old_val).__name__
   if isinstance(old_val, str):
     val_type = 'str'
   if key in type_overrides:
     val_type = type_overrides[key]
   # Converts val (a string) to a best-guessed typed value.
   if val_type == 'bool':
     return val and (val != 'False') and (val != 'false')
   elif val_type == 'int':
     return int(val)
   elif val_type == 'float':
     return float(val)
   elif val_type == 'DType':
     return tf.as_dtype(val)
   elif dataclasses.is_dataclass(old_val) or _IsNamedTuple(old_val):
     # Maps field name to new value (or its string repr, if non-POD).
     name_to_new_value = ast.literal_eval(val)
     contents = {}
     items = old_val.__dict__.items() if dataclasses.is_dataclass(
         old_val) else old_val._asdict().items()
     for k, old_field_value in items:
       new_field_value = name_to_new_value[k]
       # Recurse to parse any non-POD contents not converted by
       # literal_eval().
       if isinstance(new_field_value, str):
         contents[k] = _ValueFromText(k, old_field_value, new_field_value)
       else:
         contents[k] = new_field_value
     return type(old_val)(**contents)
   elif val_type in ['list', 'tuple']:
     return ast.literal_eval(val)
   elif val_type == 'dict':
     return ast.literal_eval(val) if val != 'dict' else {}
   elif val_type == 'str':
     val = _UnquoteString(val)
     if val.startswith('[') and val.endswith(']'):
       # We may have stored a list as a string, try converting to a list.
       # In case of ValueError - use the string as is.
       try:
         return ast.literal_eval(val)
       except ValueError:
         pass
     return val
   elif isinstance(old_val, enum.Enum):
     cls, _, name = val.rpartition('.')
     if val_type != cls:
       raise ValueError('Expected enum of class %s but got %s' %
                        (val_type, cls))
     return type(old_val)[name]
   elif (isinstance(old_val, type) or isinstance(old_val, message.Message) or
         old_val is None):
     if val == 'NoneType':
       return None
     elif old_val is None and val in ('False', 'false'):
       return False
     elif old_val is None and val in ('True', 'true'):
       return True
     else:
       try:
         val_type, pkg, cls = val.split('/', 2)
         if val_type == 'type':
           return getattr(sys.modules[pkg], cls)
         elif val_type == 'proto':
           cls, proto_str = cls.split('/', 1)
           proto_cls = getattr(sys.modules[pkg], cls)
           if not issubclass(proto_cls, message.Message):
             raise ValueError('%s is not a proto class.' % proto_cls)
           return text_format.Parse(proto_str, proto_cls())
       except ValueError as e:
         raise ValueError('Error processing %r : %r with %r' % (key, val, e))
   else:
     raise ValueError('Failed to read a parameter: %r : %r' % (key, val))
Exemplo n.º 8
0
    def FromText(self, text):
        """Merges params specified in 'text' into 'params'.

    'text' follows the simple text format as produced by
    ParamsToSimpleText.  For a param specified in both 'params' and
    'text', overwrites the value in 'params' according to 'text'.
    Params specified in 'text' but not in 'params' are ignored.

    Args:
      text: A text representation of params.
    Raises:
      AttributeError: text contains invalid parameter key
      ValueError: text contains invalid parameter value
    """
        if self._immutable:
            raise TypeError('This Params instance is immutable.')
        kv = {}
        string_continue = None  # None or (key, quote, value)
        for line in text.split('\n'):
            # Continuing a multi-line string.
            if string_continue:
                value_stripped = line.rstrip()
                if not _EndsWithTerminalQuote(value_stripped,
                                              string_continue[1]):
                    # String continues
                    string_continue = (string_continue[0], string_continue[1],
                                       string_continue[2] + '\n' + line)
                    continue
                # String terminates.
                kv[string_continue[
                    0]] = string_continue[2] + '\n' + value_stripped
                string_continue = None
                continue

            # Regular line.
            line = line.strip()
            if not line or line[0] == '#':
                # empty line or comment
                continue
            pair = line.split(':', 1)
            if len(pair) == 2:
                key = pair[0].strip()
                value = pair[1].lstrip()
                value_stripped = value.rstrip()
                # Detect single vs multi-line string start.
                if value and value[0] in ['"', '\'']:
                    quote_char = value[0]
                    if not _EndsWithTerminalQuote(value[1:], quote_char):
                        # Multi-line string.
                        string_continue = (key, quote_char, value)
                        continue
                kv[key] = value_stripped
        for key, val in six.iteritems(kv):
            old_val = self.Get(key)
            # Converts val (a string) to a best-guessed typed value.
            if isinstance(old_val, bool):
                val = (val and (val != 'False') and (val != 'false'))
            elif isinstance(old_val, int):
                val = int(val)
            elif isinstance(old_val, float):
                val = float(val)
            elif isinstance(old_val, tf.DType):
                val = tf.as_dtype(val)
            elif isinstance(old_val, (six.string_types, six.text_type)):
                val = _UnquoteString(val)
            elif isinstance(old_val, (list, tuple)):
                val = ast.literal_eval(val)
            elif isinstance(old_val, dict):
                val = ast.literal_eval(val) if val != 'dict' else {}
            elif isinstance(old_val, type) or old_val is None:
                if val == 'NoneType':
                    val = None
                elif old_val is None and val in ('False', 'false'):
                    val = False
                elif old_val is None and val in ('True', 'true'):
                    val = True
                else:
                    try:
                        _, pkg, cls = val.split('/')
                        val = getattr(sys.modules[pkg], cls)
                    except ValueError as e:
                        raise ValueError('Error processing %r : %r with %r' %
                                         (key, val, e))
            else:
                raise ValueError('Failed to read a parameter: %r : %r' %
                                 (key, val))
            self.Set(**{key: val})
Exemplo n.º 9
0
 def Wrap(val):
   dtype = tf.as_dtype(val.dtype)
   assert dtype != tf.string  # tf.string is not supported by py_func.
   return tf.py_func(lambda: val, [], dtype)