예제 #1
0
def loss_lookahead_diff(model: NeuralTeleportationModel, data: Tensor, target: Tensor,
                        metrics: TrainingMetrics, config: OptimalTeleportationTrainingConfig, **kwargs) -> Number:
    # Save the state of the model, prior to performing the lookahead
    state_dict = model.state_dict()

    # Initialize a new optimizer to perform lookahead
    optimizer = get_optimizer_from_model_and_config(model, config)
    optimizer.zero_grad()

    # Compute loss at the teleported point
    loss = torch.stack([metrics.criterion(model(data_batch), target_batch)
                        for data_batch, target_batch in zip(data, target)]).mean(dim=0)

    # Take a step using the gradient at the teleported point
    loss.backward()

    # Compute loss after the optimizer step
    lookahead_loss = torch.stack([metrics.criterion(model(data_batch), target_batch)
                                  for data_batch, target_batch in zip(data, target)]).mean(dim=0)

    # Restore the state of the model prior to the lookahead
    model.load_state_dict(state_dict)

    # Compute the difference between the lookahead loss and the original loss
    return (loss - lookahead_loss).item()
예제 #2
0
def train_epoch(model: nn.Module,
                metrics: TrainingMetrics,
                optimizer: Optimizer,
                train_loader: DataLoader,
                epoch: int,
                device: str = 'cpu',
                progress_bar: bool = True,
                config: TrainingConfig = None,
                lr_scheduler=None) -> None:
    lr_scheduler_interval = None
    if config.lr_scheduler is not None:
        lr_scheduler_interval = config.lr_scheduler[1]

    # Init data structures to keep track of the metrics at each batch
    metrics_by_batch = {metric.__name__: [] for metric in metrics.metrics}
    metrics_by_batch.update(loss=[])

    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, (data, target) in pbar:
        if batch_idx == config.max_batch:
            break
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = metrics.criterion(output, target)
        metrics_by_batch["loss"].append(loss.item())
        for metric in metrics.metrics:
            metrics_by_batch[metric.__name__].append(metric(output, target))
        loss.backward()
        optimizer.step()
        if progress_bar:
            output = 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, (batch_idx + 1) * train_loader.batch_size,
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item())
            pbar.set_postfix_str(output)
        if lr_scheduler and lr_scheduler_interval == "step":
            lr_scheduler.step()
    pbar.update()
    pbar.close()

    # Log the mean of each metric at the end of the epoch
    if config is not None and config.logger is not None:
        reduced_metrics = {
            metric: mean(values_by_batch)
            for metric, values_by_batch in metrics_by_batch.items()
        }
        config.logger.log_metrics(reduced_metrics, epoch=epoch)
        for metric_name, value in reduced_metrics.items():
            config.logger.add_scalar(metric_name, value, epoch)
예제 #3
0
def test(model: nn.Module,
         dataset: Dataset,
         metrics: TrainingMetrics,
         config: TrainingConfig,
         eval_mode: bool = True) -> Dict[str, Any]:
    test_loader = DataLoader(dataset, batch_size=config.batch_size)
    if eval_mode:
        model.eval()
    results = defaultdict(list)
    pbar = tqdm(enumerate(test_loader))
    with torch.no_grad():
        for i, (data, target) in pbar:
            if i == config.max_batch:
                break
            data, target = data.to(config.device), target.to(config.device)
            output = model(data)
            results['loss'].append(metrics.criterion(output, target).item())

            if metrics is not None:
                batch_results = compute_metrics(metrics.metrics,
                                                y=target,
                                                y_hat=output,
                                                to_tensor=False)
                for k in batch_results.keys():
                    results[k].append(batch_results[k])

            pbar.update()
            pbar.set_postfix(loss=pd.DataFrame(results['loss']).mean().values,
                             accuracy=pd.DataFrame(
                                 results['accuracy']).mean().values)

    pbar.close()
    reduced_results = dict(pd.DataFrame(results).mean())
    if config.logger is not None:
        config.logger.log_metrics(reduced_results, epoch=0)
    return reduced_results
