示例#1
0
文件: tf_utils.py 项目: NetColby/DNRL
def create_variables(
    network: snt.Module,
    input_spec: List[OLT],
) -> Optional[tf.TensorSpec]:
    """Builds the network with dummy inputs to create the necessary variables.
    Args:
      network: Sonnet Module whose variables are to be created.
      input_spec: list of input specs to the network. The length of this list
        should match the number of arguments expected by `network`.
    Returns:
      output_spec: only returns an output spec if the output is a tf.Tensor, else
          it doesn't return anything (None); e.g. if the output is a
          tfp.distributions.Distribution.
    """
    # Create a dummy observation with no batch dimension.
    dummy_input = [
        OLT(
            observation=zeros_like(in_spec.observation),
            legal_actions=ones_like(in_spec.legal_actions),
            terminal=zeros_like(in_spec.terminal),
        ) for in_spec in input_spec
    ]

    # If we have an RNNCore the hidden state will be an additional input.
    if isinstance(network, snt.RNNCore):
        initial_state = squeeze_batch_dim(network.initial_state(1))
        dummy_input += [initial_state]

    # Forward pass of the network which will create variables as a side effect.
    dummy_output = network(*add_batch_dim(dummy_input))

    # Evaluate the input signature by converting the dummy input into a
    # TensorSpec. We then save the signature as a property of the network. This is
    # done so that we can later use it when creating snapshots. We do this here
    # because the snapshot code may not have access to the precise form of the
    # inputs.
    input_signature = tree.map_structure(
        lambda t: tf.TensorSpec((None, ) + t.shape, t.dtype), dummy_input)
    network._input_signature = input_signature  # pylint: disable=protected-access

    def spec(output: tf.Tensor) -> tf.TensorSpec:
        # If the output is not a Tensor, return None as spec is ill-defined.
        if not isinstance(output, tf.Tensor):
            return None
        # If this is not a scalar Tensor, make sure to squeeze out the batch dim.
        if tf.rank(output) > 0:
            output = squeeze_batch_dim(output)
        return tf.TensorSpec(output.shape, output.dtype)

    return tree.map_structure(spec, dummy_output)
示例#2
0
def _get_input_signature(module: snt.Module) -> Optional[tf.TensorSpec]:
    """Get module input signature.

  Works even if the module with signature is wrapper into snt.Sequentual or
  snt.DeepRNN.

  Args:
    module: the module which input signature we need to get. The module has to
      either have input_signature itself (i.e. you have to run create_variables
      on the module), or it has to be a module (with input_signature) wrapped in
      (one or multiple) snt.Sequential or snt.DeepRNNs.

  Returns:
    Input signature of the module or None if it's not available.
  """
    if hasattr(module, '_input_signature'):
        return module._input_signature  # pylint: disable=protected-access

    if isinstance(module, snt.Sequential):
        first_layer = module._layers[0]  # pylint: disable=protected-access
        return _get_input_signature(first_layer)

    if isinstance(module, snt.DeepRNN):
        first_layer = module._layers[0]  # pylint: disable=protected-access
        input_signature = _get_input_signature(first_layer)

        # Wrapping a module in DeepRNN changes its state shape, so we need to bring
        # it up to date.
        state = module.initial_state(1)
        input_signature[-1] = tree.map_structure(
            lambda t: tf.TensorSpec((None, ) + t.shape[1:], t.dtype), state)

        return input_signature

    return None
示例#3
0
    def __init__(self,
                 value_func: snt.Module,
                 instrumental_feature: snt.Module,
                 policy_net: snt.Module,
                 discount: float,
                 value_learning_rate: float,
                 instrumental_learning_rate: float,
                 value_reg: float,
                 instrumental_reg: float,
                 stage1_reg: float,
                 stage2_reg: float,
                 instrumental_iter: int,
                 value_iter: int,
                 dataset: tf.data.Dataset,
                 d_tm1_weight: float = 1.0,
                 counter: counting.Counter = None,
                 logger: loggers.Logger = None,
                 checkpoint: bool = True,
                 checkpoint_interval_minutes: int = 10.0):
        """Initializes the learner.

        Args:
          value_func: value function network
          instrumental_feature: dual function network.
          policy_net: policy network.
          discount: global discount.
          value_learning_rate: learning rate for the treatment_net update.
          instrumental_learning_rate: learning rate for the instrumental_net update.
          value_reg: L2 regularizer for value net.
          instrumental_reg: L2 regularizer for instrumental net.
          stage1_reg: ridge regularizer for stage 1 regression
          stage2_reg: ridge regularizer for stage 2 regression
          instrumental_iter: number of iteration for instrumental net
          value_iter: number of iteration for value function,
          dataset: dataset to learn from.
          d_tm1_weight: weights for terminal state transitions. Ignored in this variant.
          counter: Counter object for (potentially distributed) counting.
          logger: Logger object for writing logs to.
          checkpoint: boolean indicating whether to checkpoint the learner.
          checkpoint_interval_minutes: checkpoint interval in minutes.
        """

        self._counter = counter or counting.Counter()
        self._logger = logger or loggers.TerminalLogger('learner',
                                                        time_delta=1.)

        self.stage1_reg = stage1_reg
        self.stage2_reg = stage2_reg
        self.instrumental_iter = instrumental_iter
        self.value_iter = value_iter
        self.discount = discount
        self.value_reg = value_reg
        self.instrumental_reg = instrumental_reg
        del d_tm1_weight

        # Get an iterator over the dataset.
        self._iterator = iter(dataset)  # pytype: disable=wrong-arg-types

        self.value_func = value_func
        self.value_feature = value_func._feature
        self.instrumental_feature = instrumental_feature
        self.policy = policy_net
        self._value_func_optimizer = snt.optimizers.Adam(value_learning_rate,
                                                         beta1=0.5,
                                                         beta2=0.9)
        self._instrumental_func_optimizer = snt.optimizers.Adam(
            instrumental_learning_rate, beta1=0.5, beta2=0.9)

        # Define additional variables.
        self.stage1_weight = tf.Variable(
            tf.zeros(
                (instrumental_feature.feature_dim(), value_func.feature_dim()),
                dtype=tf.float32))
        self._num_steps = tf.Variable(0, dtype=tf.int32)

        self._variables = [
            self.value_func.trainable_variables,
            self.instrumental_feature.trainable_variables,
            self.stage1_weight,
        ]

        # Create a checkpointer object.
        self._checkpointer = None
        self._snapshotter = None

        if checkpoint:
            self._checkpointer = tf2_savers.Checkpointer(
                objects_to_save=self.state,
                time_delta_minutes=checkpoint_interval_minutes,
                checkpoint_ttl_seconds=_CHECKPOINT_TTL)
            self._snapshotter = tf2_savers.Snapshotter(objects_to_save={
                'value_func':
                self.value_func,
                'instrumental_feature':
                self.instrumental_feature,
            },
                                                       time_delta_minutes=60.)