Beispiel #1
0
    def test_validate_input_duplicate_removal(self):
        validated = TableWaveform._validate_input([TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()),
                                                   TableWaveformEntry(0.1, 0.2, LinearInterpolationStrategy()),
                                                   TableWaveformEntry(0.1, 0.3, JumpInterpolationStrategy()),
                                                   TableWaveformEntry(0.1, 0.3, HoldInterpolationStrategy()),
                                                   TableWaveformEntry(0.2, 0.3, LinearInterpolationStrategy()),
                                                   TableWaveformEntry(0.3, 0.3, JumpInterpolationStrategy())])

        self.assertEqual(validated, (TableWaveformEntry(0.0, 0.2, HoldInterpolationStrategy()),
                                     TableWaveformEntry(0.1, 0.2, LinearInterpolationStrategy()),
                                     TableWaveformEntry(0.1, 0.3, HoldInterpolationStrategy()),
                                     TableWaveformEntry(0.3, 0.3, JumpInterpolationStrategy())))
    def test_build_waveform_time_type(self):
        from qupulse.utils.types import TimeType

        table = TablePulseTemplate({0: [(0, 0),
                                        ('foo', 'v', 'linear'),
                                        ('bar', 0, 'jump')]},
                                   parameter_constraints=['foo>1'],
                                   measurements=[('M', 'b', 'l'),
                                                 ('N', 1, 2)])

        parameters = {'v': 2.3,
                      'foo': TimeType.from_float(1.), 'bar': TimeType.from_float(4),
                      'b': TimeType.from_float(2), 'l': TimeType.from_float(1)}
        channel_mapping = {0: 'ch'}

        with self.assertRaises(ParameterConstraintViolation):
            table.build_waveform(parameters=parameters,
                                 channel_mapping=channel_mapping)

        parameters['foo'] = TimeType.from_float(1.1)
        waveform = table.build_waveform(parameters=parameters,
                                        channel_mapping=channel_mapping)

        self.assertIsInstance(waveform, TableWaveform)
        self.assertEqual(waveform._table,
                         ((0, 0, HoldInterpolationStrategy()),
                          (TimeType.from_float(1.1), 2.3, LinearInterpolationStrategy()),
                          (4, 0, JumpInterpolationStrategy())))
        self.assertEqual(waveform._channel_id,
                         'ch')
    def test_build_waveform_single_channel(self):
        table = TablePulseTemplate({0: [(0, 0),
                                        ('foo', 'v', 'linear'),
                                        ('bar', 0, 'jump')]},
                                   parameter_constraints=['foo>1'],
                                   measurements=[('M', 'b', 'l'),
                                                 ('N', 1, 2)])

        parameters = {'v': 2.3, 'foo': 1, 'bar': 4, 'b': 2, 'l': 1}
        channel_mapping = {0: 'ch'}

        with self.assertRaises(ParameterConstraintViolation):
            table.build_waveform(parameters=parameters,
                                 channel_mapping=channel_mapping)

        parameters['foo'] = 1.1
        waveform = table.build_waveform(parameters=parameters,
                                        channel_mapping=channel_mapping)

        self.assertIsInstance(waveform, TableWaveform)
        self.assertEqual(waveform._table,
                         ((0, 0, HoldInterpolationStrategy()),
                          (1.1, 2.3, LinearInterpolationStrategy()),
                          (4, 0, JumpInterpolationStrategy())))
        self.assertEqual(waveform._channel_id,
                         'ch')
    def test_build_waveform_multi_channel(self):
        table = TablePulseTemplate(
            {
                0: [(0, 0), ('foo', 'v', 'linear'), ('bar', 0, 'jump')],
                3: [(0, 1), ('bar+foo', 0, 'linear')]
            },
            parameter_constraints=['foo>1'],
            measurements=[('M', 'b', 'l'), ('N', 1, 2)])

        parameters = {'v': 2.3, 'foo': 1, 'bar': 4, 'b': 2, 'l': 1}
        channel_mapping = {0: 'ch', 3: 'oh'}

        with self.assertRaises(ParameterConstraintViolation):
            table.build_waveform(parameters=parameters,
                                 channel_mapping=channel_mapping)

        parameters['foo'] = 1.1
        waveform = table.build_waveform(parameters=parameters,
                                        channel_mapping=channel_mapping)

        self.assertIsInstance(waveform, MultiChannelWaveform)

        expected_waveforms = [
            TableWaveform.from_table(
                'ch', ((0, 0, HoldInterpolationStrategy()),
                       (1.1, 2.3, LinearInterpolationStrategy()),
                       (4, 0, JumpInterpolationStrategy()),
                       (5.1, 0, HoldInterpolationStrategy()))),
            TableWaveform.from_table(
                'oh', ((0, 1, HoldInterpolationStrategy()),
                       (5.1, 0, LinearInterpolationStrategy()))),
        ]

        self.assertEqual(waveform._sub_waveforms, tuple(expected_waveforms))
