def test_hdf5_load(self): hdf5_summary = "celltype_summary.hdf5" orig_ast = Astir(self.expr, self.marker_dict) orig_ast.fit_type(max_epochs=5, n_init=1, n_init_epochs=1) orig_ast.fit_state(max_epochs=5, n_init=1, n_init_epochs=1) orig_ast.save_models(hdf5_summary) new_ast = Astir() new_ast.load_model(hdf5_summary) orig_type_run_info = orig_ast.get_type_run_info() orig_state_run_info = orig_ast.get_state_run_info() new_type_run_info = new_ast.get_type_run_info() new_state_run_info = new_ast.get_state_run_info() for key, val in orig_type_run_info.items(): if val != new_type_run_info[key]: raise AssertionError( "variable " + key + " is different in original model and loaded model") for key, val in orig_state_run_info.items(): if val != new_state_run_info[key]: raise AssertionError( "variable " + key + " is different in original model and loaded model") orig_type_losses = orig_ast.get_type_losses() orig_state_losses = orig_ast.get_state_losses() new_type_losses = new_ast.get_type_losses() new_state_losses = new_ast.get_state_losses() if not (all(orig_type_losses == new_type_losses) and all(orig_state_losses == new_state_losses)): raise AssertionError( "loss is different in original model and loaded model")
def test_cellstate_diff_seed_diff_result(self): """Test whether the loss after one epoch one two different models with the different random seed have different losses after one epoch """ warnings.filterwarnings("ignore", category=UserWarning) model1 = Astir( input_expr=self.expr, marker_dict=self.marker_dict, design=None, random_seed=42, ) model2 = Astir( input_expr=self.expr, marker_dict=self.marker_dict, design=None, random_seed=1234, ) model1.fit_state(max_epochs=5) model1_loss = model1.get_state_losses() model2.fit_state(max_epochs=5) model2_loss = model2.get_state_losses() self.assertFalse(np.abs(model1_loss - model2_loss)[-1] < 1e-6)