예제 #4
0
    import torchvision.transforms as transforms
    from neuralteleportation.metrics import accuracy
    from torch.nn.modules import Flatten
    import torch.nn as nn

    mnist_train = MNIST('/tmp',
                        train=True,
                        download=True,
                        transform=transforms.ToTensor())
    mnist_val = MNIST('/tmp',
                      train=False,
                      download=True,
                      transform=transforms.ToTensor())
    mnist_test = MNIST('/tmp',
                       train=False,
                       download=True,
                       transform=transforms.ToTensor())

    model = torch.nn.Sequential(Flatten(), nn.Linear(784, 128), nn.ReLU(),
                                nn.Linear(128, 10))

    config = TrainingConfig()
    metrics = TrainingMetrics(nn.CrossEntropyLoss(), [accuracy])

    train(model,
          train_dataset=mnist_train,
          metrics=metrics,
          config=config,
          val_dataset=mnist_val)
    print(test(model, mnist_test, metrics, config))
예제 #5
0
                        type=str,
                        default="resnet18COB",
                        choices=get_model_names())

    return parser.parse_args()


if __name__ == '__main__':
    args = argument_parser()

    device = 'cuda' if cuda_avail() else 'cpu'

    trainset, valset, testset = get_dataset_subsets("cifar10")

    model = get_model("cifar10", args.model, device=device)
    metric = TrainingMetrics(criterion=nn.CrossEntropyLoss(),
                             metrics=[accuracy])
    config = LandscapeConfig(optimizer=(args.optimizer, {
        "lr": args.lr
    }),
                             epochs=args.epochs,
                             batch_size=args.batch_size,
                             cob_range=args.cob_range,
                             cob_sampling=args.cob_sampling,
                             teleport_at=[args.epochs],
                             device=device)
    if args.train:
        train(model, trainset, metric, config)
    a = torch.linspace(args.x[0], args.x[1], int(args.x[2]))
    param_o = model.get_params()
    model.random_teleport(args.cob_range, args.cob_sampling)
    param_t = model.get_params()
