# generate_data(game='SonicTheHedgehog-Genesis', state=level, extension_name=VAE_TRAINING_EXT, frame_jump=4, save_actions=False) ''' =============================================== 2 - VAE training =============================================== ''' vae = VAE() # /!\ If the memory of your graphic card is too low, you can choose a smaller batch_size # print('\nVAE training.') # Do one train with Dropout # vae.train(filepath=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5', batch_size=16, epochs=200) # Then load the trained network but without dropout to push the training a bit further (you need to comment the dropout lines in the VAE.py file) vae.load_weights(file_path=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5') vae.train(filepath=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5', batch_size=16, epochs=200) ''' =============================================== 3 - Visualization of the VAE =============================================== ''' # print('VAE visualization') # images_array_path = IMG_DIR + '/GreenHillZone.Act2.vae_train1.npy' # vae.generate_render(data_path=images_array_path) ''' =============================================== 4 - Generation of the LSTM's training dataset This training makes the LSTM learns the logic and physic of the game in order to predict the next (latent) frame. Make sure to record a lot of different situations (Sonic stuck in front of a wall, jump and move in the air...)
class MFEC: def __init__(self, env, args, device='cpu'): """ Instantiate an MFEC Agent ---------- env: gym.Env gym environment to train on args: args class from argparser args are from from train.py: see train.py for help with each arg device: string 'cpu' or 'cuda:0' depending on use_cuda flag from train.py """ self.environment_type = args.environment_type self.env = env self.actions = range(self.env.action_space.n) self.frames_to_stack = args.frames_to_stack self.Q_train_algo = args.Q_train_algo self.use_Q_max = args.use_Q_max self.force_knn = args.force_knn self.weight_neighbors = args.weight_neighbors self.delta = args.delta self.device = device self.rs = np.random.RandomState(args.seed) # Hyperparameters self.epsilon = args.initial_epsilon self.final_epsilon = args.final_epsilon self.epsilon_decay = args.epsilon_decay self.gamma = args.gamma self.lr = args.lr self.q_lr = args.q_lr # Autoencoder for state embedding network self.vae_batch_size = args.vae_batch_size # batch size for training VAE self.vae_epochs = args.vae_epochs # number of epochs to run VAE self.embedding_type = args.embedding_type self.SR_embedding_type = args.SR_embedding_type self.embedding_size = args.embedding_size self.in_height = args.in_height self.in_width = args.in_width if self.embedding_type == 'VAE': self.vae_train_frames = args.vae_train_frames self.vae_loss = VAELoss() self.vae_print_every = args.vae_print_every self.load_vae_from = args.load_vae_from self.vae_weights_file = args.vae_weights_file self.vae = VAE(self.frames_to_stack, self.embedding_size, self.in_height, self.in_width) self.vae = self.vae.to(self.device) self.optimizer = get_optimizer(args.optimizer, self.vae.parameters(), self.lr) elif self.embedding_type == 'random': self.projection = self.rs.randn( self.embedding_size, self.in_height * self.in_width * self.frames_to_stack).astype(np.float32) elif self.embedding_type == 'SR': self.SR_train_algo = args.SR_train_algo self.SR_gamma = args.SR_gamma self.SR_epochs = args.SR_epochs self.SR_batch_size = args.SR_batch_size self.n_hidden = args.n_hidden self.SR_train_frames = args.SR_train_frames self.SR_filename = args.SR_filename if self.SR_embedding_type == 'random': self.projection = np.random.randn( self.embedding_size, self.in_height * self.in_width).astype(np.float32) if self.SR_train_algo == 'TD': self.mlp = MLP(self.embedding_size, self.n_hidden) self.mlp = self.mlp.to(self.device) self.loss_fn = nn.MSELoss(reduction='mean') params = self.mlp.parameters() self.optimizer = get_optimizer(args.optimizer, params, self.lr) # QEC self.max_memory = args.max_memory self.num_neighbors = args.num_neighbors self.qec = QEC(self.actions, self.max_memory, self.num_neighbors, self.use_Q_max, self.force_knn, self.weight_neighbors, self.delta, self.q_lr) #self.state = np.empty(self.embedding_size, self.projection.dtype) #self.action = int self.memory = [] self.print_every = args.print_every self.episodes = 0 def choose_action(self, values): """ Choose epsilon-greedy policy according to Q-estimates """ # Exploration if self.rs.random_sample() < self.epsilon: self.action = self.rs.choice(self.actions) # Exploitation else: best_actions = np.argwhere(values == np.max(values)).flatten() self.action = self.rs.choice(best_actions) return self.action def TD_update(self, prev_embedding, prev_action, reward, values, time): # On-policy value estimate of current state (epsiloln-greedy) # Expected Sarsa v_t = (1 - self.epsilon) * np.max(values) + self.epsilon * np.mean(values) value = reward + self.gamma * v_t self.qec.update(prev_embedding, prev_action, value, time - 1) def MC_update(self): value = 0.0 for _ in range(len(self.memory)): experience = self.memory.pop() value = value * self.gamma + experience["reward"] self.qec.update( experience["state"], experience["action"], value, experience["time"], ) def add_to_memory(self, state_embedding, action, reward, time): self.memory.append({ "state": state_embedding, "action": action, "reward": reward, "time": time, }) def run_episode(self): """ Train an MFEC agent for a single episode: Interact with environment Perform update """ self.episodes += 1 RENDER_SPEED = 0.04 RENDER = False episode_frames = 0 total_reward = 0 total_steps = 0 # Update epsilon if self.epsilon > self.final_epsilon: self.epsilon = self.epsilon * self.epsilon_decay #self.env.seed(random.randint(0, 1000000)) state = self.env.reset() if self.environment_type == 'fourrooms': fewest_steps = self.env.shortest_path_length(self.env.state) done = False time = 0 while not done: time += 1 if self.embedding_type == 'random': state = np.array(state).flatten() state_embedding = np.dot(self.projection, state) elif self.embedding_type == 'VAE': state = torch.tensor(state).permute(2, 0, 1) #(H,W,C)->(C,H,W) state = state.unsqueeze(0).to(self.device) with torch.no_grad(): mu, logvar = self.vae.encoder(state) state_embedding = torch.cat([mu, logvar], 1) state_embedding = state_embedding.squeeze() state_embedding = state_embedding.cpu().numpy() elif self.embedding_type == 'SR': if self.SR_train_algo == 'TD': state = np.array(state).flatten() state_embedding = np.dot(self.projection, state) with torch.no_grad(): state_embedding = self.mlp( torch.tensor(state_embedding)).cpu().numpy() elif self.SR_train_algo == 'DP': s = self.env.state state_embedding = self.true_SR_dict[s] state_embedding = state_embedding / np.linalg.norm(state_embedding) if RENDER: self.env.render() time.sleep(RENDER_SPEED) # Get estimated value of each action values = [ self.qec.estimate(state_embedding, action) for action in self.actions ] action = self.choose_action(values) state, reward, done, _ = self.env.step(action) if self.Q_train_algo == 'MC': self.add_to_memory(state_embedding, action, reward, time) elif self.Q_train_algo == 'TD': if time > 1: self.TD_update(prev_embedding, prev_action, prev_reward, values, time) prev_reward = reward prev_embedding = state_embedding prev_action = action total_reward += reward total_steps += 1 episode_frames += self.env.skip if self.Q_train_algo == 'MC': self.MC_update() if self.episodes % self.print_every == 0: print("KNN usage:", np.mean(self.qec.knn_usage)) self.qec.knn_usage = [] print("Proportion of replace:", np.mean(self.qec.replace_usage)) self.qec.replace_usage = [] if self.environment_type == 'fourrooms': n_extra_steps = total_steps - fewest_steps return n_extra_steps, episode_frames, total_reward else: return episode_frames, total_reward def warmup(self): """ Collect 1 million frames from random policy and train VAE """ if self.embedding_type == 'VAE': if self.load_vae_from is not None: self.vae.load_state_dict(torch.load(self.load_vae_from)) self.vae = self.vae.to(self.device) else: # Collect 1 million frames from random policy print("Generating dataset to train VAE from random policy") vae_data = [] state = self.env.reset() total_frames = 0 while total_frames < self.vae_train_frames: action = random.randint(0, self.env.action_space.n - 1) state, reward, done, _ = self.env.step(action) vae_data.append(state) total_frames += self.env.skip if done: state = self.env.reset() # Dataset, Dataloader for 1 million frames vae_data = torch.tensor( vae_data ) # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack) vae_data = vae_data.permute(0, 3, 1, 2) # (N x C x H x W) vae_dataset = TensorDataset(vae_data) vae_dataloader = DataLoader(vae_dataset, batch_size=self.vae_batch_size, shuffle=True) # Training loop print("Training VAE") self.vae.train() for epoch in range(self.vae_epochs): train_loss = 0 for batch_idx, batch in enumerate(vae_dataloader): batch = batch[0].to(self.device) self.optimizer.zero_grad() recon_batch, mu, logvar = self.vae(batch) loss = self.vae_loss(recon_batch, batch, mu, logvar) train_loss += loss.item() loss.backward() self.optimizer.step() if batch_idx % self.vae_print_every == 0: msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format( epoch, batch_idx * len(batch), len(vae_dataloader.dataset), loss.item() / len(batch)) print(msg) print('====> Epoch {} Average loss: {:.4f}'.format( epoch, train_loss / len(vae_dataloader.dataset))) if self.vae_weights_file is not None: torch.save(self.vae.state_dict(), self.vae_weights_file) self.vae.eval() elif self.embedding_type == 'SR': if self.SR_embedding_type == 'random': if self.SR_train_algo == 'TD': total_frames = 0 transitions = [] while total_frames < self.SR_train_frames: observation = self.env.reset() s_t = self.env.state # will not work on Atari done = False while not done: action = np.random.randint(self.env.action_space.n) observation, reward, done, _ = self.env.step( action) s_tp1 = self.env.state # will not work on Atari transitions.append((s_t, s_tp1)) total_frames += self.env.skip s_t = s_tp1 # Dataset, Dataloader dataset = SRDataset(self.env, self.projection, transitions) dataloader = DataLoader(dataset, batch_size=self.SR_batch_size, shuffle=True) train_losses = [] #Training loop for epoch in range(self.SR_epochs): for batch_idx, batch in enumerate(dataloader): self.optimizer.zero_grad() e_t, e_tp1 = batch e_t = e_t.to(self.device) e_tp1 = e_tp1.to(self.device) mhat_t = self.mlp(e_t) mhat_tp1 = self.mlp(e_tp1) target = e_t + self.gamma * mhat_tp1.detach() loss = self.loss_fn(mhat_t, target) loss.backward() self.optimizer.step() train_losses.append(loss.item()) print("Epoch:", epoch, "Average loss", np.mean(train_losses)) emb_reps = np.zeros( [self.env.n_states, self.embedding_size]) SR_reps = np.zeros( [self.env.n_states, self.embedding_size]) labels = [] room_size = self.env.room_size for i, (state, obs) in enumerate(self.env.state_dict.items()): emb = np.dot(self.projection, obs.flatten()) emb_reps[i, :] = emb with torch.no_grad(): emb = torch.tensor(emb).to(self.device) SR = self.mlp(emb).cpu().numpy() SR_reps[i, :] = SR if state[0] < room_size + 1 and state[ 1] < room_size + 1: label = 0 elif state[0] > room_size + 1 and state[ 1] < room_size + 1: label = 1 elif state[0] < room_size + 1 and state[ 1] > room_size + 1: label = 2 elif state[0] > room_size + 1 and state[ 1] > room_size + 1: label = 3 else: label = 4 labels.append(label) np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps) np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps) np.save('%s_labels.npy' % (self.SR_filename), labels) elif self.SR_train_algo == 'MC': pass elif self.SR_train_algo == 'DP': # Use this to ensure same order every time idx_to_state = { i: state for i, state in enumerate(self.env.state_dict.keys()) } state_to_idx = {v: k for k, v in idx_to_state.items()} T = np.zeros([self.env.n_states, self.env.n_states]) for i, s in idx_to_state.items(): for a in range(4): self.env.state = s _, _, _, _ = self.env.step(a) s_tp1 = self.env.state T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25 true_SR = np.eye(self.env.n_states) done = False t = 0 while not done: t += 1 new_SR = true_SR + (self.SR_gamma**t) * (np.matmul( true_SR, T)) done = np.max(np.abs(true_SR - new_SR)) < 1e-10 true_SR = new_SR self.true_SR_dict = {} for s, obs in self.env.state_dict.items(): idx = state_to_idx[s] self.true_SR_dict[s] = true_SR[idx, :] else: pass # random projection doesn't require warmup
#TestFile.py # Below two lines are for f*****g graphviz #import os #os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/' from models.VAE import VAE from utils.loader import load_celeb_generator,load_mnist vae = VAE((28,28,1),[32,64,64] , [3,3,3] , [2,2,1] , [64,64,1] , [3,3,3] , [2,2,1] , 200) (a,b),(c,d) = load_mnist() vae.compile(0.005,1000) vae.train(a,100,1024)
from tensorflow.python.keras.datasets import mnist from models.VAE import VAE import tensorflow as tf import numpy as np (x_train, _), (x_test, _) = mnist.load_data() x_train = x_train.astype('float32') / 255. x_test = x_test.astype('float32') / 255. x_train = np.reshape(x_train, (len(x_train), 28, 28, 1)) x_test = np.reshape(x_test, (len(x_test), 28, 28, 1)) vae = VAE() vae.build() vae.train(x_train, x_test)
def train_vae(args, dtype=torch.float32): torch.set_default_dtype(dtype) state_dim = args.state_dim output_path = args.output_path # generate state pairs expert_traj_raw = list(pickle.load(open(args.expert_traj_path, "rb"))) state_pairs = generate_pairs(expert_traj_raw, state_dim, args.size_per_traj, max_step=10, min_step=5) # tune the step size if needed. # shuffle and split idx = np.arange(state_pairs.shape[0]) np.random.shuffle(idx) state_pairs = state_pairs[idx, :] split = (state_pairs.shape[0] * 19) // 20 state_tuples = state_pairs[:split, :] test_state_tuples = state_pairs[split:, :] print(state_tuples.shape) print(test_state_tuples.shape) goal_model = VAE(state_dim, latent_dim=128) optimizer_vae = torch.optim.Adam(goal_model.parameters(), lr=args.model_lr) save_path = '{}_softbc_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, \ args.beta) writer = SummaryWriter(log_dir=os.path.join(output_path, 'runs/' + save_path)) if args.weight: state_dim = state_dim + 1 state_tuples = torch.from_numpy(state_pairs).to(dtype) s, t = state_tuples[:, :state_dim - 1], state_tuples[:, state_dim:2 * state_dim] state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype) s_test, t_test = state_tuples_test[:, :state_dim - 1], state_tuples_test[:, state_dim:2 * state_dim] else: state_tuples = torch.from_numpy(state_pairs).to(dtype) s, t = state_tuples[:, :state_dim], state_tuples[:, state_dim:2 * state_dim] state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype) s_test, t_test = state_tuples_test[:, : state_dim], state_tuples_test[:, state_dim: 2 * state_dim] for i in range(1, args.iter + 1): loss = goal_model.train(s, t, epoch=args.epoch, optimizer=optimizer_vae, \ batch_size=args.optim_batch_size, beta=args.beta, use_weight=args.weight) next_states = goal_model.get_next_states(s_test) if args.weight: val_error = (t_test[:, -1].unsqueeze(1) * (t_test[:, :-1] - next_states)**2).mean() else: val_error = ((t_test[:, :-1] - next_states)**2).mean() writer.add_scalar('loss/vae', loss, i) writer.add_scalar('valid/vae', val_error, i) if i % args.lr_decay_rate == 0: adjust_lr(optimizer_vae, 2.) torch.save( goal_model.state_dict(), os.path.join(output_path, '{}_{}_vae.pt'.format(args.env_name, str(args.beta))))
dataLoader_train = dataLoader_val trainLoader = valLoader epoch = 0 gen_iterations = 1 display_iters = 5 recon_loss_sum, kl_loss_sum = 0, 0 Validation_loss = [] loss_L1 = lossL1() while epoch < niter: epoch += 1 # training phase vae3d.train() for idx, (rdfs, masks, weights, qsms) in enumerate(trainLoader): if gen_iterations%display_iters == 0: print('epochs: [%d/%d], batchs: [%d/%d], time: %ds, case_validation: %f' % (epoch, niter, idx, dataLoader_train.num_samples//batch_size+1, time.time()-t0, opt['case_validation'])) print('Recon loss: %f, kl_loss: %f' % (recon_loss_sum/display_iters, kl_loss_sum/display_iters)) if epoch > 1: print('Validation loss of last epoch: %f' % (Validation_loss[-1])) recon_loss_sum, kl_loss_sum = 0, 0 qsms = (qsms.to(device, dtype=torch.float) + trans) * scale masks = masks.to(device, dtype=torch.float) qsms = qsms * masks recon_loss, kl_loss = vae_train(model=vae3d, optimizer=optimizer, x=qsms, mask=masks)
class VAE_TRAINER(): def __init__(self, params): self.params = params self.loss_function = { 'ms-ssim': ms_ssim_loss, 'mse': mse_loss, 'mix': mix_loss }[params["loss"]] # Choose device self.cuda = params["cuda"] and torch.cuda.is_available() torch.manual_seed(params["seed"]) # Fix numeric divergence due to bug in Cudnn torch.backends.cudnn.benchmark = True self.device = torch.device("cuda" if self.cuda else "cpu") # Prepare data transformations red_size = params["img_size"] transform_train = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((red_size, red_size)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_val = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((red_size, red_size)), transforms.ToTensor(), ]) # Initialize Data loaders op_dataset = RolloutObservationDataset(params["path_data"], transform_train, train=True) val_dataset = RolloutObservationDataset(params["path_data"], transform_val, train=False) self.train_loader = torch.utils.data.DataLoader( op_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=0) self.eval_loader = torch.utils.data.DataLoader( val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=0) # Initialize model and hyperparams self.model = VAE(nc=3, ngf=64, ndf=64, latent_variable_size=params["latent_size"], cuda=self.cuda).to(self.device) self.optimizer = optim.Adam(self.model.parameters()) self.init_vae_model() self.visualize = params["visualize"] if self.visualize: self.plotter = VisdomLinePlotter(env_name=params['env']) self.img_plotter = VisdomImagePlotter(env_name=params['env']) self.alpha = params["alpha"] if params["alpha"] else 1.0 def train(self, epoch): self.model.train() # dataset_train.load_next_buffer() mse_loss = 0 ssim_loss = 0 train_loss = 0 # Train step for batch_idx, data in enumerate(self.train_loader): data = data.to(self.device) self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(data) loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar, self.alpha) loss.backward() train_loss += loss.item() ssim_loss += ssim mse_loss += mse self.optimizer.step() if batch_idx % params["log_interval"] == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), loss.item())) print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim)) step = len(self.train_loader.dataset) / float( self.params["batch_size"]) mean_train_loss = train_loss / step mean_ssim_loss = ssim_loss / step mean_mse_loss = mse_loss / step print('-- Epoch: {} Average loss: {:.4f}'.format( epoch, mean_train_loss)) print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format( mean_mse_loss, mean_ssim_loss)) if self.visualize: self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch, mean_train_loss) return def eval(self): self.model.eval() # dataset_test.load_next_buffer() eval_loss = 0 mse_loss = 0 ssim_loss = 0 vis = True with torch.no_grad(): # Eval step for data in self.eval_loader: data = data.to(self.device) recon_batch, mu, logvar = self.model(data) loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar, self.alpha) eval_loss += loss.item() ssim_loss += ssim mse_loss += mse if vis: org_title = "Epoch: " + str(epoch) comparison1 = torch.cat([ data[:4], recon_batch.view(params["batch_size"], 3, params["img_size"], params["img_size"])[:4] ]) if self.visualize: self.img_plotter.plot(comparison1, org_title) vis = False step = len(self.eval_loader.dataset) / float(params["batch_size"]) mean_eval_loss = eval_loss / step mean_ssim_loss = ssim_loss / step mean_mse_loss = mse_loss / step print('-- Eval set loss: {:.4f}'.format(mean_eval_loss)) print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format( mean_mse_loss, mean_ssim_loss)) if self.visualize: self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch, mean_eval_loss) self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch, mean_mse_loss) self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch, mean_ssim_loss) return mean_eval_loss def init_vae_model(self): self.vae_dir = os.path.join(self.params["logdir"], 'vae') check_dir(self.vae_dir, 'samples') if not self.params["noreload"]: # and os.path.exists(reload_file): reload_file = os.path.join(self.params["vae_location"], 'best.tar') state = torch.load(reload_file) print("Reloading model at epoch {}" ", with eval error {}".format(state['epoch'], state['precision'])) self.model.load_state_dict(state['state_dict']) self.optimizer.load_state_dict(state['optimizer']) def checkpoint(self, cur_best, eval_loss): # Save the best and last checkpoint best_filename = os.path.join(self.vae_dir, 'best.tar') filename = os.path.join(self.vae_dir, 'checkpoint.tar') is_best = not cur_best or eval_loss < cur_best if is_best: cur_best = eval_loss save_checkpoint( { 'epoch': epoch, 'state_dict': self.model.state_dict(), 'precision': eval_loss, 'optimizer': self.optimizer.state_dict() }, is_best, filename, best_filename) return cur_best def plot(self, train, eval, epochs): plt.plot(epochs, train, label="train loss") plt.plot(epochs, eval, label="eval loss") plt.legend() plt.grid() plt.savefig(self.params["logdir"] + "/vae_training_curve.png") plt.close()
def train_vae( model: VAE, optimizer, train_loader: DataLoader, valid_loader: DataLoader, nr_epochs: int, device: str, results_writer: ResultsWriter, config: Config, decoder, ): """ Trains VAE, bases on a config file """ # Define highest values best_valid_loss = np.inf previous_valid_loss = np.inf # Create elbo loss function loss_fn = make_elbo_criterion( vocab_size=model.vocab_size, latent_size=model.latent_size, freebits_param=config.freebits_param, mu_force_beta_param=config.mu_force_beta_param) for epoch in range(nr_epochs): for idx, (train_batch, batch_sent_lengths) in enumerate(train_loader): it = epoch * len(train_loader) + idx batch_loss, perp, preds = train_batch_vae( model, optimizer, loss_fn, train_batch, device, config.mu_force_beta_param, batch_sent_lengths, results_writer, it) elbo_loss, kl_loss, nlll, mu_loss = batch_loss if it % config.print_every == 0: print( f'Iteration: {it} || NLLL: {nlll} || Perp: {perp} || KL Loss: {kl_loss} || MuLoss: {mu_loss} || Total: {elbo_loss}' ) # Store in the table train_vae_results = make_vae_results_dict( batch_loss, perp, model, config, epoch, it) results_writer.add_train_batch_results(train_vae_results) if it % config.train_text_gen_every == 0: with torch.no_grad(): decoded_first_pred = decoder(preds.detach()) decoded_first_true = decoder(train_batch[:, 1:]) results_writer.add_sentence_predictions( decoded_first_pred, decoded_first_true, it) print(f'VAE is generating sentences on {it}: \n') print( f'\t The true sentence is: "{decoded_first_true}" \n') print( f'\t The predicted sentence is: "{decoded_first_pred}" \n' ) if idx % config.validate_every == 0 and it != 0: print('Validating model') valid_losses, valid_perp = evaluate_vae( model, valid_loader, epoch, device, loss_fn, config.mu_force_beta_param, iteration=it) # Store validation results valid_vae_results = make_vae_results_dict( valid_losses, valid_perp, model, config, epoch, it) results_writer.add_valid_results(valid_vae_results) valid_elbo_loss, valid_kl_loss, valid_nll_loss, valid_mu_loss = valid_losses print( f'Validation Results || Elbo loss: {valid_elbo_loss} || KL loss: {valid_kl_loss} || NLLL {valid_nll_loss} || Perp: {valid_perp} ||MU loss {valid_mu_loss}' ) # Check if the model is better and save previous_valid_loss = valid_elbo_loss if previous_valid_loss < best_valid_loss: print( f'New Best Validation score of {previous_valid_loss}!') best_valid_loss = previous_valid_loss save_model( f'vae_best_mu{config.mu_force_beta_param}_wd{model.param_wdropout_k}_fb{config.freebits_param}', model, optimizer, it) model.train() results_writer.save_train_results() results_writer.save_valid_results() print('Done training the VAE')