示例#1
0
                                                              betas,
                                                              gammas)

    try:
        # It tries to load the model, otherwise it trains it
        checkpoint = torch.load(
            ROOT_DIR + '/models/SIR_bundle_total/{}'.format(model_name))
    except FileNotFoundError:
        # Train
        optimizer = torch.optim.Adam(sir.parameters(), lr=lr)
        writer = SummaryWriter(
            'runs/' + '{}'.format(model_name))
        sir, train_losses, run_time, optimizer = train_bundle(sir, initial_conditions_set, t_final=t_final,
                                                              epochs=epochs,
                                                              num_batches=10, hack_trivial=hack_trivial,
                                                              train_size=train_size, optimizer=optimizer,
                                                              decay=decay,
                                                              writer=writer, betas=betas,
                                                              gammas=gammas)
        # Save the model
        torch.save({'model_state_dict': sir.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()},
                   ROOT_DIR + '/models/SIR_bundle_total/{}'.format(model_name))

        # Load the checkpoint
        checkpoint = torch.load(
            ROOT_DIR + '/models/SIR_bundle_total/{}'.format(model_name))

    # Load the model
    sir.load_state_dict(checkpoint['model_state_dict'])
示例#2
0
            ROOT_DIR + '/models/SIR_bundle_total/{}'.format(source_model_name))
    except FileNotFoundError:
        # Train
        optimizer = torch.optim.Adam(sir.parameters(), lr=lr)
        source_epochs = 20000
        source_hack_trivial = 0
        source_train_size = 2000
        source_decay = 1e-2
        writer = SummaryWriter('runs/{}_scratch'.format(source_model_name))
        sir, train_losses, run_time, optimizer = train_bundle(
            sir,
            initial_conditions_set,
            t_final=t_final,
            epochs=source_epochs,
            model_name=source_model_name,
            num_batches=10,
            hack_trivial=source_hack_trivial,
            train_size=source_train_size,
            optimizer=optimizer,
            decay=source_decay,
            writer=writer,
            betas=source_betas,
            gammas=source_gammas)
        # Save the model
        torch.save(
            {
                'model_state_dict': sir.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            },
            ROOT_DIR + '/models/SIR_bundle_total/{}'.format(source_model_name))

        # Load the checkpoint