def test_incompatible_arrays(self): """Test with incompatible arrays.""" rescont = MessageResultContainer() msgs = [{"a": np.zeros((10,3))}, {"a": np.zeros((10,4))}] for msg in msgs: rescont.add_message(msg) pytest.raises(ValueError, rescont.get_message)
def test_mixed_dict(self): """Test msg being a dict containing an array.""" rescont = MessageResultContainer() msg1 = { "f": 2, "a": np.zeros((10,3), 'int'), "b": "aaa", "c": 1, } msg2 = { "a": np.ones((15,3), 'int'), "b": "bbb", "c": 3, "d": 1, } rescont.add_message(msg1) rescont.add_message(msg2) combined_msg = rescont.get_message() a = np.zeros((25,3), 'int') a[10:] = 1 reference_msg = {"a": a, "c": 4, "b": "aaabbb", "d": 1, "f": 2} assert np.all(reference_msg["a"] == reference_msg["a"]) combined_msg.pop("a") reference_msg.pop("a") assert combined_msg == reference_msg
def test_incompatible_arrays(self): """Test with incompatible arrays.""" rescont = MessageResultContainer() msgs = [{"a": np.zeros((10, 3))}, {"a": np.zeros((10, 4))}] for msg in msgs: rescont.add_message(msg) pytest.raises(ValueError, rescont.get_message)
def test_mixed_dict(self): """Test msg being a dict containing an array.""" rescont = MessageResultContainer() msg1 = { "f": 2, "a": np.zeros((10, 3), 'int'), "b": "aaa", "c": 1, } msg2 = { "a": np.ones((15, 3), 'int'), "b": "bbb", "c": 3, "d": 1, } rescont.add_message(msg1) rescont.add_message(msg2) combined_msg = rescont.get_message() a = np.zeros((25, 3), 'int') a[10:] = 1 reference_msg = {"a": a, "c": 4, "b": "aaabbb", "d": 1, "f": 2} assert np.all(reference_msg["a"] == reference_msg["a"]) combined_msg.pop("a") reference_msg.pop("a") assert combined_msg == reference_msg
def test_none_msg(self): """Test with one message being None.""" rescont = MessageResultContainer() msgs = [None, {"a": 1}, None, {"a": 2, "b": 1}, None] for msg in msgs: rescont.add_message(msg) msg = rescont.get_message() assert msg == {"a": 3, "b": 1}
def use_results(self, results): """Use the result from the scheduler. During parallel training this will start the next training phase. For parallel execution this will return the result, like a normal execute would. In addition it will join any forked nodes. results -- Iterable containing the results, normally the return value of scheduler.ResultContainer.get_results(). The individual results can be the return values of the tasks. """ if self.is_parallel_training: for result in results: self._flownode.join(result) # perform local stop_training with result check self._stop_training_hook() result = self._flownode.stop_training( self._stop_messages[self._i_train_node]) self._post_stop_training_hook() if (result is not None): target = result[2] # values of +1, -1 and EXIT_TARGET are tolerated if target not in [1, -1, EXIT_TARGET]: err = ("Target node not found in flow during " + "stop_training phase, last result: " + str(result)) raise BiFlowException(err) self._flownode.bi_reset() if self.verbose: print ("finished parallel training phase of node no. " + "%d in parallel flow" % (self._i_train_node+1)) if not self.flow[self._i_train_node].is_training(): self._i_train_node += 1 self._next_train_phase() elif self.is_parallel_executing: self._exec_data_iterator = None self._exec_msg_iterator = None self._exec_target_iterator = None y_results = [] msg_results = MessageResultContainer() # use internal flownode to join all biflownodes self._flownode = BiFlowNode(BiFlow(self.flow)) for result_tuple in results: result, forked_biflownode = result_tuple # consolidate results if isinstance(result, tuple) and (len(result) == 2): y, msg = result msg_results.add_message(msg) else: y = result if y is not None: try: y_results.append(y) except: err = "Some but not all y return values were None." raise BiFlowException(err) else: y_results = None # join biflownode if forked_biflownode is not None: self._flownode.join(forked_biflownode) # return results if y_results is not None: y_results = n.concatenate(y_results) return (y_results, msg_results.get_message()) else: err = "It seems that there are no results to retrieve." raise BiFlowException(err)
def use_results(self, results): """Use the result from the scheduler. During parallel training this will start the next training phase. For parallel execution this will return the result, like a normal execute would. In addition it will join any forked nodes. results -- Iterable containing the results, normally the return value of scheduler.ResultContainer.get_results(). The individual results can be the return values of the tasks. """ if self.is_parallel_training: for result in results: self._flownode.join(result) # perform local stop_training with result check self._stop_training_hook() result = self._flownode.stop_training( self._stop_messages[self._i_train_node]) self._post_stop_training_hook() if (result is not None): target = result[2] # values of +1, -1 and EXIT_TARGET are tolerated if target not in [1, -1, EXIT_TARGET]: err = ("Target node not found in flow during " + "stop_training phase, last result: " + str(result)) raise BiFlowException(err) self._flownode.bi_reset() if self.verbose: print("finished parallel training phase of node no. " + "%d in parallel flow" % (self._i_train_node + 1)) if not self.flow[self._i_train_node].is_training(): self._i_train_node += 1 self._next_train_phase() elif self.is_parallel_executing: self._exec_data_iterator = None self._exec_msg_iterator = None self._exec_target_iterator = None y_results = [] msg_results = MessageResultContainer() # use internal flownode to join all biflownodes self._flownode = BiFlowNode(BiFlow(self.flow)) for result_tuple in results: result, forked_biflownode = result_tuple # consolidate results if isinstance(result, tuple) and (len(result) == 2): y, msg = result msg_results.add_message(msg) else: y = result if y is not None: try: y_results.append(y) except: err = "Some but not all y return values were None." raise BiFlowException(err) else: y_results = None # join biflownode if forked_biflownode is not None: self._flownode.join(forked_biflownode) # return results if y_results is not None: y_results = n.concatenate(y_results) return (y_results, msg_results.get_message()) else: err = "It seems that there are no results to retrieve." raise BiFlowException(err)