def loss_lookahead_diff(model: NeuralTeleportationModel, data: Tensor, target: Tensor, metrics: TrainingMetrics, config: OptimalTeleportationTrainingConfig, **kwargs) -> Number: # Save the state of the model, prior to performing the lookahead state_dict = model.state_dict() # Initialize a new optimizer to perform lookahead optimizer = get_optimizer_from_model_and_config(model, config) optimizer.zero_grad() # Compute loss at the teleported point loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Take a step using the gradient at the teleported point loss.backward() # Compute loss after the optimizer step lookahead_loss = torch.stack([metrics.criterion(model(data_batch), target_batch) for data_batch, target_batch in zip(data, target)]).mean(dim=0) # Restore the state of the model prior to the lookahead model.load_state_dict(state_dict) # Compute the difference between the lookahead loss and the original loss return (loss - lookahead_loss).item()
def train_epoch(model: nn.Module, metrics: TrainingMetrics, optimizer: Optimizer, train_loader: DataLoader, epoch: int, device: str = 'cpu', progress_bar: bool = True, config: TrainingConfig = None, lr_scheduler=None) -> None: lr_scheduler_interval = None if config.lr_scheduler is not None: lr_scheduler_interval = config.lr_scheduler[1] # Init data structures to keep track of the metrics at each batch metrics_by_batch = {metric.__name__: [] for metric in metrics.metrics} metrics_by_batch.update(loss=[]) model.train() pbar = tqdm(enumerate(train_loader)) for batch_idx, (data, target) in pbar: if batch_idx == config.max_batch: break data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = metrics.criterion(output, target) metrics_by_batch["loss"].append(loss.item()) for metric in metrics.metrics: metrics_by_batch[metric.__name__].append(metric(output, target)) loss.backward() optimizer.step() if progress_bar: output = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * train_loader.batch_size, len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()) pbar.set_postfix_str(output) if lr_scheduler and lr_scheduler_interval == "step": lr_scheduler.step() pbar.update() pbar.close() # Log the mean of each metric at the end of the epoch if config is not None and config.logger is not None: reduced_metrics = { metric: mean(values_by_batch) for metric, values_by_batch in metrics_by_batch.items() } config.logger.log_metrics(reduced_metrics, epoch=epoch) for metric_name, value in reduced_metrics.items(): config.logger.add_scalar(metric_name, value, epoch)
def test(model: nn.Module, dataset: Dataset, metrics: TrainingMetrics, config: TrainingConfig, eval_mode: bool = True) -> Dict[str, Any]: test_loader = DataLoader(dataset, batch_size=config.batch_size) if eval_mode: model.eval() results = defaultdict(list) pbar = tqdm(enumerate(test_loader)) with torch.no_grad(): for i, (data, target) in pbar: if i == config.max_batch: break data, target = data.to(config.device), target.to(config.device) output = model(data) results['loss'].append(metrics.criterion(output, target).item()) if metrics is not None: batch_results = compute_metrics(metrics.metrics, y=target, y_hat=output, to_tensor=False) for k in batch_results.keys(): results[k].append(batch_results[k]) pbar.update() pbar.set_postfix(loss=pd.DataFrame(results['loss']).mean().values, accuracy=pd.DataFrame( results['accuracy']).mean().values) pbar.close() reduced_results = dict(pd.DataFrame(results).mean()) if config.logger is not None: config.logger.log_metrics(reduced_results, epoch=0) return reduced_results
import torchvision.transforms as transforms from neuralteleportation.metrics import accuracy from torch.nn.modules import Flatten import torch.nn as nn mnist_train = MNIST('/tmp', train=True, download=True, transform=transforms.ToTensor()) mnist_val = MNIST('/tmp', train=False, download=True, transform=transforms.ToTensor()) mnist_test = MNIST('/tmp', train=False, download=True, transform=transforms.ToTensor()) model = torch.nn.Sequential(Flatten(), nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)) config = TrainingConfig() metrics = TrainingMetrics(nn.CrossEntropyLoss(), [accuracy]) train(model, train_dataset=mnist_train, metrics=metrics, config=config, val_dataset=mnist_val) print(test(model, mnist_test, metrics, config))
type=str, default="resnet18COB", choices=get_model_names()) return parser.parse_args() if __name__ == '__main__': args = argument_parser() device = 'cuda' if cuda_avail() else 'cpu' trainset, valset, testset = get_dataset_subsets("cifar10") model = get_model("cifar10", args.model, device=device) metric = TrainingMetrics(criterion=nn.CrossEntropyLoss(), metrics=[accuracy]) config = LandscapeConfig(optimizer=(args.optimizer, { "lr": args.lr }), epochs=args.epochs, batch_size=args.batch_size, cob_range=args.cob_range, cob_sampling=args.cob_sampling, teleport_at=[args.epochs], device=device) if args.train: train(model, trainset, metric, config) a = torch.linspace(args.x[0], args.x[1], int(args.x[2])) param_o = model.get_params() model.random_teleport(args.cob_range, args.cob_sampling) param_t = model.get_params()
def run_experiment(config_path: Path, out_root: Path, data_root_dir: Path = None, save_weights=False, enable_comet=False) -> None: with open(str(config_path), 'r') as stream: config = yaml.safe_load(stream) # Setup metrics to compute metrics = TrainingMetrics(nn.CrossEntropyLoss(), [accuracy, accuracy_top5]) # Get training params all_training_params = config["training_params"] if isinstance( config["training_params"], list) else [config["training_params"]] # datasets for dataset_name in config["datasets"]: dataset_kwargs = {} if data_root_dir is not None: dataset_kwargs.update(root=data_root_dir, download=False) train_set, val_set, test_set = get_dataset_subsets( dataset_name, **dataset_kwargs) for training_params in all_training_params: # models for model_obj in config["models"]: model_obj = copy.deepcopy(model_obj) model_kwargs = {} model_name = model_obj if not isinstance(model_obj, str): model_name = model_obj.pop("cls") model_kwargs = model_obj # initalizers for initializer in config["initializers"]: config['initializer'] = initializer # optimizers for optimizer_kwargs in config["optimizers"]: optimizer_kwargs = copy.deepcopy(optimizer_kwargs) optimizer_name = optimizer_kwargs.pop("cls") lr_scheduler_kwargs = optimizer_kwargs.pop( "lr_scheduler", None) has_scheduler = False if lr_scheduler_kwargs: lr_scheduler_name = lr_scheduler_kwargs.pop("cls") lr_scheduler_interval = lr_scheduler_kwargs.pop( "interval", "epoch") if "lr_lambda" in lr_scheduler_kwargs.keys(): # WARNING: Take care of what you pass in as lr_lambda as the string is directly # evaluated # This is needed to transform lambda functions defined as strings to a python callable lr_scheduler_kwargs["lr_lambda"] = eval( lr_scheduler_kwargs.pop("lr_lambda")) if "steps_per_epoch" in lr_scheduler_kwargs.keys(): steps = len( train_set) / training_params['batch_size'] lr_scheduler_kwargs[ 'steps_per_epoch'] = math.floor( steps) if training_params[ 'drop_last_batch'] else math.ceil( steps) has_scheduler = True # teleport configuration for teleport, teleport_config_kwargs in config[ "teleportations"].items(): # w/o teleport configuration if teleport == "no_teleport": training_config_cls = __training_configs__[ "no_teleport"] # Ensure config collections are iterable, even if no config was defined # This is done to simplify the generation of the configuration matrix teleport_config_kwargs, teleport_mode_configs = {}, [(training_config_cls, {})] # w/ teleport configuration else: # teleport == "teleport" # Copy the config to play around with its content without affecting the config loaded # in memory teleport_config_kwargs = copy.deepcopy( teleport_config_kwargs) teleport_mode_obj = teleport_config_kwargs.pop( "mode") teleport_mode_configs = [] for teleport_mode, teleport_mode_config_kwargs in teleport_mode_obj.items( ): training_config_cls = __training_configs__[ teleport_mode] if teleport_mode == "optim": teleport_mode_config_kwargs[ "optim_metric"] = [ getattr( teleport_optim, metric) for metric in teleport_mode_config_kwargs. pop("metric") ] # Ensure config collections are iterable, even if no config was defined # This is done to simplify the generation of the configuration matrix if teleport_mode_config_kwargs is None: teleport_mode_config_kwargs = {} for teleport_mode_single_config_kwargs in dict_values_product( teleport_mode_config_kwargs): teleport_mode_configs.append( (training_config_cls, teleport_mode_single_config_kwargs )) # generate matrix of training configuration # (cartesian product of values for each training config kwarg) teleport_configs = dict_values_product( teleport_config_kwargs) config_matrix = itertools.product( teleport_configs, teleport_mode_configs) # Iterate over different possible training configurations for teleport_config_kwargs, ( training_config_cls, teleport_mode_config_kwargs ) in config_matrix: num_runs = int( config["runs_per_config"] ) if "runs_per_config" in config.keys() else 1 for _ in range(num_runs): experiment_path, experiment_id = make_experiment( out_root) if enable_comet: logger = MultiLogger([ DiskLogger(experiment_path), CometLogger(experiment_id) ]) else: logger = DiskLogger(experiment_path) training_config = training_config_cls( optimizer=(optimizer_name, optimizer_kwargs), lr_scheduler=(lr_scheduler_name, lr_scheduler_interval, lr_scheduler_kwargs) if has_scheduler else None, device='cuda' if cuda_avail() else 'cpu', logger=logger, **training_params, **teleport_config_kwargs, **teleport_mode_config_kwargs, ) # Run experiment (setting up a new model and optimizer for each experiment) model = get_model( dataset_name, model_name, device=training_config.device, initializer=initializer, **model_kwargs) optimizer = getattr(optim, optimizer_name)( model.parameters(), **optimizer_kwargs) lr_scheduler = None if has_scheduler: lr_scheduler = getattr( optim.lr_scheduler, lr_scheduler_name)( optimizer, **lr_scheduler_kwargs) run_model(model, training_config, metrics, train_set, test_set, val_set=val_set, optimizer=optimizer, lr_scheduler=lr_scheduler) if save_weights: torch.save( model.state_dict(), experiment_path / 'weights.pt')