Ejemplo n.º 1
0
 def setUp(self):
   super().setUp()
   self.enc = encoder_decoder.MultipleEventSequenceEncoder([
       encoder_decoder.OneHotEventSequenceEncoderDecoder(
           testing_lib.TrivialOneHotEncoding(2)),
       encoder_decoder.OneHotEventSequenceEncoderDecoder(
           testing_lib.TrivialOneHotEncoding(3))])
Ejemplo n.º 2
0
 def setUp(self):
   super().setUp()
   self.enc = encoder_decoder.ConditionalEventSequenceEncoderDecoder(
       encoder_decoder.OneHotEventSequenceEncoderDecoder(
           testing_lib.TrivialOneHotEncoding(2)),
       encoder_decoder.OneHotEventSequenceEncoderDecoder(
           testing_lib.TrivialOneHotEncoding(3)))
Ejemplo n.º 3
0
    def setUp(self):
        self._sequence_file = tempfile.NamedTemporaryFile(
            prefix='EventSequenceRNNGraphTest')

        self.config = events_rnn_model.EventSequenceRnnConfig(
            None,
            note_seq.OneHotEventSequenceEncoderDecoder(
                testing_lib.TrivialOneHotEncoding(12)),
            contrib_training.HParams(batch_size=128,
                                     rnn_layer_sizes=[128, 128],
                                     dropout_keep_prob=0.5,
                                     clip_norm=5,
                                     learning_rate=0.01))
Ejemplo n.º 4
0
  def testEmptyLookback(self):
    enc = encoder_decoder.LookbackEventSequenceEncoderDecoder(
        testing_lib.TrivialOneHotEncoding(3), [], 2)
    self.assertEqual(5, enc.input_size)
    self.assertEqual(3, enc.num_classes)

    events = [0, 1, 0, 2, 0]

    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 0))
    self.assertEqual([0.0, 1.0, 0.0, -1.0, 1.0],
                     enc.events_to_input(events, 1))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, 1.0],
                     enc.events_to_input(events, 2))
    self.assertEqual([0.0, 0.0, 1.0, -1.0, -1.0],
                     enc.events_to_input(events, 3))
    self.assertEqual([1.0, 0.0, 0.0, 1.0, -1.0],
                     enc.events_to_input(events, 4))

    self.assertEqual(0, enc.events_to_label(events, 0))
    self.assertEqual(1, enc.events_to_label(events, 1))
    self.assertEqual(0, enc.events_to_label(events, 2))
    self.assertEqual(2, enc.events_to_label(events, 3))
    self.assertEqual(0, enc.events_to_label(events, 4))

    self.assertEqual(0, self.enc.class_index_to_event(0, events[:1]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:1]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:1]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:2]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:2]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:2]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:3]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:3]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:3]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:4]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:4]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:4]))
    self.assertEqual(0, self.enc.class_index_to_event(0, events[:5]))
    self.assertEqual(1, self.enc.class_index_to_event(1, events[:5]))
    self.assertEqual(2, self.enc.class_index_to_event(2, events[:5]))
Ejemplo n.º 5
0
 def setUp(self):
   super().setUp()
   self.enc = encoder_decoder.OneHotEventSequenceEncoderDecoder(
       testing_lib.TrivialOneHotEncoding(3, num_steps=range(3)))