def test_auto_neb(self): # Test AutoNEB procedure graph = MultiGraph() for idx, minimum in enumerate(self.minima): graph.add_node(idx + 1, **minimum) # Set up AutoNEB schedule spring_constant = float("inf") eval_config = EvalConfig(128) optim_config_1 = OptimConfig(10, SGD, {"lr": 0.1}, None, None, eval_config) optim_config_2 = OptimConfig(10, SGD, {"lr": 0.01}, None, None, eval_config) weight_decay = 0 subsample_pivot_count = 1 neb_configs = [ NEBConfig(spring_constant, weight_decay, equal, {"count": 2}, subsample_pivot_count, optim_config_1), NEBConfig(spring_constant, weight_decay, highest, {"count": 3, "key": "dense_train_loss"}, subsample_pivot_count, optim_config_1), NEBConfig(spring_constant, weight_decay, highest, {"count": 3, "key": "dense_train_loss"}, subsample_pivot_count, optim_config_2), NEBConfig(spring_constant, weight_decay, highest, {"count": 3, "key": "dense_train_loss"}, subsample_pivot_count, optim_config_2), ] auto_neb_config = AutoNEBConfig(neb_configs) self.assertEqual(auto_neb_config.cycle_count, len(neb_configs)) # Run AutoNEB auto_neb(1, 2, graph, self.model, auto_neb_config) self.assertEqual(len(graph.edges), auto_neb_config.cycle_count)
def test_neb(self): minima = self.minima[:2] neb_eval_config = EvalConfig(128) neb_optim_config = OptimConfig(10, Adam, {}, None, None, neb_eval_config) neb_config = NEBConfig(float("inf"), 1e-5, equal, {"count": 3}, 1, neb_optim_config) result = neb({ "path_coords": torch.cat([m["coords"].view(1, -1) for m in minima]), "target_distances": torch.ones(1) }, self.model, neb_config) required_keys = [ "path_coords", "target_distances", "saddle_train_error", "saddle_train_loss", "saddle_test_error", "saddle_test_loss", "dense_train_error", "dense_train_loss", "dense_test_error", "dense_test_loss", ] for key in required_keys: self.assertTrue(key in result, f"{key} not in result") value = result[key] self.assertFalse(torch.isnan(value).any().item(), f"{key} contains a NaN value") if "saddle_" in key: print(key, value.item())
def test_long_run(self): eggcarton = Eggcarton(2) model = ModelWrapper(eggcarton) minima = [find_minimum(model, OptimConfig(1000, SGD, {"lr": 0.1}, None, None, None)) for _ in range(2)] neb_optim_config = OptimConfig(1000, SGD, {"lr": 0.1}, None, None, None) neb_config = NEBConfig(float("inf"), 1e-5, equal, {"count": 20}, 1, neb_optim_config) neb({ "path_coords": torch.cat([m["coords"].view(1, -1) for m in minima]), "target_distances": torch.ones(1) }, model, neb_config)
def neb(previous_cycle_data, model: models.ModelWrapper, neb_config: config.NEBConfig) -> dict: # Initialise chain by inserting pivots start_path, target_distances = neb_config.insert_method( previous_cycle_data, **neb_config.insert_args) # Model neb_mod = neb_model.NEB(model, start_path, target_distances) neb_mod.adapt_to_config(neb_config) # Load optimiser optim_config = neb_config.optim_config # HACK: Optimisers only like parameters registered to autograd -> proper solution would keep several model instances as path and nudge their gradients after backward. neb_mod.path_coords.requires_grad_(True) optimiser = optim_config.algorithm_type( neb_mod.parameters(), **optim_config.algorithm_args) # type: optim.Optimizer # HACK END: We don't want autograd to mingle with our computations neb_mod.path_coords.requires_grad_(False) if "weight_decay" in optimiser.defaults: assert optimiser.defaults[ "weight_decay"] == 0, "NEB is not compatible with weight decay on the optimiser. Set weight decay on NEB instead." # Scheduler if optim_config.scheduler_type is not None: scheduler = optim_config.scheduler_type(optimiser, **optim_config.scheduler_args) else: scheduler = None # Optimise for _ in helper.pbar(range(optim_config.nsteps), "NEB"): neb_mod.apply(gradient=True) if scheduler is not None: scheduler.step() optimiser.step() result = { "path_coords": neb_mod.path_coords.clone().to("cpu"), "target_distances": target_distances.to("cpu") } # Analyse analysis = neb_mod.analyse(neb_config.subsample_pivot_count) saddle_analysis = { key: value for key, value in analysis.items() if "saddle_" in key } logger.debug(f"Found saddle: {saddle_analysis}.") result.update(analysis) return result