def importance_sample(self, x): with torch.enable_grad(): q_z_x, _ = self.encoder.forward(x) mu_svi = q_z_x.mu logvar_svi = q_z_x.logvar for i in range(self.n_svi_step): q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) z = q_z_x.sample() p_x_z = self.decoder.forward(z) loss = self.loss(x, z, p_x_z, self.prior, q_z_x) # create_graph=True does this allow backprop through this when we update the whole thing mu_svi_grad, logvar_svi_grad = torch.autograd.grad( loss, inputs=(mu_svi, logvar_svi), create_graph=True) # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1) # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1) mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5) logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5) # gradient ascent. mu_svi = mu_svi + self.svi_lr * mu_svi_grad logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad # obtain z_K q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) q_z_x = q_z_x.repeat(n_importance_sample) z = q_z_x.sample() p_x_z = self.decoder(z) x = x.view(-1, self.image_size).unsqueeze(1).repeat( 1, n_importance_sample, 1).view(-1, self.image_size) return self.importance_weighting(x, z, p_x_z, self.prior, q_z_x)
def step(self, minibatch): self.optimizer.zero_grad() loss = self.loss(minibatch) loss.backward() utils.clip_grad_norm(self.optimizer, 100) self.optimizer.step() return loss
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm, writer): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() acc_meter = utils.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, _ in trainloader: x = x.to(device) optimizer.zero_grad() z = net(x) sldj = net.module.logdet() loss = loss_fn(z, sldj=sldj) loss_meter.update(loss.item(), x.size(0)) loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() progress_bar.set_postfix(loss=loss_meter.avg, bpd=utils.bits_per_dim(x, loss_meter.avg)) progress_bar.update(x.size(0)) writer.add_scalar("train/loss", loss_meter.avg, epoch) writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_meter.avg), epoch)
def forward(self, x): with torch.enable_grad(): q_z_x, _ = self.encoder.forward(x) mu_svi = q_z_x.mu logvar_svi = q_z_x.logvar for i in range(self.n_svi_step): q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) z = q_z_x.sample() p_x_z = self.decoder.forward(z) loss = self.loss(x, z, p_x_z, self.prior, q_z_x) # create_graph=True does this allow backprop through this when we update the whole thing mu_svi_grad, logvar_svi_grad = torch.autograd.grad( loss, inputs=(mu_svi, logvar_svi), create_graph=True) # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1) # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1) mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5) logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5) # gradient ascent. mu_svi = mu_svi + self.svi_lr * mu_svi_grad logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad # obtain z_K q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) z_K = q_z_x.sample() p_x_z = self.decoder.forward(z_K) loss = self.loss(x, z_K, p_x_z, self.prior, q_z_x) return loss
def train(epoch, net, trainloader, device, optimizer, scheduler, loss_fn, max_grad_norm, conditional): global global_step print('\nEpoch: %d' % epoch) net.train() loss_meter = util.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x in trainloader: optimizer.zero_grad() if conditional: x, x2 = x x = x.to(device) x2 = x2.to(device) z, sldj = net(x, x2, reverse=False) else: x = x.to(device) z, sldj = net(x, reverse=False) loss = loss_fn(z, sldj) loss_meter.update(loss.item(), x.size(0)) loss.backward() if max_grad_norm > 0: util.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() scheduler.step(global_step) progress_bar.set_postfix(nll=loss_meter.avg, bpd=util.bits_per_dim(x, loss_meter.avg), lr=optimizer.param_groups[0]['lr']) progress_bar.update(x.size(0)) global_step += x.size(0)
def _perform_train(self, X, Act, Rew, Done, Mask, n_samples): # clear grad batch_size = Act.size(0) seq_len = Act.size(1) self.optim.zero_grad() # forward pass X = Variable(X) P, L, V, _ = self.policy(X, self.init_h) L = L[:, :seq_len, :] # logits P = P[:, :seq_len, :] # remove last one # compute accumulative Reward V_data = V.data cur_r = V_data[:, seq_len] V = V[:, :seq_len] V_data = V_data[:, :seq_len] R_list = [] for t in range(seq_len - 1, -1, -1): cur_r = Rew[:, t] + self.gamma * Done[:, t] * cur_r R_list.append(cur_r) R_list.reverse() R = torch.stack(R_list, dim=1) # Advantage Normalization Adv = (R - V_data) * Mask # advantage avg_val = Adv.sum() / n_samples Adv = (Adv - avg_val) * Mask # reduce mean std_val = np.sqrt(torch.sum(Adv**2) / n_samples) # standard dev Adv = Variable(Adv / max(std_val, 0.1)) # critic loss R = Variable(R) Mask = Variable(Mask) critic_loss = torch.sum(Mask * (R - V)**2) / n_samples # policy gradient loss Act = Variable(Act) # [batch_size, seq_len] Act = Act.unsqueeze(2) # [batch_size, seq_len, 1] P_Act = torch.gather(P, 2, Act).squeeze(dim=2) # [batch_size, seq_len] pg_loss = -torch.sum(P_Act * Adv * Mask) / n_samples # entropy bonus P_Ent = torch.sum(self.policy.entropy(L) * Mask) / n_samples pg_loss -= self.entropy_penalty * P_Ent # backprop loss = pg_loss + critic_loss loss.backward() L_norm = torch.sum(torch.sum(L**2, dim=-1) * Mask) / n_samples ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0], policy_entropy=P_Ent.data.cpu().numpy()[0], critic_loss=critic_loss.data.cpu().numpy()[0], logits_norm=L_norm.data.cpu().numpy()[0]) # gradient clip utils.clip_grad_norm(self.policy.parameters(), self.grad_clip) # apply SGD step self.optim.step() return ret_dict
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm, writer): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() loss_unsup_meter = utils.AverageMeter() loss_nll_meter = utils.AverageMeter() acc_meter = utils.AverageMeter() with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, y in trainloader: x = x.to(device) y = y.to(device) optimizer.zero_grad() z = net(x) sldj = net.module.logdet() logits = loss_fn.prior.class_logits(z.reshape((len(z), -1))) loss_nll = F.cross_entropy(logits, y) loss_unsup = loss_fn(z, sldj=sldj) loss = loss_nll + loss_unsup loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() preds = torch.argmax(logits, dim=1) acc = (preds == y).float().mean().item() acc_meter.update(acc, x.size(0)) loss_meter.update(loss.item(), x.size(0)) loss_unsup_meter.update(loss_unsup.item(), x.size(0)) loss_nll_meter.update(loss_nll.item(), x.size(0)) progress_bar.set_postfix(loss=loss_meter.avg, bpd=utils.bits_per_dim( x, loss_unsup_meter.avg), acc=acc_meter.avg) progress_bar.update(x.size(0)) x_img = torchvision.utils.make_grid(x[:10], nrow=2, padding=2, pad_value=255) writer.add_image("data/x", x_img) writer.add_scalar("train/loss", loss_meter.avg, epoch) writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch) writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch) writer.add_scalar("train/acc", acc_meter.avg, epoch) writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_unsup_meter.avg), epoch)
def write_summary(self, x, writer, epoch): q_z_x, _ = self.encoder.forward(x) mu_svi = q_z_x.mu logvar_svi = q_z_x.logvar for i in range(self.n_svi_step): q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) z = q_z_x.sample() p_x_z = self.decoder.forward(z) loss = self.loss(x, z, p_x_z, self.prior, q_z_x) # create_graph=True does this allow backprop through this when we update the whole thing mu_svi_grad, logvar_svi_grad = torch.autograd.grad( loss, inputs=(mu_svi, logvar_svi), create_graph=True) # mu_svi_grad = utils.clip_grad(mu_svi_grad, 1) # logvar_svi_grad = utils.clip_grad(logvar_svi_grad, 1) mu_svi_grad = utils.clip_grad_norm(mu_svi_grad, 5) logvar_svi_grad = utils.clip_grad_norm(logvar_svi_grad, 5) # gradient ascent. mu_svi = mu_svi + self.svi_lr * mu_svi_grad logvar_svi = logvar_svi + self.svi_lr * logvar_svi_grad # obtain z_K q_z_x = distribution.DiagonalGaussian(mu_svi, logvar_svi) z_K = q_z_x.sample() p_x_z = self.decoder.forward(z_K) writer.add_scalar( 'kl_div', torch.mean(-self.prior.log_probability(z_K) + q_z_x.log_probability(z_K)).item(), epoch) writer.add_scalar('recon_error', -torch.mean(p_x_z.log_probability(x)).item(), epoch) writer.add_image('data', vutils.make_grid(self.dataset.unpreprocess(x)), epoch) writer.add_image( 'reconstruction_z', vutils.make_grid(self.dataset.unpreprocess(p_x_z.mu).clamp(0, 1)), epoch) sample = torch.randn(len(x), z.shape[1]).cuda() sample = self.decoder(sample).mu writer.add_image( 'generated', vutils.make_grid(self.dataset.unpreprocess(sample).clamp(0, 1)), epoch)
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm, writer, num_samples=10, sampling=True, tb_freq=100): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() iter_count = 0 batch_count = 0 with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, _ in trainloader: iter_count += 1 batch_count += x.size(0) x = x.to(device) optimizer.zero_grad() z = net(x) sldj = net.module.logdet() loss = loss_fn(z, sldj=sldj) loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() loss_meter.update(loss.item(), x.size(0)) progress_bar.set_postfix(loss=loss_meter.avg, bpd=utils.bits_per_dim(x, loss_meter.avg)) progress_bar.update(x.size(0)) if iter_count % tb_freq == 0 or batch_count == len( trainloader.dataset): tb_step = epoch * (len(trainloader.dataset)) + batch_count writer.add_scalar("train/loss", loss_meter.avg, tb_step) writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_meter.avg), tb_step) if sampling: net.eval() draw_samples(net, writer, loss_fn, num_samples, device, tuple(x[0].shape), tb_step) net.train()
def train(model, optimizer, train_iter, epoch, args): """Train with mini-batches.""" model.train() total_stats = Statistics() batch_stats = Statistics() num_batches = len(train_iter) for i, batch in enumerate(train_iter): if args.warmup > 0: args.beta = min(1, args.beta + 1.0 / (args.warmup * num_batches)) sents = batch.sent loss, stats = model(sents, args.beta) optimizer.zero_grad() loss.backward() utils.clip_grad_norm(optimizer, args) optimizer.step() total_stats.update(stats) batch_stats.update(stats) batch_stats = report_batch(batch_stats, epoch, i, num_batches, args) torch.cuda.empty_cache() return total_stats
def update(self): if (self.sample_counter < self.args['update_freq']) or \ not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']): return None self.sample_counter = 0 self.train() tt = time.time() obs, full_act, rew, obs_next, done = \ self.replay_buffer.sample(self.batch_size) #act = split_batched_array(full_act, self.act_shape) time_counter[-1] += time.time() - tt tt = time.time() # convert to variables obs_n = self._process_frames(obs) obs_next_n = self._process_frames(obs_next, volatile=True) full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor) rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor) done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor) time_counter[0] += time.time() - tt tt = time.time() # train q network common.debugger.print('Grad Stats of Q Update ...', False) target_act_next = self.target_p(obs_next_n) target_q_next = self.target_q(obs_next_n, target_act_next) target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next target_q.volatile = False current_q = self.q(obs_n, full_act_n) q_norm = (current_q * current_q).mean().squeeze() # l2 norm q_loss = F.smooth_l1_loss( current_q, target_q) + self.args['critic_penalty'] * q_norm # huber common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) self.q_optim.zero_grad() q_loss.backward() common.debugger.print('Stats of Q Network (*before* clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.q) if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) self.q_optim.step() # train p network new_act_n = self.p(obs_n) # NOTE: maybe use <gumbel_noise=None> ? q_val = self.q(obs_n, new_act_n) p_loss = -q_val.mean().squeeze() p_ent = self.p.entropy().mean().squeeze() if self.args['ent_penalty'] is not None: p_loss -= self.args['ent_penalty'] * p_ent # encourage exploration common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False) self.p_optim.zero_grad() self.q_optim.zero_grad() # important!! clear the grad in Q p_loss.backward() if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip) self.p_optim.step() common.debugger.print( 'Stats of Q Network (in the phase of P-Update)....', False) utils.log_parameter_stats(common.debugger, self.q) common.debugger.print('Stats of P Network (after clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.p) time_counter[1] += time.time() - tt tt = time.time() # update target networks make_update_exp(self.p, self.target_p, rate=self.target_update_rate) make_update_exp(self.q, self.target_q, rate=self.target_update_rate) common.debugger.print('Stats of Q Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_q) common.debugger.print('Stats of P Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_p) time_counter[2] += time.time() - tt return dict(policy_loss=p_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0])
def update(self, obs, init_hidden, act, rew, done, target=None, supervision_mask=None, mask_input=None, return_kl_divergence=True): """ :param obs: list of list of [dims]... :param init_hidden: list of [layer, 1, units] :param act: [batch, seq_len] :param rew: [batch, seq_len] :param done: [batch, seq_len] :param target: [batch, seq_len, n_instruction] or None (when single-target) :param supervision_mask: timesteps marked with supervised learning loss [batch, seq_len] or None (pure RL) """ tt = time.time() # reward clipping if self.rew_clip is not None: rew = np.clip(rew, -self.rew_clip, self.rew_clip) # convert data to Variables obs = self._create_gpu_tensor( obs, return_variable=True) # [batch, t_max+1, dims...] init_hidden = self._create_gpu_hidden( init_hidden, return_variable=True) # [layers, batch, units] if target is not None: target = self._create_target_tensor(target, return_variable=True) if mask_input is not None: mask_input = self._create_feature_tensor(mask_input, return_variable=True) act = Variable( torch.from_numpy(act).type(LongTensor)) # [batch, t_max] mask = 1.0 - torch.from_numpy(done).type(FloatTensor) # [batch, t_max] mask_var = Variable(mask) sup_mask = None if supervision_mask is None else torch.from_numpy( supervision_mask).type(ByteTensor) # [batch, t_max] time_counter[0] += time.time() - tt batch_size = self.batch_size t_max = self.t_max gamma = self.gamma tt = time.time() if self.accu_grad_steps == 0: # clear grad self.optim.zero_grad() # forward pass logits = [] logprobs = [] values = [] t_obs_slices = torch.chunk(obs, t_max + 1, dim=1) obs_slices = [t.contiguous() for t in t_obs_slices] if target is not None: t_target_slices = torch.chunk(target, t_max + 1, dim=1) target_slices = [t.contiguous() for t in t_target_slices] if mask_input is not None: t_mask_input_slices = torch.chunk(mask_input, t_max + 1, dim=1) mask_input_slices = [m.contiguous() for m in t_mask_input_slices] cur_h = init_hidden for t in range(t_max): #cur_obs = obs[:, t:t+1, ...].contiguous() cur_obs = obs_slices[t] t_target = None if target is None else target_slices[t] t_mask = None if mask_input is None else mask_input_slices[t] cur_logp, cur_val, nxt_h = self.policy(cur_obs, cur_h, target=t_target, extra_input_feature=t_mask) cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1]) values.append(cur_val) logprobs.append(cur_logp) logits.append(self.policy.logits) #cur_obs = obs[:, t_max:t_max + 1, ...].contiguous() cur_obs = obs_slices[-1] t_target = None if target is None else target_slices[-1] t_mask = None if mask_input is None else mask_input_slices[-1] nxt_val = self.policy(cur_obs, cur_h, only_value=True, return_tensor=True, target=t_target, extra_input_feature=t_mask) V = torch.cat(values, dim=1) # [batch, t_max] P = torch.cat(logprobs, dim=1) # [batch, t_max, n_act] L = torch.cat(logits, dim=1) p_ent = torch.mean(self.policy.entropy(L)) # compute entropy #L_norm = torch.mean(torch.norm(L, dim=-1)) L_norm = torch.mean(torch.sum(L * L, dim=-1)) # L^2 penalty # estimate accumulative rewards rew = torch.from_numpy(rew).type(FloatTensor) # [batch, t_max] R = [] cur_R = nxt_val.squeeze() # [batch] for t in range(t_max - 1, -1, -1): cur_mask = mask[:, t] cur_R = rew[:, t] + gamma * cur_R * cur_mask R.append(cur_R) R.reverse() R = Variable(torch.stack(R, dim=1)) # [batch, t_max] # estimate advantage A_dat = R.data - V.data # stop gradient here std_val = None if self.adv_norm: # perform advantage normalization std_val = max(A_dat.std(), 0.1) A_dat = (A_dat - A_dat.mean()) / (std_val + 1e-10) if sup_mask is not None: # supervision A_dat[ sup_mask > 0] = 1.0 # change A * log P(a) to log P(supervised_a), act has been modified in zmq_util A = Variable(A_dat) # [optional] A = Variable(rew) - V # compute loss #critic_loss = F.smooth_l1_loss(V, R) critic_loss = torch.mean((R - V)**2) pg_loss = -torch.mean(self.policy.logprob(act, P) * A) if self.args['entropy_penalty'] is not None: pg_loss -= self.args[ 'entropy_penalty'] * p_ent # encourage exploration loss = self.q_loss_coef * critic_loss + pg_loss if self.logit_loss_coef is not None: loss += self.logit_loss_coef * L_norm # backprop if self.grad_batch > 1: loss = loss / float(self.grad_batch) loss.backward() ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_loss=critic_loss.data.cpu().numpy()[0], logits_norm=L_norm.data.cpu().numpy()[0]) if std_val is not None: ret_dict['adv_norm'] = std_val if self.accu_grad_steps == 0: self.accu_ret_dict = ret_dict else: for k in ret_dict: self.accu_ret_dict[k] += ret_dict[k] self.accu_grad_steps += 1 if self.accu_grad_steps < self.grad_batch: # do not update parameter now time_counter[1] += time.time() - tt return None # update parameters for k in self.accu_ret_dict: self.accu_ret_dict[k] /= self.grad_batch ret_dict = self.accu_ret_dict self.accu_grad_steps = 0 # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip) self.optim.step() if return_kl_divergence: cur_h = init_hidden new_logprobs = [] for t in range(t_max): # cur_obs = obs[:, t:t+1, ...].contiguous() cur_obs = obs_slices[t] t_target = target_slices[t] if self.multi_target else None t_mask = None if mask_input is None else mask_input_slices[t] cur_logp, nxt_h = self.policy(cur_obs, cur_h, return_value=False, target=t_target, extra_input_feature=t_mask) cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1]) new_logprobs.append(cur_logp) new_P = torch.cat(new_logprobs, dim=1) kl = self.policy.kl_divergence(new_P, P).mean().data.cpu()[0] ret_dict['KL(P_new||P_old)'] = kl if kl > flag_max_kl_diff: self.lrate /= flag_lrate_coef self.optim.__dict__['param_groups'][0]['lr'] = self.lrate ret_dict['!!![NOTE]:'] = ( '------>>>> KL is too large (%.6f), decrease lrate to %.5f' % (kl, self.lrate)) elif (kl < flag_min_kl_diff) and (self.lrate < flag_max_lrate): self.lrate *= flag_lrate_coef self.optim.__dict__['param_groups'][0]['lr'] = self.lrate ret_dict['!!![NOTE]:'] = ( '------>>>> KL is too small (%.6f), increase lrate to %.5f' % (kl, self.lrate)) time_counter[1] += time.time() - tt return ret_dict
def update(self): if (self.a is not None) or \ not self.replay_buffer.can_sample(self.batch_size * 4): return None self.sample_counter = 0 self.train() tt = time.time() obs, full_act, rew, msk, done, total_length = \ self.replay_buffer.sample(self.batch_size, seq_len=self.batch_len) total_length = float(total_length) #act = split_batched_array(full_act, self.act_shape) time_counter[-1] += time.time() - tt tt = time.time() # convert to variables _full_obs_n = self._process_frames( obs, merge_dim=False, return_variable=False) # [batch, seq_len+1, ...] batch = _full_obs_n.size(0) seq_len = _full_obs_n.size(1) - 1 full_obs_n = Variable(_full_obs_n, volatile=True) obs_n = Variable( _full_obs_n[:, :-1, ...]).contiguous() # [batch, seq_len, ...] obs_next_n = Variable(_full_obs_n[:, 1:, ...], volatile=True).contiguous() img_c, img_h, img_w = obs_n.size(-3), obs_n.size(-2), obs_n.size(-1) packed_obs_n = obs_n.view(-1, img_c, img_h, img_w) packed_obs_next_n = obs_next_n.view(-1, img_c, img_h, img_w) full_act_n = Variable(torch.from_numpy(full_act)).type( FloatTensor) # [batch, seq_len, ...] act_padding = Variable( torch.zeros(self.batch_size, 1, full_act_n.size(-1))).type(FloatTensor) pad_act_n = torch.cat([act_padding, full_act_n], dim=1) # [batch, seq_len+1, ...] rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor) msk_n = Variable(torch.from_numpy(msk)).type( FloatTensor) # [batch, seq_len] done_n = Variable(torch.from_numpy(done)).type( FloatTensor) # [batch, seq_len] time_counter[0] += time.time() - tt tt = time.time() # train q network common.debugger.print('Grad Stats of Q Update ...', False) full_target_act, _ = self.target_p( full_obs_n, act=pad_act_n) # list([batch, seq_len+1, act_dim]) target_act_next = torch.cat(full_target_act, dim=-1)[:, 1:, :] act_dim = target_act_next.size(-1) target_act_next = target_act_next.resize(batch * seq_len, act_dim) target_q_next = self.target_q(packed_obs_next_n, act=target_act_next) #[batch * seq_len] target_q_next.view(batch, seq_len) target_q = (rew_n + self.gamma * done_n * target_q_next) * msk_n target_q = target_q.view(-1) target_q.volatile = False current_q = self.q(packed_obs_n, act=full_act_n.view( -1, act_dim)) * msk_n.view(-1) q_norm = (current_q * current_q).sum() / total_length # l2 norm q_loss = F.smooth_l1_loss(current_q, target_q, size_average=False) / total_length \ + self.args['critic_penalty']*q_norm # huber common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) self.q_optim.zero_grad() q_loss.backward() common.debugger.print('Stats of Q Network (*before* clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.q) if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) self.q_optim.step() # train p network new_act_n, _ = self.p( obs_n, act=pad_act_n[:, :-1, :]) # [batch, seq_len, act_dim] new_act_n = torch.cat(new_act_n, dim=-1) new_act_n = new_act_n.view(-1, act_dim) q_val = self.q(packed_obs_n, new_act_n) * msk_n.view(-1) p_loss = -q_val.sum() / total_length p_ent = self.p.entropy(weight=msk_n).sum() / total_length if self.args['ent_penalty'] is not None: p_loss -= self.args['ent_penalty'] * p_ent # encourage exploration common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False) self.p_optim.zero_grad() self.q_optim.zero_grad() # important!! clear the grad in Q p_loss.backward() if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip) self.p_optim.step() common.debugger.print( 'Stats of Q Network (in the phase of P-Update)....', False) utils.log_parameter_stats(common.debugger, self.q) common.debugger.print('Stats of P Network (after clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.p) time_counter[1] += time.time() - tt tt = time.time() # update target networks make_update_exp(self.p, self.target_p, rate=self.target_update_rate) make_update_exp(self.q, self.target_q, rate=self.target_update_rate) common.debugger.print('Stats of Q Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_q) common.debugger.print('Stats of P Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_p) time_counter[2] += time.time() - tt return dict(policy_loss=p_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0])
def train(epoch, net, trainloader, ood_loader, device, optimizer, loss_fn, max_grad_norm, writer, negative_val=-1e5, num_samples=10, tb_freq=100): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() loss_positive_meter = utils.AverageMeter() loss_negative_meter = utils.AverageMeter() iter_count = 0 batch_count = 0 pooler = MedianPool2d(7, padding=3) with tqdm(total=len(trainloader.dataset)) as progress_bar: for (x, _), (x_transposed, _) in zip(trainloader, ood_loader): bs = x.shape[0] x = torch.cat((x, x_transposed), dim=0) iter_count += 1 batch_count += bs x = x.to(device) optimizer.zero_grad() z = net(x) sldj = net.module.logdet() loss = loss_fn(z, sldj=sldj, mean=False) loss[bs:] *= (-1) loss_positive = loss[:bs] loss_negative = loss[bs:] if (loss_negative > negative_val).sum() > 0: loss_negative = loss_negative[loss_negative > negative_val] loss_negative = loss_negative.mean() loss_positive = loss_positive.mean() loss = 0.5*(loss_positive + loss_negative) else: loss_negative = torch.tensor(0.) loss_positive = loss_positive.mean() loss = loss_positive loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() loss_meter.update(loss.item(), bs) loss_positive_meter.update(loss_positive.item(), bs) loss_negative_meter.update(loss_negative.item(), bs) progress_bar.set_postfix( pos_bpd=utils.bits_per_dim(x[:bs], loss_positive_meter.avg), neg_bpd=utils.bits_per_dim(x[bs:], -loss_negative_meter.avg), neg_loss=loss_negative.mean().item()) progress_bar.update(bs) if iter_count % tb_freq == 0 or batch_count == len(trainloader.dataset): tb_step = epoch*(len(trainloader.dataset))+batch_count writer.add_scalar("train/loss", loss_meter.avg, tb_step) writer.add_scalar("train/loss_positive", loss_positive_meter.avg, tb_step) writer.add_scalar("train/loss_negative", loss_negative_meter.avg, tb_step) writer.add_scalar("train/bpd_positive", utils.bits_per_dim(x[:bs], loss_positive_meter.avg), tb_step) writer.add_scalar("train/bpd_negative", utils.bits_per_dim(x[bs:], -loss_negative_meter.avg), tb_step) x1_img = torchvision.utils.make_grid(x[:10], nrow=2 , padding=2, pad_value=255) x2_img = torchvision.utils.make_grid(x[-10:], nrow=2 , padding=2, pad_value=255) writer.add_image("data/x", x1_img) writer.add_image("data/x_transposed", x2_img) net.eval() draw_samples(net, writer, loss_fn, num_samples, device, tuple(x[0].shape), tb_step) net.train()
def update(self, cpu_batch, gpu_batch): #print('[elf_ddpg] update!!!!') self.update_counter += 1 self.train() tt = time.time() obs_n, obs_next_n, full_act_n, rew_n, done_n = self._process_elf_frames( gpu_batch, keep_time=False) # collapse all the samples obs_n = (obs_n.type(FloatTensor) - 128.0) / 256.0 obs_n = Variable(obs_n) obs_next_n = (obs_next_n.type(FloatTensor) - 128.0) / 256.0 obs_next_n = Variable(obs_next_n, volatile=True) full_act_n = Variable(full_act_n) rew_n = Variable(rew_n, volatile=True) done_n = Variable(done_n, volatile=True) self.sample_counter += obs_n.size(0) time_counter[0] += time.time() - tt #print('[elf_ddpg] data loaded!!!!!') tt = time.time() self.optim.zero_grad() # train p network q_val = self.net(obs_n, action=None, output_critic=True) p_loss = -q_val.mean().squeeze() p_ent = self.net.entropy().mean().squeeze() if self.args['ent_penalty'] is not None: p_loss -= self.args['ent_penalty'] * p_ent # encourage exploration common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False) p_loss.backward() self.net.clear_critic_specific_grad( ) # we do not need to compute q_grad for actor!!! # train q network common.debugger.print('Grad Stats of Q Update ...', False) target_q_next = self.target_net(obs_next_n, output_critic=True) target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next target_q.volatile = False current_q = self.net(obs_n, action=full_act_n, output_critic=True) q_norm = (current_q * current_q).mean().squeeze() # l2 norm q_loss = F.smooth_l1_loss( current_q, target_q) + self.args['critic_penalty'] * q_norm # huber common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) q_loss = q_loss * self.q_loss_coef q_loss.backward() # total_loss = q_loss + p_loss # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip) self.optim.step() common.debugger.print('Stats of P Network (after clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.net) time_counter[1] += time.time() - tt tt = time.time() # update target networks make_update_exp(self.net, self.target_net, rate=self.target_update_rate) common.debugger.print('Stats of Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_net) time_counter[2] += time.time() - tt stats = dict(policy_loss=p_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0] / self.q_loss_coef, eplen=cpu_batch[-1]['stats_eplen'].mean(), avg_rew=cpu_batch[-1]['stats_rew'].mean()) self.print_log(stats)
def update(self): if (self.sample_counter < self.args['update_freq']) or \ not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 20)): return None self._update_counter += 1 self.sample_counter = 0 self.train() tt = time.time() obs, act, rew, obs_next, done = \ self.replay_buffer.sample(self.batch_size) if self.multi_target: target_idx = self.target_buffer[self.replay_buffer._idxes] targets = np.zeros((self.batch_size, common.n_target_instructions), dtype=np.uint8) targets[list(range(self.batch_size)), target_idx] = 1 #act = split_batched_array(full_act, self.act_shape) time_counter[-1] += time.time() - tt tt = time.time() # convert to variables obs_n = self._process_frames(obs) obs_next_n = self._process_frames(obs_next, volatile=True) act_n = Variable(torch.from_numpy(act)).type(LongTensor) rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor) done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor) if self.multi_target: target_n = Variable(torch.from_numpy(targets).type(FloatTensor)) else: target_n = None time_counter[0] += time.time() - tt tt = time.time() # compute critic loss target_q_val_next = self.target_net(obs_next_n, only_q_value=True, target=target_n) # double Q learning target_act_next = torch.max(self.net(obs_next_n, only_q_value=True, target=target_n), dim=1, keepdim=True)[1] target_q_next = torch.gather(target_q_val_next, 1, target_act_next).squeeze() target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next target_q.volatile=False current_q_val = self.net(obs_n, only_q_value=True, target=target_n) current_q = torch.gather(current_q_val, 1, act_n.view(-1, 1)).squeeze() q_norm = (current_q * current_q).mean().squeeze() q_loss = F.smooth_l1_loss(current_q, target_q) common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False) total_loss = q_loss.mean() if self.args['critic_penalty'] > 1e-10: total_loss += self.args['critic_penalty']*q_norm # compute gradient self.optim.zero_grad() #autograd.backward([total_loss, current_act], [torch.ones(1), None]) total_loss.backward() if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip) self.optim.step() common.debugger.print('Stats of Model (*after* clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.net) time_counter[1] += time.time() -tt tt =time.time() # update target networks if self.target_net_update_freq is not None: if self._update_counter == self.target_net_update_freq: self._update_counter = 0 self.target_net.load_state_dict(self.net.state_dict()) else: make_update_exp(self.net, self.target_net, rate=self.target_update_rate) common.debugger.print('Stats of Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_net) time_counter[2] += time.time()-tt return dict(critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0])
def update(self, obs, init_hidden, act, rew, done, target=None, aux_target=None, return_kl_divergence=True): """ :param obs: list of list of [dims]... :param init_hidden: list of [layer, 1, units] :param act: [batch, seq_len] :param rew: [batch, seq_len] :param done: [batch, seq_len] :param target: [batch, seq_len, n_instruction] or None (when single-target) :param aux_target: 0/1 label matrix [batch, seq_len, n_aux_pred] or None (not updating the aux-loss) """ assert (aux_target is not None), 'AuxTrainer must be given <aux_target>' tt = time.time() # reward clipping rew = np.clip(rew, -1, 1) # convert data to Variables obs = self._create_gpu_tensor( obs, return_variable=True) # [batch, t_max+1, dims...] init_hidden = self._create_gpu_hidden( init_hidden, return_variable=True) # [layers, batch, units] if target is not None: target = self._create_target_tensor(target, return_variable=True) aux_target = self._create_aux_target_tensor(aux_target) act = Variable( torch.from_numpy(act).type(LongTensor)) # [batch, t_max] mask = 1.0 - torch.from_numpy(done).type(FloatTensor) # [batch, t_max] mask_var = Variable(mask) time_counter[0] += time.time() - tt batch_size = self.batch_size t_max = self.t_max gamma = self.gamma tt = time.time() self.optim.zero_grad() # forward pass logits = [] logprobs = [] values = [] aux_preds = [] obs = obs t_obs_slices = torch.chunk(obs, t_max + 1, dim=1) obs_slices = [t.contiguous() for t in t_obs_slices] if target is not None: t_target_slices = torch.chunk(target, t_max + 1, dim=1) target_slices = [t.contiguous() for t in t_target_slices] cur_h = init_hidden for t in range(t_max): #cur_obs = obs[:, t:t+1, ...].contiguous() cur_obs = obs_slices[t] if target is not None: ret_vals = self.policy( cur_obs, cur_h, target=target_slices[t], compute_aux_pred=True, return_aux_logprob=self.use_supervised_loss) else: ret_vals = self.policy( cur_obs, cur_h, compute_aux_pred=True, return_aux_logprob=self.use_supervised_loss) cur_logp, cur_val, nxt_h, aux_p = ret_vals cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1]) values.append(cur_val) logprobs.append(cur_logp) logits.append(self.policy.logits) aux_preds.append(aux_p) #cur_obs = obs[:, t_max:t_max + 1, ...].contiguous() cur_obs = obs_slices[-1] if target is not None: nxt_val = self.policy(cur_obs, cur_h, only_value=True, return_tensor=True, target=target_slices[-1]) else: nxt_val = self.policy(cur_obs, cur_h, only_value=True, return_tensor=True) V = torch.cat(values, dim=1) # [batch, t_max] P = torch.cat(logprobs, dim=1) # [batch, t_max, n_act] L = torch.cat(logits, dim=1) p_ent = torch.mean(self.policy.entropy(L)) # compute entropy Aux_P = torch.cat(aux_preds, dim=1) # [batch, t_max, n_aux_pred] # estimate accumulative rewards rew = torch.from_numpy(rew).type(FloatTensor) # [batch, t_max] R = [] cur_R = nxt_val.squeeze() # [batch] for t in range(t_max - 1, -1, -1): cur_mask = mask[:, t] cur_R = rew[:, t] + gamma * cur_R * cur_mask R.append(cur_R) R.reverse() R = Variable(torch.stack(R, dim=1)) # [batch, t_max] # estimate advantage A = Variable(R.data - V.data) # stop gradient here # [optional] A = Variable(rew) - V # compute loss #critic_loss = F.smooth_l1_loss(V, R) critic_loss = torch.mean((R - V)**2) pg_loss = -torch.mean(self.policy.logprob(act, P) * A) if self.args['entropy_penalty'] is not None: pg_loss -= self.args[ 'entropy_penalty'] * p_ent # encourage exploration # aux task loss aux_loss = -(Aux_P * aux_target).sum(dim=-1).mean() loss = self.q_loss_coef * critic_loss + pg_loss + self.aux_loss_coef * aux_loss # backprop loss.backward() # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip) self.optim.step() ret_dict = dict(pg_loss=pg_loss.data.cpu().numpy()[0], aux_task_loss=aux_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_loss=critic_loss.data.cpu().numpy()[0]) if return_kl_divergence: cur_h = init_hidden new_logprobs = [] for t in range(t_max): # cur_obs = obs[:, t:t+1, ...].contiguous() cur_obs = obs_slices[t] if self.multi_target: cur_target = target_slices[t] else: cur_target = None cur_logp, nxt_h = self.policy(cur_obs, cur_h, return_value=False, target=cur_target) cur_h = self.policy.mark_hidden_states(nxt_h, mask_var[:, t:t + 1]) new_logprobs.append(cur_logp) new_P = torch.cat(new_logprobs, dim=1) kl = self.policy.kl_divergence(new_P, P).mean().data.cpu()[0] ret_dict['KL(P_new||P_old)'] = kl if kl > flag_max_kl_diff: self.lrate /= flag_lrate_coef self.optim.__dict__['param_groups'][0]['lr'] = self.lrate ret_dict['!!![NOTE]:'] = ( '------>>>> KL is too large (%.6f), decrease lrate to %.5f' % (kl, self.lrate)) elif (kl < flag_min_kl_diff) and (self.lrate < flag_max_lrate): self.lrate *= flag_lrate_coef self.optim.__dict__['param_groups'][0]['lr'] = self.lrate ret_dict['!!![NOTE]:'] = ( '------>>>> KL is too small (%.6f), increase lrate to %.5f' % (kl, self.lrate)) time_counter[1] += time.time() - tt return ret_dict
def train(epoch, net, trainloader, device, optimizer, loss_fn, max_grad_norm, writer, num_samples=10, sampling=True, tb_freq=100): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() loss_unsup_meter = utils.AverageMeter() loss_reconstr_meter = utils.AverageMeter() kl_loss_meter = utils.AverageMeter() acc_meter = utils.AverageMeter() iter_count = 0 batch_count = 0 with tqdm(total=len(trainloader.dataset)) as progress_bar: for x, _ in trainloader: iter_count += 1 batch_count += x.size(0) x = x.to(device) optimizer.zero_grad() z = net(x) sldj = net.module.logdet() loss_unsup = loss_fn(z, sldj=sldj) # if vae_loss: # logvar_z = -logvar_net(z) # z_perturbed = z + torch.randn_like(z) * torch.exp(0.5 * logvar_z) # x_reconstr = net.module.inverse(z_perturbed) # if decoder_likelihood == 'binary_ce': # loss_reconstr = F.binary_cross_entropy(x_reconstr, x, reduction='sum') / x.size(0) # else: # loss_reconstr = F.mse_loss(x_reconstr, x, reduction='sum') / x.size(0) # kl_loss = -0.5 * (logvar_z - logvar_z.exp()).sum(dim=[1]) # kl_loss = kl_loss.mean() # loss = loss_unsup + loss_reconstr * reconstr_weight + kl_loss * reconstr_weight # else: logvar_z = torch.tensor([0.]) loss_reconstr = torch.tensor([0.]) kl_loss = torch.tensor([0.]) loss = loss_unsup loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() loss_unsup_meter.update(loss_unsup.item(), x.size(0)) loss_reconstr_meter.update(loss_reconstr.item(), x.size(0)) kl_loss_meter.update(kl_loss.item(), x.size(0)) loss_meter.update(loss.item(), x.size(0)) progress_bar.set_postfix(loss=loss_meter.avg, bpd=utils.bits_per_dim(x, loss_meter.avg)) progress_bar.update(x.size(0)) if iter_count % tb_freq == 0 or batch_count == len( trainloader.dataset): tb_step = epoch * (len(trainloader.dataset)) + batch_count writer.add_scalar("train/loss", loss_meter.avg, tb_step) writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, tb_step) writer.add_scalar("train/loss_reconstr", loss_reconstr_meter.avg, tb_step) writer.add_scalar("train/kl_loss", kl_loss_meter.avg, tb_step) writer.add_scalar("train/bpd", utils.bits_per_dim(x, loss_unsup_meter.avg), tb_step) writer.add_histogram('train/logvar_z', logvar_z, tb_step) if sampling: net.eval() draw_samples(net, writer, loss_fn, num_samples, device, tuple(x[0].shape), tb_step) net.train()
def train( epoch, net, trainloader, device, optimizer, loss_fn, label_weight, max_grad_norm, writer, use_unlab=True, ): print('\nEpoch: %d' % epoch) net.train() loss_meter = utils.AverageMeter() loss_unsup_meter = utils.AverageMeter() loss_nll_meter = utils.AverageMeter() jaclogdet_meter = utils.AverageMeter() acc_meter = utils.AverageMeter() with tqdm(total=trainloader.batch_sampler.num_labeled) as progress_bar: for x1, y in trainloader: x1 = x1.to(device) y = y.to(device) labeled_mask = (y != NO_LABEL) optimizer.zero_grad() z1 = net(x1) sldj = net.module.logdet() z_labeled = z1.reshape((len(z1), -1)) z_labeled = z_labeled[labeled_mask] y_labeled = y[labeled_mask] logits_labeled = loss_fn.prior.class_logits(z_labeled) loss_nll = F.cross_entropy(logits_labeled, y_labeled) if use_unlab: loss_unsup = loss_fn(z1, sldj=sldj) loss = loss_nll * label_weight + loss_unsup else: loss_unsup = torch.tensor([0.]) loss = loss_nll loss.backward() utils.clip_grad_norm(optimizer, max_grad_norm) optimizer.step() preds = torch.argmax(logits_labeled, dim=1) acc = (preds == y_labeled).float().mean().item() acc_meter.update(acc, x1.size(0)) loss_meter.update(loss.item(), x1.size(0)) loss_unsup_meter.update(loss_unsup.item(), x1.size(0)) loss_nll_meter.update(loss_nll.item(), x1.size(0)) jaclogdet_meter.update(sldj.mean().item(), x1.size(0)) progress_bar.set_postfix(loss=loss_meter.avg, bpd=utils.bits_per_dim( x1, loss_unsup_meter.avg), acc=acc_meter.avg) progress_bar.update(y_labeled.size(0)) writer.add_scalar("train/loss", loss_meter.avg, epoch) writer.add_scalar("train/loss_unsup", loss_unsup_meter.avg, epoch) writer.add_scalar("train/loss_nll", loss_nll_meter.avg, epoch) writer.add_scalar("train/jaclogdet", jaclogdet_meter.avg, epoch) writer.add_scalar("train/acc", acc_meter.avg, epoch) writer.add_scalar("train/bpd", utils.bits_per_dim(x1, loss_unsup_meter.avg), epoch)
def update(self, obs, act, length_mask, target=None, mask_input=None, hidden=None): """ all input params are numpy arrays :param obs: [batch, seq_len, n, m, channel] :param act: [batch, seq_len] :param length_mask: [batch, seq_len] :param target: [batch] or None (when single-target) :param mask_input: (optional) [batch, seq_len, feat_dim] """ tt = time.time() # convert data to Variables batch_size = obs.shape[0] seq_len = obs.shape[1] total_samples = float(np.sum(length_mask)) obs = self._create_gpu_tensor( obs, return_variable=True) # [batch, t_max, dims...] if hidden is None: hidden = self.policy.get_zero_state( batch=batch_size, return_variable=True, hidden_batch_first=self._is_multigpu) if target is not None: target = self._create_target_tensor(target, seq_len, return_variable=True) if mask_input is not None: mask_input = self._create_feature_tensor(mask_input, return_variable=True) length_mask = self._create_feature_tensor( length_mask, return_variable=True) #[batch, t_max] # create action tensor #act = Variable(torch.from_numpy(act).type(LongTensor)) # [batch, t_max] act_n = torch.zeros(batch_size, seq_len, self.policy.out_dim).type(FloatTensor) ids = torch.from_numpy(np.array(act)).type(LongTensor).view( batch_size, seq_len, 1) act_n.scatter_(2, ids, 1.0) act_n = Variable(act_n) time_counter[0] += time.time() - tt tt = time.time() if self.accu_grad_steps == 0: # clear grad self.optim.zero_grad() # forward pass # logits: [batch, seq_len, n_act] logits, _ = self.net(obs, hidden, return_value=False, sample_action=False, return_tensor=False, target=target, extra_input_feature=mask_input, return_logits=True, hidden_batch_first=self._is_multigpu) # compute loss #critic_loss = F.smooth_l1_loss(V, R) block_size = batch_size * seq_len act_size = logits.size(-1) flat_logits = logits.view(block_size, act_size) logp = torch.sum(F.log_softmax(flat_logits).view( batch_size, seq_len, act_size) * act_n, dim=-1) * length_mask loss = -torch.sum(logp) / total_samples # entropy penalty L_ent = torch.sum( self.policy.entropy(logits=logits) * length_mask) / total_samples if self.args['entropy_penalty'] is not None: loss -= self.args['entropy_penalty'] * L_ent # L^2 penalty L_norm = torch.sum( torch.sum(logits * logits, dim=-1) * length_mask) / total_samples if self.args['logits_penalty'] is not None: loss += self.args['logits_penalty'] * L_norm # compute accuracy _, max_idx = torch.max(logits.data, dim=-1, keepdim=True) L_accu = torch.sum( (max_idx == ids).type(FloatTensor) * length_mask.data.view(batch_size, seq_len, 1)) / total_samples ret_dict = dict(loss=loss.data.cpu().numpy()[0], entropy=L_ent.data.cpu().numpy()[0], logits_norm=L_norm.data.cpu().numpy()[0], accuracy=L_accu) # backprop if self.grad_batch > 1: loss = loss / float(self.grad_batch) loss.backward() # accumulative stats if self.accu_grad_steps == 0: self.accu_ret_dict = ret_dict else: for k in ret_dict: self.accu_ret_dict[k] += ret_dict[k] self.accu_grad_steps += 1 if self.accu_grad_steps < self.grad_batch: # do not update parameter now time_counter[1] += time.time() - tt return None # update stats for k in self.accu_ret_dict: self.accu_ret_dict[k] /= self.grad_batch ret_dict = self.accu_ret_dict self.accu_grad_steps = 0 # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip) self.optim.step() time_counter[1] += time.time() - tt return ret_dict
def update(self, obs, label): """ all input params are numpy arrays :param obs: [batch, n, m, channel] or [batch, stack_frame, n, m, channel] :param label: [batch, n_class] (sigmoid) or [batch] (softmax) """ tt = time.time() # convert data to Variables batch_size = obs.shape[0] obs = self._create_gpu_tensor(obs, return_variable=True) # [batch, channel, n, m] # create label tensor if self.multi_label: t_label = torch.from_numpy(np.array(label)).type(FloatTensor) else: t_label = torch.from_numpy(np.array(label)).type(LongTensor) label = Variable(t_label) time_counter[0] += time.time() - tt tt = time.time() if self.accu_grad_steps == 0: # clear grad self.optim.zero_grad() # forward pass # logits: [batch, n_class] logits = self.policy(obs, return_logits=True) # compute loss if self.multi_label: loss = torch.mean(F.binary_cross_entropy_with_logits(logits, label)) else: loss = torch.mean(F.cross_entropy(logits, label)) # entropy penalty L_ent = torch.mean(self.policy.entropy(logits=logits)) if self.args['entropy_penalty'] is not None: loss -= self.args['entropy_penalty'] * L_ent # L^2 penalty L_norm = torch.mean(torch.sum(logits * logits, dim=-1)) if self.args['logits_penalty'] is not None: loss += self.args['logits_penalty'] * L_norm # compute accuracy if self.multi_label: max_idx = (logits.data > 0.5).type(FloatTensor) total_sample = batch_size * self.out_dim else: _, max_idx = torch.max(logits.data, dim=-1, keepdim=False) total_sample = batch_size L_accu = torch.sum((max_idx == t_label).type(FloatTensor)) / batch_size ret_dict = dict(loss=loss.data.cpu().numpy()[0], entropy=L_ent.data.cpu().numpy()[0], logits_norm=L_norm.data.cpu().numpy()[0], accuracy=L_accu) # backprop if self.grad_batch > 1: loss = loss / float(self.grad_batch) loss.backward() # accumulative stats if self.accu_grad_steps == 0: self.accu_ret_dict = ret_dict else: for k in ret_dict: self.accu_ret_dict[k] += ret_dict[k] self.accu_grad_steps += 1 if self.accu_grad_steps < self.grad_batch: # do not update parameter now time_counter[1] += time.time() - tt return None # update stats for k in self.accu_ret_dict: self.accu_ret_dict[k] /= self.grad_batch ret_dict = self.accu_ret_dict self.accu_grad_steps = 0 # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.policy.parameters(), self.grad_norm_clip) self.optim.step() time_counter[1] += time.time() - tt return ret_dict
def update(self): if (self.sample_counter < self.args['update_freq']) or \ not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']): return None self.sample_counter = 0 self.train() tt = time.time() obs, full_act, rew, obs_next, done = \ self.replay_buffer.sample(self.batch_size) #act = split_batched_array(full_act, self.act_shape) time_counter[-1] += time.time() - tt tt = time.time() # convert to variables obs_n = self._process_frames(obs) obs_next_n = self._process_frames(obs_next, volatile=True) full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor) rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor) done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor) time_counter[0] += time.time() - tt tt = time.time() self.optim.zero_grad() # train p network q_val = self.net(obs_n, action=None, output_critic=True) p_loss = -q_val.mean().squeeze() p_ent = self.net.entropy().mean().squeeze() if self.args['ent_penalty'] is not None: p_loss -= self.args['ent_penalty'] * p_ent # encourage exploration common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False) p_loss.backward() self.net.clear_critic_specific_grad( ) # we do not need to compute q_grad for actor!!! if self.grad_norm_clip is not None: utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip) self.optim.step() # train q network self.optim.zero_grad() common.debugger.print('Grad Stats of Q Update ...', False) target_q_next = self.target_net(obs_next_n, output_critic=True) target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next target_q.volatile = False current_q = self.net(obs_n, action=full_act_n, output_critic=True) q_norm = (current_q * current_q).mean().squeeze() # l2 norm q_loss = F.smooth_l1_loss( current_q, target_q) + self.args['critic_penalty'] * q_norm # huber common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) #q_loss = q_loss * 50 q_loss.backward() # total_loss = q_loss + p_loss # grad clip if self.grad_norm_clip is not None: utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip) self.optim.step() common.debugger.print('Stats of P Network (after clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.net) time_counter[1] += time.time() - tt tt = time.time() # update target networks make_update_exp(self.net, self.target_net, rate=self.target_update_rate) common.debugger.print('Stats of Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_net) time_counter[2] += time.time() - tt return dict(policy_loss=p_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0])
def update(self): if (self.sample_counter < self.args['update_freq']) or \ not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 10)): return None self.sample_counter = 0 self.train() tt = time.time() obs, act, rew, obs_next, done = \ self.replay_buffer.sample(self.batch_size) #act = split_batched_array(full_act, self.act_shape) time_counter[-1] += time.time() - tt tt = time.time() # convert to variables obs_n = self._process_frames(obs) obs_next_n = self._process_frames(obs_next, volatile=True) act_n = torch.from_numpy(act).type(LongTensor) rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor) done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor) time_counter[0] += time.time() - tt tt = time.time() # compute critic loss target_q_next = self.target_net(obs_next_n, only_value=True) target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next target_q.volatile=False current_act, current_q = self.net(obs_n, return_value=True) q_norm = (current_q * current_q).mean().squeeze() q_loss = F.smooth_l1_loss(current_q, target_q) common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False) common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False) total_loss = q_loss.mean() if self.args['critic_penalty'] > 1e-10: total_loss += self.args['critic_penalty']*q_norm # compute policy loss # NOTE: currently 1-step lookahead!!! TODO: multiple-step lookahead raw_adv_ts = (rew_n - current_q).data #raw_adv_ts = (target_q - current_q).data # use estimated advantage?? adv_ts = (raw_adv_ts - raw_adv_ts.mean()) / (raw_adv_ts.std() + 1e-15) #current_act.reinforce(adv_ts) p_ent = self.net.entropy().mean() p_loss = self.net.logprob(act_n) p_loss = p_loss * Variable(adv_ts) p_loss = p_loss.mean() total_loss -= p_loss if self.args['ent_penalty'] is not None: total_loss -= self.args['ent_penalty'] * p_ent # encourage exploration common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False) common.debugger.print('>> P_Entropy = {}'.format(p_ent.data.mean()), False) # compute gradient self.optim.zero_grad() #autograd.backward([total_loss, current_act], [torch.ones(1), None]) total_loss.backward() if self.grad_norm_clip is not None: #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip) utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip) self.optim.step() common.debugger.print('Stats of Model (*after* clip and opt)....', False) utils.log_parameter_stats(common.debugger, self.net) time_counter[1] += time.time() -tt tt =time.time() # update target networks make_update_exp(self.net, self.target_net, rate=self.target_update_rate) common.debugger.print('Stats of Target Network (After Update)....', False) utils.log_parameter_stats(common.debugger, self.target_net) time_counter[2] += time.time()-tt return dict(policy_loss=p_loss.data.cpu().numpy()[0], policy_entropy=p_ent.data.cpu().numpy()[0], critic_norm=q_norm.data.cpu().numpy()[0], critic_loss=q_loss.data.cpu().numpy()[0])