def run_experiment(config_path: Path,
                   out_root: Path,
                   data_root_dir: Path = None,
                   save_weights=False,
                   enable_comet=False) -> None:
    with open(str(config_path), 'r') as stream:
        config = yaml.safe_load(stream)

    # Setup metrics to compute
    metrics = TrainingMetrics(nn.CrossEntropyLoss(), [accuracy, accuracy_top5])

    # Get training params
    all_training_params = config["training_params"] if isinstance(
        config["training_params"], list) else [config["training_params"]]

    # datasets
    for dataset_name in config["datasets"]:
        dataset_kwargs = {}
        if data_root_dir is not None:
            dataset_kwargs.update(root=data_root_dir, download=False)
        train_set, val_set, test_set = get_dataset_subsets(
            dataset_name, **dataset_kwargs)
        for training_params in all_training_params:
            # models
            for model_obj in config["models"]:
                model_obj = copy.deepcopy(model_obj)
                model_kwargs = {}
                model_name = model_obj
                if not isinstance(model_obj, str):
                    model_name = model_obj.pop("cls")
                    model_kwargs = model_obj
                # initalizers
                for initializer in config["initializers"]:
                    config['initializer'] = initializer

                    # optimizers
                    for optimizer_kwargs in config["optimizers"]:
                        optimizer_kwargs = copy.deepcopy(optimizer_kwargs)
                        optimizer_name = optimizer_kwargs.pop("cls")
                        lr_scheduler_kwargs = optimizer_kwargs.pop(
                            "lr_scheduler", None)
                        has_scheduler = False
                        if lr_scheduler_kwargs:
                            lr_scheduler_name = lr_scheduler_kwargs.pop("cls")
                            lr_scheduler_interval = lr_scheduler_kwargs.pop(
                                "interval", "epoch")
                            if "lr_lambda" in lr_scheduler_kwargs.keys():
                                # WARNING: Take care of what you pass in as lr_lambda as the string is directly
                                # evaluated
                                # This is needed to transform lambda functions defined as strings to a python callable
                                lr_scheduler_kwargs["lr_lambda"] = eval(
                                    lr_scheduler_kwargs.pop("lr_lambda"))

                            if "steps_per_epoch" in lr_scheduler_kwargs.keys():
                                steps = len(
                                    train_set) / training_params['batch_size']
                                lr_scheduler_kwargs[
                                    'steps_per_epoch'] = math.floor(
                                        steps) if training_params[
                                            'drop_last_batch'] else math.ceil(
                                                steps)

                            has_scheduler = True

                        # teleport configuration
                        for teleport, teleport_config_kwargs in config[
                                "teleportations"].items():

                            # w/o teleport configuration
                            if teleport == "no_teleport":
                                training_config_cls = __training_configs__[
                                    "no_teleport"]
                                # Ensure config collections are iterable, even if no config was defined
                                # This is done to simplify the generation of the configuration matrix
                                teleport_config_kwargs, teleport_mode_configs = {}, [(training_config_cls, {})]

                            # w/ teleport configuration
                            else:  # teleport == "teleport"
                                # Copy the config to play around with its content without affecting the config loaded
                                #  in memory
                                teleport_config_kwargs = copy.deepcopy(
                                    teleport_config_kwargs)

                                teleport_mode_obj = teleport_config_kwargs.pop(
                                    "mode")
                                teleport_mode_configs = []
                                for teleport_mode, teleport_mode_config_kwargs in teleport_mode_obj.items(
                                ):
                                    training_config_cls = __training_configs__[
                                        teleport_mode]
                                    if teleport_mode == "optim":
                                        teleport_mode_config_kwargs[
                                            "optim_metric"] = [
                                                getattr(
                                                    teleport_optim, metric)
                                                for metric in
                                                teleport_mode_config_kwargs.
                                                pop("metric")
                                            ]

                                    # Ensure config collections are iterable, even if no config was defined
                                    # This is done to simplify the generation of the configuration matrix
                                    if teleport_mode_config_kwargs is None:
                                        teleport_mode_config_kwargs = {}

                                    for teleport_mode_single_config_kwargs in dict_values_product(
                                            teleport_mode_config_kwargs):
                                        teleport_mode_configs.append(
                                            (training_config_cls,
                                             teleport_mode_single_config_kwargs
                                             ))

                            # generate matrix of training configuration
                            # (cartesian product of values for each training config kwarg)
                            teleport_configs = dict_values_product(
                                teleport_config_kwargs)
                            config_matrix = itertools.product(
                                teleport_configs, teleport_mode_configs)

                            # Iterate over different possible training configurations
                            for teleport_config_kwargs, (
                                    training_config_cls,
                                    teleport_mode_config_kwargs
                            ) in config_matrix:
                                num_runs = int(
                                    config["runs_per_config"]
                                ) if "runs_per_config" in config.keys() else 1
                                for _ in range(num_runs):
                                    experiment_path, experiment_id = make_experiment(
                                        out_root)

                                    if enable_comet:
                                        logger = MultiLogger([
                                            DiskLogger(experiment_path),
                                            CometLogger(experiment_id)
                                        ])
                                    else:
                                        logger = DiskLogger(experiment_path)

                                    training_config = training_config_cls(
                                        optimizer=(optimizer_name,
                                                   optimizer_kwargs),
                                        lr_scheduler=(lr_scheduler_name,
                                                      lr_scheduler_interval,
                                                      lr_scheduler_kwargs)
                                        if has_scheduler else None,
                                        device='cuda'
                                        if cuda_avail() else 'cpu',
                                        logger=logger,
                                        **training_params,
                                        **teleport_config_kwargs,
                                        **teleport_mode_config_kwargs,
                                    )

                                    # Run experiment (setting up a new model and optimizer for each experiment)
                                    model = get_model(
                                        dataset_name,
                                        model_name,
                                        device=training_config.device,
                                        initializer=initializer,
                                        **model_kwargs)
                                    optimizer = getattr(optim, optimizer_name)(
                                        model.parameters(), **optimizer_kwargs)
                                    lr_scheduler = None
                                    if has_scheduler:
                                        lr_scheduler = getattr(
                                            optim.lr_scheduler,
                                            lr_scheduler_name)(
                                                optimizer,
                                                **lr_scheduler_kwargs)
                                    run_model(model,
                                              training_config,
                                              metrics,
                                              train_set,
                                              test_set,
                                              val_set=val_set,
                                              optimizer=optimizer,
                                              lr_scheduler=lr_scheduler)

                                    if save_weights:
                                        torch.save(
                                            model.state_dict(),
                                            experiment_path / 'weights.pt')