Beispiel #5
0
    def test_from_table(self):
        expected = ConstantWaveform(0.1, 0.2, 'A')

        for interp in (HoldInterpolationStrategy(), JumpInterpolationStrategy(), LinearInterpolationStrategy()):
            wf = TableWaveform.from_table('A',
                                          [TableWaveformEntry(0.0, 0.2, interp),
                                           TableWaveformEntry(0.1, 0.2, interp)])
            self.assertEqual(expected, wf)
Beispiel #6
0
 def test_jump_interpolation(self):
     start = (-1, -1)
     end = (3, 3)
     t = np.linspace(-1, 3, 100)
     strat = JumpInterpolationStrategy()
     result = strat(start, end, t)
     self.assertTrue(all(result == 3))
     with self.assertRaises(ValueError):
         strat(end, start, t)
Beispiel #7
0
    def test_known_interpolation_strategies(self):
        strategies = [("linear", LinearInterpolationStrategy()),
                      ("hold", HoldInterpolationStrategy()),
                      ("jump", JumpInterpolationStrategy())]

        for strat_name, strat_val in strategies:
            entry = TableEntry('a', Expression('b'), strat_name)

            self.assertEqual(entry.t, Expression('a'))
            self.assertEqual(entry.v, Expression('b'))
            self.assertEqual(entry.interp, strat_val)
Beispiel #8
0
    def test_repr_str(self):
        #Test hash
        strategies = {
            LinearInterpolationStrategy():
            ("linear", "<Linear Interpolation>"),
            HoldInterpolationStrategy(): ("hold", "<Hold Interpolation>"),
            JumpInterpolationStrategy(): ("jump", "<Jump Interpolation>")
        }

        for strategy in strategies:
            repr_ = strategies[strategy][1]
            str_ = strategies[strategy][0]
            self.assertEqual(repr(strategy), repr_)
            self.assertEqual(str(strategy), str_)
        self.assertTrue(
            LinearInterpolationStrategy() != HoldInterpolationStrategy())
        self.assertTrue(
            LinearInterpolationStrategy() != JumpInterpolationStrategy())
        self.assertTrue(
            JumpInterpolationStrategy() != HoldInterpolationStrategy())
 def test_get_entries_auto_insert(self) -> None:
     table = TablePulseTemplate({0: [('foo', 'v', 'linear'),
                                     ('bar', 0, 'jump')],
                                 1: [(0, 3, 'linear'),
                                     ('bar+foo', 2, 'linear')]})
     instantiated_entries = table.get_entries_instantiated({'v': 2.3, 'foo': 1, 'bar': 4})
     self.assertEqual({0: [(0, 2.3, HoldInterpolationStrategy()),
                           (1, 2.3, LinearInterpolationStrategy()),
                           (4, 0, JumpInterpolationStrategy()),
                           (5, 0, HoldInterpolationStrategy())],
                       1: [(0, 3, LinearInterpolationStrategy()),
                           (5, 2, LinearInterpolationStrategy())]}, instantiated_entries)
    def test_build_waveform_multi_channel(self):
        table = TablePulseTemplate({0: [(0, 0),
                                        ('foo', 'v', 'linear'),
                                        ('bar', 0, 'jump')],
                                    3: [(0, 1),
                                        ('bar+foo', 0, 'linear')]},
                                   parameter_constraints=['foo>1'],
                                   measurements=[('M', 'b', 'l'),
                                                 ('N', 1, 2)])

        parameters = {'v': 2.3, 'foo': 1, 'bar': 4, 'b': 2, 'l': 1}
        channel_mapping = {0: 'ch', 3: 'oh'}

        with self.assertRaises(ParameterConstraintViolation):
            table.build_waveform(parameters=parameters,
                                 channel_mapping=channel_mapping)

        parameters['foo'] = 1.1
        waveform = table.build_waveform(parameters=parameters,
                                        channel_mapping=channel_mapping)

        self.assertIsInstance(waveform, MultiChannelWaveform)
        self.assertEqual(len(waveform._sub_waveforms), 2)

        channels = {'oh', 'ch'}
        for wf in waveform._sub_waveforms:
            self.assertIsInstance(wf, TableWaveform)
            self.assertIn(wf._channel_id, channels)
            channels.remove(wf._channel_id)
            if wf.defined_channels == {'ch'}:
                self.assertEqual(wf._table,
                                 ((0, 0, HoldInterpolationStrategy()),
                                  (1.1, 2.3, LinearInterpolationStrategy()),
                                  (4, 0, JumpInterpolationStrategy()),
                                  (5.1, 0, HoldInterpolationStrategy())))
            elif wf.defined_channels == {'oh'}:
                self.assertEqual(wf._table,
                                 ((0, 1, HoldInterpolationStrategy()),
                                  (5.1, 0, LinearInterpolationStrategy())))
