Пример #1
0
    def test_map_samples_with_model(self):
        sbi_mappers_configuration = configure_mappers(ploidy=2,
                                                      extra_genotypes=2,
                                                      num_samples=1,
                                                      count_dim=16,
                                                      sample_dim=32)
        sbi_mapper = BatchOfInstances(*sbi_mappers_configuration)
        json_string = '{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4}]}]}'

        # determine feature size:
        import ujson
        record = ujson.loads(json_string)

        mapped_features_size = sbi_mapper([record, record],
                                          tensor_cache=NoCache(),
                                          cuda=False).size(1)
        problem = StructuredSbiGenotypingProblem(
            mini_batch_size=2,
            code="struct_genotyping:/data/struct/CNG-NA12878-softmax-indels")
        output_size = problem.output_size("softmaxGenotype")
        parser = define_train_auto_encoder_parser()
        args = parser.parse_args()
        model = StructGenotypingModel(args,
                                      sbi_mapper,
                                      mapped_features_size,
                                      output_size,
                                      use_cuda=False,
                                      use_batching=True)
        print(model.map_sbi_messages(sbi_records=[record] * 2, cuda=True))
Пример #2
0
    def collect_inputs(self, sample, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None):
        """Collect input data for all counts in this sample. """
        if 'indices' not in sample:
            sample['indices'] = {}
        observed_counts = self.get_observed_counts(sample)[0:self.num_counts]
        if phase < 2:

            for count in observed_counts:
                store_indices_in_message(mapper=self.count_mapper, message=count, indices=
                self.count_mapper.collect_inputs(count, tensor_cache=tensor_cache,
                                                 cuda=cuda, phase=phase, batcher=batcher))

            return []

        if phase == 2:
            list_mapped_counts = []
            for count in observed_counts:
                mapped_count = batcher.get_forward_for_example(self.count_mapper,
                                                               example_indices=count['indices'][id(self.count_mapper)])
                list_mapped_counts += [mapped_count]
            while len(list_mapped_counts) < self.num_counts:
                # pad the list with zeros:
                variable = Variable(torch.zeros(*list_mapped_counts[0].size()), requires_grad=True)
                if cuda:
                    variable = variable.cuda(async=True)
                list_mapped_counts += [variable]
            cat_list_mapped_counts = torch.cat(list_mapped_counts, dim=-1)

            store_indices_in_message(mapper=self.reduce_counts, message=sample,
                                     indices=self.reduce_counts.collect_inputs(cat_list_mapped_counts,
                                                                               tensor_cache=tensor_cache,
                                                                               cuda=cuda,
                                                                               phase=phase,
                                                                               batcher=batcher))
            return []
Пример #3
0
    def create_struct_model(self, problem, args, use_cuda):

        sbi_mappers_configuration = configure_mappers(
            ploidy=args.struct_ploidy,
            extra_genotypes=args.struct_extra_genotypes,
            num_samples=1,
            count_dim=args.struct_count_dim,
            sample_dim=args.struct_sample_dim,
            use_cuda=use_cuda)
        sbi_mapper = BatchOfInstances(*sbi_mappers_configuration)
        # determine feature size:

        import ujson
        record = ujson.loads(sbi_json_string)
        if self.use_cuda:
            sbi_mapper.cuda()
        mapped_features_size = sbi_mapper([record],
                                          tensor_cache=NoCache(),
                                          cuda=self.use_cuda).size(1)

        output_size = problem.output_size("softmaxGenotype")
        model = StructGenotypingModel(args, sbi_mapper, mapped_features_size,
                                      output_size, self.use_cuda,
                                      args.use_batching)
        print(model)
        return model
Пример #4
0
    def collect_inputs(self, nwf_list, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None):
        if phase == 0:
            # the following tensors are batched:
            nwf_list['indices'] = {}
            store_indices_in_message(mapper=self.map_number, message=nwf_list,
                                     indices=self.map_number.collect_inputs(
                                         values=nwf_list['number'], phase=phase, cuda=cuda, batcher=batcher))

            store_indices_in_message(mapper=self.map_frequency, message=nwf_list,
                                     indices=self.map_frequency.collect_inputs(
                                         values=nwf_list['frequency'], phase=phase, cuda=cuda, batcher=batcher))
Пример #5
0
 def test_mapper(self):
     json_string = '{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4}]}]}'
     import ujson
     record = ujson.loads(json_string)
     print(record)
     mappers, all_modules = configure_mappers(ploidy=2,
                                              extra_genotypes=3,
                                              num_samples=1)
     mapper = BatchOfInstances(mappers=mappers, all_modules=all_modules)
     out = mapper([record], tensor_cache=NoCache())
     print(out)
Пример #6
0
    def cat_inputs(self, mapper, list_of_values, tensor_cache=NoCache(), phase=0, cuda=False, direct_forward=False):
        mapper_id = id(mapper)
        results = {mapper_id: []}
        for value in list_of_values:
            if direct_forward:
                mapped = mapper(value, tensor_cache=tensor_cache, cuda=cuda)
            else:
                mapped = mapper.collect_inputs(value, tensor_cache=tensor_cache, phase=phase, cuda=cuda)

            results[mapper_id] += [mapped]

        return torch.cat(results[mapper_id], dim=1)
Пример #7
0
 def test_mapper_number_freqs(self):
     json_string='{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0,"qualityScoresForwardStrand":[{"type":"NumberWithFrequency","frequency":7,"number":40}],"qualityScoresReverseStrand":[{"type":"NumberWithFrequency","frequency":32,"number":40}],"readIndicesForwardStrand":[{"type":"NumberWithFrequency","frequency":1,"number":23},{"type":"NumberWithFrequency","frequency":1,"number":30},{"type":"NumberWithFrequency","frequency":5,"number":34}],"readIndicesReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":6},{"type":"NumberWithFrequency","frequency":1,"number":22},{"type":"NumberWithFrequency","frequency":1,"number":28},{"type":"NumberWithFrequency","frequency":1,"number":31},{"type":"NumberWithFrequency","frequency":1,"number":34},{"type":"NumberWithFrequency","frequency":1,"number":35},{"type":"NumberWithFrequency","frequency":1,"number":44},{"type":"NumberWithFrequency","frequency":1,"number":145},{"type":"NumberWithFrequency","frequency":1,"number":150},{"type":"NumberWithFrequency","frequency":5,"number":151},{"type":"NumberWithFrequency","frequency":2,"number":171},{"type":"NumberWithFrequency","frequency":4,"number":172}],"queryPositions":[{"type":"NumberWithFrequency","frequency":39,"number":0}],"pairFlags":[{"type":"NumberWithFrequency","frequency":6,"number":16},{"type":"NumberWithFrequency","frequency":14,"number":83},{"type":"NumberWithFrequency","frequency":6,"number":99},{"type":"NumberWithFrequency","frequency":12,"number":147},{"type":"NumberWithFrequency","frequency":1,"number":163}],"distancesToReadVariationsForwardStrand":[{"type":"NumberWithFrequency","frequency":2,"number":-70},{"type":"NumberWithFrequency","frequency":4,"number":-29}],"distancesToReadVariationsReverseStrand":[{"type":"NumberWithFrequency","frequency":2,"number":-24},{"type":"NumberWithFrequency","frequency":1,"number":-15},{"type":"NumberWithFrequency","frequency":1,"number":-2},{"type":"NumberWithFrequency","frequency":1,"number":12},{"type":"NumberWithFrequency","frequency":1,"number":13},{"type":"NumberWithFrequency","frequency":1,"number":15},{"type":"NumberWithFrequency","frequency":13,"number":29},{"type":"NumberWithFrequency","frequency":1,"number":49},{"type":"NumberWithFrequency","frequency":3,"number":62},{"type":"NumberWithFrequency","frequency":9,"number":70},{"type":"NumberWithFrequency","frequency":1,"number":73}],"distanceToStartOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":18},{"type":"NumberWithFrequency","frequency":1,"number":23},{"type":"NumberWithFrequency","frequency":1,"number":26},{"type":"NumberWithFrequency","frequency":1,"number":30},{"type":"NumberWithFrequency","frequency":30,"number":33},{"type":"NumberWithFrequency","frequency":5,"number":34}],"distanceToEndOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":6},{"type":"NumberWithFrequency","frequency":1,"number":22},{"type":"NumberWithFrequency","frequency":1,"number":28},{"type":"NumberWithFrequency","frequency":1,"number":31},{"type":"NumberWithFrequency","frequency":1,"number":34},{"type":"NumberWithFrequency","frequency":2,"number":35},{"type":"NumberWithFrequency","frequency":1,"number":44},{"type":"NumberWithFrequency","frequency":1,"number":48},{"type":"NumberWithFrequency","frequency":1,"number":49},{"type":"NumberWithFrequency","frequency":1,"number":50},{"type":"NumberWithFrequency","frequency":1,"number":52},{"type":"NumberWithFrequency","frequency":1,"number":62},{"type":"NumberWithFrequency","frequency":1,"number":63},{"type":"NumberWithFrequency","frequency":1,"number":68},{"type":"NumberWithFrequency","frequency":2,"number":75},{"type":"NumberWithFrequency","frequency":2,"number":76},{"type":"NumberWithFrequency","frequency":1,"number":81},{"type":"NumberWithFrequency","frequency":1,"number":83},{"type":"NumberWithFrequency","frequency":1,"number":88},{"type":"NumberWithFrequency","frequency":1,"number":89},{"type":"NumberWithFrequency","frequency":1,"number":100},{"type":"NumberWithFrequency","frequency":1,"number":104},{"type":"NumberWithFrequency","frequency":1,"number":109},{"type":"NumberWithFrequency","frequency":1,"number":111},{"type":"NumberWithFrequency","frequency":1,"number":117},{"type":"NumberWithFrequency","frequency":1,"number":118},{"type":"NumberWithFrequency","frequency":1,"number":121},{"type":"NumberWithFrequency","frequency":1,"number":125},{"type":"NumberWithFrequency","frequency":1,"number":128},{"type":"NumberWithFrequency","frequency":1,"number":133},{"type":"NumberWithFrequency","frequency":2,"number":138},{"type":"NumberWithFrequency","frequency":4,"number":139}]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":7}],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":115}],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":60}],"numVariationsInReads":[{"type":"NumberWithFrequency","frequency":1,"number":2}],"insertSizes":[{"type":"NumberWithFrequency","frequency":1,"number":-301}],"targetAlignedLengths":[{"type":"NumberWithFrequency","frequency":2,"number":148}],"queryAlignedLengths":[{"type":"NumberWithFrequency","frequency":1,"number":148}],"queryPositions":[{"type":"NumberWithFrequency","frequency":1,"number":0}],"pairFlags":[{"type":"NumberWithFrequency","frequency":1,"number":147}],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":-29},{"type":"NumberWithFrequency","frequency":1,"number":0}],"distanceToStartOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":33}],"distanceToEndOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":115}]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]}]}]}' \
                 ''
     import ujson
     record = ujson.loads(json_string)
     print(record)
     mappers, all_modules = configure_mappers(ploidy=2,
                                              extra_genotypes=3,
                                              num_samples=1)
     mapper = BatchOfInstances(mappers=mappers, all_modules=all_modules)
     out = mapper([record], tensor_cache=NoCache())
     print(out)
Пример #8
0
    def forward(self, input, tensor_cache=NoCache(), cuda=None):
        if cuda is None:
            cuda = next(self.parameters()).data.is_cuda

        return self.reduce_samples(
            [self.map_sequence(input['referenceBase'], tensor_cache=tensor_cache, cuda=cuda)] +
            [self.map_sequence(input['genomicSequenceContext'], tensor_cache=tensor_cache,
                               cuda=cuda)] +
            # each sample has a unique from (same across all counts), which gets mapped and concatenated
            # with the sample mapped reduction:
            [torch.cat([self.sample_mapper(sample, tensor_cache=tensor_cache, cuda=cuda),
                        self.map_sequence(sample['counts'][0]['fromSequence'], tensor_cache=tensor_cache,
                                          cuda=cuda)], dim=1)
             for sample in input['samples'][0:self.num_samples]], cuda)
Пример #9
0
    def forward(self, nwf_list, tensor_cache=NoCache(), cuda=None, nf_name="unknown"):

        if len(nwf_list) > 0:
            mapped_frequencies = torch.cat([
                self.map_number([nwf['number'] for nwf in nwf_list], tensor_cache, cuda),
                self.map_frequency([nwf['frequency'] for nwf in nwf_list], cuda)], dim=1)

            if use_mean_to_map_nwf:
                return self.mean_sequence(mapped_frequencies, cuda=cuda)
            else:
                return self.map_sequence(mapped_frequencies, cuda=cuda)
        else:
            variable = Variable(torch.zeros(1, self.embedding_size), requires_grad=True)
            if cuda:
                variable = variable.cuda(async=True)
            return variable
Пример #10
0
    def collect_inputs(self, c, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None):
        if phase == 0:
            c['indices'] = {}
            # the following tensors are batched:
            store_indices_in_message(mapper=self.map_gobyGenotypeIndex, message=c,
                                     indices=self.map_gobyGenotypeIndex.collect_inputs(
                                         values=[c['gobyGenotypeIndex']], phase=phase, cuda=cuda, batcher=batcher))

            store_indices_in_message(mapper=self.map_count, message=c, indices=self.map_count.collect_inputs(
                values=[c['genotypeCountForwardStrand'], c['genotypeCountReverseStrand']], phase=phase, cuda=cuda,
                batcher=batcher))

            store_indices_in_message(mapper=self.map_boolean, message=c, indices=self.map_boolean.collect_inputs(
                values=[c['isIndel'], c['matchesReference']],
                tensor_cache=tensor_cache, phase=phase,
                cuda=cuda, batcher=batcher))

            c['mapped-not-batched'] = {}
            # the following tensors are not batched, but computed once per instance, right here with direct_forward=True:
            c['mapped-not-batched'][id(self.map_sequence)] = self.cat_inputs(self.map_sequence,
                                                                             [c['fromSequence'], c['toSequence']],
                                                                             tensor_cache=tensor_cache, phase=phase,
                                                                             cuda=cuda,
                                                                             direct_forward=True)
            return []

        if phase == 1:
            mapped_goby_genotype_indices = batcher.get_forward_for_example(mapper=self.map_gobyGenotypeIndex,
                                                                           message=c).view(1, -1)

            mapped_counts = batcher.get_forward_for_example(mapper=self.map_count, message=c).view(1, -1)

            mapped_booleans = batcher.get_forward_for_example(mapper=self.map_boolean, message=c).view(1, -1)

            # mapped_sequences are not currently batchable, so we get the input from the prior phase:
            mapped_sequences = c['mapped-not-batched'][id(self.map_sequence)].view(1, -1)

            all_mapped = [mapped_goby_genotype_indices, mapped_counts, mapped_booleans, mapped_sequences]

            return batcher.store_inputs(mapper=self, inputs=self.reduce_batched(all_mapped))
Пример #11
0
            def collect_inputs(self,
                               message,
                               phase=0,
                               tensor_cache=NoCache(),
                               cuda=None,
                               batcher=None):
                assert isinstance(message, Message)

                if phase == 0:
                    message['indices'] = {}
                    store_indices_in_message(mapper=self.map_a,
                                             message=message,
                                             indices=self.map_a.collect_inputs(
                                                 values=[message['a']],
                                                 phase=phase,
                                                 tensor_cache=tensor_cache,
                                                 cuda=cuda,
                                                 batcher=batcher))
                    store_indices_in_message(mapper=self.map_b,
                                             message=message,
                                             indices=self.map_b.collect_inputs(
                                                 values=[message['b']],
                                                 phase=phase,
                                                 tensor_cache=tensor_cache,
                                                 cuda=cuda,
                                                 batcher=batcher))
                    # no indices yet for this mapper.
                    return []
                if phase == 1:

                    my_a = batcher.get_forward_for_example(
                        mapper=self.map_a,
                        example_indices=get_indices_in_message(
                            mapper=self.map_a, message=message))
                    my_b = batcher.get_forward_for_example(
                        mapper=self.map_b,
                        example_indices=get_indices_in_message(
                            mapper=self.map_b, message=message))
                    return batcher.store_inputs(mapper=self,
                                                inputs=(my_a + my_b))
Пример #12
0
    def test_integer_batching(self):
        mapper = IntegerModel(100, 2)
        batcher = Batcher()
        list_of_ints = [12, 3, 2]
        no_cache = NoCache()

        example_indices = []

        for value in list_of_ints:
            example_indices += batcher.collect_inputs(mapper, [value])

        print(batcher.get_batched_input(mapper))
        print(batcher.forward_batch(mapper))
        for example_index in example_indices:
            print(batcher.get_forward_for_example(mapper, example_index))

        self.assertEqual(str(mapper([12], no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 0).data))
        self.assertEqual(str(mapper([3], no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 1).data))
        self.assertEqual(str(mapper([2], no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 2).data))
Пример #13
0
    def map_sbi_messages(self, sbi_records, tensor_cache=NoCache(), cuda=None):
        batcher = Batcher()
        mapper = self.sbi_mapper.mappers.mapper_for_type("SampleInfo")
        if self.use_batching:
            features = None
            for phase in [0, 1, 2]:
                # print("Mapping phase "+str(phase))
                for record in sbi_records:
                    for sample in record['samples']:
                        batcher.collect_inputs(mapper=mapper,
                                               example=sample,
                                               phase=phase,
                                               cuda=cuda,
                                               tensor_cache=tensor_cache)
                features = batcher.forward_batch(mapper=mapper, phase=phase)

        else:
            # Create a new cache for each mini-batch because we cache embeddings:
            tensor_cache = TensorCache()
            features = self.sbi_mapper(sbi_records,
                                       tensor_cache=tensor_cache,
                                       cuda=self.use_cuda)
        return features
Пример #14
0
    def test_boolean_batching(self):
        mapper = map_Boolean()
        batcher = Batcher()
        list_of_bools = [True, True, False]
        no_cache = NoCache()

        example_indices = []

        for value in list_of_bools:
            example_indices += batcher.collect_inputs(mapper, value)

        print("batched input={}".format(batcher.get_batched_input(mapper)))
        print(batcher.forward_batch(mapper))
        for example_index in example_indices:
            print("example {} = {}".format(
                example_index,
                batcher.get_forward_for_example(mapper, example_index)))

        self.assertEqual(str(mapper(True, no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 0).data))
        self.assertEqual(str(mapper(True, no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 1).data))
        self.assertEqual(str(mapper(False, no_cache).data),
                         str(batcher.get_forward_for_example(mapper, 2).data))
Пример #15
0
    def collect_inputs(self,
                       mapper,
                       example,
                       phase=0,
                       tensor_cache=NoCache(),
                       cuda=None):
        """
        Use the mapper on an example to obtain inputs for a batch.
        :param mapper:
        :param example:
        :return: a dictionary where each key is the name of a batch of inputs, and the value the list of inputs in that batch.
        """
        id_mapper = id(mapper)
        self.initialize_mapper_variables(id_mapper)

        # keep track of all mappers used:
        self.all_mappers.update([mapper])
        input_indices_in_batch = mapper.collect_inputs(
            example,
            phase=phase,
            batcher=self,
            tensor_cache=tensor_cache,
            cuda=cuda)
        return input_indices_in_batch
Пример #16
0
 def forward(self, sequence_field, tensor_cache=NoCache(), cuda=None):
     return self.map_sequence(
         self.map_bases(list([self.base_to_index[b] for b in sequence_field]), tensor_cache=tensor_cache, cuda=cuda),
         cuda)
