def main(): parser = argparse.ArgumentParser(description='VAE') parser.add_argument('--batch_size', type=int, default=100, help='Batch size for training (default=100)') parser.add_argument('--n_epochs', type=int, default=100, help='Number of epochs to train (default=100)') parser.add_argument('--latent_dim', type=int, default=32, help='Dimension of latent space (default=32)') parser.add_argument('--episode_len', type=int, default=1000, help='Length of rollout (default=1000)') parser.add_argument( '--kl_bound', type=float, default=0.5, help='Clamp KL loss by kl_bound*latent_dim from below (default=0.5)') parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate for optimizer (default=1e-4)') parser.add_argument('--cuda', action='store_true', default=False, help='enables CUDA training (default=False)') parser.add_argument('--dir_name', help='Rollouts directory name') parser.add_argument('--log_interval', nargs='?', default='2', type=int, help='After how many batches to log (default=2)') args = parser.parse_args() # Read in and prepare the data. dataset = RolloutDataset( path_to_dir=os.path.join(DATA_DIR, 'rollouts', args.dir_name), size=int(args.dir_name.split('_')[-1])) # TODO: hack. fix? data_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) # Use GPU if available. use_cuda = args.cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Set up the model and the optimizer. vae = VAE(latent_dim=args.latent_dim).to(device) optimizer = optim.Adam(params=vae.parameters(), lr=args.learning_rate) # Training procedure. def train(epoch): vae.train() train_loss = 0 start_time = datetime.datetime.now() # for rollout_id, rollout in enumerate(data_loader): # n_batches = len(rollout.squeeze()['obs']) // args.batch_size # for batch_id in range(n_batches): # start, stop = args.batch_size * batch_id, args.batch_size * (batch_id + 1) # batch = rollout.squeeze()['obs'][start:stop] # batch = batch.to(device) # # optimizer.zero_grad() # # recon_batch, mu, logvar = vae(batch) # rec_loss, kl_loss = vae_loss(recon_batch, batch, mu, logvar, kl_bound=args.kl_bound) # loss = rec_loss + kl_loss # loss.backward() # train_loss += loss.item() # # optimizer.step() # # if batch_id % args.log_interval == 0: # print( # 'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}' # .format(epoch, (batch_id + 1) * len(batch), len(data_loader.dataset), # 100. * (batch_id + 1) / len(data_loader), # rec_loss.item() / len(batch), # kl_loss.item() / len(batch))) for batch_id, batch in enumerate(data_loader): batch = batch['obs'] # Take a random observation from each rollout. batch = batch[ torch.arange(args.batch_size, dtype=torch.long), torch. randint(high=1000, size=(args.batch_size, ), dtype=torch.long)] # TODO: use all obs from the rollout (from the randomized start)? batch = batch.to(device) optimizer.zero_grad() recon_batch, mu, logvar = vae(batch) rec_loss, kl_loss = vae_loss(recon_batch, batch, mu, logvar, kl_bound=args.kl_bound) loss = rec_loss + kl_loss loss.backward() train_loss += loss.item() optimizer.step() if batch_id % args.log_interval == 0: print( 'Epoch: {0:}\t| Examples: {1:} / {2:}({3:.0f}%)\t| Rec Loss: {4: .4f}\t| KL Loss: {5:.4f}' .format(epoch, (batch_id + 1) * len(batch), len(data_loader.dataset), 100. * (batch_id + 1) / len(data_loader), rec_loss.item() / len(batch), kl_loss.item() / len(batch))) duration = datetime.datetime.now() - start_time print( 'Epoch {} average train loss was {:.4f} after {}m{}s of training.'. format(epoch, train_loss / len(data_loader.dataset), *divmod(int(duration.total_seconds()), 60))) # TODO: add test for VAE? # Train loop. for i in range(1, args.n_epochs + 1): train(i) # Save the learned model. if not os.path.exists(os.path.join(DATA_DIR, 'vae')): os.makedirs(os.path.join(DATA_DIR, 'vae')) torch.save( vae.state_dict(), os.path.join( DATA_DIR, 'vae', datetime.datetime.today().isoformat() + '_' + str(args.n_epochs)))
def run(batch_size, max_batch_steps, epochs, annealing_epochs, temp, min_af, loader_workers, eval_freq, _run: "Run"): pyro.clear_param_store() _run.add_artifact(_run.config["config_file"]) # Seed randomness for repeatability seed_random() # dataset wildfire_dataset = WildFireDataset(train=True, config_file="config.ini") data_loader = DataLoader(wildfire_dataset, batch_size=batch_size, shuffle=True, num_workers=loader_workers) expected_batch_size = np.ceil(len(wildfire_dataset) / batch_size) expected_batch_size = max_batch_steps if max_batch_steps > 0 else expected_batch_size vae_config = get_vae_config() with open(temp / "vae_config.json", "w") as fptr: json.dump(vae_config.__dict__, fptr, indent=1) _run.add_artifact(temp / "vae_config.json") vae = VAE(vae_config) svi = SVI(vae.model, vae.guide, vae.optimizer, loss=Trace_ELBO()) from src.data.dataset import _ct for step in trange(epochs, desc="Epoch: ", ascii=False, dynamic_ncols=True, bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}'): if step < annealing_epochs: annealing_factor = min_af + (1.0 - min_af) * step / annealing_epochs else: annealing_factor = 1.0 _run.log_scalar("annealing_factor", annealing_factor, step=step) epoch_elbo = 0.0 epoch_time_slices = 0 for batch_steps_i, d in tqdm(enumerate(data_loader), desc="Batch: ", ascii=False, dynamic_ncols=True, bar_format='{desc:<8.5}{percentage:3.0f}%|{bar:40}{r_bar}', total=expected_batch_size, leave=False): epoch_elbo += svi.step(_ct(d.diurnality), _ct(d.viirs), _ct(d.land_cover), _ct(d.latitude), _ct(d.longitude), _ct(d.meteorology), annealing_factor) epoch_time_slices += d.viirs.shape[0] * d.viirs.shape[0] if 0 < max_batch_steps == batch_steps_i: break elbo = -epoch_elbo / epoch_time_slices print(f" [{step:05d}] ELBO: {elbo:.3f}", end="") _run.log_scalar("elbo", elbo, step=step) alpha = pyro.param("alpha").item() beta = pyro.param("beta").item() _run.log_scalar("alpha", alpha, step=step) _run.log_scalar("beta", beta, step=step) inferred_mean, inferred_std = beta_to_mean_std(alpha, beta) _run.log_scalar("inferred_mean", inferred_mean, step=step) _run.log_scalar("inferred_std", inferred_std, step=step) if eval_freq > 0 and step > 0 and step % eval_freq == 0: logger.info("Evaluating") eval_light(Path(_run.observers[0].dir), vae, data_loader, wildfire_dataset, step) vae.train() torch.save(vae.state_dict(), temp / "model_final.pt") _run.add_artifact(temp / "model_final.pt") vae.optimizer.save(temp / "optimizer.pt") _run.add_artifact(temp / "optimizer.pt")