コード例 #1
0
ファイル: train_vae.py プロジェクト: kgalias/world_models
def main():
    parser = argparse.ArgumentParser(description='VAE')
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='Batch size for training (default=100)')
    parser.add_argument('--n_epochs',
                        type=int,
                        default=100,
                        help='Number of epochs to train (default=100)')
    parser.add_argument('--latent_dim',
                        type=int,
                        default=32,
                        help='Dimension of latent space (default=32)')
    parser.add_argument('--episode_len',
                        type=int,
                        default=1000,
                        help='Length of rollout (default=1000)')
    parser.add_argument(
        '--kl_bound',
        type=float,
        default=0.5,
        help='Clamp KL loss by kl_bound*latent_dim from below (default=0.5)')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=1e-4,
                        help='Learning rate for optimizer (default=1e-4)')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='enables CUDA training (default=False)')
    parser.add_argument('--dir_name', help='Rollouts directory name')
    parser.add_argument('--log_interval',
                        nargs='?',
                        default='2',
                        type=int,
                        help='After how many batches to log (default=2)')
    args = parser.parse_args()

    # Read in and prepare the data.
    dataset = RolloutDataset(
        path_to_dir=os.path.join(DATA_DIR, 'rollouts', args.dir_name),
        size=int(args.dir_name.split('_')[-1]))  # TODO: hack. fix?
    data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    # Use GPU if available.
    use_cuda = args.cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Set up the model and the optimizer.
    vae = VAE(latent_dim=args.latent_dim).to(device)
    optimizer = optim.Adam(params=vae.parameters(), lr=args.learning_rate)

    # Training procedure.
    def train(epoch):
        vae.train()
        train_loss = 0
        start_time = datetime.datetime.now()

        # for rollout_id, rollout in enumerate(data_loader):
        #     n_batches = len(rollout.squeeze()['obs']) // args.batch_size
        #     for batch_id in range(n_batches):
        #         start, stop = args.batch_size * batch_id, args.batch_size * (batch_id + 1)
        #         batch = rollout.squeeze()['obs'][start:stop]
        #         batch = batch.to(device)
        #
        #         optimizer.zero_grad()
        #
        #         recon_batch, mu, logvar = vae(batch)
        #         rec_loss, kl_loss = vae_loss(recon_batch, batch, mu, logvar, kl_bound=args.kl_bound)
        #         loss = rec_loss + kl_loss
        #         loss.backward()
        #         train_loss += loss.item()
        #
        #         optimizer.step()
        #
        #         if batch_id % args.log_interval == 0:
        #             print(
        #                 'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}'
        #                  .format(epoch, (batch_id + 1) * len(batch), len(data_loader.dataset),
        #                          100. * (batch_id + 1) / len(data_loader),
        #                          rec_loss.item() / len(batch),
        #                          kl_loss.item() / len(batch)))

        for batch_id, batch in enumerate(data_loader):
            batch = batch['obs']
            # Take a random observation from each rollout.
            batch = batch[
                torch.arange(args.batch_size, dtype=torch.long),
                torch.
                randint(high=1000, size=(args.batch_size, ), dtype=torch.long)]
            # TODO: use all obs from the rollout (from the randomized start)?
            batch = batch.to(device)

            optimizer.zero_grad()

            recon_batch, mu, logvar = vae(batch)
            rec_loss, kl_loss = vae_loss(recon_batch,
                                         batch,
                                         mu,
                                         logvar,
                                         kl_bound=args.kl_bound)
            loss = rec_loss + kl_loss
            loss.backward()
            train_loss += loss.item()

            optimizer.step()

            if batch_id % args.log_interval == 0:
                print(
                    'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}'
                    .format(epoch, (batch_id + 1) * len(batch),
                            len(data_loader.dataset),
                            100. * (batch_id + 1) / len(data_loader),
                            rec_loss.item() / len(batch),
                            kl_loss.item() / len(batch)))

        duration = datetime.datetime.now() - start_time
        print(
            'Epoch {} average train loss was {:.4f} after {}m{}s of training.'.
            format(epoch, train_loss / len(data_loader.dataset),
                   *divmod(int(duration.total_seconds()), 60)))

    # TODO: add test for VAE?

    # Train loop.
    for i in range(1, args.n_epochs + 1):
        train(i)

    # Save the learned model.
    if not os.path.exists(os.path.join(DATA_DIR, 'vae')):
        os.makedirs(os.path.join(DATA_DIR, 'vae'))

    torch.save(
        vae.state_dict(),
        os.path.join(
            DATA_DIR, 'vae',
            datetime.datetime.today().isoformat() + '_' + str(args.n_epochs)))
