Пример #1
0
from anode.training import Trainer

hidden_dim = 32

model = ODENet(device,
               data_dim,
               hidden_dim,
               time_dependent=True,
               non_linearity='relu')

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

from viz.plots import get_feature_history

# Set up trainer
trainer = Trainer(model, optimizer, device)
num_epochs = 12

# Optionally record how the features evolve during training
visualize_features = True

if visualize_features:
    feature_history = get_feature_history(trainer, dataloader, inputs, targets,
                                          num_epochs)
else:
    # If we don't record feature evolution, simply train model
    trainer.train(dataloader, num_epochs)

from viz.plots import multi_feature_plt

multi_feature_plt(feature_history[::2], targets, save_fig='node_feats.png')
Пример #2
0
def run_and_save_experiments_img(device, path_to_config):
    """Runs and saves experiments as they are produced (so results are still
    saved even if NFEs become excessively large or underflow occurs).

    Parameters
    ----------
    device : torch.device

    path_to_config : string
        Path to config json file.
    """
    # Open config file
    with open(path_to_config) as config_file:
        config = json.load(config_file)

    # Create a folder to store experiment results
    timestamp = time.strftime("%Y-%m-%d_%H-%M")
    directory = "img_results_{}_{}".format(timestamp, config["id"])
    if not os.path.exists(directory):
        os.makedirs(directory)

    # Save config file in experiment directory
    with open(directory + '/config.json', 'w') as config_file:
        json.dump(config, config_file)

    num_reps = config["num_reps"]
    dataset = config["dataset"]
    model_configs = config["model_configs"]
    training_config = config["training_config"]

    results = {"dataset": dataset, "model_info": []}

    if dataset == 'mnist':
        data_loader, test_loader = mnist(training_config["batch_size"])
        img_size = (1, 28, 28)
        output_dim = 10

    if dataset == 'cifar10':
        data_loader, test_loader = cifar10(training_config["batch_size"])
        img_size = (3, 32, 32)
        output_dim = 10

    if dataset == 'imagenet':
        data_loader = tiny_imagenet(training_config["batch_size"])
        img_size = (3, 64, 64)
        output_dim = 200

    only_success = True  # Boolean to keep track of any experiments failing

    for i, model_config in enumerate(model_configs):
        results["model_info"].append({})
        # Keep track of losses and nfes
        loss_histories = []
        nfe_histories = []
        bnfe_histories = []
        total_nfe_histories = []
        epoch_loss_histories = []
        epoch_nfe_histories = []
        epoch_bnfe_histories = []
        epoch_total_nfe_histories = []
        # Keep track of models potentially failing
        model_stats = {
            "exceeded": {
                "count": 0,
                "final_losses": [],
                "final_nfes": [],
                "final_bnfes": []
            },
            "underflow": {
                "count": 0,
                "final_losses": [],
                "final_nfes": [],
                "final_bnfes": []
            },
            "success": {
                "count": 0,
                "final_losses": [],
                "final_nfes": [],
                "final_bnfes": []
            }
        }

        if model_config["validation"]:
            epoch_loss_val_histories = []

        is_ode = model_config["type"] == "odenet" or model_config[
            "type"] == "anode"

        for j in range(num_reps):
            print("{}/{} model, {}/{} rep".format(i + 1, len(model_configs),
                                                  j + 1, num_reps))

            if is_ode:
                if model_config["type"] == "odenet":
                    augment_dim = 0
                else:
                    augment_dim = model_config["augment_dim"]

                model = ConvODENet(
                    device,
                    img_size,
                    model_config["num_filters"],
                    output_dim=output_dim,
                    augment_dim=augment_dim,
                    time_dependent=model_config["time_dependent"],
                    non_linearity=model_config["non_linearity"],
                    adjoint=True)
            else:
                model = ResNet(data_dim,
                               model_config["hidden_dim"],
                               model_config["num_layers"],
                               output_dim=output_dim,
                               is_img=True)

            model.to(device)

            optimizer = torch.optim.Adam(
                model.parameters(),
                lr=model_config["lr"],
                weight_decay=model_config["weight_decay"])

            trainer = Trainer(model,
                              optimizer,
                              device,
                              classification=True,
                              print_freq=training_config["print_freq"],
                              record_freq=training_config["record_freq"],
                              verbose=True,
                              save_dir=(directory, '{}_{}'.format(i, j)))

            loss_histories.append([])
            epoch_loss_histories.append([])
            nfe_histories.append([])
            epoch_nfe_histories.append([])
            bnfe_histories.append([])
            epoch_bnfe_histories.append([])
            total_nfe_histories.append([])
            epoch_total_nfe_histories.append([])

            if model_config["validation"]:
                epoch_loss_val_histories.append([])

            # Train one epoch at a time, as NODEs can underflow or exceed the
            # maximum NFEs
            for epoch in range(training_config["epochs"]):
                print("\nEpoch {}".format(epoch + 1))
                try:
                    trainer.train(data_loader, 1, epoch)
                    end_training = False
                except AssertionError as e:
                    only_success = False
                    # Assertion error means we either underflowed or exceeded
                    # the maximum number of steps
                    error_message = e.args[0]
                    # Error message in torchdiffeq for max_num_steps starts
                    # with 'max_num_steps'
                    if error_message.startswith("max_num_steps"):
                        print("Maximum number of steps exceeded")
                        file_name_root = 'exceeded'
                    elif error_message.startswith("underflow"):
                        print("Underflow")
                        file_name_root = 'underflow'
                    else:
                        print("Unknown assertion error")
                        file_name_root = 'unknown'

                    model_stats[file_name_root]["count"] += 1

                    if len(trainer.buffer['loss']):
                        final_loss = np.mean(trainer.buffer['loss'])
                    else:
                        final_loss = None
                    model_stats[file_name_root]["final_losses"].append(
                        final_loss)

                    if len(trainer.buffer['nfe']):
                        final_nfes = np.mean(trainer.buffer['nfe'])
                    else:
                        final_nfes = None
                    model_stats[file_name_root]["final_nfes"].append(
                        final_nfes)

                    if len(trainer.buffer['bnfe']):
                        final_bnfes = np.mean(trainer.buffer['bnfe'])
                    else:
                        final_bnfes = None
                    model_stats[file_name_root]["final_bnfes"].append(
                        final_bnfes)

                    # Save final NFEs before error happened
                    with open(
                            directory +
                            '/{}_{}_{}.json'.format(file_name_root, i, j),
                            'w') as f:
                        json.dump(
                            {
                                "forward": trainer.nfe_buffer,
                                "backward": trainer.bnfe_buffer
                            }, f)

                    end_training = True

                # Save info at every epoch
                loss_histories[-1] = trainer.histories['loss_history']
                epoch_loss_histories[-1] = trainer.histories[
                    'epoch_loss_history']
                if is_ode:
                    nfe_histories[-1] = trainer.histories['nfe_history']
                    epoch_nfe_histories[-1] = trainer.histories[
                        'epoch_nfe_history']
                    bnfe_histories[-1] = trainer.histories['bnfe_history']
                    epoch_bnfe_histories[-1] = trainer.histories[
                        'epoch_bnfe_history']
                    total_nfe_histories[-1] = trainer.histories[
                        'total_nfe_history']
                    epoch_total_nfe_histories[-1] = trainer.histories[
                        'epoch_total_nfe_history']

                if model_config["validation"]:
                    epoch_loss_val = dataset_mean_loss(trainer, test_loader,
                                                       device)
                    if epoch == 0:
                        epoch_loss_val_histories[-1] = [epoch_loss_val]
                    else:
                        epoch_loss_val_histories[-1].append(epoch_loss_val)

                results["model_info"][-1]["type"] = model_config["type"]
                results["model_info"][-1]["loss_history"] = loss_histories
                results["model_info"][-1][
                    "epoch_loss_history"] = epoch_loss_histories
                if model_config["validation"]:
                    results["model_info"][-1][
                        "epoch_loss_val_history"] = epoch_loss_val_histories

                if is_ode:
                    results["model_info"][-1][
                        "epoch_nfe_history"] = epoch_nfe_histories
                    results["model_info"][-1]["nfe_history"] = nfe_histories
                    results["model_info"][-1][
                        "epoch_bnfe_history"] = epoch_bnfe_histories
                    results["model_info"][-1]["bnfe_history"] = bnfe_histories
                    results["model_info"][-1][
                        "epoch_total_nfe_history"] = epoch_total_nfe_histories
                    results["model_info"][-1][
                        "total_nfe_history"] = total_nfe_histories

                # Save losses and nfes at every epoch
                with open(directory + '/losses_and_nfes.json', 'w') as f:
                    json.dump(results['model_info'], f)

                # If training failed, move on to next rep
                if end_training:
                    break

                # If we reached end of training, increment success counter
                if epoch == training_config["epochs"] - 1:
                    model_stats["success"]["count"] += 1

                    if len(trainer.buffer['loss']):
                        final_loss = np.mean(trainer.buffer['loss'])
                    else:
                        final_loss = None
                    model_stats["success"]["final_losses"].append(final_loss)

                    if len(trainer.buffer['nfe']):
                        final_nfes = np.mean(trainer.buffer['nfe'])
                    else:
                        final_nfes = None
                    model_stats["success"]["final_nfes"].append(final_nfes)

                    if len(trainer.buffer['bnfe']):
                        final_bnfes = np.mean(trainer.buffer['bnfe'])
                    else:
                        final_bnfes = None
                    model_stats["success"]["final_bnfes"].append(final_bnfes)

        # Save model stats
        with open(directory + '/model_stats{}.json'.format(i), 'w') as f:
            json.dump(model_stats, f)

    # Create plots

    # Extract size of augmented dims
    augment_labels = [
        'p = 0' if model_config['type'] == 'odenet' else 'p = {}'.format(
            model_config['augment_dim'])
        for model_config in config['model_configs']
    ]
    # Create losses figure
    # Note that we can only calculate mean loss if all models trained to
    # completion. Therefore we only include mean if only_success is True
    histories_plt(results["model_info"],
                  plot_type='loss',
                  labels=augment_labels,
                  include_mean=only_success,
                  save_fig=directory + '/losses.png')
    histories_plt(results["model_info"],
                  plot_type='loss',
                  labels=augment_labels,
                  include_mean=only_success,
                  shaded_err=True,
                  save_fig=directory + '/losses_shaded.png')

    # Create NFE plots if ODE model is included
    contains_ode = False
    for model_config in config["model_configs"]:
        if model_config["type"] == "odenet" or model_config["type"] == "anode":
            contains_ode = True
            break

    if contains_ode:
        # If adjoint method was used, plot forwards, backwards and total nfes
        if trainer.model.odeblock.adjoint:
            nfe_types = ['nfe', 'bnfe', 'total_nfe']
        else:
            nfe_types = ['nfe']

        for nfe_type in nfe_types:
            histories_plt(results["model_info"],
                          plot_type='nfe',
                          labels=augment_labels,
                          include_mean=only_success,
                          nfe_type=nfe_type,
                          save_fig=directory + '/{}s.png'.format(nfe_type))
            histories_plt(results["model_info"],
                          plot_type='nfe',
                          labels=augment_labels,
                          include_mean=only_success,
                          shaded_err=True,
                          nfe_type=nfe_type,
                          save_fig=directory +
                          '/{}s_shaded.png'.format(nfe_type))
            histories_plt(results["model_info"],
                          plot_type='nfe_vs_loss',
                          labels=augment_labels,
                          include_mean=only_success,
                          nfe_type=nfe_type,
                          save_fig=directory +
                          '/{}_vs_loss.png'.format(nfe_type))