Exemplo n.º 1
0
def describe(cfg):
    from pathlib import Path
    from utils import import_
    klass = import_(cfg.model.klass)
    model = klass(*cfg.model.args, **cfg.model.kwargs)
    if 'state_dict' in cfg:
        model.load_state_dict(torch.load(Path(cfg.state_dict).expanduser().resolve()))
    print(model)
    print(f'Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')

    for name, parameter in model.named_parameters():
        print(f'{name} {tuple(parameter.shape)}:')
        if 'state_dict' in cfg:
            print(parameter.numpy().round())
            print()
sort_dict(session, [
    'epochs', 'batch_size', 'losses', 'seed', 'cpus', 'device', 'samples',
    'status', 'datetime_started', 'datetime_completed', 'data', 'log',
    'checkpoint', 'git', 'gpus'
])
experiment.sessions.append(session)
pyaml.pprint(experiment, sort_dicts=False, width=200)
del session
# endregion

# region Building phase
# Seeds (set them after the random run id is generated)
set_seeds(experiment.session.seed)

# Model
model: torch.nn.Module = import_(experiment.model.fn)(
    *experiment.model.args, **experiment.model.kwargs)
if 'state_dict' in experiment.model:
    model.load_state_dict(torch.load(experiment.model.state_dict))
model.to(experiment.session.device)

# Optimizer
optimizer: torch.optim.Optimizer = import_(
    experiment.optimizer.fn)(model.parameters(), *experiment.optimizer.args,
                             **experiment.optimizer.kwargs)
if 'state_dict' in experiment.optimizer:
    optimizer.load_state_dict(torch.load(experiment.optimizer.state_dict))

# Logger
if len(experiment.session.log.when) > 0:
    logger = SummaryWriter(experiment.session.log.folder)
    logger.add_text(
    raise ValueError(f'Invalid number of cpus: {options.cpus}')
if options.output.exists() and not options.output.is_dir():
    raise ValueError(f'Invalid output path {options.output}')

pyaml.pprint({
    'model': model,
    'options': options,
    'data': data
},
             sort_dicts=False,
             width=200)
# endregion

# region Building phase
# Model
net: torch.nn.Module = import_(model.fn)(*model.args, **model.kwargs)
net.load_state_dict(torch.load(model.state_dict))
net.to(options.device)

# Output folder
options.output.mkdir(parents=True, exist_ok=True)
# endregion

# region Training
# Dataset and dataloader
dataset_predict: InfectionDataset = torch.load(data[0])
dataloader_predict = torch.utils.data.DataLoader(
    dataset_predict,
    shuffle=False,
    num_workers=min(options.cpus, 1)
    if 'cuda' in options.device else options.cpus,