Beispiel #11
0
class TablePulseTemplate(AtomicPulseTemplate, ParameterConstrainer):
    """The TablePulseTemplate class implements pulses described by a table with time, voltage and interpolation strategy
    inputs. The interpolation strategy describes how the voltage between the entries is interpolated(see also
    InterpolationStrategy.) It can define multiple channels of which each has a separate table. If they do not have the
    same length the shorter channels are extended to the longest duration.

    If the time entries of all channels are equal it is more convenient to use the :paramrefPointPulseTemplate`."""
    interpolation_strategies = {
        'linear': LinearInterpolationStrategy(),
        'hold': HoldInterpolationStrategy(),
        'jump': JumpInterpolationStrategy(),
        'default': HoldInterpolationStrategy()
    }

    def __init__(self,
                 entries: Dict[ChannelID, Sequence[EntryInInit]],
                 identifier: Optional[str] = None,
                 *,
                 parameter_constraints: Optional[List[Union[
                     str, ParameterConstraint]]] = None,
                 measurements: Optional[List[MeasurementDeclaration]] = None,
                 consistency_check: bool = True,
                 registry: PulseRegistryType = None) -> None:
        """
        Construct a `TablePulseTemplate` from a dict which maps channels to their entries. By default the consistency
        of the provided entries is checked. There are two static functions for convenience construction: from_array and
        from_entry_list.

        Args:
            entries: A dictionary that maps channel ids to a list of entries. An entry is a
                (time, voltage[, interpolation strategy]) tuple or a TableEntry
            identifier: Used for serialization
            parameter_constraints: Constraint list that is forwarded to the ParameterConstrainer superclass
            measurements: Measurement declaration list that is forwarded to the MeasurementDefiner superclass
            consistency_check: If True the consistency of the times will be checked on construction as far as possible
        """
        AtomicPulseTemplate.__init__(self,
                                     identifier=identifier,
                                     measurements=measurements)
        ParameterConstrainer.__init__(
            self, parameter_constraints=parameter_constraints)

        if not entries:
            raise ValueError(
                "Cannot construct an empty TablePulseTemplate (no entries given). There is currently no "
                "specific reason for this. Please submit an issue if you need this 'feature'."
            )

        self._entries = dict((ch, list()) for ch in entries.keys())
        for channel, channel_entries in entries.items():
            if len(channel_entries) == 0:
                raise ValueError('Channel {} is empty'.format(channel))

            for entry in channel_entries:
                self._add_entry(channel, TableEntry(*entry))

        self._duration = self.calculate_duration()
        self._table_parameters = set(
            var for channel_entries in self.entries.values()
            for entry in channel_entries
            for var in itertools.chain(entry.t.variables, entry.v.variables
                                       )) | self.constrained_parameters

        if self.duration == 0:
            warnings.warn(
                'Table pulse template with duration 0 on construction.',
                category=ZeroDurationTablePulseTemplate)

        if consistency_check:
            # perform a simple consistency check. All inequalities with more than one free variable are ignored as the
            # sympy solver does not support them

            # collect all conditions
            inequalities = [eq.sympified_expression for eq in self._parameter_constraints] +\
                           [sympy.Le(previous_entry.t.underlying_expression, entry.t.underlying_expression)
                            for channel_entries in self._entries.values()
                            for previous_entry, entry in zip(channel_entries, channel_entries[1:])]

            # test if any condition is already dissatisfied
            if any(
                    isinstance(eq, BooleanAtom) and bool(eq) is False
                    for eq in inequalities):
                raise ValueError(
                    'Table pulse template has impossible parametrization')

            # filter conditions that are inequalities with one free variable and test if the solution set is empty
            inequalities = [
                eq for eq in inequalities
                if isinstance(eq, sympy.Rel) and len(eq.free_symbols) == 1
            ]
            if not sympy.reduce_inequalities(inequalities):
                raise ValueError(
                    'Table pulse template has impossible parametrization')

        self._register(registry=registry)

    def _add_entry(self, channel, new_entry: TableEntry) -> None:
        ch_entries = self._entries[channel]

        # comparisons with Expression can yield None -> use 'is True' and 'is False'
        if (new_entry.t < 0) is True:
            raise ValueError(
                'Time parameter number {} of channel {} is negative.'.format(
                    len(ch_entries), channel))

        for previous_entry in ch_entries:
            if (new_entry.t < previous_entry.t) is True:
                raise ValueError(
                    'Time parameter number {} of channel {} is smaller than a previous one'
                    .format(len(ch_entries), channel))

        self._entries[channel].append(new_entry)

    @property
    def entries(self) -> Dict[ChannelID, List[TableEntry]]:
        return self._entries

    def get_entries_instantiated(self, parameters: Dict[str, numbers.Real]) \
            -> Dict[ChannelID, List[TableWaveformEntry]]:
        """Compute an instantiated list of the table's entries.

        Args:
            parameters (Dict(str -> Parameter)): A mapping of parameter names to Parameter objects.
        Returns:
             (float, float)-list of all table entries with concrete values provided by the given
                parameters.
        """
        if not (self.table_parameters <= set(parameters.keys())):
            raise ParameterNotProvidedException(
                (self.table_parameters - set(parameters.keys())).pop())

        instantiated_entries = dict(
        )  # type: Dict[ChannelID,List[TableWaveformEntry]]

        for channel, channel_entries in self._entries.items():
            instantiated = [
                entry.instantiate(parameters) for entry in channel_entries
            ]

            # Add (0, v) entry if wf starts at finite time
            if instantiated[0].t > 0:
                instantiated.insert(
                    0,
                    TableWaveformEntry(
                        0, instantiated[0].v,
                        TablePulseTemplate.interpolation_strategies['hold']))
            instantiated_entries[channel] = instantiated

        duration = max(instantiated[-1].t
                       for instantiated in instantiated_entries.values())

        # ensure that all channels have equal duration
        for channel, instantiated in instantiated_entries.items():
            final_entry = instantiated[-1]
            if final_entry.t < duration:
                instantiated.append(
                    TableWaveformEntry(
                        duration, final_entry.v,
                        TablePulseTemplate.interpolation_strategies['hold']))
            instantiated_entries[channel] = instantiated
        return instantiated_entries

    @property
    def table_parameters(self) -> Set[str]:
        return self._table_parameters

    @property
    def parameter_names(self) -> Set[str]:
        return self.table_parameters | self.measurement_parameters | self.constrained_parameters

    @property
    def duration(self) -> ExpressionScalar:
        return self._duration

    def calculate_duration(self) -> ExpressionScalar:
        duration_expressions = [
            entries[-1].t for entries in self._entries.values()
        ]
        duration_expression = sympy.Max(*(expr.sympified_expression
                                          for expr in duration_expressions))
        return ExpressionScalar(duration_expression)

    @property
    def defined_channels(self) -> Set[ChannelID]:
        return set(self._entries.keys())

    def get_serialization_data(self,
                               serializer: Optional[Serializer] = None
                               ) -> Dict[str, Any]:
        data = super().get_serialization_data(serializer)

        if serializer:  # compatibility to old serialization routines, deprecated
            data = dict()

        local_data = dict(
            entries=dict(
                (channel,
                 [entry.get_serialization_data() for entry in channel_entries])
                for channel, channel_entries in self.entries.items()),
            parameter_constraints=[str(c) for c in self.parameter_constraints],
            measurements=self.measurement_declarations)
        data.update(**local_data)
        return data

    def build_waveform(
        self, parameters: Dict[str, numbers.Real],
        channel_mapping: Dict[ChannelID, Optional[ChannelID]]
    ) -> Optional[Union[TableWaveform, MultiChannelWaveform]]:
        self.validate_parameter_constraints(parameters, volatile=set())

        if all(channel_mapping[channel] is None
               for channel in self.defined_channels):
            return None

        instantiated = [(channel_mapping[channel], instantiated_channel)
                        for channel, instantiated_channel in
                        self.get_entries_instantiated(parameters).items()
                        if channel_mapping[channel] is not None]

        if self.duration.evaluate_numeric(**parameters) == 0:
            return None

        waveforms = [
            TableWaveform(*ch_instantiated) for ch_instantiated in instantiated
        ]

        if len(waveforms) == 1:
            return waveforms.pop()
        else:
            return MultiChannelWaveform(waveforms)

    @staticmethod
    def from_array(times: np.ndarray, voltages: np.ndarray,
                   channels: List[ChannelID]) -> 'TablePulseTemplate':
        """Static constructor to build a TablePulse from numpy arrays.

        Args:
            times: 1D numpy array with time values
            voltages: 1D or 2D numpy array with voltage values
            channels: channels to define

        Returns:
            TablePulseTemplate with the given values, hold interpolation everywhere and no free
            parameters.
        """
        if times.ndim == 0 or voltages.ndim == 0:
            raise ValueError('Zero dimensional input is not accepted.')

        if times.ndim > 2 or voltages.ndim > 2:
            raise ValueError(
                'Three or higher dimensional input is not accepted.')

        if times.ndim == 2 and times.shape[0] != len(channels):
            raise ValueError(
                'First dimension of times must be equal to the number of channels'
            )

        if voltages.ndim == 2 and voltages.shape[0] != len(channels):
            raise ValueError(
                'First dimension of voltages must be equal to the number of channels'
            )

        if voltages.shape[-1] != times.shape[-1]:
            ValueError('Different number of entries for times and voltages')

        return TablePulseTemplate(
            dict((channel,
                  list(
                      zip(times if times.ndim == 1 else times[i, :],
                          voltages if voltages.ndim == 1 else voltages[i, :])))
                 for i, channel in enumerate(channels)))

    @staticmethod
    def from_entry_list(entry_list: List[Tuple],
                        channel_names: Optional[List[ChannelID]] = None,
                        **kwargs) -> 'TablePulseTemplate':
        """Static constructor for a TablePulseTemplate where all channel's entries share the same times.

        :param entry_list: List of tuples of the form (t, v_1, ..., v_N[, interp])
        :param channel_names: Optional list of channel identifiers to use. Default is [0, ..., N-1]
        :param kwargs: Forwarded to TablePulseTemplate constructor
        :return: TablePulseTemplate with
        """

        # TODO: Better doc string
        def is_valid_interpolation_strategy(inter):
            return inter in TablePulseTemplate.interpolation_strategies or isinstance(
                inter, InterpolationStrategy)

        # determine number of channels
        max_len = max(len(data) for data in entry_list)
        min_len = min(len(data) for data in entry_list)

        if max_len - min_len > 1:
            raise ValueError(
                'There are entries of contradicting lengths: {}'.format(
                    set(len(t) for t in entry_list)))
        elif max_len - min_len == 1:
            num_chan = min_len - 1
        else:
            # figure out whether all last entries are interpolation strategies
            if all(
                    is_valid_interpolation_strategy(interp)
                    for *data, interp in entry_list):
                num_chan = min_len - 2
            else:
                num_chan = min_len - 1

        # insert default interpolation strategy key
        entry_list = [(t, *data, interp) if len(data) == num_chan else
                      (t, *data, interp, 'default')
                      for t, *data, interp in entry_list]

        for *_, last_voltage, _ in entry_list:
            if last_voltage in TablePulseTemplate.interpolation_strategies:
                warnings.warn(
                    '{} is also an interpolation strategy name but handled as a voltage. Is it intended?'
                    .format(last_voltage), AmbiguousTablePulseEntry)

        if channel_names is None:
            channel_names = list(range(num_chan))
        elif len(channel_names) != num_chan:
            raise ValueError(
                'Number of channel identifiers does not correspond to the number of channels.'
            )

        parsed = {channel_name: [] for channel_name in channel_names}

        for time, *voltages, interp in entry_list:
            for channel_name, volt in zip(channel_names, voltages):
                parsed[channel_name].append((time, volt, interp))

        return TablePulseTemplate(parsed, **kwargs)

    @property
    def integral(self) -> Dict[ChannelID, ExpressionScalar]:
        expressions = dict()
        for channel, channel_entries in self._entries.items():
            pre_entry = TableEntry(0, channel_entries[0].v, None)
            post_entry = TableEntry(self.duration, channel_entries[-1].v,
                                    'hold')
            channel_entries = [pre_entry] + channel_entries + [post_entry]
            expressions[channel] = TableEntry._sequence_integral(
                channel_entries, lambda v: v.sympified_expression)

        return expressions

    def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]:
        expressions = dict()
        for channel, channel_entries in self._entries.items():
            pre_value = channel_entries[0].v.sympified_expression
            post_value = channel_entries[-1].v.sympified_expression

            expressions[channel] = TableEntry._sequence_as_expression(
                channel_entries,
                lambda v: v.sympified_expression,
                t=self._AS_EXPRESSION_TIME,
                pre_value=pre_value,
                post_value=post_value)
        return expressions