예제 #1
0
    def test_detect_input_type(self):
        # Load some input: DbInput
        dbi = DbInput.from_file(DB_SEQUENCES_FILE, {"index": 0})
        # Run it through the preprocessor
        datatype, obj = detect_input_type(dbi)
        # Get the datatype from the type name lists
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Do the same with ChordInput
        ci = ChordInput.from_file(CHORDS_FILE, options={"roman": True})
        datatype, obj = detect_input_type(ci)
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Try some bulk input
        bulk = DbBulkInput.from_file(DB_SEQUENCES_FILE)
        datatype, obj = detect_input_type(bulk, allow_bulk=True)
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Try restricting the allowed type
        datatype, obj = detect_input_type(ci, allowed=["chords"])
        # And this one should get rejected
        self.assertRaises(InputTypeError, detect_input_type, (ci,), {"allowed": "db"})
예제 #2
0
    def test_detect_input_type(self):
        # Load some input: DbInput
        dbi = DbInput.from_file(DB_SEQUENCES_FILE, {'index': 0})
        # Run it through the preprocessor
        datatype, obj = detect_input_type(dbi)
        # Get the datatype from the type name lists
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Do the same with ChordInput
        ci = ChordInput.from_file(CHORDS_FILE, options={'roman': True})
        datatype, obj = detect_input_type(ci)
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Try some bulk input
        bulk = DbBulkInput.from_file(DB_SEQUENCES_FILE)
        datatype, obj = detect_input_type(bulk, allow_bulk=True)
        datatype2 = input_type_name(type(obj))
        self.assertEqual(datatype, datatype2)

        # Try restricting the allowed type
        datatype, obj = detect_input_type(ci, allowed=['chords'])
        # And this one should get rejected
        self.assertRaises(InputTypeError, detect_input_type, (ci, ),
                          {'allowed': 'db'})
예제 #3
0
 def test_from_file(self):
     # Simply load a sequence index file
     bulk = DbBulkInput.from_file(DB_SEQUENCES_FILE)
     # We can get the sequences just but converting the iter to a list
     seqs = list(bulk)
     # There should be a non-zero number of sequences loaded
     self.assertNotEqual(len(seqs), 0)
     # Check the type of the first one
     self.assertIsInstance(seqs[0], DbInput)
예제 #4
0
 def test_from_file(self):
     # Simply load a sequence index file
     bulk = DbBulkInput.from_file(DB_SEQUENCES_FILE)
     # We can get the sequences just but converting the iter to a list
     seqs = list(bulk)
     # There should be a non-zero number of sequences loaded
     self.assertNotEqual(len(seqs), 0)
     # Check the type of the first one
     self.assertIsInstance(seqs[0], DbInput)
예제 #5
0
     # Get the bulk input to supply names
     name_getter = iter(input_data.get_identifiers())
     num_inputs = len(input_data)
     # Fill the progress record with names and mark as incomplete
     completed_parses = dict([(name,False) \
                             for name in input_data.get_identifiers()])
     if partitions > 1:
         if options.sequence_partitions is not None:
             # Split the inputs up into partitions on the basis of 
             #  an even partitioning of chord sequences
             # This can only be done with 
             if not isinstance(input_data, SegmentedMidiBulkInput):
                 logger.error("option --sequence-partitions is only "\
                     "valid with bulk midi input data")
                 return 1
             chord_seqs = DbBulkInput.from_file(options.sequence_partitions)
             # Partition the chord sequences: we only need indices
             seq_indices = enumerate(partition(
                         [i for i in range(len(chord_seqs))], partitions))
             seq_partitions = dict(
                 sum([[(index,part_num) for index in part] for 
                         (part_num,part) in seq_indices], []) )
             # Associate a partition num with each midi input
             partition_numbers = [
                 seq_partitions[midi.sequence_index] for midi in input_data]
         else:
             # Prepare a list of partition numbers to append to model names
             partition_numbers = sum([
                 [partnum for i in part] for (partnum,part) in \
                  enumerate(partition(range(num_inputs), partitions))], [])
 else: