def test_map_samples_with_model(self): sbi_mappers_configuration = configure_mappers(ploidy=2, extra_genotypes=2, num_samples=1, count_dim=16, sample_dim=32) sbi_mapper = BatchOfInstances(*sbi_mappers_configuration) json_string = '{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4}]}]}' # determine feature size: import ujson record = ujson.loads(json_string) mapped_features_size = sbi_mapper([record, record], tensor_cache=NoCache(), cuda=False).size(1) problem = StructuredSbiGenotypingProblem( mini_batch_size=2, code="struct_genotyping:/data/struct/CNG-NA12878-softmax-indels") output_size = problem.output_size("softmaxGenotype") parser = define_train_auto_encoder_parser() args = parser.parse_args() model = StructGenotypingModel(args, sbi_mapper, mapped_features_size, output_size, use_cuda=False, use_batching=True) print(model.map_sbi_messages(sbi_records=[record] * 2, cuda=True))
def collect_inputs(self, sample, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): """Collect input data for all counts in this sample. """ if 'indices' not in sample: sample['indices'] = {} observed_counts = self.get_observed_counts(sample)[0:self.num_counts] if phase < 2: for count in observed_counts: store_indices_in_message(mapper=self.count_mapper, message=count, indices= self.count_mapper.collect_inputs(count, tensor_cache=tensor_cache, cuda=cuda, phase=phase, batcher=batcher)) return [] if phase == 2: list_mapped_counts = [] for count in observed_counts: mapped_count = batcher.get_forward_for_example(self.count_mapper, example_indices=count['indices'][id(self.count_mapper)]) list_mapped_counts += [mapped_count] while len(list_mapped_counts) < self.num_counts: # pad the list with zeros: variable = Variable(torch.zeros(*list_mapped_counts[0].size()), requires_grad=True) if cuda: variable = variable.cuda(async=True) list_mapped_counts += [variable] cat_list_mapped_counts = torch.cat(list_mapped_counts, dim=-1) store_indices_in_message(mapper=self.reduce_counts, message=sample, indices=self.reduce_counts.collect_inputs(cat_list_mapped_counts, tensor_cache=tensor_cache, cuda=cuda, phase=phase, batcher=batcher)) return []
def create_struct_model(self, problem, args, use_cuda): sbi_mappers_configuration = configure_mappers( ploidy=args.struct_ploidy, extra_genotypes=args.struct_extra_genotypes, num_samples=1, count_dim=args.struct_count_dim, sample_dim=args.struct_sample_dim, use_cuda=use_cuda) sbi_mapper = BatchOfInstances(*sbi_mappers_configuration) # determine feature size: import ujson record = ujson.loads(sbi_json_string) if self.use_cuda: sbi_mapper.cuda() mapped_features_size = sbi_mapper([record], tensor_cache=NoCache(), cuda=self.use_cuda).size(1) output_size = problem.output_size("softmaxGenotype") model = StructGenotypingModel(args, sbi_mapper, mapped_features_size, output_size, self.use_cuda, args.use_batching) print(model) return model
def collect_inputs(self, nwf_list, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): if phase == 0: # the following tensors are batched: nwf_list['indices'] = {} store_indices_in_message(mapper=self.map_number, message=nwf_list, indices=self.map_number.collect_inputs( values=nwf_list['number'], phase=phase, cuda=cuda, batcher=batcher)) store_indices_in_message(mapper=self.map_frequency, message=nwf_list, indices=self.map_frequency.collect_inputs( values=nwf_list['frequency'], phase=phase, cuda=cuda, batcher=batcher))
def test_mapper(self): json_string = '{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4}]}]}' import ujson record = ujson.loads(json_string) print(record) mappers, all_modules = configure_mappers(ploidy=2, extra_genotypes=3, num_samples=1) mapper = BatchOfInstances(mappers=mappers, all_modules=all_modules) out = mapper([record], tensor_cache=NoCache()) print(out)
def cat_inputs(self, mapper, list_of_values, tensor_cache=NoCache(), phase=0, cuda=False, direct_forward=False): mapper_id = id(mapper) results = {mapper_id: []} for value in list_of_values: if direct_forward: mapped = mapper(value, tensor_cache=tensor_cache, cuda=cuda) else: mapped = mapper.collect_inputs(value, tensor_cache=tensor_cache, phase=phase, cuda=cuda) results[mapper_id] += [mapped] return torch.cat(results[mapper_id], dim=1)
def test_mapper_number_freqs(self): json_string='{"type":"BaseInformation","referenceBase":"A","genomicSequenceContext":"GCAGATATACTTCACAGCCCACGCTGACTCTGCCAAGCACA","samples":[{"type":"SampleInfo","counts":[{"type":"CountInfo","matchesReference":true,"isCalled":true,"isIndel":false,"fromSequence":"A","toSequence":"A","genotypeCountForwardStrand":7,"genotypeCountReverseStrand":32,"gobyGenotypeIndex":0,"qualityScoresForwardStrand":[{"type":"NumberWithFrequency","frequency":7,"number":40}],"qualityScoresReverseStrand":[{"type":"NumberWithFrequency","frequency":32,"number":40}],"readIndicesForwardStrand":[{"type":"NumberWithFrequency","frequency":1,"number":23},{"type":"NumberWithFrequency","frequency":1,"number":30},{"type":"NumberWithFrequency","frequency":5,"number":34}],"readIndicesReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":6},{"type":"NumberWithFrequency","frequency":1,"number":22},{"type":"NumberWithFrequency","frequency":1,"number":28},{"type":"NumberWithFrequency","frequency":1,"number":31},{"type":"NumberWithFrequency","frequency":1,"number":34},{"type":"NumberWithFrequency","frequency":1,"number":35},{"type":"NumberWithFrequency","frequency":1,"number":44},{"type":"NumberWithFrequency","frequency":1,"number":145},{"type":"NumberWithFrequency","frequency":1,"number":150},{"type":"NumberWithFrequency","frequency":5,"number":151},{"type":"NumberWithFrequency","frequency":2,"number":171},{"type":"NumberWithFrequency","frequency":4,"number":172}],"queryPositions":[{"type":"NumberWithFrequency","frequency":39,"number":0}],"pairFlags":[{"type":"NumberWithFrequency","frequency":6,"number":16},{"type":"NumberWithFrequency","frequency":14,"number":83},{"type":"NumberWithFrequency","frequency":6,"number":99},{"type":"NumberWithFrequency","frequency":12,"number":147},{"type":"NumberWithFrequency","frequency":1,"number":163}],"distancesToReadVariationsForwardStrand":[{"type":"NumberWithFrequency","frequency":2,"number":-70},{"type":"NumberWithFrequency","frequency":4,"number":-29}],"distancesToReadVariationsReverseStrand":[{"type":"NumberWithFrequency","frequency":2,"number":-24},{"type":"NumberWithFrequency","frequency":1,"number":-15},{"type":"NumberWithFrequency","frequency":1,"number":-2},{"type":"NumberWithFrequency","frequency":1,"number":12},{"type":"NumberWithFrequency","frequency":1,"number":13},{"type":"NumberWithFrequency","frequency":1,"number":15},{"type":"NumberWithFrequency","frequency":13,"number":29},{"type":"NumberWithFrequency","frequency":1,"number":49},{"type":"NumberWithFrequency","frequency":3,"number":62},{"type":"NumberWithFrequency","frequency":9,"number":70},{"type":"NumberWithFrequency","frequency":1,"number":73}],"distanceToStartOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":18},{"type":"NumberWithFrequency","frequency":1,"number":23},{"type":"NumberWithFrequency","frequency":1,"number":26},{"type":"NumberWithFrequency","frequency":1,"number":30},{"type":"NumberWithFrequency","frequency":30,"number":33},{"type":"NumberWithFrequency","frequency":5,"number":34}],"distanceToEndOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":6},{"type":"NumberWithFrequency","frequency":1,"number":22},{"type":"NumberWithFrequency","frequency":1,"number":28},{"type":"NumberWithFrequency","frequency":1,"number":31},{"type":"NumberWithFrequency","frequency":1,"number":34},{"type":"NumberWithFrequency","frequency":2,"number":35},{"type":"NumberWithFrequency","frequency":1,"number":44},{"type":"NumberWithFrequency","frequency":1,"number":48},{"type":"NumberWithFrequency","frequency":1,"number":49},{"type":"NumberWithFrequency","frequency":1,"number":50},{"type":"NumberWithFrequency","frequency":1,"number":52},{"type":"NumberWithFrequency","frequency":1,"number":62},{"type":"NumberWithFrequency","frequency":1,"number":63},{"type":"NumberWithFrequency","frequency":1,"number":68},{"type":"NumberWithFrequency","frequency":2,"number":75},{"type":"NumberWithFrequency","frequency":2,"number":76},{"type":"NumberWithFrequency","frequency":1,"number":81},{"type":"NumberWithFrequency","frequency":1,"number":83},{"type":"NumberWithFrequency","frequency":1,"number":88},{"type":"NumberWithFrequency","frequency":1,"number":89},{"type":"NumberWithFrequency","frequency":1,"number":100},{"type":"NumberWithFrequency","frequency":1,"number":104},{"type":"NumberWithFrequency","frequency":1,"number":109},{"type":"NumberWithFrequency","frequency":1,"number":111},{"type":"NumberWithFrequency","frequency":1,"number":117},{"type":"NumberWithFrequency","frequency":1,"number":118},{"type":"NumberWithFrequency","frequency":1,"number":121},{"type":"NumberWithFrequency","frequency":1,"number":125},{"type":"NumberWithFrequency","frequency":1,"number":128},{"type":"NumberWithFrequency","frequency":1,"number":133},{"type":"NumberWithFrequency","frequency":2,"number":138},{"type":"NumberWithFrequency","frequency":4,"number":139}]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"C","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":1,"gobyGenotypeIndex":2,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":7}],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":115}],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":60}],"numVariationsInReads":[{"type":"NumberWithFrequency","frequency":1,"number":2}],"insertSizes":[{"type":"NumberWithFrequency","frequency":1,"number":-301}],"targetAlignedLengths":[{"type":"NumberWithFrequency","frequency":2,"number":148}],"queryAlignedLengths":[{"type":"NumberWithFrequency","frequency":1,"number":148}],"queryPositions":[{"type":"NumberWithFrequency","frequency":1,"number":0}],"pairFlags":[{"type":"NumberWithFrequency","frequency":1,"number":147}],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[{"type":"NumberWithFrequency","frequency":1,"number":-29},{"type":"NumberWithFrequency","frequency":1,"number":0}],"distanceToStartOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":33}],"distanceToEndOfRead":[{"type":"NumberWithFrequency","frequency":1,"number":115}]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"T","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":1,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"G","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":3,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]},{"type":"CountInfo","matchesReference":false,"isCalled":false,"isIndel":false,"fromSequence":"A","toSequence":"N","genotypeCountForwardStrand":0,"genotypeCountReverseStrand":0,"gobyGenotypeIndex":4,"qualityScoresForwardStrand":[],"qualityScoresReverseStrand":[],"readIndicesForwardStrand":[],"readIndicesReverseStrand":[],"readMappingQualityForwardStrand":[],"readMappingQualityReverseStrand":[],"numVariationsInReads":[],"insertSizes":[],"targetAlignedLengths":[],"queryAlignedLengths":[],"queryPositions":[],"pairFlags":[],"distancesToReadVariationsForwardStrand":[],"distancesToReadVariationsReverseStrand":[],"distanceToStartOfRead":[],"distanceToEndOfRead":[]}]}]}' \ '' import ujson record = ujson.loads(json_string) print(record) mappers, all_modules = configure_mappers(ploidy=2, extra_genotypes=3, num_samples=1) mapper = BatchOfInstances(mappers=mappers, all_modules=all_modules) out = mapper([record], tensor_cache=NoCache()) print(out)
def forward(self, input, tensor_cache=NoCache(), cuda=None): if cuda is None: cuda = next(self.parameters()).data.is_cuda return self.reduce_samples( [self.map_sequence(input['referenceBase'], tensor_cache=tensor_cache, cuda=cuda)] + [self.map_sequence(input['genomicSequenceContext'], tensor_cache=tensor_cache, cuda=cuda)] + # each sample has a unique from (same across all counts), which gets mapped and concatenated # with the sample mapped reduction: [torch.cat([self.sample_mapper(sample, tensor_cache=tensor_cache, cuda=cuda), self.map_sequence(sample['counts'][0]['fromSequence'], tensor_cache=tensor_cache, cuda=cuda)], dim=1) for sample in input['samples'][0:self.num_samples]], cuda)
def forward(self, nwf_list, tensor_cache=NoCache(), cuda=None, nf_name="unknown"): if len(nwf_list) > 0: mapped_frequencies = torch.cat([ self.map_number([nwf['number'] for nwf in nwf_list], tensor_cache, cuda), self.map_frequency([nwf['frequency'] for nwf in nwf_list], cuda)], dim=1) if use_mean_to_map_nwf: return self.mean_sequence(mapped_frequencies, cuda=cuda) else: return self.map_sequence(mapped_frequencies, cuda=cuda) else: variable = Variable(torch.zeros(1, self.embedding_size), requires_grad=True) if cuda: variable = variable.cuda(async=True) return variable
def collect_inputs(self, c, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): if phase == 0: c['indices'] = {} # the following tensors are batched: store_indices_in_message(mapper=self.map_gobyGenotypeIndex, message=c, indices=self.map_gobyGenotypeIndex.collect_inputs( values=[c['gobyGenotypeIndex']], phase=phase, cuda=cuda, batcher=batcher)) store_indices_in_message(mapper=self.map_count, message=c, indices=self.map_count.collect_inputs( values=[c['genotypeCountForwardStrand'], c['genotypeCountReverseStrand']], phase=phase, cuda=cuda, batcher=batcher)) store_indices_in_message(mapper=self.map_boolean, message=c, indices=self.map_boolean.collect_inputs( values=[c['isIndel'], c['matchesReference']], tensor_cache=tensor_cache, phase=phase, cuda=cuda, batcher=batcher)) c['mapped-not-batched'] = {} # the following tensors are not batched, but computed once per instance, right here with direct_forward=True: c['mapped-not-batched'][id(self.map_sequence)] = self.cat_inputs(self.map_sequence, [c['fromSequence'], c['toSequence']], tensor_cache=tensor_cache, phase=phase, cuda=cuda, direct_forward=True) return [] if phase == 1: mapped_goby_genotype_indices = batcher.get_forward_for_example(mapper=self.map_gobyGenotypeIndex, message=c).view(1, -1) mapped_counts = batcher.get_forward_for_example(mapper=self.map_count, message=c).view(1, -1) mapped_booleans = batcher.get_forward_for_example(mapper=self.map_boolean, message=c).view(1, -1) # mapped_sequences are not currently batchable, so we get the input from the prior phase: mapped_sequences = c['mapped-not-batched'][id(self.map_sequence)].view(1, -1) all_mapped = [mapped_goby_genotype_indices, mapped_counts, mapped_booleans, mapped_sequences] return batcher.store_inputs(mapper=self, inputs=self.reduce_batched(all_mapped))
def collect_inputs(self, message, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): assert isinstance(message, Message) if phase == 0: message['indices'] = {} store_indices_in_message(mapper=self.map_a, message=message, indices=self.map_a.collect_inputs( values=[message['a']], phase=phase, tensor_cache=tensor_cache, cuda=cuda, batcher=batcher)) store_indices_in_message(mapper=self.map_b, message=message, indices=self.map_b.collect_inputs( values=[message['b']], phase=phase, tensor_cache=tensor_cache, cuda=cuda, batcher=batcher)) # no indices yet for this mapper. return [] if phase == 1: my_a = batcher.get_forward_for_example( mapper=self.map_a, example_indices=get_indices_in_message( mapper=self.map_a, message=message)) my_b = batcher.get_forward_for_example( mapper=self.map_b, example_indices=get_indices_in_message( mapper=self.map_b, message=message)) return batcher.store_inputs(mapper=self, inputs=(my_a + my_b))
def test_integer_batching(self): mapper = IntegerModel(100, 2) batcher = Batcher() list_of_ints = [12, 3, 2] no_cache = NoCache() example_indices = [] for value in list_of_ints: example_indices += batcher.collect_inputs(mapper, [value]) print(batcher.get_batched_input(mapper)) print(batcher.forward_batch(mapper)) for example_index in example_indices: print(batcher.get_forward_for_example(mapper, example_index)) self.assertEqual(str(mapper([12], no_cache).data), str(batcher.get_forward_for_example(mapper, 0).data)) self.assertEqual(str(mapper([3], no_cache).data), str(batcher.get_forward_for_example(mapper, 1).data)) self.assertEqual(str(mapper([2], no_cache).data), str(batcher.get_forward_for_example(mapper, 2).data))
def map_sbi_messages(self, sbi_records, tensor_cache=NoCache(), cuda=None): batcher = Batcher() mapper = self.sbi_mapper.mappers.mapper_for_type("SampleInfo") if self.use_batching: features = None for phase in [0, 1, 2]: # print("Mapping phase "+str(phase)) for record in sbi_records: for sample in record['samples']: batcher.collect_inputs(mapper=mapper, example=sample, phase=phase, cuda=cuda, tensor_cache=tensor_cache) features = batcher.forward_batch(mapper=mapper, phase=phase) else: # Create a new cache for each mini-batch because we cache embeddings: tensor_cache = TensorCache() features = self.sbi_mapper(sbi_records, tensor_cache=tensor_cache, cuda=self.use_cuda) return features
def test_boolean_batching(self): mapper = map_Boolean() batcher = Batcher() list_of_bools = [True, True, False] no_cache = NoCache() example_indices = [] for value in list_of_bools: example_indices += batcher.collect_inputs(mapper, value) print("batched input={}".format(batcher.get_batched_input(mapper))) print(batcher.forward_batch(mapper)) for example_index in example_indices: print("example {} = {}".format( example_index, batcher.get_forward_for_example(mapper, example_index))) self.assertEqual(str(mapper(True, no_cache).data), str(batcher.get_forward_for_example(mapper, 0).data)) self.assertEqual(str(mapper(True, no_cache).data), str(batcher.get_forward_for_example(mapper, 1).data)) self.assertEqual(str(mapper(False, no_cache).data), str(batcher.get_forward_for_example(mapper, 2).data))
def collect_inputs(self, mapper, example, phase=0, tensor_cache=NoCache(), cuda=None): """ Use the mapper on an example to obtain inputs for a batch. :param mapper: :param example: :return: a dictionary where each key is the name of a batch of inputs, and the value the list of inputs in that batch. """ id_mapper = id(mapper) self.initialize_mapper_variables(id_mapper) # keep track of all mappers used: self.all_mappers.update([mapper]) input_indices_in_batch = mapper.collect_inputs( example, phase=phase, batcher=self, tensor_cache=tensor_cache, cuda=cuda) return input_indices_in_batch
def forward(self, sequence_field, tensor_cache=NoCache(), cuda=None): return self.map_sequence( self.map_bases(list([self.base_to_index[b] for b in sequence_field]), tensor_cache=tensor_cache, cuda=cuda), cuda)
def test_message(self): torch.manual_seed(1212) class Message: def __init__(self, a, b): self.dict = {'type': "Message", 'a': a, 'b': b} def __getitem__(self, key): return self.dict[key] def __setitem__(self, key, value): self.dict[key] = value def store_indices_in_message(mapper, message, indices): message['indices'][id(mapper)] = indices def get_indices_in_message(mapper, message): return message['indices'][id(mapper)] class MapMessage(StructuredEmbedding): def __init__(self): super().__init__(embedding_size=2) self.map_a = IntegerModel(distinct_numbers=100, embedding_size=2) self.map_b = IntegerModel(distinct_numbers=100, embedding_size=2) def forward(self, message, tensor_cache, cuda=None): return self.map_a([message['a']], tensor_cache, cuda) + self.map_b([message['b']], tensor_cache, cuda) def collect_inputs(self, message, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): assert isinstance(message, Message) if phase == 0: message['indices'] = {} store_indices_in_message(mapper=self.map_a, message=message, indices=self.map_a.collect_inputs( values=[message['a']], phase=phase, tensor_cache=tensor_cache, cuda=cuda, batcher=batcher)) store_indices_in_message(mapper=self.map_b, message=message, indices=self.map_b.collect_inputs( values=[message['b']], phase=phase, tensor_cache=tensor_cache, cuda=cuda, batcher=batcher)) # no indices yet for this mapper. return [] if phase == 1: my_a = batcher.get_forward_for_example( mapper=self.map_a, example_indices=get_indices_in_message( mapper=self.map_a, message=message)) my_b = batcher.get_forward_for_example( mapper=self.map_b, example_indices=get_indices_in_message( mapper=self.map_b, message=message)) return batcher.store_inputs(mapper=self, inputs=(my_a + my_b)) def forward_batch(self, batcher, phase=0): my_input = {} if phase == 0: # do forward for batches, results are kept in batcher. batched_a = batcher.forward_batch(mapper=self.map_a) batched_b = batcher.forward_batch(mapper=self.map_b) return (batched_a, batched_b) if phase == 1: return batcher.get_batched_input(mapper=self) messages = [Message(12, 4), Message(1, 6)] batcher = Batcher() mapper = MapMessage() no_cache = NoCache() for message in messages: batcher.collect_inputs(mapper, message) batcher.forward_batch(mapper=mapper, phase=0) for message in messages: message.dict['indices'][id(mapper)] = batcher.collect_inputs( mapper, message, phase=1) batcher.forward_batch(mapper=mapper, phase=1) for message in messages: print( batcher.get_forward_for_example( mapper, example_indices=message.dict['indices'][id(mapper)])) self.assertEqual( str(mapper(messages[0], no_cache).data), str( batcher.get_forward_for_example( mapper, example_indices=messages[0].dict['indices'][id( mapper)]).data)) self.assertEqual( str(mapper(messages[1], no_cache).data), str( batcher.get_forward_for_example( mapper, example_indices=messages[1].dict['indices'][id( mapper)]).data))
def collect_inputs(self, values, phase=0, tensor_cache=NoCache(), cuda=None, batcher=None): if phase == 0: return batcher.store_inputs(mapper=self, inputs=self.convert_list_of_floats(values)) else: return []