def main(run_dir, epoch, device, attack, epsilon, min_end_time, max_end_time, tol, batches, pgd_step_size, pgd_num_steps, pgd_random_start, _log): config = read_config(run_dir) _log.info(f"Read config from {run_dir}") model_ing = import_source(run_dir, "model_ingredient") model = model_ing.make_model(**{ **config['model'], 'device': device }, _log=_log) path = get_model_path(run_dir, epoch) if isinstance(model, nn.DataParallel): model.module.load_state_dict(torch.load(path)) else: model.load_state_dict(torch.load(path, map_location=device)) model = model.eval() _log.info(f"Loaded state dict from {path}") if hasattr(model, "odeblock"): _log.info(f"Updated times to {[min_end_time, max_end_time]}") model.odeblock.min_end_time = min_end_time model.odeblock.max_end_time = max_end_time model.odeblock.atol = tol model.odeblock.rtol = tol data_ing = import_source(run_dir, "data_ingredient") dset, tl, vl, test_loader = data_ing.make_dataloaders(**{ **config['dataset'], 'device': device }, _log=_log) if attack == 'pgd': attack_fn = partial(ATTACKS[attack], epsilon=epsilon, step_size=pgd_step_size, num_steps=pgd_num_steps, random_start=pgd_random_start) else: attack_fn = partial(ATTACKS[attack], epsilon=epsilon) adv_test_loader = adv.AdversarialLoader(model, test_loader, attack_fn) adv_test_loader = TruncIterator(adv_test_loader, batches) _log.info("Testing model...") test_loss, test_acc = validate(model, adv_test_loader) output = f"Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}" _log.info(output) return output
def main(run_dir, epoch, device, attack, epsilon, end_time_start, end_time_end, num_times, tol, batches, _log): config = read_config(run_dir) _log.info(f"Read config from {run_dir}") model_ing = import_source(run_dir, "model_ingredient") model = model_ing.make_model(**{ **config['model'], 'device': device }, _log=_log) path = get_model_path(run_dir, epoch) if isinstance(model, nn.DataParallel): model.module.load_state_dict(torch.load(path)) else: model.load_state_dict(torch.load(path, map_location=device)) model = model.eval() _log.info(f"Loaded state dict from {path}") if hasattr(model, "odeblock"): model.odeblock.atol = tol model.odeblock.rtol = tol data_ing = import_source(run_dir, "data_ingredient") dset, tl, vl, test_loader = data_ing.make_dataloaders(**{ **config['dataset'], 'device': device }, _log=_log) attack = partial(ATTACKS[attack], epsilon=epsilon) adv_test_loader = adv.AdversarialLoader(model, test_loader, attack) adv_test_loader = TruncIterator(adv_test_loader, batches) _log.info("Testing model...") for end_time in np.linspace(end_time_start, end_time_end, num_times): model.odeblock.min_end_time = end_time model.odeblock.max_end_time = end_time model.odeblock.t = torch.tensor([0, end_time]).float() test_loss, test_acc = validate(model, adv_test_loader) ex.log_scalar("test_loss", test_loss) ex.log_scalar("test_acc", test_acc) ex.log_scalar("end_time", end_time) _log.info( f"end_time={end_time:.6f}, Test loss = {test_loss:.6f}, Test accuracy = {test_acc:.4f}" )