def _generate_sources( self, op: op_pb2.OpProto, operand_types: Iterable[xls_type_pb2.TypeProto], output_type: xls_type_pb2.TypeProto) -> Tuple[str, str]: """Generates XLS IR and netlist sources for a single LEC execution. This function creates IR and a netlist for the given op and argument/output types, suitable as inputs to a LEC operation. Currently, this only supports a single op (per the internal operation of op_module_generator). In the future, it will likely be useful to see how execution time scales with operation composition. Args: op: The XLS IR opcode for which to generate sources. operand_types: The types of the arguments to use for this op execution. output_type: The type of the operation output. Returns: A tuple of IR and netlist sources for executing the given operation as text. """ op_name = op_pb2.OpProto.Name(op)[3:].lower() operand_type_strs = [ self._proto_to_ir_type(ot) for ot in operand_types ] ir_text = op_module_generator.generate_ir_package( op_name, self._proto_to_ir_type(output_type), operand_type_strs, [], None) verilog_text = op_module_generator.generate_verilog_module( self._MODULE_NAME, ir_text).verilog_text creds = client_credentials.get_credentials() netlist_text = None with grpc.secure_channel(self._synthesis_server_address, creds) as channel: grpc.channel_ready_future(channel).result() stub = synthesis_service_pb2_grpc.SynthesisServiceStub(channel) request = synthesis_pb2.CompileRequest() logging.vlog(logging.INFO, 'Module text:\n %s', verilog_text) request.module_text = verilog_text request.top_module_name = self._MODULE_NAME # We're always going to be in a single cycle. request.target_frequency_hz = 1 response = stub.Compile(request) netlist_text = response.netlist return (ir_text, netlist_text)
def _convert_one_function(package: ir_package.Package, module: ast.Module, function: ast.Function, type_info: type_info_mod.TypeInfo, symbolic_bindings: Optional[SymbolicBindings] = None, emit_positions: bool = True) -> Text: """Converts a single function into its emitted text form. Args: package: IR package we're converting the function into. module: Module we're converting a function within. function: Function we're converting. type_info: Type information about module from the typechecking phase. symbolic_bindings: Parametric bindings to use during conversion, if this function is parametric. emit_positions: Whether to emit position information into the IR based on the AST's source positions. Returns: The converted IR function text. """ function_by_name = module.get_function_by_name() constant_by_name = module.get_constant_by_name() converter = _IrConverterFb( package, module, type_info, emit_positions=emit_positions) freevars = function.body.get_free_variables( function.span.start).get_name_def_tups() logging.vlog(3, 'Free variables for function: %s', freevars) for identifier, name_def in freevars: if identifier in function_by_name or isinstance(name_def, ast.BuiltinNameDef): pass elif identifier in constant_by_name: converter.add_constant_dep(constant_by_name[identifier]) else: raise NotImplementedError(identifier) symbolic_binding_keys = set(k for k, _ in symbolic_bindings or ()) f_parametric_keys = function.get_free_parametric_keys() if f_parametric_keys > symbolic_binding_keys: raise ValueError( 'Not enough symbolic bindings to convert function {!r}; need {!r} got {!r}' .format(function.name.identifier, f_parametric_keys, symbolic_binding_keys)) logging.vlog(3, 'Converting function: %s; symbolic bindings: %s', function, symbolic_bindings) f = converter.visit_Function(function, symbolic_bindings) return f.dump_ir(recursive=False)
def _deduce_ModRef(self: ast.ModRef, ctx: DeduceCtx) -> ConcreteType: # pytype: disable=wrong-arg-types """Deduces the type of an entity referenced via module reference.""" imported_module, imported_type_info = ctx.type_info.get_imported(self.mod) leaf_name = self.value # May be a type definition reference. if leaf_name in imported_module.get_typedef_by_name(): td = imported_module.get_typedef_by_name()[leaf_name] if not td.public: raise TypeInferenceError( self.span, type_=None, suffix='Attempted to refer to module type that is not public.') return imported_type_info[td.name] # May be a function reference. try: f = imported_module.get_function(leaf_name) except KeyError: raise TypeInferenceError( self.span, type_=None, suffix='Module {!r} function {!r} does not exist.'.format( imported_module.name, leaf_name)) if not f.public: raise TypeInferenceError( self.span, type_=None, suffix= 'Attempted to refer to module {!r} function {!r} that is not public.' .format(imported_module.name, f.name)) if f.name not in imported_type_info: logging.vlog( 2, 'Function name not in imported_type_info; must be parametric: %r', f.name) assert f.is_parametric() # We don't type check parametric functions until invocations. # Let's typecheck this imported parametric function with respect to its # module (this will only get the type signature, body gets typechecked # after parametric instantiation). imported_ctx = DeduceCtx(imported_type_info, imported_module, ctx.interpret_expr, ctx.check_function_in_module) imported_ctx.fn_stack.append(ctx.fn_stack[-1]) ctx.check_function_in_module(f, imported_ctx) ctx.type_info.update(imported_ctx.type_info) imported_type_info = imported_ctx.type_info return imported_type_info[f.name]
def excess_wrong_bits(f, *args): """Returns excess inaccuracy of 32-bit `f(*args)`, relative to conditioning. If this is positive, that suggests the implementation of `f` is introducing unnecessary numerical error at the given arguments. This function assumes that `f` is differentiable, and that numerical error when computing `f` or its derivatives in float64 is negligible. One necessary condition for this to work correctly is that `f` be _dtype-polymorphic_: the dtype in which computations internal to `f` (and its derivatives) are performed should match the dtype of the arguments of `f`. Args: f: Function whose accuracy to evaluate. Must be differentiable and dtype-polymorphic. *args: Arguments at which to test the accuracy of `f`. Returns: wrong: The wrong bits when computing `f(*args)` in float32, in excess of what would be expected from `f` being ill-conditioned. """ err = relative_error_at(f, *args) logging.vlog(1, 'Relative error: %s', err) conditioning_err = error_due_to_ill_conditioning(f, *args) logging.vlog(1, 'Relative error due to input rounding: %s', conditioning_err) wrong = wrong_bits(err) conditioning = wrong_bits(conditioning_err) logging.vlog(1, 'Wrong bits: %s', wrong) logging.vlog(1, 'Wrong bits due to input rounding: %s', conditioning) return wrong - conditioning
def enable_v2_tensorshape(): """In TensorFlow 2.0, iterating over a TensorShape instance returns values. This enables the new behavior. Concretely, `tensor_shape[i]` returned a Dimension instance in V1, but it V2 it returns either an integer, or None. Examples: ``` ####################### # If you had this in V1: value = tensor_shape[i].value # Do this in V2 instead: value = tensor_shape[i] ####################### # If you had this in V1: for dim in tensor_shape: value = dim.value print(value) # Do this in V2 instead: for value in tensor_shape: print(value) ####################### # If you had this in V1: dim = tensor_shape[i] dim.assert_is_compatible_with(other_shape) # or using any other shape method # Do this in V2 instead: if tensor_shape.rank is None: dim = Dimension(None) else: dim = tensor_shape.dims[i] dim.assert_is_compatible_with(other_shape) # or using any other shape method # The V2 suggestion above is more explicit, which will save you from # the following trap (present in V1): # you might do in-place modifications to `dim` and expect them to be reflected # in `tensor_shape[i]`, but they would not be. ``` """ global _TENSORSHAPE_V2_OVERRIDE # pylint: disable=invalid-name _TENSORSHAPE_V2_OVERRIDE = True logging.vlog(1, "Enabling v2 tensorshape") _api_usage_gauge.get_cell().set(True)
def collect_trajectories(self, evaluate): logging.info('SimPLe epoch [% 6d]: collecting data.', self._simple_epoch) start_time = time.time() self.policy_trainer.train_env = self.train_env self.policy_trainer.trajectory_dump_dir = os.path.join( self._trajectory_dump_root_dir, str(self.epoch)) self._policy_epoch += self._n_real_epochs self.policy_trainer.training_loop(self._policy_epoch, evaluate=evaluate) logging.vlog(1, 'Collecting trajectories took %0.2f sec.', time.time() - start_time)
def barrier_wait(logging_name: Optional[str] = None): """Blocks the calling thread until all current outfeed is processed. Waits until all outfeed from computations already running on all devices has been received and processed by the Python callbacks. Raises TapFunctionException if there were exceptions while processing the callbacks. This works by enqueueing a special tap computation to all devices to which we are listening for outfeed. Once all those tap computations are done, we return from barrier_wait. Note: If any of the devices are busy and cannot accept new computations, this will deadlock. Args: logging_name: an optional string that will be used in the logging statements for this invocation. See `Debugging` in the module documentation. """ logging_name = logging_name or "" logging.vlog(2, f"barrier_wait[{logging_name}]: start") if not _outfeed_receiver.receiver: logging.vlog(2, f"barrier_wait[{logging_name}]: receiver not started") return lock = threading.Lock() cv = threading.Condition(lock=lock) num_at_large = len(_outfeed_receiver.devices) # Protected by lock def barrier_tap(dev_idx, _): nonlocal num_at_large logging.vlog( 2, f"barrier_wait[{logging_name}]: at barrier_tap for device {_outfeed_receiver.devices[dev_idx]} " f". Thread {threading.current_thread()}") with lock: num_at_large -= 1 logging.vlog(2, f"barrier_wait[{logging_name}]: still waiting for {num_at_large} barrier_tap") cv.notify() for d_idx, d in enumerate(_outfeed_receiver.devices): logging.vlog(2, f"barrier_wait[{logging_name}]: enqueueing barrier on device {d}") x_on_dev = api.device_put(d_idx, device=d) api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev) logging.vlog(2, f"barrier_wait[{logging_name}]: waiting for callbacks") with lock: cv.wait_for(lambda: num_at_large == 0) logging.vlog(2, f"barrier_wait[{logging_name}]: done") if _outfeed_receiver.num_tap_exceptions > 0: _outfeed_receiver.num_tap_exceptions = 0 raise TapFunctionException( "There were exceptions during id_tap processing.")
def save(self): """Save the agent parameters.""" logging.vlog(1, "Epoch [% 6d] saving model.", self._epoch) old_model_files = gfile.glob( os.path.join(self._output_dir, "model-??????.pkl")) params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch) with gfile.GFile(params_file, "wb") as f: pickle.dump(self._policy_and_value_net_params, f) # Remove the old model files. for path in old_model_files: gfile.remove(path) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def get_compile_options(num_replicas, num_partitions, device_assignment=None, use_spmd_partitioning=True): """Returns the compile options to use, as derived from flag values. Args: num_replicas: int indicating the number of replicas for which to compile. num_partitions: int indicating the number of partitions for which to compile. device_assignment: Optional tuple of integers indicating the assignment of logical replicas to physical devices (default inherited from xla_client.CompileOptions). Must be consistent with `num_replicas` and `num_partitions`. use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD partitioning in XLA. """ compile_options = xla_client.CompileOptions() compile_options.num_replicas = num_replicas compile_options.num_partitions = num_partitions build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning if device_assignment is not None: logging.vlog( 2, 'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s', num_replicas, num_partitions, device_assignment) device_assignment = np.array(device_assignment) # Allow 1D device assignment if num_partitions is 1. if (device_assignment.ndim == 1) and (num_partitions == 1): device_assignment = device_assignment[:, None] if num_replicas != device_assignment.shape[0]: msg = 'device_assignment does not match num_replicas: {} vs {}.' raise ValueError(msg.format(device_assignment, num_replicas)) if num_partitions != device_assignment.shape[1]: msg = 'device_assignment does not match num_partitions: {} vs {}.' raise ValueError(msg.format(device_assignment, num_partitions)) device_assignment = xla_client.DeviceAssignment.create(device_assignment) assert device_assignment.replica_count() == num_replicas assert device_assignment.computation_count() == num_partitions compile_options.device_assignment = device_assignment if FLAGS.jax_disable_most_optimizations: debug_options = compile_options.executable_build_options.debug_options debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False return compile_options
def compile_replicated(jaxpr, backend, axis_name, axis_size, global_axis_size, devices, consts, tuple_args, *abstract_args): jaxpr_replicas = xla.jaxpr_replicas(jaxpr) num_local_replicas = axis_size * jaxpr_replicas num_replicas = global_axis_size * jaxpr_replicas logging.vlog( 1, "compile_replicated: axis_size=%d global_axis_size=%d jaxpr_replicas=%d" % (axis_size, global_axis_size, jaxpr_replicas)) if devices is None: if num_replicas > xb.device_count(backend): msg = ( "compiling computation that requires {} replicas, but only {} XLA " "devices are available") raise ValueError(msg.format(num_replicas, xb.device_count(backend))) device_assignment = None else: assert any(d.host_id == xb.host_id() for d in devices) local_devices = [d for d in devices if d.host_id == xb.host_id()] assert len(local_devices) > 0 if num_local_replicas != len(local_devices): local_devices_str = ", ".join(map(str, local_devices)) raise ValueError( "Leading axis size of input to pmapped function must equal the " "number of local devices passed to pmap. Got axis_size=%d, " "num_local_devices=%d.\n(Local devices passed to pmap: %s)" % (axis_size, len(local_devices), local_devices_str)) if num_replicas != len(devices): raise ValueError( "compiling computation that requires %s replicas, " "but %s devices were specified" % (num_replicas, len(devices))) device_assignment = tuple(d.id for d in devices) axis_env = xla.AxisEnv(num_replicas, [axis_name], [global_axis_size], devices) arg_shapes = list(map(aval_to_xla_shape, abstract_args)) built_c = xla.jaxpr_computation(jaxpr, backend, axis_env, consts, (), arg_shapes, tuple_args=tuple_args, inner=False) compiled = built_c.Compile(compile_options=xb.get_compile_options( num_replicas, device_assignment), backend=xb.get_backend(backend)) return compiled, num_local_replicas
def _initialize_outfeed_receiver( clients: Optional[List[XlaLocalClient]] = None, max_callback_queue_size_bytes: int = int(256 * 1e6)): """Creates and starts the outfeed_receiver. This function is called lazily only when we compile an id_tap. Args: * clients: the list of clients (backends) on whose devices to listen on. * max_callback_queue_size_bytes: an optional integer to bound the maximum size of arrays in the callback queue. When this limit is reached the device listener pauses. """ try: outfeed_receiver_module = xla_extension.outfeed_receiver except AttributeError: raise NotImplementedError( "id_tap works only with jaxlib version 0.1.51 and higher") with _outfeed_receiver.lock: if _outfeed_receiver.receiver is not None: return if clients is None: # By default, all devices on all backends clients = xla_client._get_local_backends().values( ) # type: ignore[protected-class] # Drop the interpreter clients clients = tuple([ c for c in clients if c.platform != "interpreter" ]) # type: ignore devices = list( itertools.chain(*[backend.devices() for backend in clients])) _outfeed_receiver.clients = clients # type: ignore[assignment] _outfeed_receiver.devices = devices # type: ignore[assignment] logging.vlog( 2, f"Starting outfeed_receiver for {[str(d) for d in devices]}. " f"max_callback_queue_size_bytes={max_callback_queue_size_bytes}") _outfeed_receiver.receiver = outfeed_receiver_module.start( _outfeed_receiver_callback, tuple(clients), max_callback_queue_size_bytes) def exit_handler(): # Prevent logging usage during compilation, gives errors under pytest xla._on_exit = True logging.vlog(2, "Barrier wait atexit") barrier_wait() atexit.register(exit_handler) # We wait as long as we have callbacks
def _chain_gets_correct_expectations(self, x, independent_chain_ndims): counter = collections.Counter() def log_gamma_log_prob(x): counter['target_calls'] += 1 event_dims = ps.range(independent_chain_ndims, ps.rank(x)) return self._log_gamma_log_prob(x, event_dims) samples, kernel_results = tfp.mcmc.sample_chain( num_results=150, current_state=x, kernel=tfp.mcmc.HamiltonianMonteCarlo( target_log_prob_fn=log_gamma_log_prob, step_size=0.05, num_leapfrog_steps=2), num_burnin_steps=150, seed=test_util.test_seed()) if tf.executing_eagerly() and not JAX_MODE: # JAX always traces loops # TODO(b/79991421): Figure out why this is approx twice as many as it # should be. I.e., `expected_calls = (150 + 150) * 2 + 1`. expected_calls = 1202 else: expected_calls = 4 self.assertAllEqual(dict(target_calls=expected_calls), counter) expected_x = (tf.math.digamma(self._shape_param) - np.log(self._rate_param)) expected_exp_x = self._shape_param / self._rate_param log_accept_ratio_, samples_, expected_x_ = self.evaluate( [kernel_results.log_accept_ratio, samples, expected_x]) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging.vlog( 1, 'True E[x, exp(x)]: {}\t{}'.format(expected_x_, expected_exp_x)) logging.vlog( 1, 'Estimated E[x, exp(x)]: {}\t{}'.format(actual_x, actual_exp_x)) self.assertAllClose(actual_x, expected_x_, atol=.045, rtol=0.) self.assertAllClose(actual_exp_x, expected_exp_x, atol=.02, rtol=0.) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs > 0.5) self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), acceptance_probs <= 1.)
def test_create_contiguous_submeshes_for_tpu_v4(self): v4 = mesh_utils._TPU_V4 for topology, mesh_shapes in mesh_utils._TRANSPOSE_TRICKS.items(): logging.vlog(1, "topology: %s", topology) devices = mock_devices(topology[0], topology[1], topology[2], v4, one_device_per_chip=True) for mesh_shape in mesh_shapes: logging.vlog(1, " mesh_shape: %s", mesh_shape) mesh = mesh_utils.create_device_mesh(mesh_shape, devices=devices, contiguous_submeshes=True) self._assert_contiguous_submeshes(mesh)
def save(self): """Save the agent parameters.""" logging.vlog(1, "PPO epoch [% 6d]: saving model.", self._epoch) old_model_files = ppo.get_policy_model_files(self._output_dir) params_file = os.path.join(self._output_dir, "model-%06d.pkl" % self._epoch) with gfile.GFile(params_file, "wb") as f: pickle.dump((self._policy_and_value_opt_state, self._model_state, self._total_opt_step), f) # Remove the old model files. for path in old_model_files: if path != params_file: gfile.remove(path) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
def evaluate(self): """Evaluate the agent.""" logging.vlog(1, "PPO epoch [% 6d]: evaluating policy.", self._epoch) self._rng, key = jax_random.split(self._rng, num=2) reward_stats, self._model_state = ppo.evaluate_policy( self.eval_env, self._get_predictions, temperatures=self._eval_temperatures, max_timestep=self._max_timestep_eval, n_evals=self._n_evals, len_history_for_policy=self._len_history_for_policy, state=self._model_state, rng=key) ppo.write_eval_reward_summaries( reward_stats, self._eval_sw, epoch=self._epoch)
def _check_shapes(array_name, expected_shape_string, array, expected_shape, array_prefix=None): actual_shape = array.shape[:array_prefix] prefix = '' if not array_prefix else f'[:{array_prefix}]' logging.vlog(1, f'Shape of {array_name}{prefix} is {actual_shape}.') if array_prefix: logging.vlog(1, f'Shape of {array_name} is {array.shape}.') if actual_shape != expected_shape: raise ValueError( f'Shape of {array_name}{prefix} is expected to be ' f'{expected_shape_string} which is {expected_shape}, but is ' f'{actual_shape} instead.')
def evaluate(self): """Evaluate the agent.""" if not self._separate_eval: return logging.vlog(1, 'PPO epoch [% 6d]: evaluating policy.', self.epoch) if self._controller is not None: ntp_updates = self._controller(self._history)(self.epoch) self._nontrainable_params.update(ntp_updates) (_, _, opt_params) = self._policy_and_value_opt_state opt_params.update(ntp_updates) for (name, value) in self._nontrainable_params.items(): self._log('train', 'training/{}'.format(name), value) super(PPO, self).evaluate()
def is_constant(cls, expr: Expr) -> bool: """Returns true iff 'expr' is ok to hold in an ast.Constant value.""" if isinstance(expr, ConstantArray): return True if isinstance(expr, EnumRef): return True if isinstance(expr, Number): return True if isinstance(expr, ConstRef): return True if isinstance(expr, XlsTuple): return all(cls.is_constant(m) for m in expr.members) if isinstance(expr, Cast): return cls.is_constant(expr.expr) logging.vlog(5, 'Not constant: %r', expr) return False
def _conv_kernel_size(node, name_to_node): """Computes kernel size given a TF convolution node. Args: node: Tensorflow node (NodeDef proto). name_to_node: Dict keyed by node name, each entry containing the node's NodeDef. Returns: kernel_size_x: Kernel size for horizontal direction (integer). kernel_size_y: Kernel size for vertical direction (integer). Raises: ValueError: If the weight layer node is misconfigured. """ weights_layer_read_name = node.input[1] print(weights_layer_read_name) if weights_layer_read_name.endswith("/read"): weights_layer_param_name = weights_layer_read_name[:-5] elif weights_layer_read_name.endswith("/Conv2D/ReadVariableOp"): weights_layer_param_name = weights_layer_read_name + "/resource" #weights_layer_param_name = weights_layer_read_name[:-22] + "/kernel" elif weights_layer_read_name.endswith("/depthwise/ReadVariableOp"): weights_layer_param_name = weights_layer_read_name + "/resource" #weights_layer_param_name = weights_layer_read_name[:-22] + "/kernel" else: raise ValueError( "Weight layer's name input to conv layer does not end with '/read' or " "'/Conv2D/ReadVariableOp': %s" % weights_layer_read_name) weights_node = name_to_node[weights_layer_param_name] if weights_node.op == "VariableV2" or weights_node.op == "VarHandleOp": shape_dim = weights_node.attr["shape"].shape.dim elif weights_node.op == "Const": shape_dim = weights_node.attr["value"].tensor.tensor_shape.dim else: raise ValueError( "Weight layer {} is not of type VariableV2, VarHandleOp or Const: {}" .format(weights_layer_param_name, weights_node.op)) if len(shape_dim) != 4: raise ValueError( "Weight layer {} does not have rank 4. Instead, it has: {}".format( weights_layer_param_name, len(shape_dim))) logging.vlog(4, "weight shape = %s", shape_dim) kernel_size_y = shape_dim[0].size kernel_size_x = shape_dim[1].size return kernel_size_x, kernel_size_y
def prediction_input_fn(self, params): """Implementation of `input_fn` contract for prediction mode. Args: params: a dict containing an integer value for key 'batch_size'. Returns: the tuple (features, labels), where: - features is a dict of Tensor-valued input features; keys populated are: 'image' 'variant' 'alt_allele_indices' Aside from 'image', these may be encoded specially for TPU. """ def load_dataset(filename): dataset = tf.data.TFRecordDataset( filename, buffer_size=self.prefetch_dataset_buffer_size, compression_type=compression_type) return dataset batch_size = params['batch_size'] compression_type = tf_utils.compression_type_of_files(self.input_files) files = tf.data.Dataset.list_files( sharded_file_utils.normalize_to_sharded_file_pattern( self.input_file_spec), shuffle=False, ) logging.vlog( 3, 'self.input_read_threads={}'.format(self.input_read_threads)) dataset = files.apply( tf.data.experimental.parallel_interleave( load_dataset, cycle_length=self.input_read_threads, sloppy=self.sloppy)) logging.vlog( 3, 'self.input_map_threads={}'.format(self.input_map_threads)) dataset = dataset.apply( tf.data.experimental.map_and_batch( self.parse_tfexample, batch_size=batch_size, num_parallel_batches=self.input_map_threads)) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset
def save(self): """Save the agent parameters.""" if not self._should_save_checkpoints: return logging.vlog(1, 'PPO epoch [% 6d]: saving model.', self._epoch) ppo.save_opt_state( self._output_dir, self._policy_and_value_opt_state, self._model_state, self._epoch, self._total_opt_step, self._history, ) # Reset this number. self._n_trajectories_done = 0 self._last_saved_at = self._epoch
async def initialize(self) -> None: logging.vlog(1, 'Initializing CardDb.') if FLAGS.carddb_local_file: logging.info('Initializing CardDb from local file: %s', FLAGS.carddb_local_file) with open(FLAGS.carddb_local_file, 'r') as fin: db_json = fin.read() else: logging.info('Initializing CardDb from cloud file: %s/%s', CARDDB_BUCKET, CARDDB_DB_FILE) storage = Storage() bucket = storage.get_bucket(CARDDB_BUCKET) blob = await bucket.get_blob(CARDDB_DB_FILE) db_json = await blob.download() logging.info('Loaded cloud file.') await self._parse_db_json(db_json) self._is_initialized.set()
def train_policy(self): logging.info("SimPLe epoch [% 6d]: training policy.", self._simple_epoch) start_time = time.time() self._sim_env.initialize( batch_size=self._simulated_batch_size, history_stream=itertools.repeat(None), ) self._policy_trainer.train_env = self._sim_env # Don't dump trajectories from the simulated environment. self._policy_trainer.trajectory_dump_dir = None self._policy_epoch += self._n_simulated_epochs self._policy_trainer.training_loop(self._policy_epoch, evaluate=False) logging.vlog(1, "Training policy took %0.2f sec.", time.time() - start_time)
def _hash_xla_flags(hash_obj): xla_flags = [] xla_flags_env_var = os.getenv("XLA_FLAGS") if xla_flags_env_var: xla_flags.extend(xla_flags_env_var.split()) xla_flags.extend(arg for arg in sys.argv if arg.startswith("--xla_")) # N.B. all XLA flags that take an argument must use '=' and not a space # (e.g. --xla_force_host_platform_device_count=8) (I think). for flag in xla_flags: if flag.split('=')[0] in _xla_flags_to_exclude_from_cache_key: logging.vlog(1, "Not including XLA flag in cache key: %s", flag) continue logging.vlog(1, "Including XLA flag in cache key: %s", flag) _hash_string(hash_obj, flag)
def backends(): global _backends global _backends_errors global _default_backend with _backend_lock: if _backends is not None: return _backends default_priority = -1000 _backends = {} _backends_errors = {} for name, (factory, priority) in _backend_factories.items(): logging.vlog(1, "Initializing backend '%s'" % name) try: backend = factory() if backend is not None: if backend.device_count() > 0: _backends[name] = backend util.distributed_debug_log( ("Initialized backend", backend.platform), ("process_index", backend.process_index()), ("device_count", backend.device_count()), ("local_devices", backend.local_devices())) logging.vlog(1, "Backend '%s' initialized" % name) if priority > default_priority: _default_backend = backend default_priority = priority except Exception as err: if name in ('cpu', 'interpreter'): # We always expect the CPU and interpreter backends to initialize # successfully. raise else: # If the backend isn't built into the binary, or if it has no devices, # we expect a RuntimeError. logging.info("Unable to initialize backend '%s': %s" % (name, err)) _backends_errors[name] = str(err) continue if _default_backend.platform == "cpu" and FLAGS.jax_platform_name != 'cpu': logging.warning( 'No GPU/TPU found, falling back to CPU. ' '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)') return _backends
def evaluate(self): """Evaluate the agent.""" if not self._separate_eval: return logging.vlog(1, 'PPO epoch [% 6d]: evaluating policy.', self._epoch) if self._controller is not None: ntp_updates = self._controller(self._history)(self._epoch) self._nontrainable_params.update(ntp_updates) (_, _, opt_params) = self._policy_and_value_opt_state opt_params.update(ntp_updates) for (name, value) in self._nontrainable_params.items(): self._log('train', 'training/{}'.format(name), value) processed_reward_sums = collections.defaultdict(list) raw_reward_sums = collections.defaultdict(list) for _ in range(self._n_evals): for temperature in self._eval_temperatures: trajs, _, _, self._model_state = self.collect_trajectories( train=False, temperature=temperature) processed_reward_sums[temperature].extend( sum(traj[2]) for traj in trajs) raw_reward_sums[temperature].extend( sum(traj[3]) for traj in trajs) # Return the mean and standard deviation for each temperature. def compute_stats(reward_dict): return { temperature: { # pylint: disable=g-complex-comprehension 'mean': onp.mean(rewards), 'std': onp.std(rewards) } for (temperature, rewards) in reward_dict.items() } reward_stats = { 'processed': compute_stats(processed_reward_sums), 'raw': compute_stats(raw_reward_sums), } ppo.write_eval_reward_summaries(reward_stats, self._log, epoch=self._epoch)
async def decklist(ctx: Context, url: str, mode: str = 'compact') -> None: if url.startswith('<') and url.endswith('>'): url = url[1:-1] logging.info('Looking up decklist for: %s', url) handler = decklist_handlers.lookup(url) try: decklist = await handler(ctx, url) except requests.RequestException: logging.exception('RequestException during decklist handler.') decklist = None if decklist: logging.info('Found decklist named: %s', decklist.name) if logging.vlog_is_on(1): logging.vlog(1, 'Decklist contents: %s', pprint.pformat(decklist.to_embed().to_dict())) await ctx.send(embed=decklist.to_embed(mode == 'flat')) else: logging.info('No decklist found for: %s', url)
def barrier_wait(): """Blocks the calling thread until all current outfeed is processed. Waits until all outfeed from computations already running on all devices has been received and processed by the Python callbacks. Raises TapFunctionException if there were exceptions while processing the callbacks. This works by enqueueing a special tap computation to all devices to which we are listening for outfeed. Once all those tap computations are done, we return from barrier_wait. Note: If any of the devices are busy and cannot accept new computations, this will deadlock. """ logging.vlog(2, "barrier_wait: start") if not _outfeed_receiver.receiver: logging.vlog(2, "barrier_wait: receiver not started") return lock = threading.Lock() cv = threading.Condition(lock=lock) num_at_large = len(_outfeed_receiver.devices) # Protected by lock def barrier_tap(dev_idx): nonlocal num_at_large logging.vlog( 2, f"barrier_wait: thread {threading.current_thread()} for " f"device {_outfeed_receiver.devices[dev_idx]} at barrier_tap") with lock: num_at_large -= 1 cv.notify() for d_idx, d in enumerate(_outfeed_receiver.devices): logging.vlog(2, f"barrier_wait: enqueueing barrier on device {d}") x_on_dev = api.device_put(d_idx, device=d) api.jit(lambda x: id_tap(barrier_tap, x), device=d)(x_on_dev) logging.vlog(2, "barrier_wait: waiting for calblacks") with lock: cv.wait_for(lambda: num_at_large == 0) logging.vlog(2, "Done barrier_wait") if _outfeed_receiver.num_tap_exceptions > 0: _outfeed_receiver.num_tap_exceptions = 0 raise TapFunctionException( "There were exceptions during id_tap processing.")
def _visit_Function( self, node: ast.Function, symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function: self.symbolic_bindings = {} if symbolic_bindings is None else dict( symbolic_bindings) self._extract_module_level_constants(self.module) # We use a function builder for the duration of converting this # ast.Function. When it's done being built, we drop the reference to it (by # setting self.fb to None). self.fb = function_builder.FunctionBuilder( mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(), self.module, symbolic_bindings), self.package) try: for param in node.params: param.accept(self) for parametric_binding in node.parametric_bindings: logging.vlog(4, 'Resolving parametric binding %s', parametric_binding) sb_value = self.symbolic_bindings[parametric_binding.name.identifier] value = self._resolve_dim(sb_value) assert isinstance(value, int), \ 'Expect integral parametric binding; got {!r}'.format(value) self._def_const( parametric_binding, value, self._resolve_type(parametric_binding.type_).get_total_bit_count()) self._def_alias(parametric_binding, to=parametric_binding.name) for dep in self._constant_deps: dep.accept(self) del self._constant_deps[:] node.body.accept(self) last_expression = self.last_expression or node.body if isinstance(last_expression, ast.NameRef): self._def(last_expression, self.fb.add_identity, self._use(last_expression)) f = self.fb.build() logging.vlog(3, 'Built function: %s', f.name) verifier_mod.verify_function(f) return f finally: self.fb = None
def make_conv_sep2d_layer(input_node, in_channels, channel_multiplier, out_channels, layer_name, filter_size, filter_size_2=None, batch_norm=False, is_training=True, atrou_rate=1, data_format='NHWC', stddev=0.01): """Use separable convolutions.""" if filter_size_2 is None: filter_size_2 = filter_size logging.vlog(1, 'layer %s in %d out %d chan mult %d', layer_name, in_channels, out_channels, channel_multiplier) with tf.variable_scope(layer_name): with tf.variable_scope('depthwise'): w_depthwise = weight_variable( [filter_size, filter_size_2, in_channels, channel_multiplier], stddev=stddev) with tf.variable_scope('pointwise'): w_pointwise = weight_variable( [1, 1, in_channels * channel_multiplier, out_channels], stddev=stddev) h_conv = tf.nn.separable_conv2d(input_node, w_depthwise, w_pointwise, padding='SAME', strides=[1, 1, 1, 1], rate=[atrou_rate, atrou_rate], data_format=data_format) if batch_norm: h_conv = batch_norm_layer(h_conv, layer_name=layer_name, is_training=is_training, data_format=data_format) else: b_conv = bias_variable([out_channels]) h_conv = tf.nn.bias_add(h_conv, b_conv, data_format=data_format) return h_conv