Exemple #1
0
def simulate_teleportation_sphere(model: NeuralTeleportationModel,
                                  config: "PseudoTeleportationTrainingConfig",
                                  **kwargs) -> NeuralTeleportationModel:
    print(
        f"Shifting weights on a sphere similar to a {config.cob_sampling} teleportation w/ {config.cob_range} "
        f"COB range.")

    model.cpu(
    )  # Move model to CPU to avoid having 2 models on the GPU (to avoid possible CUDA OOM error)

    teleported_model = deepcopy(model).random_teleport(
        cob_range=config.cob_range, sampling_type=config.cob_sampling)

    init_layers = model.get_weights(concat=False)
    teleported_layers = teleported_model.get_weights(concat=False)

    pseudo_teleported_layers = []
    for init_layer, teleported_layer in zip(init_layers, teleported_layers):
        layer_shift = torch.randn_like(init_layer)
        layer_shift = normalize(layer_shift, p=1, dim=0) * torch.norm(
            teleported_layer - init_layer, 1)
        pseudo_teleported_layer = init_layer + layer_shift
        pseudo_teleported_layers.append(pseudo_teleported_layer)

    pseudo_teleported_weights = torch.cat(pseudo_teleported_layers)
    model.set_weights(pseudo_teleported_weights)
    return model.to(config.device)
def start_training(model: NeuralTeleportationModel,
                   trainloader: DataLoader,
                   valset: VisionDataset,
                   metric: TrainingMetrics,
                   config: CompareTrainingConfig,
                   teleport_chance: float) -> np.ndarray:
    """
        This function starts a model training with a specific Scenario configuration.

        Scenario 1: train the model without using teleportation (teleportation_chance = 0.0)
        Scenario 2: train the model using a probability of teleporting every Xth epochs
        (0 < teleportation_chance < 1.0)
        Scenario 3: train the model using teleportation every Xth epochs (teleportation_chance = 1.0)

        returns:
            np.array containing the validation accuracy results of every epochs.
    """
    model.to(config.device)
    optimizer = get_optimizer_from_model_and_config(model, config)

    results = []
    for e in np.arange(1, args.epochs + 1):
        train_epoch(model=model, metrics=metric, optimizer=optimizer, train_loader=trainloader, epoch=e,
                    device=config.device)
        results.append(test(model=model, dataset=valset, metrics=metric, config=config)['accuracy'])
        model.train()

        if e % config.every_n_epochs == 0 and random.random() <= teleport_chance:
            print("teleported model")
            if config.targeted_teleportation:
                # TODO: use teleportation function here when they are available.
                raise NotImplementedError
            else:
                model.random_teleport(cob_range=config.cob_range, sampling_type=config.cob_sampling)
                optimizer = get_optimizer_from_model_and_config(model, config)

    model.cpu()  # Force the network to go out of the cuda mem.

    return np.array(results)
Exemple #3
0
def simulate_teleportation_distribution(
        model: NeuralTeleportationModel,
        config: "DistributionTeleportationTrainingConfig",
        **kwargs) -> NeuralTeleportationModel:
    print(
        f"Shifting weights to a similar distribution to a {config.cob_sampling} teleportation w/ {config.cob_range}"
        f"COB range.")

    model.cpu()
    teleported_model = deepcopy(model).random_teleport(
        cob_range=config.cob_range, sampling_type=config.cob_sampling)
    teleported_weights = teleported_model.get_weights(
        concat=True).cpu().detach().numpy()
    hist, bin_edges = np.histogram(teleported_weights, bins=1000)
    hist = hist / hist.sum()
    model.init_like_histogram(hist, bin_edges)
    return model.to(config.device)
Exemple #4
0
def get_model(dataset_name: str,
              model_name: str,
              device: str = 'cpu',
              initializer: Dict[str, Union[str, float]] = None,
              **model_kwargs) -> NeuralTeleportationModel:
    # Look up if the requested model is available in the model zoo
    model_factories = _get_model_factories()
    if model_name not in model_factories:
        raise KeyError(f"{model_name} was not found in the model zoo")

    # Dynamically determine the parameters for initializing the model based on the dataset
    model_kwargs.update(get_dataset_info(dataset_name, "num_classes"))
    if "mlp" in model_name.lower():
        input_channels, image_size = get_dataset_info(dataset_name,
                                                      "input_channels",
                                                      "image_size").values()
        model_kwargs.update(input_shape=(input_channels, *image_size))
    else:
        model_kwargs.update(get_dataset_info(dataset_name, "input_channels"))

    if "cifar" in dataset_name and ("resnet" in model_name):
        model_kwargs.update({"for_dataset": "cifar"})
    # Instantiate the model
    model_factory = model_factories[model_name]
    model = model_factory(**model_kwargs)
    # Initialize the model
    if initializer is not None:
        init_gain = None if "gain" not in initializer.keys(
        ) and initializer["type"] == "none" else initializer["gain"]
        init_non_linearity = None if "non_linearity" not in initializer.keys(
        ) else initializer["non_linearity"]
        model = initialize_model(model,
                                 init_type=initializer["type"],
                                 init_gain=init_gain,
                                 non_linearity=init_non_linearity)

    # Transform the base ``nn.Module`` to a ``NeuralTeleportationModel``
    input_channels, image_size = get_dataset_info(dataset_name,
                                                  "input_channels",
                                                  "image_size").values()
    model = NeuralTeleportationModel(network=model,
                                     input_shape=(2, input_channels,
                                                  *image_size))

    return model.to(device)