def _batch_shape_tensor(self, logits_or_probs=None, total_count=None):
     if logits_or_probs is None:
         logits_or_probs = self._logits if self._probs is None else self._logits
     total_count = self._total_count if total_count is None else total_count
     return prefer_static.broadcast_shape(
         prefer_static.shape(logits_or_probs),
         prefer_static.shape(total_count))
Exemplo n.º 2
0
 def _batch_shape_tensor(self, temperature=None, logits=None):
     param = logits
     if param is None:
         param = self._logits if self._logits is not None else self._probs
     if temperature is None:
         temperature = self.temperature
     return prefer_static.broadcast_shape(prefer_static.shape(temperature),
                                          prefer_static.shape(param)[:-1])
Exemplo n.º 3
0
 def _batch_shape_tensor(self, concentration1=None, concentration0=None):
     return prefer_static.broadcast_shape(
         prefer_static.shape(self.concentration1
                             if concentration1 is None else concentration1),
         prefer_static.shape(self.concentration0
                             if concentration0 is None else concentration0))
Exemplo n.º 4
0
 def _batch_shape_tensor(self, loc=None, scale=None):
   return prefer_static.broadcast_shape(
       prefer_static.shape(self.loc if loc is None else loc),
       prefer_static.shape(self.scale if scale is None else scale))
Exemplo n.º 5
0
 def _batch_shape_tensor(self, concentration=None, rate=None):
   return prefer_static.broadcast_shape(
       prefer_static.shape(
           self.concentration if concentration is None else concentration),
       prefer_static.shape(self.rate if rate is None else rate))