Пример #17
0
    def test_message(self):
        torch.manual_seed(1212)

        class Message:
            def __init__(self, a, b):
                self.dict = {'type': "Message", 'a': a, 'b': b}

            def __getitem__(self, key):
                return self.dict[key]

            def __setitem__(self, key, value):
                self.dict[key] = value

        def store_indices_in_message(mapper, message, indices):
            message['indices'][id(mapper)] = indices

        def get_indices_in_message(mapper, message):
            return message['indices'][id(mapper)]

        class MapMessage(StructuredEmbedding):
            def __init__(self):
                super().__init__(embedding_size=2)

                self.map_a = IntegerModel(distinct_numbers=100,
                                          embedding_size=2)
                self.map_b = IntegerModel(distinct_numbers=100,
                                          embedding_size=2)

            def forward(self, message, tensor_cache, cuda=None):
                return self.map_a([message['a']], tensor_cache,
                                  cuda) + self.map_b([message['b']],
                                                     tensor_cache, cuda)

            def collect_inputs(self,
                               message,
                               phase=0,
                               tensor_cache=NoCache(),
                               cuda=None,
                               batcher=None):
                assert isinstance(message, Message)

                if phase == 0:
                    message['indices'] = {}
                    store_indices_in_message(mapper=self.map_a,
                                             message=message,
                                             indices=self.map_a.collect_inputs(
                                                 values=[message['a']],
                                                 phase=phase,
                                                 tensor_cache=tensor_cache,
                                                 cuda=cuda,
                                                 batcher=batcher))
                    store_indices_in_message(mapper=self.map_b,
                                             message=message,
                                             indices=self.map_b.collect_inputs(
                                                 values=[message['b']],
                                                 phase=phase,
                                                 tensor_cache=tensor_cache,
                                                 cuda=cuda,
                                                 batcher=batcher))
                    # no indices yet for this mapper.
                    return []
                if phase == 1:

                    my_a = batcher.get_forward_for_example(
                        mapper=self.map_a,
                        example_indices=get_indices_in_message(
                            mapper=self.map_a, message=message))
                    my_b = batcher.get_forward_for_example(
                        mapper=self.map_b,
                        example_indices=get_indices_in_message(
                            mapper=self.map_b, message=message))
                    return batcher.store_inputs(mapper=self,
                                                inputs=(my_a + my_b))

            def forward_batch(self, batcher, phase=0):
                my_input = {}
                if phase == 0:
                    # do forward for batches, results are kept in batcher.
                    batched_a = batcher.forward_batch(mapper=self.map_a)
                    batched_b = batcher.forward_batch(mapper=self.map_b)
                    return (batched_a, batched_b)

                if phase == 1:
                    return batcher.get_batched_input(mapper=self)

        messages = [Message(12, 4), Message(1, 6)]
        batcher = Batcher()
        mapper = MapMessage()
        no_cache = NoCache()
        for message in messages:
            batcher.collect_inputs(mapper, message)

        batcher.forward_batch(mapper=mapper, phase=0)
        for message in messages:
            message.dict['indices'][id(mapper)] = batcher.collect_inputs(
                mapper, message, phase=1)

        batcher.forward_batch(mapper=mapper, phase=1)
        for message in messages:
            print(
                batcher.get_forward_for_example(
                    mapper,
                    example_indices=message.dict['indices'][id(mapper)]))

        self.assertEqual(
            str(mapper(messages[0], no_cache).data),
            str(
                batcher.get_forward_for_example(
                    mapper,
                    example_indices=messages[0].dict['indices'][id(
                        mapper)]).data))
        self.assertEqual(
            str(mapper(messages[1], no_cache).data),
            str(
                batcher.get_forward_for_example(
                    mapper,
                    example_indices=messages[1].dict['indices'][id(
                        mapper)]).data))
Пример #18
0
    def collect_inputs(self, values, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None):
        if phase == 0:

            return batcher.store_inputs(mapper=self, inputs=self.convert_list_of_floats(values))
        else:
            return []