def run_vqvae(info, vqvae_model, opt, train_buffer, valid_buffer, num_samples_to_train=10000, save_every_samples=1000, batches=0): if len(info['model_train_cnts']): train_cnt = info['model_train_cnts'][-1] else: train_cnt = 0 while train_cnt < num_samples_to_train: st = time.time() batch = train_buffer.get_minibatch(info['MODEL_BATCH_SIZE']) opt.zero_grad() avg_train_losses, _, _, _ = train_vqvae(vqvae_model, info, batch) parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 5) opt.step() opt.zero_grad() train_cnt += info['MODEL_BATCH_SIZE'] if (((train_cnt - info['model_last_save']) >= save_every_samples) or batches == 0): valid_batch = valid_buffer.get_minibatch(info['MODEL_BATCH_SIZE']) save_vqvae(info, train_cnt, vqvae_model, opt, avg_train_losses, valid_batch) batches += 1 if not batches % 500: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return vqvae_model, opt
def train_gp(self, seed=0, n_iter=300, pred_interval=5, test=None, verbose=True, n_epoch=50): set_seed(seed) print('SEED=', seed) optimizer = opt.Adam(self.model.parameters()) for i in range(n_iter): batch_nll = 0.0 epoch = 0 for (X, Y) in self.data: if epoch > n_epoch: break epoch += 1 self.model.train() optimizer.zero_grad() nll = self.gp.NLL_batch(X, Y) batch_nll += nll * X.shape[0] nll.backward() clip_grad_value_(self.model.parameters(), 20) optimizer.step() torch.cuda.empty_cache() if i % pred_interval == 0: record = {'nll': batch_nll.item() / self.train['X'].shape[0], 'iter': i} if test is not None: err = rmse(self.gp(test['X'])[0], test['Y']) record['rmse'] = err.item() if verbose: print(record) self.history.append(record) return self.history
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqvae_model.train() opt.zero_grad() #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states) - 1).to(DEVICE) rec = (2 * reshape_input(pred_states[:, 0][:, None]) - 1).to(DEVICE) # dont normalize diff diff = (reshape_input(pred_states[:, 1][:, None])).to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states) # (args.nr_logistic_mix/2)*3 is needed for each reconstruction z_q_x.retain_grad() rec_est = x_d[:, :nmix] diff_est = x_d[:, nmix:] loss_rec = discretized_mix_logistic_loss(rec_est, rec, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_diff = discretized_mix_logistic_loss(diff_est, diff, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) loss_rec.backward(retain_graph=True) loss_diff.backward(retain_graph=True) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) loss_list = [ loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] if batches > 5: handle_checkpointing(train_cnt, loss_list) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss loss = self._compute_loss(batch) / self.batches_per_iter accumulated_loss += loss.item() loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: learning_rate = self.iter_scheduler.get_lr()[ 0] # type: ignore log(f'{tb_prefix}Training/LR', learning_rate, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad()
def train_acn(train_cnt): train_kl_loss = 0.0 train_rec_loss = 0.0 init_cnt = train_cnt st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: encoder_model.train() prior_model.train() pcnn_decoder.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() states = states.to(DEVICE) # 1 channel expected next_states = next_states[:,args.number_condition-1:].to(DEVICE) actions = actions.to(DEVICE) z, u_q = encoder_model(states) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, class_condition=actions, float_condition=z)) #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_states, float_condition=z)) #print(train_cnt) prior_model.codes[relative_indexes-args.number_condition] = u_q.detach().cpu().numpy() np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() mix, u_ps, s_ps = prior_model(u_q) kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_states, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss loss.backward() parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_kl_loss+= kl_loss.item() train_rec_loss+= rec_loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_kl_loss = train_kl_loss/float((train_cnt+states.shape[0])-init_cnt) avg_train_rec_loss = train_rec_loss/float((train_cnt+states.shape[0])-init_cnt) handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss) train_cnt+=len(states) batches+=1 if not batches%1000: print("finished %s epoch after %s seconds at cnt %s"%(batches, time.time()-st, train_cnt)) return train_cnt
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqenc.train() pcnn_decoder.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = reshape_input(states).to(DEVICE) # only predict future observation - normalize targets = (2 * states[:, -1:] - 1).to(DEVICE) #actions = actions.to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states, targets) #z_e_x, z_q_x, latents = vqenc(states) #float_condition = latents.view(latents.shape[0], latents.shape[1]*latents.shape[2]).float() #x_d = pcnn_decoder(targets, class_condition=actions, float_condition=float_condition) z_q_x.retain_grad() vqvae_model.spatial_condition.retain_grad() loss_1 = discretized_mix_logistic_loss(x_d, targets, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) #loss_1, loss_2, loss_3 = get_vqvae_loss(x_d, targets, z_e_x, z_q_x, nr_logistic_mix=args.nr_logistic_mix, beta=args.beta, device=DEVICE) loss_1.backward(retain_graph=True) #vqvae_model.encoder.embedding.zero_grad() #z_e_x.backward(z_q_x.grad, retain_graph=True) z_e_x.backward(vqvae_model.spatial_condition.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) handle_checkpointing(train_cnt, loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def train_acn(train_cnt): train_loss = 0 init_cnt = train_cnt st = time.time() for batch_idx, (data, label, data_index) in enumerate(train_loader): encoder_model.train() prior_model.train() pcnn_decoder.train() lst = time.time() data = data.to(DEVICE) opt.zero_grad() z, u_q = encoder_model(data) #yhat_batch = encoder_model.decode(u_q, s_q, data) # add the predicted codes to the input # TODO - this isn't how you sample pcnn yhat_batch = torch.sigmoid(pcnn_decoder(x=data, float_condition=z)) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() prior_model.codes[data_index] = np_uq #prior_model.fit_knn(prior_model.codes) # output is gmp mixtures, u_ps, s_ps = prior_model(u_q) kl_reg, rec_loss = acn_gmp_loss_function(yhat_batch, data, u_q, mixtures, u_ps, s_ps) loss = kl_reg + rec_loss if not batch_idx % 10: print(train_cnt, batch_idx, kl_reg.item(), rec_loss.item()) loss.backward() parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_loss += loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_loss = train_loss / float((train_cnt + data.shape[0]) - init_cnt) print('batch', train_cnt, avg_train_loss) handle_checkpointing(train_cnt, avg_train_loss) train_cnt += len(data) print(train_loss) print("finished epoch after %s seconds at cnt %s" % (time.time() - st, train_cnt)) return train_cnt
def train(train_input_images, train_actions, encoder, decoder, criterion, optimizer, model_paras, config): times = int(train_input_images.shape[1] / config.seq_len) train_loss = 0 for i in range(times): train_input_image_seq = train_input_images[:, i * config.seq_len:(i + 1) * config.seq_len].cuda() train_action_seq = {action: torch.cat((config.init_y[action], values[:, i*config.seq_len:(i+1)*config.seq_len]), dim=1).cuda()\ for action, values in train_actions.items()} encoder.zero_grad() decoder.zero_grad() encoder_outputs = encoder.forward(train_input_image_seq, config.decoder_batch_size, config.seq_len) y = decoder.forward(encoder_outputs, train_action_seq, config.decoder_batch_size, config.seq_len) losses = [] for action in config.y_keys_info.keys(): losses.append(criterion(y[action], train_action_seq[action][:, 1:])) total_loss = sum(losses) total_loss.backward() train_loss += total_loss.item() clip_grad_value_(model_paras, config.clip_value) optimizer.step() accuracy = {} for action in config.y_keys_info.keys(): _, y_pred = y[action].max(dim=1) accuracy[action] = (y_pred == train_action_seq[action][:, 1:]).sum( ).item() / (config.decoder_batch_size * config.seq_len) print(accuracy) return train_loss / times
def train_vqvae(train_cnt): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: vqvae_model.train() opt.zero_grad() states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(states[:, -1:]) - 1).to(DEVICE) x_d, z_e_x, z_q_x, latents = vqvae_model(states) z_q_x.retain_grad() loss_1 = discretized_mix_logistic_loss(x_d, states, nr_mix=args.nr_logistic_mix, DEVICE=DEVICE) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_3 = args.beta * F.mse_loss(z_e_x, z_q_x.detach()) loss_1.backward(retain_graph=True) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) handle_checkpointing(train_cnt, loss_1.item() / bs, loss_2.item() / bs, loss_3.item() / bs) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def pt_latent_learn( ): # latent_states, actions, rewards, latent_next_states, terminal_flags, masks): opt.zero_grad() batch = replay_memory.get_minibatch(info['BATCH_SIZE']) terminal_flags = torch.Tensor(batch[4].astype(np.int)).to(info['DEVICE']) masks = torch.FloatTensor(batch[5].astype(np.int)).to(info['DEVICE']) # do losses based on next state estimate # backward() for state_representation happens in train_vqvae - still need to # step avg_train_losses, next_z_q_x, next_z_e_x, pt_data = train_vqvae( vqvae_model, info, batch) states, actions, rewards, next_states = pt_data rewards = rewards.float() # get representation for current state x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model( states) # min history to learn is 200,000 frames in dqn - 50000 steps losses = [0.0 for _ in range(info['N_ENSEMBLE'])] q_policy_vals = policy_net(z_e_x, None) next_q_target_vals = target_net(next_z_e_x, None) next_q_policy_vals = policy_net(next_z_e_x, None) z_e_x.retain_grad() next_z_e_x.retain_grad() # trying to work on e rather than q to look at grads # z_q_x dows not give .grads to vqvae_model #q_policy_vals = policy_net(z_q_x, None) #next_q_target_vals = target_net(next_z_q_x, None) #next_q_policy_vals = policy_net(next_z_q_x, None) #z_q_x.retain_grad() #next_z_q_x.retain_grad() cnt_losses = [] for k in range(info['N_ENSEMBLE']): #TODO finish masking total_used = torch.sum(masks[:, k]) if total_used > 0.0: next_q_vals = next_q_target_vals[k].data if info['DOUBLE_DQN']: next_actions = next_q_policy_vals[k].data.max(1, True)[1] next_qs = next_q_vals.gather(1, next_actions).squeeze(1) else: next_qs = next_q_vals.max(1)[0] # max returns a pair preds = q_policy_vals[k].gather(1, actions[:, None]).squeeze(1) targets = rewards + info['GAMMA'] * next_qs * (1 - terminal_flags) l1loss = F.smooth_l1_loss(preds, targets, reduction='mean') full_loss = masks[:, k] * l1loss loss = torch.sum(full_loss / total_used) cnt_losses.append(loss) losses[k] = loss.cpu().detach().item() loss = sum(cnt_losses) / info['N_ENSEMBLE'] loss.backward() vqvae_parameters = list(vqvae_model.parameters()) clip_grad_value_(vqvae_parameters, 5) for param in policy_net.core_net.parameters(): if param.grad is not None: # divide grads in core param.grad.data *= 1.0 / float(info['N_ENSEMBLE']) nn.utils.clip_grad_norm_(policy_net.parameters(), info['CLIP_GRAD']) opt.step() return np.mean(losses) + np.sum(avg_train_losses)
def train_forward(train_cnt, conv_forward_model, opt, latents, actions, rewards, next_latents): st = time.time() batches = 0 while train_cnt < args.num_examples_to_train: conv_forward_model.train() opt.zero_grad() # we want the forward model to produce a next latent in which the vq # model can determine the action we gave it. latents = torch.FloatTensor(latents[:, None]).to(DEVICE) # next_latents is long because of prediction next_latents = torch.LongTensor(next_latents[:, None]).to(DEVICE) rewards = torch.LongTensor(rewards).to(DEVICE) actions = torch.LongTensor(actions).to(DEVICE) # put actions into channel for conditioning bs, _, h, w = latents.shape channel_actions = torch.zeros((bs, num_actions, h, w)).to(DEVICE) for a in range(num_actions): channel_actions[actions == a, a] = 1 channel_rewards = torch.zeros((bs, num_rewards, h, w)).to(DEVICE) for r in range(num_rewards): channel_rewards[rewards == r, r] = 1 # combine input together state_input = torch.cat( (channel_actions, channel_rewards, latents, next_latents), dim=1) bs = float(latents.shape[0]) pred_next_latents = conv_forward_model(state_input) # pred_next_latents shape is bs, c, h, w - need to permute shape for # don't optimize vqmodel - just optimize against its understanding of # the latent data with torch.no_grad(): N, _, H, W = latents.shape C = vq_largs.num_z pred_next_latent_inds = torch.argmax(pred_next_latents, dim=1) x_tilde, pred_z_q_x, pred_actions, pred_rewards = vqvae_model.decode_clusters( pred_next_latent_inds, N, H, W, C) # should be able to predict the input action that got us to this # timestep loss_act = args.alpha_act * F.nll_loss( pred_actions, actions, weight=actions_weight) loss_reward = args.alpha_rew * F.nll_loss( pred_rewards, rewards, weight=rewards_weight) # determine which values of latents change over this time step ts_change = (torch.abs(latents.long() - next_latents) > 1).view(-1) pred_next_latents = pred_next_latents.permute(0, 2, 3, 1).contiguous() next_latents = next_latents.permute(0, 2, 3, 1).contiguous() latents = latents.permute(0, 2, 3, 1).contiguous() loss_rec = args.alpha_rec * F.nll_loss(pred_next_latents.view( -1, num_k), next_latents.view(-1), reduction='mean') # we want to penalize when these are wrong in particular loss_diff_rec = args.alpha_rec * F.nll_loss( pred_next_latents.view(-1, num_k)[ts_change == 1], next_latents.view(-1)[ts_change == 1], reduction='mean') loss = loss_reward + loss_act + loss_rec + loss_diff_rec loss.backward(retain_graph=True) parameters = list(conv_forward_model.parameters()) clip_grad_value_(parameters, 10) opt.step() loss_list = [ loss_reward.item() / bs, loss_act.item() / bs, loss_rec.item() / bs, loss_diff_rec.item() / bs ] if batches > 1000: handle_checkpointing(train_cnt, loss_list) train_cnt += bs batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() metrics_with_states: List[Tuple] = [ (metric, {}) for metric in self.training_metrics ] self._last_train_log_step = 0 log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log_prefix += 'Training' with torch.enable_grad(): for i in range(self.iter_per_step): # print('MEMORY ALLOCATED %f' % float(torch.cuda.memory_allocated() / BYTES_IN_GB)) # print('MEMORY CACHED %f' % float(torch.cuda.memory_cached() / BYTES_IN_GB)) t = time.time() # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) if self.device_count == 1: batch = self._batch_to_device(batch) # Compute loss if self.top_k == 'None': _, _, loss = self._compute_batch( batch, metrics_with_states) else: loss = self._compute_kl_loss(batch) print('LOSS') print(loss) accumulated_loss += loss.item() / self.batches_per_iter loss.backward() # try: # loss.backward() # except RuntimeError: # torch.cuda.empty_cache() # print('EMPTIED CACHE FOR LOSS') # continue # Log loss global_step = (self.iter_per_step * self._step) + i self.beta = self.get_beta(global_step) # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{log_prefix}/Loss', accumulated_loss, global_step) if self.device_count > 1: log(f'{log_prefix}/Gradient_Norm', self.model.module.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.module.parameter_norm, global_step) else: log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm, global_step) log(f'{log_prefix}/Beta', self.beta, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: lr = self.optimizer.param_groups[0]['lr'] # type: ignore log(f'{log_prefix}/LR', lr, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad() # logging train metrics if self.extra_training_metrics_log_interval > self._last_train_log_step: self._log_metrics(log_prefix, metrics_with_states, global_step) self._last_train_log_step = i print('TOTAL TIME: %f' % (time.time() - t)) if self._last_train_log_step != i: # log again at end of step, if not logged at the end of # step before self._log_metrics(log_prefix, metrics_with_states, global_step)
def train_acn(train_cnt): train_kl_loss = 0.0 train_rec_loss = 0.0 init_cnt = train_cnt st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < args.num_examples_to_train: encoder_model.train() prior_model.train() pcnn_decoder.train() opt.zero_grad() lst = time.time() data, label, data_index, is_new_epoch = data_loader.next_unique_batch() if is_new_epoch: # prior_model.new_epoch() print(train_cnt, 'train, is new epoch', prior_model.available_indexes.shape) data = data.to(DEVICE) label = label.to(DEVICE) # inf happens sometime after 0001,680,896 z, u_q = encoder_model(data) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() # add the predicted codes to the input yhat_batch = torch.sigmoid(pcnn_decoder(x=label, float_condition=z)) #print(train_cnt) prior_model.codes[data_index - args.number_condition] = u_q.detach().cpu().numpy() #mixtures, u_ps, s_ps = prior_model(u_q) #loss = acn_gmp_loss_function(yhat_batch, label, u_q, mixtures, u_ps, s_ps) np_uq = u_q.detach().cpu().numpy() if np.isinf(np_uq).sum() or np.isnan(np_uq).sum(): print('train bad') embed() #loss.backward() # parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) # clip_grad_value_(parameters, 10) # train_loss+= loss.item() mix, u_ps, s_ps = prior_model(u_q) #kl_loss, rec_loss = acn_loss_function(yhat_batch, data, u_q, u_ps, s_ps) kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, label, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss loss.backward() parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) + list(pcnn_decoder.parameters()) clip_grad_value_(parameters, 10) train_kl_loss += kl_loss.item() train_rec_loss += rec_loss.item() opt.step() # add batch size because it hasn't been added to train cnt yet avg_train_kl_loss = train_kl_loss / float((train_cnt + data.shape[0]) - init_cnt) avg_train_rec_loss = train_rec_loss / float( (train_cnt + data.shape[0]) - init_cnt) handle_checkpointing(train_cnt, avg_train_kl_loss, avg_train_rec_loss) train_cnt += len(data) # add batch size because it hasn't been added to train cnt yet # avg_train_loss = train_loss/float((train_cnt+data.shape[0])-init_cnt) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt
def train_epoch(self, model: nn.Module, train_loader: DataLoader, val_clean_loader: DataLoader, val_triggered_loader: DataLoader, epoch_num: int, use_amp: bool = False): """ Runs one epoch of training on the specified model :param model: the model to train for one epoch :param train_loader: a DataLoader object pointing to the training dataset :param val_clean_loader: a DataLoader object pointing to the validation dataset that is clean :param val_triggered_loader: a DataLoader object pointing to the validation dataset that is triggered :param epoch_num: the epoch number that is being trained :param use_amp: if True use automated mixed precision for FP16 training. :return: a list of statistics for batches where statistics were computed """ pid = os.getpid() train_dataset_len = len(train_loader.dataset) loop = tqdm(train_loader, disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar) scaler = None if use_amp: scaler = torch.cuda.amp.GradScaler() train_n_correct, train_n_total = None, None sum_batchmean_train_loss = 0 running_train_acc = 0 num_batches = len(train_loader) model.train() for batch_idx, (x, y_truth) in enumerate(loop): x = x.to(self.device) y_truth = y_truth.to(self.device) # put network into training mode & zero out previous gradient computations self.optimizer.zero_grad() # get predictions based on input & weights learned so far if use_amp: with torch.cuda.amp.autocast(): y_hat = model(x) # compute metrics batch_train_loss = self._eval_loss_function(y_hat, y_truth) else: y_hat = model(x) # compute metrics batch_train_loss = self._eval_loss_function(y_hat, y_truth) sum_batchmean_train_loss += batch_train_loss.item() running_train_acc, train_n_total, train_n_correct = _running_eval_acc(y_hat, y_truth, n_total=train_n_total, n_correct=train_n_correct, soft_to_hard_fn=self.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.soft_to_hard_fn_kwargs) # if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc): # _save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc, # train_n_total, train_n_correct, model) # compute gradient if use_amp: # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(batch_train_loss).backward() else: if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc): _save_nandata(x, y_hat, y_truth, batch_train_loss, sum_batchmean_train_loss, running_train_acc, train_n_total, train_n_correct, model) batch_train_loss.backward() # perform gradient clipping if configured if self.optimizer_cfg.training_cfg.clip_grad: if use_amp: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(self.optimizer) if self.optimizer_cfg.training_cfg.clip_type == 'norm': # clip_grad_norm_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val, **self.optimizer_cfg.training_cfg.clip_kwargs) elif self.optimizer_cfg.training_cfg.clip_type == 'val': # clip_grad_val_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_value_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val) else: msg = "Unknown clipping type for gradient clipping!" logger.error(msg) raise ValueError(msg) if use_amp: # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(self.optimizer) # Updates the scale for next iteration. scaler.update() else: self.optimizer.step() loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs)) loop.set_postfix(avg_train_loss=batch_train_loss.item()) # report batch statistics to tensorflow if self.tb_writer: try: batch_num = int(epoch_num * num_batches + batch_idx) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss', batch_train_loss.item(), global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc', running_train_acc, global_step=batch_num) except: # TODO: catch specific expcetions pass if batch_idx % self.num_batches_per_logmsg == 0: logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format( pid, epoch_num, batch_idx * len(x), train_dataset_len, 100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc)) train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches)) # if we have validation data, we compute on the validation dataset num_val_batches_clean = len(val_clean_loader) if num_val_batches_clean > 0: logger.info('Running Validation on Clean Data') running_val_clean_acc, _, _, val_clean_loss = \ _eval_acc(val_clean_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info("No dataset computed for validation on clean dataset!") running_val_clean_acc = None val_clean_loss = None num_val_batches_triggered = len(val_triggered_loader) if num_val_batches_triggered > 0: logger.info('Running Validation on Triggered Data') running_val_triggered_acc, _, _, val_triggered_loss = \ _eval_acc(val_triggered_loader, model, self.device, self.soft_to_hard_fn, self.soft_to_hard_fn_kwargs, self._eval_loss_function) else: logger.info("No dataset computed for validation on triggered dataset!") running_val_triggered_acc = None val_triggered_loss = None validation_stats = EpochValidationStatistics(running_val_clean_acc, val_clean_loss, running_val_triggered_acc, val_triggered_loss) if num_val_batches_clean > 0: logger.info('{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}'.format( pid, epoch_num, val_clean_loss, running_val_clean_acc)) if num_val_batches_triggered > 0: logger.info('{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}'.format( pid, epoch_num, val_triggered_loss, running_val_triggered_acc)) if self.tb_writer: try: batch_num = int((epoch_num + 1) * num_batches) if num_val_batches_clean > 0: self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val-loss', val_clean_loss, global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val_acc', running_val_clean_acc, global_step=batch_num) if num_val_batches_triggered > 0: self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val-loss', val_triggered_loss, global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num) except: pass # update the lr-scheduler if necessary if self.lr_scheduler is not None: if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None: self.lr_scheduler.step() elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc': val_acc = validation_stats.get_val_acc() if val_acc is not None: self.lr_scheduler.step(val_acc) else: msg = "val_clean_acc not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss': val_loss = validation_stats.get_val_loss() if val_loss is not None: self.lr_scheduler.step(val_loss) else: msg = "val_clean_loss not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) else: msg = "Unknown mode for calling lr_scheduler!" logger.error(msg) raise ValueError(msg) return train_stats, validation_stats
def train_epoch(self, model: nn.Module, train_loader: TextDataIterator, val_loader: TextDataIterator, epoch_num: int, progress_bar_disable: bool = False): """ Runs one epoch of training on the specified model :param model: the model to train for one epoch :param train_loader: a DataLoader object pointing to the training dataset :param val_loader: a DataLoader object pointing to the validation dataset :param epoch_num: the epoch number that is being trained :param progress_bar_disable: if True, disables the progress bar :return: a list of statistics for batches where statistics were computed """ pid = os.getpid() train_dataset_len = len(train_loader.dataset) loop = tqdm(train_loader, disable=progress_bar_disable) train_n_correct, train_n_total = None, None val_n_correct, val_n_total = None, None sum_batchmean_train_loss = 0 running_train_acc = 0 num_batches = len(train_loader) # put network into training mode model.train() for batch_idx, batch in enumerate(loop): # zero out previous gradient computations self.optimizer.zero_grad() # get predictions based on input & weights learned so far if model.packed_padded_sequences: text, text_lengths = batch.text x = (text, text_lengths) predictions = model(text, text_lengths).squeeze(1) else: x = batch.text predictions = model(batch.text).squeeze(1) # compute metrics batch_train_loss = self._eval_loss_function(predictions, batch.label) sum_batchmean_train_loss += batch_train_loss.item() running_train_acc, train_n_total, train_n_correct = \ _running_eval_acc(predictions, batch.label, n_total=train_n_total, n_correct=train_n_correct, soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs) if np.isnan(sum_batchmean_train_loss) or np.isnan(running_train_acc): _save_nandata(x, predictions, batch.label, batch_train_loss, sum_batchmean_train_loss, running_train_acc, train_n_total, train_n_correct, model) # compute gradient batch_train_loss.backward() # perform gradient clipping if configured if self.optimizer_cfg.training_cfg.clip_grad: if self.optimizer_cfg.training_cfg.clip_type == 'norm': # clip_grad_norm_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_norm_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val, **self.optimizer_cfg.training_cfg.clip_kwargs) elif self.optimizer_cfg.training_cfg.clip_type == 'val': # clip_grad_val_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_value_(model.parameters(), self.optimizer_cfg.training_cfg.clip_val) else: msg = "Unknown clipping type for gradient clipping!" logger.error(msg) raise ValueError(msg) self.optimizer.step() loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs)) loop.set_postfix(avg_train_loss=batch_train_loss.item()) # report batch statistics to tensorboard if self.tb_writer: try: batch_num = int(epoch_num * num_batches + batch_idx) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss', batch_train_loss.item(), global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc', running_train_acc, global_step=batch_num) except: # TODO: catch specific exceptions! pass if batch_idx % self.num_batches_per_logmsg == 0: logger.info('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}'.format( pid, epoch_num, batch_idx * len(batch), train_dataset_len, 100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc)) train_stats = EpochTrainStatistics(running_train_acc, sum_batchmean_train_loss / float(num_batches)) # if we have validation data, we compute on the validation dataset validation_stats = None num_val_batches = len(val_loader) if num_val_batches > 0: logger.info('Running validation') val_acc, _, _, val_loss = TorchTextOptimizer._eval_acc(val_loader, model, device=self.device, soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs, loss_fn=self._eval_loss_function) validation_stats = EpochValidationStatistics(val_acc, val_loss, None, None) logger.info('{}\tTrain Epoch: {} \tValLoss: {:.6f}\tValAcc: {:.6f}'.format( pid, epoch_num, val_loss, val_acc)) if self.tb_writer: try: batch_num = int((epoch_num + 1) * num_batches) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-validation_loss', val_loss, global_step=batch_num) self.tb_writer.add_scalar(self.optimizer_cfg.reporting_cfg.experiment_name + '-validation_acc', val_acc, global_step=batch_num) except: # TODO: catch specific exceptions! pass # update the lr-scheduler if necessary if self.lr_scheduler is not None: if self.optimizer_cfg.training_cfg.lr_scheduler_call_arg is None: self.lr_scheduler.step() elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_acc': if num_val_batches > 0: # this check ensures that this variable is defined self.lr_scheduler.step(val_acc) else: msg = "val_acc not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) elif self.optimizer_cfg.training_cfg.lr_scheduler_call_arg.lower() == 'val_loss': if num_val_batches > 0: self.lr_scheduler.step(val_loss) else: msg = "val_loss not defined b/c validation dataset is not defined! Ignoring LR step!" logger.warning(msg) else: msg = "Unknown mode for calling lr_scheduler!" logger.error(msg) raise ValueError(msg) return train_stats, validation_stats
def train(): input_text, input_rems_text = get_data(train=True) dec_idx_to_word, dec_word_to_idx, dec_tok_text, dec_bias = tokenize_text( input_text, lower_case=True, vsize=20000) dec_padded_text = pad_text(dec_tok_text) dec_vocab_size = len(dec_idx_to_word) enc_idx_to_word, enc_word_to_idx, enc_tok_text, _ = tokenize_text( input_rems_text) enc_padded_text = pad_text(enc_tok_text) enc_vocab_size = len(enc_idx_to_word) dec_text_tensor = torch.tensor(dec_padded_text, requires_grad=False) if cuda: dec_text_tensor = dec_text_tensor.cuda(device=device) enc, dec = build_model(enc_vocab_size, dec_vocab_size, dec_bias=dec_bias) enc_optim, dec_optim, lossfunc = build_trainers(enc, dec) num_batches = enc_padded_text.shape[0] / BATCH_SIZE sm_loss = None enc.train() dec.train() for epoch in xrange(0, 13): print "Starting New Epoch: %d" % epoch order = np.arange(enc_padded_text.shape[0]) np.random.shuffle(order) enc_padded_text = enc_padded_text[order] dec_text_tensor.data = dec_text_tensor.data[order] for i in xrange(num_batches): s = i * BATCH_SIZE e = (i + 1) * BATCH_SIZE _, enc_pp, dec_pp, enc_lengths = make_packpadded( s, e, enc_padded_text, dec_text_tensor) enc.zero_grad() dec.zero_grad() hid = enc.initHidden(BATCH_SIZE) out_enc, hid_enc = enc.forward(enc_pp, hid, enc_lengths) hid_enc = torch.cat([hid_enc[0, :, :], hid_enc[1, :, :]], dim=1).unsqueeze(0) out_dec, hid_dec, attn = dec.forward(dec_pp[:, :-1], hid_enc, out_enc) out_perm = out_dec.permute(0, 2, 1) dec_text_tensor.shape loss = lossfunc(out_perm, dec_pp[:, 1:]) if sm_loss is None: sm_loss = loss.data else: sm_loss = sm_loss * 0.95 + 0.05 * loss.data loss.backward() clip_grad_value_(enc_optim.param_groups[0]['params'], 5.0) clip_grad_value_(dec_optim.param_groups[0]['params'], 5.0) enc_optim.step() dec_optim.step() #del loss if i % 100 == 0: print "Epoch: %.3f" % (i / float(num_batches) + epoch, ), "Loss:", sm_loss print "GEN:", untokenize( torch.argmax(out_dec, dim=2)[0, :], dec_idx_to_word) #print "GEN:", untokenize(torch.argmax(out_dec,dim=2)[1,:], dec_idx_to_word) print "GT:", untokenize(dec_pp[0, :], dec_idx_to_word) print "IN:", untokenize(enc_pp[0, :], enc_idx_to_word) print torch.argmax(attn[0], dim=1) print "--------------" save_state(enc, dec, enc_optim, dec_optim, dec_idx_to_word, dec_word_to_idx, enc_idx_to_word, enc_word_to_idx, epoch)
def train_epoch(self, model: nn.Module, train_loader: TextDataIterator, val_clean_loader: TextDataIterator, val_triggered_loader: TextDataIterator, epoch_num: int, progress_bar_disable: bool = False, use_amp: bool = False): """ Runs one epoch of training on the specified model :param model: the model to train for one epoch :param train_loader: a DataLoader object pointing to the training dataset :param val_loader: a DataLoader object pointing to the validation dataset :param epoch_num: the epoch number that is being trained :param progress_bar_disable: if True, disables the progress bar :return: a list of statistics for batches where statistics were computed """ pid = os.getpid() train_dataset_len = len(train_loader.dataset) scaler = None if use_amp: scaler = torch.cuda.amp.GradScaler() train_n_correct, train_n_total = None, None val_n_correct, val_n_total = None, None sum_batchmean_train_loss = 0 running_train_acc = 0 num_batches = len(train_loader) # put network and embedding into training mode model.train() loop = tqdm( train_loader, disable=self.optimizer_cfg.reporting_cfg.disable_progress_bar) for batch_idx, (input_ids, attention_mask, labels, label_mask) in enumerate(loop): # zero out previous gradient computations self.optimizer.zero_grad() # batch = tuple(t.to(self.device) for t in batch) # input_ids, attention_mask, labels, label_mask, valid_ids, token_type_ids = batch input_ids = input_ids.to(self.device) attention_mask = attention_mask.to(self.device) labels = labels.to(self.device) label_mask = label_mask.to(self.device) if use_amp: with torch.cuda.amp.autocast(): batch_train_loss, predictions = model( input_ids, attention_mask=attention_mask, labels=labels) else: batch_train_loss, predictions = model( input_ids, attention_mask=attention_mask, labels=labels) sum_batchmean_train_loss += batch_train_loss.item() running_train_acc, train_n_total, train_n_correct = \ _running_eval_acc(predictions, labels, label_mask, n_total=train_n_total, n_correct=train_n_correct, soft_to_hard_fn=self.optimizer_cfg.training_cfg.soft_to_hard_fn, soft_to_hard_fn_kwargs=self.optimizer_cfg.training_cfg.soft_to_hard_fn_kwargs) # backward pass if use_amp: # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same dtype autocast chose for corresponding forward ops. scaler.scale(batch_train_loss).backward() else: if np.isnan(sum_batchmean_train_loss) or np.isnan( running_train_acc): # TODO: Figure out how to pass the original text ... input ids is tokenized _save_nandata(input_ids, predictions, labels, batch_train_loss, sum_batchmean_train_loss, running_train_acc, train_n_total, train_n_correct, model) batch_train_loss.backward() # perform gradient clipping if configured if self.optimizer_cfg.training_cfg.clip_grad: if use_amp: # Unscales the gradients of optimizer's assigned params in-place scaler.unscale_(self.optimizer) if self.optimizer_cfg.training_cfg.clip_type == 'norm': # clip_grad_norm_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_norm_( model.parameters(), self.optimizer_cfg.training_cfg.clip_val, **self.optimizer_cfg.training_cfg.clip_kwargs) elif self.optimizer_cfg.training_cfg.clip_type == 'val': # clip_grad_val_ modifies gradients in place # see: https://pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html torch_clip_grad.clip_grad_value_( model.parameters(), self.optimizer_cfg.training_cfg.clip_val) else: msg = "Unknown clipping type for gradient clipping!" logger.error(msg) raise ValueError(msg) if use_amp: # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(self.optimizer) # Updates the scale for next iteration. scaler.update() else: self.optimizer.step() if self.lr_scheduler is not None: self.lr_scheduler.step() # report batch statistics to tensorflow if self.tb_writer: try: batch_num = int(epoch_num * num_batches + batch_idx) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-train_loss', batch_train_loss.item(), global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-running_train_acc', running_train_acc, global_step=batch_num) except: # TODO: catch specific expcetions pass loop.set_description('Epoch {}/{}'.format(epoch_num + 1, self.num_epochs)) loop.set_postfix(avg_train_loss=batch_train_loss.item()) if batch_idx % self.num_batches_per_logmsg == 0: # TODO: Determine best way to get text (possibly from tokenizer if needed...) # acc_per_label = "Accuracy Per item: " # for k in train_n_total.keys(): # acc_per_label += "{}: {}, ".format(k, 0 if train_n_total[k] == 0 else float(train_n_correct[k]) / float(train_n_total[k])) logger.info( '{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tTrainLoss: {:.6f}\tTrainAcc: {:.6f}' .format( pid, epoch_num, batch_idx * input_ids.shape[0], train_dataset_len, # pid, epoch_num, batch_idx * len(text), train_dataset_len, 100. * batch_idx / num_batches, batch_train_loss.item(), running_train_acc)) train_stats = EpochTrainStatistics( running_train_acc, sum_batchmean_train_loss / float(num_batches)) clean_counts = None triggered_counts = None # if we have validation data, we compute on the validation dataset num_val_batches_clean = len(val_clean_loader) if num_val_batches_clean > 0: logger.info('Running Validation on Clean Data') running_val_clean_acc, clean_n_total, clean_n_correct, val_clean_loss, clean_counts = \ self._eval_acc(val_clean_loader, model) else: logger.info("No dataset computed for validation on clean dataset!") running_val_clean_acc = None val_clean_loss = None num_val_batches_triggered = len(val_triggered_loader) if num_val_batches_triggered > 0: logger.info('Running Validation on Triggered Data') running_val_triggered_acc, triggered_n_total, triggered_n_correct, val_triggered_loss, triggered_counts = \ self._eval_acc(val_triggered_loader, model) else: logger.info( "No dataset computed for validation on triggered dataset!") running_val_triggered_acc = None val_triggered_loss = None validation_stats = EpochValidationStatistics( running_val_clean_acc, val_clean_loss, running_val_triggered_acc, val_triggered_loss) if num_val_batches_clean > 0: acc_per_label = "{" for k in train_n_total.keys(): acc_per_label += "{}: {}, ".format( k, 0 if clean_n_total[k] == 0 else float(clean_n_correct[k]) / float(clean_n_total[k])) acc_per_label += "}" logger.info( '{}\tTrain Epoch: {} \tCleanValLoss: {:.6f}\tCleanValAcc: {:.6f}\nCleanPerLabelAcc: {}\nclean_total: {}\nclean_correct: {}' .format(pid, epoch_num, val_clean_loss, running_val_clean_acc, acc_per_label, clean_n_total, clean_n_correct)) if num_val_batches_triggered > 0: acc_per_label = "{" for k in triggered_n_total.keys(): acc_per_label += "{}: {}, ".format( k, 0 if triggered_n_total[k] == 0 else float(triggered_n_correct[k]) / float(triggered_n_total[k])) acc_per_label += "}" logger.info( '{}\tTrain Epoch: {} \tTriggeredValLoss: {:.6f}\tTriggeredValAcc: {:.6f}\tTriggeredPerLabelAcc: {}' .format(pid, epoch_num, val_triggered_loss, running_val_triggered_acc, acc_per_label)) if self.tb_writer: try: batch_num = int((epoch_num + 1) * num_batches) if num_val_batches_clean > 0: self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val-loss', val_clean_loss, global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-clean-val_acc', running_val_clean_acc, global_step=batch_num) if num_val_batches_triggered > 0: self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val-loss', val_triggered_loss, global_step=batch_num) self.tb_writer.add_scalar( self.optimizer_cfg.reporting_cfg.experiment_name + '-triggered-val_acc', running_val_triggered_acc, global_step=batch_num) except: pass # Add epoch to stats self.ner_metrics.add_epoch_stats(epoch_num, clean_counts, triggered_counts) return train_stats, validation_stats
def train_gp(self, seed=0, n_iter=300, pred_interval=5, test=False, verbose=True, n_epoch=20): set_seed(seed) print('SEED=', seed) optimizer = opt.Adam(self.model.parameters()) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) pred_start = torch.cuda.Event(enable_timing=True) pred_end = torch.cuda.Event(enable_timing=True) pred_time = 0.0 start.record() for i in range(n_iter): batch_nll = 0.0 epoch = 0 for (X, Y) in self.data: if epoch > n_epoch: break epoch += 1 self.model.train() optimizer.zero_grad() nll = self.gp.NLL(X, Y) batch_nll += nll * X.shape[0] nll.backward() clip_grad_value_(self.model.parameters(), 10) optimizer.step() torch.cuda.empty_cache() if i % pred_interval == 0: record = { 'nll': batch_nll.item() / self.train['X'].shape[0], 'iter': i } if test: pred_start.record() Ypred_full = [] Yt_full = [] count = 0 for (Xt, Yt) in self.test: count += 1 Yt_full.append(Yt) Ypred = self.gp(Xt, self.train['X'], self.train['Y'], var=False) Ypred_full.append(Ypred) Ypred_full = torch.cat(Ypred_full) Yt_full = torch.cat(Yt_full) record['rmse'] = rmse(Ypred_full, Yt_full).item() pred_end.record() torch.cuda.synchronize() pred_time += pred_start.elapsed_time(pred_end) end.record() torch.cuda.synchronize() record['time'] = start.elapsed_time(end) - pred_time if verbose: print(record) self.history.append(record) return self.history
def train_acn(info, model_dict, data_buffers, phase='train'): encoder_model = model_dict['encoder_model'] prior_model = model_dict['prior_model'] pcnn_decoder = model_dict['pcnn_decoder'] opt = model_dict['opt'] # add one to the rewards so that they are all positive # use next_states because that is the t-1 action if len(info['model_train_cnts']): train_cnt = info['model_train_cnts'][-1] else: train_cnt = 0 num_batches = data_buffers['train'].count // info['MODEL_BATCH_SIZE'] while train_cnt < 10000000: if phase == 'valid': encoder_model.eval() prior_model.eval() pcnn_decoder.eval() else: encoder_model.train() prior_model.train() pcnn_decoder.train() batch_num = 0 data_buffers[phase].reset_unique() print('-------------new epoch %s------------------' % phase) print('num batches', num_batches) while data_buffers[phase].unique_available: opt.zero_grad() batch = data_buffers[phase].get_unique_minibatch( info['MODEL_BATCH_SIZE']) relative_indices = batch[-1] states, actions, rewards, next_states = make_state( batch[:-1], info['DEVICE'], info['NORM_BY']) next_state = next_states[:, -1:] bs = states.shape[0] #states, actions, rewards, next_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_unique_minibatch() z, u_q = encoder_model(states) # add the predicted codes to the input #yhat_batch = torch.sigmoid(pcnn_decoder(x=next_state, # class_condition=actions, # float_condition=z)) yhat_batch = encoder_model.decode(z) prior_model.codes[relative_indices] = u_q.detach().cpu().numpy() mix, u_ps, s_ps = prior_model(u_q) # track losses kl_loss, rec_loss = acn_gmp_loss_function(yhat_batch, next_state, u_q, mix, u_ps, s_ps) loss = kl_loss + rec_loss # aatch size because it hasn't been added to train cnt yet if not phase == 'valid': loss.backward() #parameters = list(encoder_model.parameters()) + list(prior_model.parameters()) + list(pcnn_decoder.parameters()) parameters = list(encoder_model.parameters()) + list( prior_model.parameters()) clip_grad_value_(parameters, 10) opt.step() train_cnt += bs if not batch_num % info['MODEL_LOG_EVERY_BATCHES']: print(phase, train_cnt, batch_num, kl_loss.item(), rec_loss.item()) info = add_losses(info, train_cnt, phase, kl_loss.item(), rec_loss.item()) batch_num += 1 if (((train_cnt - info['model_last_save']) >= info['MODEL_SAVE_EVERY'])): info = add_losses(info, train_cnt, phase, kl_loss.item(), rec_loss.item()) if phase == 'train': # run as valid phase and get back to here phase = 'valid' else: model_dict = { 'encoder_model': encoder_model, 'prior_model': prior_model, 'pcnn_decoder': pcnn_decoder, 'opt': opt } info = save_model(info, model_dict) phase = 'train' model_dict = { 'encoder_model': encoder_model, 'prior_model': prior_model, 'pcnn_decoder': pcnn_decoder, 'opt': opt } info = save_model(info, model_dict)
def train(latent_dim, datasource, num_actions, num_rewards, encoder, decoder, reward_predictor, discriminator, transition): batch_size = args.batch_size train_iters = args.train_iters td_lambda_coef = args.td_lambda td_steps = args.td_steps truncate_bptt = args.truncate_bptt enable_td = args.latent_td enable_latent_overshooting = args.latent_overshooting learning_rate = args.learning_rate min_prediction_horizon = args.horizon_min max_prediction_horizon = args.horizon_max finetune_reward = args.finetune_reward REWARD_COEF = args.reward_coef ACTIVATION_L1_COEF = args.activation_l1_coef TRANSITION_L1_COEF = args.transition_l1_coef counterfactual_horizon = args.counterfactual_horizon start_iter = args.start_iter opt_enc = torch.optim.Adam(encoder.parameters(), lr=learning_rate) opt_dec = torch.optim.Adam(decoder.parameters(), lr=learning_rate) opt_trans = torch.optim.Adam(transition.parameters(), lr=learning_rate) opt_disc = torch.optim.Adam(discriminator.parameters(), lr=learning_rate) opt_pred = torch.optim.Adam(reward_predictor.parameters(), lr=learning_rate) ts = TimeSeries('Training Model', train_iters, tensorboard=True) for train_iter in range(start_iter, train_iters + 1): if train_iter % ITERS_PER_VIDEO == 0: print('Evaluating networks...') evaluate(datasource, encoder, decoder, transition, discriminator, reward_predictor, latent_dim, train_iter=train_iter) print('Saving networks to filesystem...') torch.save(transition.state_dict(), 'model-transition.pth') torch.save(encoder.state_dict(), 'model-encoder.pth') torch.save(decoder.state_dict(), 'model-decoder.pth') torch.save(discriminator.state_dict(), 'model-discriminator.pth') torch.save(reward_predictor.state_dict(), 'model-reward_predictor.pth') theta = train_iter / train_iters pred_delta = max_prediction_horizon - min_prediction_horizon prediction_horizon = min_prediction_horizon + int(pred_delta * theta) train_mode( [encoder, decoder, transition, discriminator, reward_predictor]) # Train encoder/transition/decoder opt_enc.zero_grad() opt_dec.zero_grad() opt_trans.zero_grad() opt_pred.zero_grad() states, rewards, dones, actions = datasource.get_trajectories( batch_size, prediction_horizon) states = torch.Tensor(states).cuda() rewards = torch.Tensor(rewards).cuda() dones = torch.Tensor(dones.astype(int)).cuda() # Encode the initial state (using the first 3 frames) # Given t, t+1, t+2, encoder outputs the state at time t+1 z = encoder(states[:, 0:3]) z_orig = z.clone() # But wait, here's the problem: We can't use the encoded initial state as # an initial state of the dynamical system and expect the system to work # The dynamical system needs to have something like the Echo State Property # So the dynamical parts need to run long enough to reach a steady state # Keep track of "done" states to stop a training trajectory at the final time step active_mask = torch.ones(batch_size).cuda() loss = 0 lo_loss = 0 lo_z_set = {} # Given the state encoded at t=2, predict state at t=3, t=4, ... for t in range(1, prediction_horizon - 1): active_mask = active_mask * (1 - dones[:, t]) # Predict reward expected_reward = reward_predictor(z) actual_reward = rewards[:, t] reward_difference = torch.mean( torch.mean( (expected_reward - actual_reward)**2, dim=1) * active_mask) ts.collect('Rd Loss t={}'.format(t), reward_difference) loss += theta * REWARD_COEF * reward_difference # Normalize by height * width # Reconstruction loss target_pixels = states[:, t] predicted = torch.sigmoid(decoder(z)) rec_loss_batch = decoder_pixel_loss(target_pixels, predicted) if truncate_bptt and t > 1: z.detach_() rec_loss = torch.mean(rec_loss_batch * active_mask) ts.collect('Reconstruction t={}'.format(t), rec_loss) loss += rec_loss # Apply activation L1 loss #l1_values = z.abs().mean(-1).mean(-1).mean(-1) #l1_loss = ACTIVATION_L1_COEF * torch.mean(l1_values * active_mask) #ts.collect('L1 t={}'.format(t), l1_loss) #loss += theta * l1_loss # Predict transition to the next state onehot_a = torch.eye(num_actions)[actions[:, t]].cuda() new_z = transition(z, onehot_a) # Apply transition L1 loss #t_l1_values = ((new_z - z).abs().mean(-1).mean(-1).mean(-1)) #t_l1_loss = TRANSITION_L1_COEF * torch.mean(t_l1_values * active_mask) #ts.collect('T-L1 t={}'.format(t), t_l1_loss) #loss += theta * t_l1_loss z = new_z if enable_latent_overshooting: # Latent Overshooting, Hafner et al. lo_z_set[t] = encoder(states[:, t - 1:t + 2]) # For each previous t_left, step forward to t for t_left in range(1, t): a = torch.eye(num_actions)[actions[:, t - 1]].cuda() lo_z_set[t_left] = transition(lo_z_set[t_left], a) for t_a in range(2, t - 1): # It's like TD but only N:1 for all N predicted_activations = lo_z_set[t_a] target_activations = lo_z_set[t].detach() lo_loss_batch = latent_state_loss(target_activations, predicted_activations) lo_loss += td_lambda_coef * torch.mean( lo_loss_batch * active_mask) if enable_latent_overshooting: ts.collect('LO total', lo_loss) loss += theta * lo_loss # COUNTERFACTUAL DISENTANGLEMENT REGULARIZATION # Suppose that our representation is ideally, perfectly disentangled # Then the PGM has no edges, the causal graph is just nodes with no relationships # In this case, it should be true that intervening on any one factor has no effect on the others # One fun way of intervening is swapping factors, a la FactorVAE # If we intervene on some dimensions, the other dimensions should be unaffected if enable_cf_shuffle_loss and train_iter % CF_REGULARIZATION_RATE == 0: # Counterfactual scenario A: our memory of what really happened z_cf_a = z.clone() # Counterfactual scenario B: a bizzaro world where two dimensions are swapped z_cf_b = z_orig unswapped_factor_map = torch.ones((batch_size, latent_dim)).cuda() for i in range(batch_size): idx_a = np.random.randint(latent_dim) idx_b = np.random.randint(latent_dim) unswapped_factor_map[i, idx_a] = 0 unswapped_factor_map[i, idx_b] = 0 z_cf_b[i, idx_a], z_cf_b[i, idx_b] = z_cf_b[i, idx_b], z_cf_b[i, idx_a] # But we take the same actions for t in range(1, counterfactual_horizon): onehot_a = torch.eye(num_actions)[actions[:, t]].cuda() z_cf_b = transition(z_cf_b, onehot_a) # Every UNSWAPPED dimension should be as similar as possible to its bizzaro-world equivalent cf_loss = torch.abs(z_cf_a - z_cf_b).mean(-1).mean( -1) * unswapped_factor_map cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean( cf_loss.mean(-1) * active_mask) loss += cf_loss ts.collect('CF Disentanglement Loss', cf_loss) # COUNTERFACTUAL ACTION-CONTROL REGULARIZATION # In difficult POMDPs, deep neural networks can suffer from learned helplessness # They learn, rationally, that their actions have no causal influence on the reward # This is undesirable: the learned model should assume that outcomes are controllable if enable_control_bias_loss and train_iter % CF_REGULARIZATION_RATE == 0: # Counterfactual scenario A: our memory of what really happened z_cf_a = z.clone() # Counterfactual scenario B: our imagination of what might have happened z_cf_b = z_orig # Instead of the regular actions, apply an alternate policy cf_actions = actions.copy() np.random.shuffle(cf_actions) for t in range(1, counterfactual_horizon): onehot_a = torch.eye(num_actions)[cf_actions[:, t]].cuda() z_cf_b = transition(z_cf_b, onehot_a) eps = .001 # for numerical stability cf_loss = -torch.log( torch.abs(z_cf_a - z_cf_b).mean(-1).mean(-1).mean(-1) + eps) cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean( cf_loss * active_mask) loss += cf_loss ts.collect('CF Control Bias Loss', cf_loss) loss.backward() from torch.nn.utils.clip_grad import clip_grad_value_ clip_grad_value_(encoder.parameters(), 0.1) clip_grad_value_(transition.parameters(), 0.1) clip_grad_value_(decoder.parameters(), 0.1) opt_pred.step() if not args.finetune_reward: opt_enc.step() opt_dec.step() opt_trans.step() ts.print_every(10) print(ts) print('Finished')
def train_gp(self, seed=0, burn_iter=100, n_iter=100, lmbda=1.0, pred_interval=5, test=None, verbose=True, n_epoch=50): set_seed(seed) print('SEED=', seed) optimizer = opt.Adam(self.model.parameters()) for i in range(n_iter): batch_nll = 0.0 batch_elbo = 0.0 batch_loss = 0.0 #batch_validation = 0.0 epoch = 0 for (X, Y) in self.data: if epoch > n_epoch: break epoch += 1 X.to(self.device) Y.to(self.device) ''' Z = self.vae(X) self.gp.data['X'] = Z self.gp.data['Y'] = Y self.gp.n_data = Z.shape[0] self.gp.n_dim = Z.shape[1] ''' Xr = self.vae(self.vae(X), encode=False) self.gp.data['X'] = Xr #print(self.gp.data['X'].shape) self.gp.data['Y'] = Y self.gp.n_data = Xr.shape[0] self.gp.n_dim = Xr.shape[1] self.model.train() optimizer.zero_grad() delbo = self.vae.dsELBO(X, alpha=1.0, beta=1.0, gamma=1.0, verbose=False) nll = self.gp.NLL_batch(Xr, Y) #validation = rmse(self.gp(Xr, grad=True)[0], Y) loss = - delbo + lmbda * nll #batch_validation += validation * X.shape[0] batch_nll += nll * X.shape[0] batch_elbo += delbo * X.shape[0] batch_loss += loss * X.shape[0] loss.backward() clip_grad_value_(self.model.parameters(), 20) optimizer.step() torch.cuda.empty_cache() if i % pred_interval == 0: record = {'nll': batch_nll.item() / self.train['X'].shape[0], 'elbo': batch_elbo.item() / self.train['X'].shape[0], 'loss': batch_loss.item() / self.train['X'].shape[0], 'iter': i} if test is not None: #self.gp.data['X'] = self.vae(self.original['X'], grad=False) self.gp.data['X'] = self.vae(self.vae(self.original['X'], grad=False), encode=False, grad=False) #print(self.gp.data['X'].shape) self.gp.data['Y'] = self.original['Y'] self.gp.n_data = self.gp.data['X'].shape[0] self.gp.n_dim = self.gp.data['X'].shape[1] Xtr = self.vae(self.vae(test['X'], grad=False), encode=False, grad=False) err = rmse(self.gp(Xtr)[0], test['Y']) record['rmse'] = err.item() if verbose: print(record) self.history.append(record) ''' for i in range(n_iter): batch_nll = 0.0 epoch = 0 for (X, Y) in self.data: if epoch > n_epoch: break epoch += 1 X.to(self.device) Y.to(self.device) Xr = self.vae(self.vae(X), encode=False) self.gp.data['X'] = Z self.gp.data['Y'] = Y self.gp.n_data = Z.shape[0] self.gp.n_dim = Z.shape[1] self.gp.data['X'] = Xr self.gp.data['Y'] = Y self.gp.n_data = Xr.shape[0] self.gp.n_dim = Xr.shape[1] self.gp_model.train() optimizer.zero_grad() nll = self.gp.NLL_batch(Xr, Y) batch_nll += nll * X.shape[0] nll.backward() optimizer.step() torch.cuda.empty_cache() if i % pred_interval == 0: record = {'nll': batch_nll.item() / self.train['X'].shape[0], 'iter': i} if test is not None: self.gp.data['X'] = self.vae(self.original['X']) self.gp.data['Y'] = self.original['Y'] self.gp.n_data = self.gp.data['X'].shape[0] self.gp.n_dim = self.gp.data['X'].shape[1] err = rmse(self.gp(self.vae(test['X']))[0], test['Y']) record['rmse'] = err.item() if verbose: print(record) self.history.append(record) ''' return self.history
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() metrics_with_states: List[Tuple] = [ (metric, {}) for metric in self.training_metrics ] self._last_train_log_step = 0 log_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" log_prefix += 'Training' with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch try: batch = next(self._train_iterator) except StopIteration: self._create_train_iterator() batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss _, _, loss = self._compute_batch(batch, metrics_with_states) accumulated_loss += loss.item() / self.batches_per_iter loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{log_prefix}/Loss', accumulated_loss, global_step) log(f'{log_prefix}/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{log_prefix}/Parameter_Norm', self.model.parameter_norm, global_step) # Optimize self.optimizer.step() # Update iter scheduler if self.iter_scheduler is not None: learning_rate = self.iter_scheduler.get_lr()[ 0] # type: ignore log(f'{log_prefix}/LR', learning_rate, global_step) self.iter_scheduler.step() # type: ignore # Zero the gradients when exiting a train step self.optimizer.zero_grad() # logging train metrics if self.extra_training_metrics_log_interval > self._last_train_log_step: self._log_metrics(log_prefix, metrics_with_states, global_step) self._last_train_log_step = i if self._last_train_log_step != i: # log again at end of step, if not logged at the end of # step before self._log_metrics(log_prefix, metrics_with_states, global_step)
def clip_grad_value(params, clip_value=10): # slow down training clip_grad.clip_grad_value_(filter(lambda p: p.requires_grad, params), clip_value=clip_value)
def _train_step(self) -> None: """Run a training step over the training data.""" self.model.train() tb_prefix = f"{self.tb_log_prefix} " if self.tb_log_prefix else "" with torch.enable_grad(): for i in range(self.iter_per_step): # Zero the gradients and clear the accumulated loss self.optimizer.zero_grad() self.optimizer_alphas.zero_grad() self.optimizer_lambdas.zero_grad() accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss loss = self._compute_loss(batch) / self.batches_per_iter accumulated_loss += loss.item() loss.backward() # Log loss global_step = (self.iter_per_step * self._step) + i # Clip gradients if necessary if self.max_grad_norm: clip_grad_norm_(self.model.parameters(), self.max_grad_norm) if self.max_grad_abs_val: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) if global_step >= self.iter_before_pruning: pruning_step = global_step - self.iter_before_pruning num_parameters = get_num_params(self.hard_concrete_modules, train=True) expected_sparsity = 1. - (num_parameters / self.max_prunable) if self.target_sparsity_warmup > 0: factor = min(1.0, pruning_step / self.target_sparsity_warmup) target_sparsity = self.target_sparsity * factor else: target_sparsity = self.target_sparsity lagrangian_loss = self.lambda_1 * (target_sparsity - expected_sparsity) lagrangian_loss += self.lambda_2 * (target_sparsity - expected_sparsity) ** 2 lagrangian_loss.backward() log("Expected_sparsity", float(expected_sparsity), global_step) log("Lagrangian_loss", lagrangian_loss.item(), global_step) log("Target_sparsity", target_sparsity, global_step) log("lambda_1", self.lambda_1.item(), global_step) log("lambda_2", self.lambda_2.item(), global_step) self.optimizer_lambdas.step() self.lambdas_scheduler.step(pruning_step) self.optimizer_alphas.step() self.alphas_scheduler.step(pruning_step) # Optimize self.optimizer.step() self.lr_scheduler.step(global_step) # Zero the gradients when exiting a train step self.optimizer.zero_grad() self.optimizer_lambdas.zero_grad() self.optimizer_alphas.zero_grad()
def train_gp(self, seed=0, n_iter=100, n_epoch=50, lmbda=1.0, pred_interval=5, test=False, verbose=True): set_seed(seed) print('SEED=', seed) optimizer = opt.Adam(self.model.parameters()) start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) pred_start = torch.cuda.Event(enable_timing=True) pred_end = torch.cuda.Event(enable_timing=True) pred_time = 0.0 start.record() for i in range(n_iter): batch_nll = 0.0 batch_elbo = 0.0 batch_loss = 0.0 epoch = 0 for (X, Y) in self.data: if epoch > n_epoch: break epoch += 1 X.to(self.device) Y.to(self.device) Xa = self.vae(self.vae(X), encode=False) self.model.train() optimizer.zero_grad() delbo = self.vae.dsELBO(X, alpha=1.0, beta=1.0, gamma=1.0, verbose=False) nll = self.gp.NLL(Xa, Y) loss = -0.01 * delbo + lmbda * nll batch_nll += nll * X.shape[0] batch_elbo += delbo * X.shape[0] batch_loss += loss * X.shape[0] loss.backward() clip_grad_value_(self.model.parameters(), 10) optimizer.step() torch.cuda.empty_cache() if i % pred_interval == 0: torch.cuda.synchronize() record = { 'iter': i, 'nll': batch_nll.item() / self.train['X'].shape[0], 'elbo': batch_elbo.item() / self.train['X'].shape[0], 'loss': batch_loss.item() / self.train['X'].shape[0], } if test: pred_start.record() Ypred_full = [] Yt_full = [] Xr = self.vae(self.vae(self.original['X'], grad=False), encode=False, grad=False) for (Xt, Yt) in self.test: Xtr = self.vae(self.vae(Xt, grad=False), encode=False, grad=False) Ypred = self.gp(Xtr, Xr, self.original['Y'], var=False) Ypred_full.append(Ypred) Yt_full.append(Yt) del Xtr torch.cuda.empty_cache() Ypred_full = torch.cat(Ypred_full) Yt_full = torch.cat(Yt_full) record['rmse'] = rmse(Ypred_full, Yt_full).item() pred_end.record() torch.cuda.synchronize() pred_time += pred_start.elapsed_time(pred_end) del Xr torch.cuda.empty_cache() end.record() torch.cuda.synchronize() record['time'] = start.elapsed_time(end) - pred_time if verbose: print(record) self.history.append(record) return self.history
def clip_parameters(model_dict, clip_val=10): for name, model in model_dict.items(): if 'model' in name: clip_grad_value_(model.parameters(), clip_val) return model_dict
def train_vqvae(train_cnt, vqvae_model, opt, info, train_data_loader, valid_data_loader): st = time.time() #for batch_idx, (data, label, data_index) in enumerate(train_loader): batches = 0 while train_cnt < info['VQ_NUM_EXAMPLES_TO_TRAIN']: vqvae_model.train() opt.zero_grad() states, actions, rewards, values, pred_states, terminals, is_new_epoch, relative_indexes = train_data_loader.get_framediff_minibatch( ) # because we have 4 layers in vqvae, need to be divisible by 2, 4 times states = (2 * reshape_input(torch.FloatTensor(states)) - 1).to( info['DEVICE']) rec = ( 2 * reshape_input(torch.FloatTensor(pred_states)[:, 0][:, None]) - 1).to(info['DEVICE']) actions = torch.LongTensor(actions).to(info['DEVICE']) rewards = torch.LongTensor(rewards).to(info['DEVICE']) # dont normalize diff diff = (reshape_input( torch.FloatTensor(pred_states)[:, 1][:, None])).to(info['DEVICE']) x_d, z_e_x, z_q_x, latents, pred_actions, pred_rewards = vqvae_model( states) z_q_x.retain_grad() rec_est = x_d[:, :info['nmix']] diff_est = x_d[:, info['nmix']:] loss_rec = info['ALPHA_REC'] * discretized_mix_logistic_loss( rec_est, rec, nr_mix=info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_diff = discretized_mix_logistic_loss( diff_est, diff, nr_mix=info['NR_LOGISTIC_MIX'], DEVICE=info['DEVICE']) loss_act = info['ALPHA_ACT'] * F.nll_loss( pred_actions, actions, weight=info['actions_weight']) loss_rewards = info['ALPHA_REW'] * F.nll_loss( pred_rewards, rewards, weight=info['rewards_weight']) loss_2 = F.mse_loss(z_q_x, z_e_x.detach()) loss_act.backward(retain_graph=True) loss_rec.backward(retain_graph=True) loss_diff.backward(retain_graph=True) loss_3 = info['BETA'] * F.mse_loss(z_e_x, z_q_x.detach()) vqvae_model.embedding.zero_grad() z_e_x.backward(z_q_x.grad, retain_graph=True) loss_2.backward(retain_graph=True) loss_3.backward() parameters = list(vqvae_model.parameters()) clip_grad_value_(parameters, 10) opt.step() bs = float(x_d.shape[0]) avg_train_losses = [ loss_rewards.item() / bs, loss_act.item() / bs, loss_rec.item() / bs, loss_diff.item() / bs, loss_2.item() / bs, loss_3.item() / bs ] if batches > info['VQ_MIN_BATCHES_BEFORE_SAVE']: if ((train_cnt - info['vq_last_save']) >= info['VQ_SAVE_EVERY']): info['vq_last_save'] = train_cnt info['vq_save_times'].append(time.time()) avg_valid_losses = valid_vqvae(train_cnt, vqvae_model, info, valid_data_loader) handle_plot_ckpt(train_cnt, info, avg_train_losses, avg_valid_losses) filename = info[ 'vq_model_base_filepath'] + "_%010dex.pt" % train_cnt print("SAVING MODEL:%s" % filename) print("Saving model at cnt:%s cnt since last saved:%s" % (train_cnt, train_cnt - info['vq_last_save'])) state = { 'vqvae_state_dict': vqvae_model.state_dict(), 'vq_optimizer': opt.state_dict(), 'vq_embedding': vqvae_model.embedding, 'vq_info': info, } save_checkpoint(state, filename=filename) train_cnt += len(states) batches += 1 if not batches % 1000: print("finished %s epoch after %s seconds at cnt %s" % (batches, time.time() - st, train_cnt)) return train_cnt