def testEncodeDecode(self):
        start = PolyphonicEvent(event_type=PolyphonicEvent.START, pitch=0)
        step_end = PolyphonicEvent(event_type=PolyphonicEvent.STEP_END,
                                   pitch=0)
        new_note = PolyphonicEvent(event_type=PolyphonicEvent.NEW_NOTE,
                                   pitch=0)
        continued_note = PolyphonicEvent(
            event_type=PolyphonicEvent.CONTINUED_NOTE, pitch=60)
        continued_max_note = PolyphonicEvent(
            event_type=PolyphonicEvent.CONTINUED_NOTE, pitch=127)

        index = self.enc.encode_event(start)
        self.assertEqual(0, index)
        event = self.enc.decode_event(index)
        self.assertEqual(start, event)

        index = self.enc.encode_event(step_end)
        self.assertEqual(2, index)
        event = self.enc.decode_event(index)
        self.assertEqual(step_end, event)

        index = self.enc.encode_event(new_note)
        self.assertEqual(3, index)
        event = self.enc.decode_event(index)
        self.assertEqual(new_note, event)

        index = self.enc.encode_event(continued_note)
        self.assertEqual(191, index)
        event = self.enc.decode_event(index)
        self.assertEqual(continued_note, event)

        index = self.enc.encode_event(continued_max_note)
        self.assertEqual(258, index)
        event = self.enc.decode_event(index)
        self.assertEqual(continued_max_note, event)
    def decode_event(self, index):
        if index < len(EVENT_CLASSES_WITHOUT_PITCH):
            return PolyphonicEvent(
                event_type=EVENT_CLASSES_WITHOUT_PITCH[index], pitch=0)

        pitched_index = index - len(EVENT_CLASSES_WITHOUT_PITCH)
        if pitched_index < len(EVENT_CLASSES_WITH_PITCH) * PITCH_CLASSES:
            event_type = len(EVENT_CLASSES_WITHOUT_PITCH) + (pitched_index //
                                                             PITCH_CLASSES)
            pitch = pitched_index % PITCH_CLASSES
            return PolyphonicEvent(event_type=event_type, pitch=pitch)

        raise ValueError('Unknown event index: %s' % index)
 def testEventToNumSteps(self):
     self.assertEqual(
         0,
         self.enc.event_to_num_steps(
             PolyphonicEvent(event_type=PolyphonicEvent.START, pitch=0)))
     self.assertEqual(
         0,
         self.enc.event_to_num_steps(
             PolyphonicEvent(event_type=PolyphonicEvent.END, pitch=0)))
     self.assertEqual(
         1,
         self.enc.event_to_num_steps(
             PolyphonicEvent(event_type=PolyphonicEvent.STEP_END, pitch=0)))
     self.assertEqual(
         0,
         self.enc.event_to_num_steps(
             PolyphonicEvent(event_type=PolyphonicEvent.NEW_NOTE,
                             pitch=60)))
     self.assertEqual(
         0,
         self.enc.event_to_num_steps(
             PolyphonicEvent(event_type=PolyphonicEvent.CONTINUED_NOTE,
                             pitch=72)))
 def default_event(self):
     return PolyphonicEvent(event_type=PolyphonicEvent.STEP_END, pitch=0)