def test_RepeaterPipe(): dumbo = one_way_iter_dummy() robert = pl.RepeaterPipe(dumbo) numbaz = Message() assert len(numbaz) == 0 for numba in robert: numbaz = numbaz.append(numba) assert len(numbaz) == robert.repetitions * 20 dumbo = one_way_dummy() robert = pl.RepeaterPipe(dumbo) numbaz = Message() robert.reset() i = 0 assert len(numbaz) == 0 while True: try: numbaz = numbaz.append(robert.__next__()) i += 1 if i > 1000: # If something goes horribly wrong, cancel test assert False except StopIteration: break assert len(numbaz) == robert.repetitions * 20
def test_append(): t = tensors v = vectors m1 = Message(t, v) m2 = Message(t, v) m3 = Message(t) m4 = TensorMessage(t) m5 = Message(pd.DataFrame(v)) m6 = pd.DataFrame(v) m0 = Message() assert(len(m0) == 0) m = m0.append(Message(t)) assert m == Message(t) m = m0.append(Message(v)) assert m == Message(v) m = m0.append(Message(t,v)) assert m == Message(t,v) m = m1.append(m2) assert len(m) == 6 assert m == Message({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}, {'c': np.array([7,8,9,7,8,9]), 'd': np.array([10,11,12,10,11,12])}) m = m3.append(t) assert len(m) == 6 assert m == Message({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m3.append(m3) assert len(m) == 6 assert m == Message({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m3.append(m4) assert len(m) == 6 assert m == Message({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m4.append(t) assert len(m) == 6 assert m == TensorMessage({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m4.append(m3) assert len(m) == 6 assert m == TensorMessage({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m4.append(m4) assert len(m) == 6 assert m == TensorMessage({'a': torch.Tensor([1,2,3,1,2,3]), 'b': torch.Tensor([4,5,6,4,5,6])}) m = m5.append(v) assert len(m) == 6 assert m == Message({'c': np.array([7,8,9,7,8,9]), 'd': np.array([10,11,12,10,11,12])}) m = m5.append(m5) assert len(m) == 6 assert m == Message({'c': np.array([7,8,9,7,8,9]), 'd': np.array([10,11,12,10,11,12])}) m = m5.append(m6) assert len(m) == 6 assert m == Message({'c': np.array([7,8,9,7,8,9]), 'd': np.array([10,11,12,10,11,12])}) # Test type conversions on appending to TensorMessage m = m4.append({'a': np.array([42]), 'b': np.array([24])}) assert len(m) == 4 assert m == TensorMessage({'a': torch.Tensor([1,2,3,42]), 'b': torch.Tensor([4,5,6,24])})
def test_junction(): dumbo = one_way_dummy() bumbo = one_way_dummy() gumbo = one_way_dummy() angry = jn.RandomHubJunction(components={ 'dumbo': dumbo, 'bumbo': bumbo, 'gumbo': gumbo }) angry.reset() assert angry._available_inputs == set(['dumbo', 'bumbo', 'gumbo']) numbaz = Message() counter = lambda l, i: sum([1 for x in l if x == i]) # Counts how often i appears in l while True: try: numbaz = numbaz.append(angry.__next__()) except StopIteration: break assert dumbo.count == 21 assert bumbo.count == 21 assert gumbo.count == 21 assert len(numbaz) == 60 counts = {i: counter(numbaz['count'], i) for i in range(20)} for count in counts.values(): assert count == 3 # Make sure each element showed up 3 times, corresponding to the 3 inputs mangry = jn.ClockworkHubJunction(components={ 'dumbo': dumbo, 'bumbo': bumbo, 'gumbo': gumbo }) bumbaz = Message() for nextone in mangry: bumbaz = bumbaz.append(nextone) assert dumbo.count == 21 assert bumbo.count == 21 assert gumbo.count == 21 assert len(bumbaz) == 60 for x, i in zip(bumbaz, itertools.count()): assert x['count'][0] == math.floor(i / 3) counts = {i: counter(bumbaz['count'], i) for i in range(20)} for count in counts.values(): assert count == 3 # Make sure each element showed up 3 times, corresponding to the 3 inputs
class ModelSaverMetric(Metric): def __init__(self, output_transform=lambda x: x, log_interval=100): self.model_state = Message() Metric.__init__(self, output_transform=output_transform) self.log_interval = log_interval def iteration_completed(self, engine): iter = (engine.state.iteration - 1) if iter % self.log_interval == 0: current_state = Message.from_objects( deepcopy(engine.state.output['state'])) current_state['iteration'] = [iter] self.model_state = self.model_state.append(current_state) def compute(self): # Return most recent model state l = len(self.model_state) return self.model_state[l - 1] def reset(self): pass def update(self, output): pass