def create_variables( network: snt.Module, input_spec: List[OLT], ) -> Optional[tf.TensorSpec]: """Builds the network with dummy inputs to create the necessary variables. Args: network: Sonnet Module whose variables are to be created. input_spec: list of input specs to the network. The length of this list should match the number of arguments expected by `network`. Returns: output_spec: only returns an output spec if the output is a tf.Tensor, else it doesn't return anything (None); e.g. if the output is a tfp.distributions.Distribution. """ # Create a dummy observation with no batch dimension. dummy_input = [ OLT( observation=zeros_like(in_spec.observation), legal_actions=ones_like(in_spec.legal_actions), terminal=zeros_like(in_spec.terminal), ) for in_spec in input_spec ] # If we have an RNNCore the hidden state will be an additional input. if isinstance(network, snt.RNNCore): initial_state = squeeze_batch_dim(network.initial_state(1)) dummy_input += [initial_state] # Forward pass of the network which will create variables as a side effect. dummy_output = network(*add_batch_dim(dummy_input)) # Evaluate the input signature by converting the dummy input into a # TensorSpec. We then save the signature as a property of the network. This is # done so that we can later use it when creating snapshots. We do this here # because the snapshot code may not have access to the precise form of the # inputs. input_signature = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input) network._input_signature = input_signature # pylint: disable=protected-access def spec(output: tf.Tensor) -> tf.TensorSpec: # If the output is not a Tensor, return None as spec is ill-defined. if not isinstance(output, tf.Tensor): return None # If this is not a scalar Tensor, make sure to squeeze out the batch dim. if tf.rank(output) > 0: output = squeeze_batch_dim(output) return tf.TensorSpec(output.shape, output.dtype) return tree.map_structure(spec, dummy_output)
def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]: """Get module input signature. Works even if the module with signature is wrapper into snt.Sequentual or snt.DeepRNN. Args: module: the module which input signature we need to get. The module has to either have input_signature itself (i.e. you have to run create_variables on the module), or it has to be a module (with input_signature) wrapped in (one or multiple) snt.Sequential or snt.DeepRNNs. Returns: Input signature of the module or None if it's not available. """ if hasattr(module, '_input_signature'): return module._input_signature # pylint: disable=protected-access if isinstance(module, snt.Sequential): first_layer = module._layers[0] # pylint: disable=protected-access return _get_input_signature(first_layer) if isinstance(module, snt.DeepRNN): first_layer = module._layers[0] # pylint: disable=protected-access input_signature = _get_input_signature(first_layer) # Wrapping a module in DeepRNN changes its state shape, so we need to bring # it up to date. state = module.initial_state(1) input_signature[-1] = tree.map_structure( lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state) return input_signature return None
def __init__(self, value_func: snt.Module, instrumental_feature: snt.Module, policy_net: snt.Module, discount: float, value_learning_rate: float, instrumental_learning_rate: float, value_reg: float, instrumental_reg: float, stage1_reg: float, stage2_reg: float, instrumental_iter: int, value_iter: int, dataset: tf.data.Dataset, d_tm1_weight: float = 1.0, counter: counting.Counter = None, logger: loggers.Logger = None, checkpoint: bool = True, checkpoint_interval_minutes: int = 10.0): """Initializes the learner. Args: value_func: value function network instrumental_feature: dual function network. policy_net: policy network. discount: global discount. value_learning_rate: learning rate for the treatment_net update. instrumental_learning_rate: learning rate for the instrumental_net update. value_reg: L2 regularizer for value net. instrumental_reg: L2 regularizer for instrumental net. stage1_reg: ridge regularizer for stage 1 regression stage2_reg: ridge regularizer for stage 2 regression instrumental_iter: number of iteration for instrumental net value_iter: number of iteration for value function, dataset: dataset to learn from. d_tm1_weight: weights for terminal state transitions. Ignored in this variant. counter: Counter object for (potentially distributed) counting. logger: Logger object for writing logs to. checkpoint: boolean indicating whether to checkpoint the learner. checkpoint_interval_minutes: checkpoint interval in minutes. """ self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) self.stage1_reg = stage1_reg self.stage2_reg = stage2_reg self.instrumental_iter = instrumental_iter self.value_iter = value_iter self.discount = discount self.value_reg = value_reg self.instrumental_reg = instrumental_reg del d_tm1_weight # Get an iterator over the dataset. self._iterator = iter(dataset) # pytype: disable=wrong-arg-types self.value_func = value_func self.value_feature = value_func._feature self.instrumental_feature = instrumental_feature self.policy = policy_net self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate, beta1=0.5, beta2=0.9) self._instrumental_func_optimizer = snt.optimizers.Adam( instrumental_learning_rate, beta1=0.5, beta2=0.9) # Define additional variables. self.stage1_weight = tf.Variable( tf.zeros( (instrumental_feature.feature_dim(), value_func.feature_dim()), dtype=tf.float32)) self._num_steps = tf.Variable(0, dtype=tf.int32) self._variables = [ self.value_func.trainable_variables, self.instrumental_feature.trainable_variables, self.stage1_weight, ] # Create a checkpointer object. self._checkpointer = None self._snapshotter = None if checkpoint: self._checkpointer = tf2_savers.Checkpointer( objects_to_save=self.state, time_delta_minutes=checkpoint_interval_minutes, checkpoint_ttl_seconds=_CHECKPOINT_TTL) self._snapshotter = tf2_savers.Snapshotter(objects_to_save={ 'value_func': self.value_func, 'instrumental_feature': self.instrumental_feature, }, time_delta_minutes=60.)