def _FromParamValue(param_pb: hyperparams_pb2.HyperparamValue) -> Any: """Deserializes HyperparamValue proto.""" which_oneof = param_pb.WhichOneof('kind') if not which_oneof: return None elif which_oneof == 'param_val': return _FromParam(param_pb.param_val) elif which_oneof == 'list_val': return [ _FromParamValue(val) for val in param_pb.list_val.items ] elif which_oneof == 'named_tuple_val': named_tuple_cls = _LoadClass(param_pb.named_tuple_val.type) if not dataclasses.is_dataclass( named_tuple_cls) and not issubclass( named_tuple_cls, tuple): return None return named_tuple_cls(*[ _FromParamValue(val) for val in param_pb.named_tuple_val.items ]) elif which_oneof == 'tuple_val': return tuple( [_FromParamValue(val) for val in param_pb.tuple_val.items]) elif which_oneof == 'dict_val': dict_val = dict() for k in param_pb.dict_val.items: dict_val[k] = _FromParamValue(param_pb.dict_val.items[k]) return dict_val elif which_oneof == 'type_val': tokens = param_pb.type_val.split('/') assert len(tokens) == 2 return getattr(importlib.import_module(tokens[0]), tokens[1]) elif which_oneof == 'dtype_val': return tf.as_dtype(param_pb.dtype_val) elif which_oneof == 'enum_val': enum_cls = _LoadClass(param_pb.enum_val.type) if not issubclass(enum_cls, enum.Enum): return None return enum_cls[param_pb.enum_val.name] elif which_oneof == 'proto_val': proto_cls = _LoadClass(param_pb.proto_val.type) if not issubclass(proto_cls, message.Message): return None proto_msg = proto_cls() proto_msg.ParseFromString(param_pb.proto_val.val) return proto_msg elif which_oneof == 'symbolic_val': sym = pickle.loads(param_pb.symbolic_val) if not symbolic.IsExpr(sym): raise TypeError( 'Unexpected result when deserializing symbolic expr.') return sym elif which_oneof == 'string_repr_val': raise TypeError( 'Cannot deserialize string_repr_val instance: %s' % param_pb.string_repr_val) else: return getattr(param_pb, which_oneof)
def _FromParamValue(param_pb): """Deserializes HyperparamValue proto.""" which_oneof = param_pb.WhichOneof('kind') if which_oneof == 'param_val': return _FromParam(param_pb.param_val) elif which_oneof == 'list_val': return [_FromParamValue(val) for val in param_pb.list_val.items] elif which_oneof == 'tuple_val': return tuple([_FromParamValue(val) for val in param_pb.tuple_val.items]) elif which_oneof == 'dict_val': dict_val = dict() for k in param_pb.dict_val.items: dict_val[k] = _FromParamValue(param_pb.dict_val.items[k]) return dict_val elif which_oneof == 'type_val': tokens = param_pb.type_val.split('/') assert len(tokens) == 2 return getattr(importlib.import_module(tokens[0]), tokens[1]) elif which_oneof == 'dtype_val': return tf.as_dtype(param_pb.dtype_val) elif which_oneof == 'string_val': return param_pb.string_val elif which_oneof == 'int_val': return param_pb.int_val elif which_oneof == 'float_val': return param_pb.float_val elif which_oneof == 'bool_val': return param_pb.bool_val else: # If nothing is set, it's the None type. return None
def _FromParamValue(param_pb): """Deserializes HyperparamValue proto.""" which_oneof = param_pb.WhichOneof('kind') if which_oneof == 'param_val': return _FromParam(param_pb.param_val) elif which_oneof == 'list_val': return [ _FromParamValue(val) for val in param_pb.list_val.items ] elif which_oneof == 'named_tuple_val': named_tuple_cls = _LoadClass(param_pb.named_tuple_val.type) if not dataclasses.is_dataclass( named_tuple_cls) and not issubclass( named_tuple_cls, tuple): return None return named_tuple_cls(*[ _FromParamValue(val) for val in param_pb.named_tuple_val.items ]) elif which_oneof == 'tuple_val': return tuple( [_FromParamValue(val) for val in param_pb.tuple_val.items]) elif which_oneof == 'dict_val': dict_val = dict() for k in param_pb.dict_val.items: dict_val[k] = _FromParamValue(param_pb.dict_val.items[k]) return dict_val elif which_oneof == 'type_val': tokens = param_pb.type_val.split('/') assert len(tokens) == 2 return getattr(importlib.import_module(tokens[0]), tokens[1]) elif which_oneof == 'dtype_val': return tf.as_dtype(param_pb.dtype_val) elif which_oneof == 'string_val': return param_pb.string_val elif which_oneof == 'int_val': return param_pb.int_val elif which_oneof == 'float_val': return param_pb.float_val elif which_oneof == 'bool_val': return param_pb.bool_val elif which_oneof == 'enum_val': enum_cls = _LoadClass(param_pb.enum_val.type) if not issubclass(enum_cls, enum.Enum): return None return enum_cls[param_pb.enum_val.name] elif which_oneof == 'proto_val': proto_cls = _LoadClass(param_pb.proto_val.type) if not issubclass(proto_cls, message.Message): return None proto_msg = proto_cls() proto_msg.ParseFromString(param_pb.proto_val.val) return proto_msg else: # If nothing is set, it's the None type. return None
def _config_infeed(self, num_partitions, device_assignment, batch_size, key_size=2, return_tgt_mask=False, use_partitioned_infeed_queue=False): """Config the infeed ops and args.""" zero_batch = get_zero_batch(batch_size=batch_size, max_len=self._prefix_max_len, key_size=key_size, return_tgt_mask=return_tgt_mask) host_device = device_assignment.host_device(replica=0, job=self._tpu) host_id = int(host_device.split('/task:')[1].split('/device:')[0]) input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1) for x in zero_batch] if use_partitioned_infeed_queue: infeed = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(zero_batch), host_id=host_id, input_partition_dims=input_partition_dims, device_assignment=device_assignment) else: infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(zero_batch)) self.infeed_args = [] for x in zero_batch: p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape) self.infeed_args += [p] if use_partitioned_infeed_queue: self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args]) else: self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops( self.infeed_args, device_assignment=device_assignment) return infeed
def FromText(self, text, type_overrides=None): """Merges params specified in 'text' into 'params'. 'text' follows the simple text format as produced by ParamsToSimpleText. For a param specified in both 'params' and 'text', overwrites the value in 'params' according to 'text'. Params specified in 'text' but not in 'params' are ignored. Args: text: A text representation of params. type_overrides: Overrides for the types of the params. Raises: AttributeError: text contains invalid parameter key ValueError: text contains invalid parameter value """ if self._immutable: raise TypeError('This Params instance is immutable.') kv = {} type_overrides = type_overrides or {} string_continue = None # None or (key, quote, value) for line in text.split('\n'): # Continuing a multi-line string. if string_continue: value_stripped = line.rstrip() if not _EndsWithTerminalQuote(value_stripped, string_continue[1]): # String continues string_continue = (string_continue[0], string_continue[1], string_continue[2] + '\n' + line) continue # String terminates. kv[string_continue[ 0]] = string_continue[2] + '\n' + value_stripped string_continue = None continue # Regular line. line = line.strip() if not line or line[0] == '#': # empty line or comment continue pair = line.split(':', 1) if len(pair) == 2: key = pair[0].strip() value = pair[1].lstrip() value_stripped = value.rstrip() # Detect single vs multi-line string start. if value and value[0] in ['"', '\'']: quote_char = value[0] if not _EndsWithTerminalQuote(value[1:], quote_char): # Multi-line string. string_continue = (key, quote_char, value) continue kv[key] = value_stripped for key, val in six.iteritems(kv): old_val = self.Get(key) val_type = type(old_val).__name__ if isinstance(old_val, (six.string_types, six.text_type)): val_type = 'str' if key in type_overrides: val_type = type_overrides[key] # Converts val (a string) to a best-guessed typed value. if val_type == 'bool': val = (val and (val != 'False') and (val != 'false')) elif val_type == 'int': val = int(val) elif val_type == 'float': val = float(val) elif val_type == 'DType': val = tf.as_dtype(val) elif val_type in ['list', 'tuple']: val = ast.literal_eval(val) elif val_type == 'dict': val = ast.literal_eval(val) if val != 'dict' else {} elif val_type == 'str': val = _UnquoteString(val) if val.startswith('[') and val.endswith(']'): # We may have stored a list as a string, try converting to a list. # In case of ValueError - use the string as is. try: val = ast.literal_eval(val) except ValueError: pass elif isinstance(old_val, enum.Enum): cls, _, name = val.rpartition('.') if val_type != cls: raise ValueError('Expected enum of class %s but got %s' % (val_type, cls)) val = type(old_val)[name] elif (isinstance(old_val, type) or isinstance(old_val, message.Message) or old_val is None): if val == 'NoneType': val = None elif old_val is None and val in ('False', 'false'): val = False elif old_val is None and val in ('True', 'true'): val = True else: try: val_type, pkg, cls = val.split('/', 2) if val_type == 'type': val = getattr(sys.modules[pkg], cls) elif val_type == 'proto': cls, proto_str = cls.split('/', 1) proto_cls = getattr(sys.modules[pkg], cls) if not issubclass(proto_cls, message.Message): raise ValueError('%s is not a proto class.' % proto_cls) val = text_format.Parse(proto_str, proto_cls()) except ValueError as e: raise ValueError('Error processing %r : %r with %r' % (key, val, e)) else: raise ValueError('Failed to read a parameter: %r : %r' % (key, val)) self.Set(**{key: val})
def _OutfeedDequeueLoop(self, per_example_tensors, num_loops, num_devices): """Process all per-example tensor outfeed data for a TPU sess.run. Args: per_example_tensors: dict of key -> tensor as generated by TpuTrainStep. num_loops: number of times that TpuTrainStep will be executed by TpuTrain. num_devices: number of TPU cores assigned to this process. Returns: A dict of per-example tensors from the latest TpuTrainStep. """ if not per_example_tensors: return tf.no_op() tensor_shapes = [ py_utils.GetShape(per_example_tensors[key]) for key in sorted(per_example_tensors) ] tensor_types = [ tf.as_dtype(per_example_tensors[key].dtype) for key in sorted(per_example_tensors) ] def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): for core in range(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device( replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal( replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write( offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays) def LoopCond(i, *output_arrays): del output_arrays return i < num_loops output_arrays = [] for i in range(len(tensor_shapes)): output_arrays.append( tf.TensorArray(tensor_types[i], size=num_loops * num_devices, element_shape=tensor_shapes[i])) # Loop once for each time that TpuTrainStep runs. output_arrays = tf.while_loop(LoopCond, LoopBody, [0] + output_arrays, parallel_iterations=1)[1:] concatenated_arrays = [array.concat() for array in output_arrays] return dict(zip(sorted(per_example_tensors), concatenated_arrays))
def _ValueFromText(key, old_val, val): """Returns the new param value from its text representation.""" val_type = type(old_val).__name__ if isinstance(old_val, str): val_type = 'str' if key in type_overrides: val_type = type_overrides[key] # Converts val (a string) to a best-guessed typed value. if val_type == 'bool': return val and (val != 'False') and (val != 'false') elif val_type == 'int': return int(val) elif val_type == 'float': return float(val) elif val_type == 'DType': return tf.as_dtype(val) elif dataclasses.is_dataclass(old_val) or _IsNamedTuple(old_val): # Maps field name to new value (or its string repr, if non-POD). name_to_new_value = ast.literal_eval(val) contents = {} items = old_val.__dict__.items() if dataclasses.is_dataclass( old_val) else old_val._asdict().items() for k, old_field_value in items: new_field_value = name_to_new_value[k] # Recurse to parse any non-POD contents not converted by # literal_eval(). if isinstance(new_field_value, str): contents[k] = _ValueFromText(k, old_field_value, new_field_value) else: contents[k] = new_field_value return type(old_val)(**contents) elif val_type in ['list', 'tuple']: return ast.literal_eval(val) elif val_type == 'dict': return ast.literal_eval(val) if val != 'dict' else {} elif val_type == 'str': val = _UnquoteString(val) if val.startswith('[') and val.endswith(']'): # We may have stored a list as a string, try converting to a list. # In case of ValueError - use the string as is. try: return ast.literal_eval(val) except ValueError: pass return val elif isinstance(old_val, enum.Enum): cls, _, name = val.rpartition('.') if val_type != cls: raise ValueError('Expected enum of class %s but got %s' % (val_type, cls)) return type(old_val)[name] elif (isinstance(old_val, type) or isinstance(old_val, message.Message) or old_val is None): if val == 'NoneType': return None elif old_val is None and val in ('False', 'false'): return False elif old_val is None and val in ('True', 'true'): return True else: try: val_type, pkg, cls = val.split('/', 2) if val_type == 'type': return getattr(sys.modules[pkg], cls) elif val_type == 'proto': cls, proto_str = cls.split('/', 1) proto_cls = getattr(sys.modules[pkg], cls) if not issubclass(proto_cls, message.Message): raise ValueError('%s is not a proto class.' % proto_cls) return text_format.Parse(proto_str, proto_cls()) except ValueError as e: raise ValueError('Error processing %r : %r with %r' % (key, val, e)) else: raise ValueError('Failed to read a parameter: %r : %r' % (key, val))
def FromText(self, text): """Merges params specified in 'text' into 'params'. 'text' follows the simple text format as produced by ParamsToSimpleText. For a param specified in both 'params' and 'text', overwrites the value in 'params' according to 'text'. Params specified in 'text' but not in 'params' are ignored. Args: text: A text representation of params. Raises: AttributeError: text contains invalid parameter key ValueError: text contains invalid parameter value """ if self._immutable: raise TypeError('This Params instance is immutable.') kv = {} string_continue = None # None or (key, quote, value) for line in text.split('\n'): # Continuing a multi-line string. if string_continue: value_stripped = line.rstrip() if not _EndsWithTerminalQuote(value_stripped, string_continue[1]): # String continues string_continue = (string_continue[0], string_continue[1], string_continue[2] + '\n' + line) continue # String terminates. kv[string_continue[ 0]] = string_continue[2] + '\n' + value_stripped string_continue = None continue # Regular line. line = line.strip() if not line or line[0] == '#': # empty line or comment continue pair = line.split(':', 1) if len(pair) == 2: key = pair[0].strip() value = pair[1].lstrip() value_stripped = value.rstrip() # Detect single vs multi-line string start. if value and value[0] in ['"', '\'']: quote_char = value[0] if not _EndsWithTerminalQuote(value[1:], quote_char): # Multi-line string. string_continue = (key, quote_char, value) continue kv[key] = value_stripped for key, val in six.iteritems(kv): old_val = self.Get(key) # Converts val (a string) to a best-guessed typed value. if isinstance(old_val, bool): val = (val and (val != 'False') and (val != 'false')) elif isinstance(old_val, int): val = int(val) elif isinstance(old_val, float): val = float(val) elif isinstance(old_val, tf.DType): val = tf.as_dtype(val) elif isinstance(old_val, (six.string_types, six.text_type)): val = _UnquoteString(val) elif isinstance(old_val, (list, tuple)): val = ast.literal_eval(val) elif isinstance(old_val, dict): val = ast.literal_eval(val) if val != 'dict' else {} elif isinstance(old_val, type) or old_val is None: if val == 'NoneType': val = None elif old_val is None and val in ('False', 'false'): val = False elif old_val is None and val in ('True', 'true'): val = True else: try: _, pkg, cls = val.split('/') val = getattr(sys.modules[pkg], cls) except ValueError as e: raise ValueError('Error processing %r : %r with %r' % (key, val, e)) else: raise ValueError('Failed to read a parameter: %r : %r' % (key, val)) self.Set(**{key: val})
def Wrap(val): dtype = tf.as_dtype(val.dtype) assert dtype != tf.string # tf.string is not supported by py_func. return tf.py_func(lambda: val, [], dtype)