示例#1
0
    def __init__(self, iter, std_dict, mean_dict=None):
        """
        Args:
            iter (DataIterator):
                Any DataIterator which iterates over data that noise should be
                added to.
            std_dict (dict[str, float]):
                Specifies the standard deviation of the noise that should be
                added for some of the named data items.
            mean_dict (Optional(dict[str, float])):
                Specifies the mean of the gaussian noise that should be
                added for some of the named data items.
                Defaults to None meaning all means are treated as 0.
        """
        DataIterator.__init__(self, iter.data_shapes, iter.length)
        mean_keys = set(mean_dict.keys()) if mean_dict is not None else set()
        std_keys = set(std_dict.keys())
        if mean_dict is not None and mean_keys != std_keys:
            raise IteratorValidationError(
                "means and standard deviations must be provided for the same "
                "data names. But {} != {}".format(mean_keys, std_keys))
        for key in std_keys:
            if key not in iter.data_shapes:
                raise IteratorValidationError(
                    "key {} is not present in iterator. Available keys: {"
                    "}".format(key, iter.data_shapes.keys()))

        self.mean_dict = {} if mean_dict is None else mean_dict
        self.std_dict = std_dict
        self.iter = iter
示例#2
0
    def __init__(self, iter, prob_dict, ratio_dict=None):
        """
        Args:
            iter (DataIterator):
                Any DataIterator which iterates over data that noise should be
                added to.
            prob_dict (dict[str, float]):
                Specifies the probability that an input is affected for some of
                the named data items. Omitted data items are treated as having
                an amount of 0.
            ratio_dict (Optional(dict[str, float])):
                Specifies the ratio of salt of all corrupted inputs.
                Defaults to None meaning the ratio is treated as 0.5.
        """
        DataIterator.__init__(self, iter.data_shapes, iter.length)
        ratio_keys = set() if ratio_dict is None else set(ratio_dict.keys())
        prob_keys = set(prob_dict.keys())
        if ratio_dict is not None and ratio_keys != prob_keys:
            raise IteratorValidationError(
                "probabilities and ratios must be provided for the "
                "same data names. But {} != {}".format(prob_keys, ratio_keys))
        for key in prob_keys:
            if key not in iter.data_shapes:
                raise IteratorValidationError(
                    "key {} is not present in iterator. Available keys: {"
                    "}".format(key, iter.data_shapes.keys()))

        self.ratio_dict = {} if ratio_dict is None else ratio_dict
        self.prob_dict = prob_dict
        self.iter = iter
示例#3
0
 def __init__(self, iter, shape_dict):
     """
     Args:
         iter (DataIterator):
             A DataIterator which iterates over data to be cropped.
         shape_dict (dict[str, (int, int)]):
             Specifies the crop shapes for some named data items.
     """
     super(RandomCrop, self).__init__(iter.data_shapes, iter.length)
     for key, val in shape_dict.items():
         if key not in iter.data_shapes:
             raise IteratorValidationError(
                 "key {} is not present in iterator. Available keys: {"
                 "}".format(key, iter.data_shapes.keys()))
         if not (isinstance(val, tuple) and len(val) == 2):
             raise IteratorValidationError("Shape must be a size 2 tuple")
         data_shape = iter.data_shapes[key]
         if len(data_shape) != 5:
             raise IteratorValidationError("Only 5D data is supported")
         if val[0] > data_shape[2] or val[0] < 0:
             raise IteratorValidationError("Invalid crop height")
         if val[1] > data_shape[3] or val[1] < 0:
             raise IteratorValidationError("Invalid crop width")
     self.shape_dict = shape_dict
     self.iter = iter
示例#4
0
 def __init__(self, iter, size_dict, value_dict=None):
     """
     Args:
         iter (DataIterator):
             A DataIterator which iterates over the images to be padded.
         size_dict (dict[str, int]):
             Specifies the padding sizes for some named data items.
         value_dict (dict[str, int]):
             Specifies the pad values for some named data items.
     """
     super(Pad, self).__init__(iter.data_shapes, iter.length)
     if value_dict is not None:
         if set(size_dict.keys()) != set(value_dict.keys()):
             raise IteratorValidationError(
                 "padding sizes and values must be provided for the same "
                 "data names")
     for key in size_dict.keys():
         if key not in iter.data_shapes:
             raise IteratorValidationError(
                 "key {} is not present in iterator. Available keys: {"
                 "}".format(key, iter.data_shapes.keys()))
         if len(iter.data_shapes[key]) != 5:
             raise IteratorValidationError("Only 5D data is supported")
     self.value_dict = {} if value_dict is None else value_dict
     self.size_dict = size_dict
     self.iter = iter
示例#5
0
 def __init__(self, iter, vocab_size_dict):
     """
     Args:
         iter (DataIterator):
             DataIterator which iterates over the images to be padded.
         vocab_size_dict (dict[str, int]):
             Specifies the size of one hot vectors (the vocabulary size)
             for some named data items.
     """
     DataIterator.__init__(self, iter.data_shapes, iter.length)
     for key in vocab_size_dict.keys():
         if key not in iter.data_shapes:
             raise IteratorValidationError(
                 "key {} is not present in iterator. Available keys: {"
                 "}".format(key, iter.data_shapes.keys()))
         if not isinstance(vocab_size_dict[key], int):
             raise IteratorValidationError("Vocabulary size must be int")
         shape = iter.data_shapes[key]
         if not (shape[-1] == 1 and len(shape) == 3):
             raise IteratorValidationError("Only 3D data is supported")
     self.vocab_size_dict = vocab_size_dict
     self.iter = iter
示例#6
0
 def __init__(self, iter, prob_dict=None):
     """
     Args:
         iter (DataIterator):
             Any DataIterator which iterates over data to be flipped.
         prob_dict (dict[str, float]):
             Specifies the probability of flipping for some named
             data items.
     """
     Seedable.__init__(self)
     super(Flip, self).__init__(iter.data_shapes, iter.length)
     prob_dict = {'default': 0.5} if prob_dict is None else prob_dict
     for key in prob_dict.keys():
         if key not in iter.data_shapes:
             raise IteratorValidationError(
                 "key {} is not present in iterator. Available keys: {"
                 "}".format(key, iter.data_shapes.keys()))
         if prob_dict[key] > 1.0 or prob_dict[key] < 0.0:
             raise IteratorValidationError("Invalid probability")
         if len(iter.data_shapes[key]) != 5:
             raise IteratorValidationError("Only 5D data is supported")
     self.prob_dict = prob_dict
     self.iter = iter
示例#7
0
def _assert_correct_data_format(named_data):
    nr_sequences = {}
    nr_timesteps = {}
    for name, data in named_data.items():
        if not hasattr(data, 'shape'):
            raise IteratorValidationError(
                "{} has a wrong type. (no shape attribute)".format(name))
        if len(data.shape) < 3:
            raise IteratorValidationError(
                'All inputs have to have at least 3 dimensions, where the '
                'first two are time_size and batch_size.')
        nr_sequences[name] = data.shape[1]
        nr_timesteps[name] = data.shape[0]

    if min(nr_sequences.values()) != max(nr_sequences.values()):
        raise IteratorValidationError(
            'The number of sequences of all inputs must be equal, but got {}'.
            format(nr_sequences))
    if min(nr_timesteps.values()) != max(nr_timesteps.values()):
        raise IteratorValidationError(
            'The number of time steps of all inputs must be equal, '
            'but got {}'.format(nr_timesteps))

    return int(min(nr_sequences.values())), min(nr_timesteps.values())