コード例 #2
0
ファイル: experiment.py プロジェクト: taoyilee/ml-hackathon
def run(batch_size, max_batch_steps, epochs, annealing_epochs, temp, min_af, loader_workers, eval_freq, _run: "Run"):
    pyro.clear_param_store()
    _run.add_artifact(_run.config["config_file"])

    # Seed randomness for repeatability
    seed_random()

    # dataset
    wildfire_dataset = WildFireDataset(train=True, config_file="config.ini")
    data_loader = DataLoader(wildfire_dataset, batch_size=batch_size, shuffle=True, num_workers=loader_workers)
    expected_batch_size = np.ceil(len(wildfire_dataset) / batch_size)
    expected_batch_size = max_batch_steps if max_batch_steps > 0 else expected_batch_size
    vae_config = get_vae_config()

    with open(temp / "vae_config.json", "w") as fptr:
        json.dump(vae_config.__dict__, fptr, indent=1)

    _run.add_artifact(temp / "vae_config.json")

    vae = VAE(vae_config)
    svi = SVI(vae.model, vae.guide, vae.optimizer, loss=Trace_ELBO())

    from src.data.dataset import _ct

    for step in trange(epochs, desc="Epoch: ", ascii=False, dynamic_ncols=True,
                       bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}'):
        if step < annealing_epochs:
            annealing_factor = min_af + (1.0 - min_af) * step / annealing_epochs
        else:
            annealing_factor = 1.0
        _run.log_scalar("annealing_factor", annealing_factor, step=step)

        epoch_elbo = 0.0
        epoch_time_slices = 0
        for batch_steps_i, d in tqdm(enumerate(data_loader), desc="Batch: ", ascii=False, dynamic_ncols=True,
                                     bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}',
                                     total=expected_batch_size,
                                     leave=False):
            epoch_elbo += svi.step(_ct(d.diurnality), _ct(d.viirs), _ct(d.land_cover), _ct(d.latitude),
                                   _ct(d.longitude), _ct(d.meteorology), annealing_factor)
            epoch_time_slices += d.viirs.shape[0] * d.viirs.shape[0]
            if 0 < max_batch_steps == batch_steps_i:
                break
        elbo = -epoch_elbo / epoch_time_slices
        print(f" [{step:05d}] ELBO: {elbo:.3f}", end="")
        _run.log_scalar("elbo", elbo, step=step)
        alpha = pyro.param("alpha").item()
        beta = pyro.param("beta").item()
        _run.log_scalar("alpha", alpha, step=step)
        _run.log_scalar("beta", beta, step=step)
        inferred_mean, inferred_std = beta_to_mean_std(alpha, beta)
        _run.log_scalar("inferred_mean", inferred_mean, step=step)
        _run.log_scalar("inferred_std", inferred_std, step=step)

        if eval_freq > 0 and step > 0 and step % eval_freq == 0:
            logger.info("Evaluating")
            eval_light(Path(_run.observers[0].dir), vae, data_loader, wildfire_dataset, step)
            vae.train()

    torch.save(vae.state_dict(), temp / "model_final.pt")
    _run.add_artifact(temp / "model_final.pt")
    vae.optimizer.save(temp / "optimizer.pt")
    _run.add_artifact(temp / "optimizer.pt")