Пример #1
0
 def __init__(self, creator: TensorCreator,
              params: DatasetSimplePointGravityParams, random: np.random):
     super().__init__(creator.device)
     self._random = random
     self.point_pos = params.point_pos
     self.attractor_distance = params.attractor_distance
     self.output_data = creator.zeros(*params.canvas_shape,
                                      device=self._device)
     self._state = States.BLANK
     self.move_strategy = params.move_strategy
     self._frame_backbuffer = creator.zeros_like(self.output_data)
Пример #2
0
    def __init__(self,
                 creator: TensorCreator,
                 params: DatasetAlphabetParams,
                 random: Optional[RandomState] = None):
        super().__init__(creator.device)
        self._validate_params(params)

        random = random or np.random.RandomState()

        # Generate all symbols
        generator = AlphabetGenerator(params.padding_right)
        all_symbols = generator.create_symbols(params.symbols)
        self.all_symbols = creator.zeros_like(all_symbols)
        self.all_symbols.copy_(all_symbols.to(creator.device))

        # Create output tensors
        shape = list(self.all_symbols.shape)
        self.output_data = creator.zeros(shape[1:], device=creator.device)
        self.output_label = creator.zeros(1,
                                          dtype=torch.int64,
                                          device=creator.device)
        self.output_sequence_id = creator.zeros(1,
                                                dtype=torch.int64,
                                                device=creator.device)
        self.output_sequence_id_one_hot = creator.zeros(
            (1, len(params.sequence_probs.seqs)),
            dtype=self._float_dtype,
            device=creator.device)

        if params.mode == DatasetAlphabetMode.SEQUENCE_PROBS:
            seqs = [
                self.convert_string_to_positions(params.symbols, seq)
                for seq in params.sequence_probs.seqs
            ]
            transition_probs = params.sequence_probs.transition_probs or SequenceGenerator.default_transition_probs(
                seqs)
            self.seq = SequenceGenerator(seqs, transition_probs, random=random)
            self._current = next(self.seq)

        self._n_symbols = shape[0]