コード例 #1
0
  def __init__(self):
    """Initialize stage policy."""
    self._cur_train_dataset = None
    self._cur_eval_dataset = None
    self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)

    stage_id = 0
    self._stage_id = tf.Variable(
        stage_id,
        trainable=False,
        dtype=tf.int64,
        aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        shape=[])
    self._volatiles.reassign_trackable(
        optimizer=self.get_optimizer(stage_id),
        model=self.get_model(stage_id, old_model=None))
コード例 #2
0
ファイル: policies.py プロジェクト: xiangww00/models
    def __init__(self):
        """Initialize stage policy."""
        self._cur_train_dataset = None
        self._cur_eval_dataset = None
        self._volatiles = utils.VolatileTrackable(optimizer=None, model=None)

        stage_id = 0
        self._stage_id = tf.Variable(
            stage_id,
            trainable=False,
            dtype=tf.int64,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
            shape=[])
        self._volatiles.reassign_trackable(
            optimizer=self.get_optimizer(stage_id),
            model=self.get_model(stage_id, old_model=None))  # pytype: disable=wrong-arg-types  # typed-keras

        streamz_counters.progressive_policy_creation_counter.get_cell(
        ).increase_by(1)