Esempio n. 1
0
    def validate(self):
        """Checks if all inputs have valid shapes.

        Rewards have to be either one float which will be translated into a tuple (positive, negative) reward and used
        for all experts, or separate tuples one for each expert in either (flock_size, 2) or (flock_shape, 2).
        Context input should be either
        (flock_size, n_providers, NUMBER_OF_CONTEXT_TYPES, self.params.temporal.incoming_context_size) or
        (flock_shape, n_providers, NUMBER_OF_CONTEXT_TYPES, self.params.temporal.incoming_context_size).
        """

        flock_shape = derive_flock_shape(self.inputs.sp.data_input.tensor.shape,
                                         self.params.flock_size)

        reward_input = self.inputs.tp.reward_input.tensor
        if reward_input is not None:
            expected_shapes = [(1,), (self.params.flock_size, 2), flock_shape + (2,), (2,)]
            if reward_input.size() not in expected_shapes:
                raise NodeValidationException(f"Reward input has unexpected shape {reward_input.size()}, "
                                              f"expected one of {expected_shapes}.")

        # Validate context_input shape
        context_input = self.inputs.tp.context_input.tensor
        if context_input is not None:
            matcher_pattern = (
                TensorShapePatternMatcher.Sum(self.params.flock_size),
                TensorShapePatternMatcher.Sum(self.params.temporal.n_providers, greedy=True),
                TensorShapePatternMatcher.Exact((NUMBER_OF_CONTEXT_TYPES,)),
                TensorShapePatternMatcher.Sum(self.params.temporal.incoming_context_size, greedy=True)
            )
            matcher = TensorShapePatternMatcher(matcher_pattern)
            if not matcher.matches(context_input.shape):
                pattern_str = ", ".join(map(str, matcher_pattern))
                raise NodeValidationException(
                    f"Context input has unexpected shape {list(context_input.shape)}, "
                    f"expected pattern: [{pattern_str}]")
Esempio n. 2
0
 def validate(self):
     if self._params.mode == DatasetAlphabetMode.SEQUENCE_PROBS:
         symbols = self._params.symbols
         for seq in self._params.sequence_probs.seqs:
             for s in seq:
                 if symbols.find(s) == -1:
                     raise NodeValidationException(
                         f'Symbol "{s}" in sequence "{seq}" was not found in all symbols "{symbols}". '
                         f'Remove it from the sequence or add id to symbols.')
Esempio n. 3
0
 def validate(self):
     outputs_numel = [
         output.tensor.shape[self._dim] for output in self.outputs
     ]
     if self.inputs.input.tensor.shape[self._dim] != sum(outputs_numel):
         raise NodeValidationException(
             f"Input tensor selected dim {self.inputs.input.tensor.shape[self._dim]} "
             f"is not corresponding with output dims "
             f"{' + '.join(map(str, outputs_numel))} = "
             f"{sum(outputs_numel)}.")
Esempio n. 4
0
 def validate(self):
     if self.fails_validation:
         raise NodeValidationException('Node failed to validate')
Esempio n. 5
0
 def validate(self):
     if not 0 < self._params.location_filter_ratio <= 1.0:
         raise NodeValidationException(
             f'validation error: location_filter_ratio expected to be from (0,1>'
         )
Esempio n. 6
0
 def validate(self):
     if tuple(self.inputs.input.tensor.shape) != tuple(self._output_shape):
         raise NodeValidationException(
             f"In PassNode, input {self.inputs.input.tensor.shape} must equal declared "
             f"output {self._output_shape}.")
Esempio n. 7
0
 def validate(self):
     if self.inputs.input is not None:
         if tuple(self.inputs.input.tensor.shape)[-1] != 3:
             raise NodeValidationException(
                 f"In GrayscaleNode, the RGB input shape {self.inputs.input.tensor.shape} "
                 f" must be end with 3 (the RGB channels).")