コード例 #1
0
 def compute_mask(self, inputs, mask=None):
   if mask is None:
     return None
   if not isinstance(mask, list):
     raise ValueError('`mask` should be a list.')
   if not isinstance(inputs, list):
     raise ValueError('`inputs` should be a list.')
   if len(mask) != len(inputs):
     raise ValueError('The lists `inputs` and `mask` '
                      'should have the same length.')
   if all([m is None for m in mask]):
     return None
   # Make a list of masks while making sure
   # the dimensionality of each mask
   # is the same as the corresponding input.
   masks = []
   for input_i, mask_i in zip(inputs, mask):
     if mask_i is None:
       # Input is unmasked. Append all 1s to masks,
       # but cast it to bool first
       masks.append(K.cast(K.ones_like(input_i), 'bool'))
     elif K.ndim(mask_i) < K.ndim(input_i):
       # Mask is smaller than the input, expand it
       masks.append(K.expand_dims(mask_i))
     else:
       masks.append(mask_i)
   concatenated = K.concatenate(masks, axis=self.axis)
   return K.all(concatenated, axis=-1, keepdims=False)
コード例 #2
0
 def compute_mask(self, inputs, mask=None):
   if mask is None:
     return None
   if not isinstance(mask, list):
     raise ValueError('`mask` should be a list.')
   if not isinstance(inputs, list):
     raise ValueError('`inputs` should be a list.')
   if len(mask) != len(inputs):
     raise ValueError('The lists `inputs` and `mask` '
                      'should have the same length.')
   if all([m is None for m in mask]):
     return None
   masks = [K.expand_dims(m, 0) for m in mask if m is not None]
   return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)