Esempio n. 1
0
def load_experiment(run_uuid: str, checkpoint: Optional[int] = None):
    """
    Load a saved experiment from [train model](train_model.html).
    """

    # Create configurations object
    conf = Configs()
    # Load custom configurations used in the experiment
    conf_dict = experiment.load_configs(run_uuid)
    # We need to get inputs to the feed forward layer, $f(c_i)$
    conf_dict['is_save_ff_input'] = True

    # This experiment is just an evaluation; i.e. nothing is tracked or saved
    experiment.evaluate()
    # Initialize configurations
    experiment.configs(conf, conf_dict, 'run')
    # Set models for saving/loading
    experiment.add_pytorch_models(get_modules(conf))
    # Specify the experiment to load from
    experiment.load(run_uuid, checkpoint)

    # Start the experiment; this is when it actually loads models
    experiment.start()

    return conf
Esempio n. 2
0
def get_predictor():
    conf = Configs()
    experiment.evaluate()

    # This will download a pretrained model checkpoint and some cached files.
    # It will download the archive as `saved_checkpoint.tar.gz` and extract it.
    #
    # If you have a locally trained model load it directly with
    # run_uuid = 'RUN_UUID'
    # And for latest checkpoint
    # checkpoint = None
    run_uuid, checkpoint = experiment.load_bundle(
        lab.get_path() / 'saved_checkpoint.tar.gz',
        url=
        'https://github.com/lab-ml/python_autocomplete/releases/download/0.0.4/transformer_checkpoint.tar.gz'
    )

    conf_dict = experiment.load_configs(run_uuid)
    experiment.configs(conf, conf_dict)
    experiment.add_pytorch_models(get_modules(conf))
    experiment.load(run_uuid, checkpoint)

    experiment.start()
    conf.model.eval()
    return Predictor(conf.model, cache('stoi', lambda: conf.text.stoi),
                     cache('itos', lambda: conf.text.itos))
Esempio n. 3
0
def load_experiment() -> Configs:
    conf = Configs()
    experiment.evaluate()

    # This will download a pretrained model checkpoint and some cached files.
    # It will download the archive as `saved_checkpoint.tar.gz` and extract it.
    #
    # If you have a locally trained model load it directly with
    # run_uuid = 'RUN_UUID'
    # And for latest checkpoint
    # checkpoint = None

    # run_uuid = 'a6cff3706ec411ebadd9bf753b33bae6'  # bpe
    # checkpoint = None
    run_uuid, checkpoint = experiment.load_bundle(
        lab.get_path() / 'saved_checkpoint.tar.gz',
        url=
        'https://github.com/lab-ml/python_autocomplete/releases/download/0.0.5/bundle.tar.gz'
    )

    conf_dict = experiment.load_configs(run_uuid)
    conf_dict['text.is_load_data'] = False
    experiment.configs(conf, conf_dict)
    experiment.add_pytorch_models(get_modules(conf))
    experiment.load(run_uuid, checkpoint)

    experiment.start()

    return conf
Esempio n. 4
0
def get_predictor():
    conf = Configs()
    experiment.evaluate()

    # Replace this with your training experiment UUID
    run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'

    conf_dict = experiment.load_configs(run_uuid)
    experiment.configs(conf, conf_dict)
    experiment.add_pytorch_models(get_modules(conf))
    experiment.load(run_uuid)

    experiment.start()
    conf.model.eval()
    return Predictor(conf.model, cache('stoi', lambda: conf.text.stoi), cache('itos', lambda: conf.text.itos))
Esempio n. 5
0
def main():
    """Generate samples"""

    # Training experiment run UUID
    run_uuid = "a44333ea251411ec8007d1a1762ed686"

    # Start an evaluation
    experiment.evaluate()

    # Create configs
    configs = Configs()
    # Load custom configuration of the training run
    configs_dict = experiment.load_configs(run_uuid)
    # Set configurations
    experiment.configs(configs, configs_dict)

    # Initialize
    configs.init()

    # Set PyTorch modules for saving and loading
    experiment.add_pytorch_models({'eps_model': configs.eps_model})

    # Load training experiment
    experiment.load(run_uuid)

    # Create sampler
    sampler = Sampler(diffusion=configs.diffusion,
                      image_channels=configs.image_channels,
                      image_size=configs.image_size,
                      device=configs.device)

    # Start evaluation
    with experiment.start():
        # No gradients
        with torch.no_grad():
            # Sample an image with an denoising animation
            sampler.sample_animation()

            if False:
                # Get some images fro data
                data = next(iter(configs.data_loader)).to(configs.device)

                # Create an interpolation animation
                sampler.interpolate_animate(data[0], data[1])
Esempio n. 6
0
def main():
    conf = Configs()
    experiment.evaluate()

    # Replace this with your training experiment UUID
    run_uuid = '39b03a1e454011ebbaff2b26e3148b3d'

    conf_dict = experiment.load_configs(run_uuid)
    experiment.configs(conf, conf_dict)
    experiment.add_pytorch_models(get_modules(conf))
    experiment.load(run_uuid)

    experiment.start()
    predictor = Predictor(conf.model, cache('stoi', lambda: conf.text.stoi),
                          cache('itos', lambda: conf.text.itos))
    conf.model.eval()

    with open(str(lab.get_data_path() / 'sample.py'), 'r') as f:
        sample = f.read()
    evaluate(predictor, sample)
Esempio n. 7
0
File: __init__.py Progetto: wx-b/nn
def get_saved_model(run_uuid: str, checkpoint: int):
    """
    ### Load [trained large model](large.html)
    """

    from labml_nn.distillation.large import Configs as LargeConfigs

    # In evaluation mode (no recording)
    experiment.evaluate()
    # Initialize configs of the large model training experiment
    conf = LargeConfigs()
    # Load saved configs
    experiment.configs(conf, experiment.load_configs(run_uuid))
    # Set models for saving/loading
    experiment.add_pytorch_models({'model': conf.model})
    # Set which run and checkpoint to load
    experiment.load(run_uuid, checkpoint)
    # Start the experiment - this will load the model, and prepare everything
    experiment.start()

    # Return the model
    return conf.model