def test_huber_loss_delta_3(self): criterion = modules.HuberLoss(3) x = np.array([ [0], ]) x_hat = np.array([ [5], ]) expected_loss = np.array([ 3 * (5 - 3/2), ]) x_var = ptu.Variable(ptu.from_numpy(x).float()) x_hat_var = ptu.Variable(ptu.from_numpy(x_hat).float()) result_var = criterion(x_var, x_hat_var) result = ptu.get_numpy(result_var) self.assertNpAlmostEqual(expected_loss, result) x = np.array([ [4], ]) x_hat = np.array([ [6], ]) expected_loss = np.array([ 0.5 * 2 * 2, ]) x_var = ptu.Variable(ptu.from_numpy(x).float()) x_hat_var = ptu.Variable(ptu.from_numpy(x_hat).float()) result_var = criterion(x_var, x_hat_var) result = ptu.get_numpy(result_var) self.assertNpAlmostEqual(expected_loss, result)
def __init__(self, matrix_input_size, vector_size): super().__init__() self.vector_size = vector_size self.L = nn.Linear(matrix_input_size, vector_size**2) self.L.weight.data.mul_(0.1) self.L.bias.data.mul_(0.1) self.tril_mask = ptu.Variable( torch.tril(torch.ones(vector_size, vector_size), k=-1).unsqueeze(0)) self.diag_mask = ptu.Variable( torch.diag(torch.diag(torch.ones(vector_size, vector_size))).unsqueeze(0))
def reparameterize(self, mu, logvar): if self.training: std = logvar.mul(0.5).exp_() eps = ptu.Variable(std.data.new(std.size()).normal_()) return eps.mul(std).add_(mu) else: return mu
def __init__( self, env, qf, replay_buffer, num_epochs=100, num_batches_per_epoch=100, qf_learning_rate=1e-3, batch_size=100, num_unique_batches=1000, ): self.qf = qf self.replay_buffer = replay_buffer self.env = env self.num_epochs = num_epochs self.num_batches_per_epoch = num_batches_per_epoch self.qf_learning_rate = qf_learning_rate self.batch_size = batch_size self.num_unique_batches = num_unique_batches self.qf_optimizer = optim.Adam(self.qf.parameters(), lr=self.qf_learning_rate) self.batch_iterator = None self.discount = ptu.Variable( ptu.from_numpy(np.zeros((batch_size, 1))).float() ) self.mode_to_batch_iterator = {}
def debug_statistics(self): """ Given an image $$x$$, samples a bunch of latents from the prior $$z_i$$ and decode them $$\hat x_i$$. Compare this to $$\hat x$$, the reconstruction of $$x$$. Ideally - All the $$\hat x_i$$s do worse than $$\hat x$$ (makes sure VAE isn’t ignoring the latent) - Some $$\hat x_i$$ do better than other $$\hat x_i$$ (tests for coverage) """ debug_batch_size = 64 data = self.get_batch(train=False) recon_batch, mu, logvar, predicted_alpha, alpha = self.model(data) img = data[0] recon_mse = ((recon_batch[0] - img)**2).mean() img_repeated = img.expand((debug_batch_size, img.shape[0])) samples = ptu.Variable( torch.randn(debug_batch_size, self.representation_size)) random_imgs = self.model.decode(samples) random_mse = ((random_imgs - img_repeated)**2).mean(dim=1) mse_improvement = ptu.get_numpy(random_mse - recon_mse) stats = create_stats_ordered_dict( 'debug/MSE improvement over random', mse_improvement, ) stats['debug/MSE of random reconstruction'] = ptu.get_numpy( recon_mse)[0] return stats
def _get_training_batch(self): if self.few_shot_version: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.max_context_size, num_tasks=self.num_tasks_used_per_update, keys=['observations', 'actions', 'next_observations'] # keys=['observations', 'actions'] ) mask = ptu.Variable(torch.zeros(self.num_tasks_used_per_update, self.max_context_size, 1)) this_context_sizes = np.random.randint(self.min_context_size, self.max_context_size+1, size=self.num_tasks_used_per_update) for i, c_size in enumerate(this_context_sizes): mask[i,:c_size,:] = 1.0 else: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.num_context_trajs_for_training, num_tasks=self.num_tasks_used_per_update, keys=['observations', 'actions', 'next_observations'] # keys=['observations', 'actions'] ) mask = None # OLD VERSION # # get the pred version of the context batch # # subsample the trajs # flat_context_batch = [subsample_traj(traj, self.train_samples_per_traj) for task_trajs in context_batch for traj in task_trajs] # context_pred_batch = concat_trajs(flat_context_batch) # test_batch, _ = self.train_test_expert_replay_buffer.sample_trajs( # self.num_test_trajs_for_training, # task_identifiers=task_identifiers_list, # keys=['observations', 'actions'], # samples_per_traj=self.train_samples_per_traj # ) # flat_test_batch = [traj for task_trajs in test_batch for traj in task_trajs] # test_pred_batch = concat_trajs(flat_test_batch) # NEW VERSION # get the test batch for the tasks from policy buffer policy_batch, _ = self.train_test_expert_replay_buffer.sample_random_batch( self.policy_optim_batch_size_per_task, task_identifiers_list=task_identifiers_list ) policy_obs = np.concatenate([d['observations'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_acts = np.concatenate([d['actions'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_terminals = np.concatenate([d['terminals'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_next_obs = np.concatenate([d['next_observations'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim # policy_absorbing = np.concatenate([d['absorbing'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_batch = dict( observations=policy_obs, actions=policy_acts, terminals=policy_terminals, next_observations=policy_next_obs, # absorbing=absorbing ) # OLD VERSION # return context_batch, context_pred_batch, test_pred_batch, mask # NEW VERSION return context_batch, mask, policy_batch
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): obs, taus = split_tau(obs) h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward( obs=h, deterministic=deterministic, return_log_prob=return_log_prob, return_entropy=return_entropy, return_log_prob_of_mean=return_log_prob_of_mean, )
def dump_samples(self, epoch): self.model.eval() sample = ptu.Variable(torch.randn(64, self.representation_size)) sample = self.model.decode(sample).cpu() save_dir = osp.join(logger.get_snapshot_dir(), 's%d.png' % epoch) save_image( sample.data.view(64, self.input_channels, self.imsize, self.imsize), save_dir)
def test_batch_square_diagonal_module(self): x = np.array([ [2, 7], ]) diag_vals = np.array([ [2, 1], ]) expected = np.array([ [57] # 2^2 * 2 + 7^2 * 1 = 8 + 49 = 57 ]) x_var = ptu.Variable(ptu.from_numpy(x).float()) diag_var = ptu.Variable(ptu.from_numpy(diag_vals).float()) net = modules.BatchSquareDiagonal(2) result_var = net(vector=x_var, diag_values=diag_var) result = ptu.get_numpy(result_var) self.assertNpAlmostEqual(expected, result)
def get_train_dict(self, subtraj_batch): subtraj_rewards = subtraj_batch['rewards'] subtraj_rewards_np = ptu.get_numpy(subtraj_rewards).squeeze(2) returns = np_util.batch_discounted_cumsum(subtraj_rewards_np, self.discount) returns = np.expand_dims(returns, 2) returns = np.ascontiguousarray(returns).astype(np.float32) returns = ptu.Variable(ptu.from_numpy(returns)) subtraj_batch['returns'] = returns batch = flatten_subtraj_batch(subtraj_batch) # rewards = batch['rewards'] returns = batch['returns'] terminals = batch['terminals'] obs = batch['observations'] actions = batch['actions'] next_obs = batch['next_observations'] """ Policy operations. """ policy_actions = self.policy(obs) q = self.qf(obs, policy_actions) policy_loss = -q.mean() """ Critic operations. """ next_actions = self.policy(next_obs) # TODO: try to get this to work # next_actions = None q_target = self.target_qf( next_obs, next_actions, ) # y_target = self.reward_scale * rewards + (1. - terminals) * self.discount * v_target batch_size = q_target.size()[0] discount_factors = self.discount_factors.repeat( batch_size // self.subtraj_length, 1, ) y_target = self.reward_scale * returns + ( 1. - terminals) * discount_factors * q_target # noinspection PyUnresolvedReferences y_target = y_target.detach() y_pred = self.qf(obs, actions) bellman_errors = (y_pred - y_target)**2 qf_loss = self.qf_criterion(y_pred, y_target) return OrderedDict([ ('Policy Actions', policy_actions), ('Policy Loss', policy_loss), ('Policy Q Values', q), ('Target Y', y_target), ('Predicted Y', y_pred), ('Bellman Errors', bellman_errors), ('Y targets', y_target), ('Y predictions', y_pred), ('QF Loss', qf_loss), ])
def __init__( self, obs_dim, action_dim, hidden_size, use_batchnorm=False, b_init_value=0.01, hidden_init=ptu.fanin_init, use_exp_for_diagonal_not_square=True, ): super(NafPolicy, self).__init__() self.obs_dim = obs_dim self.action_dim = action_dim self.use_batchnorm = use_batchnorm self.use_exp_for_diagonal_not_square = use_exp_for_diagonal_not_square if use_batchnorm: self.bn_state = nn.BatchNorm1d(obs_dim) self.bn_state.weight.data.fill_(1) self.bn_state.bias.data.fill_(0) self.linear1 = nn.Linear(obs_dim, hidden_size) self.linear2 = nn.Linear(hidden_size, hidden_size) self.V = nn.Linear(hidden_size, 1) self.mu = nn.Linear(hidden_size, action_dim) self.L = nn.Linear(hidden_size, action_dim**2) self.tril_mask = ptu.Variable( torch.tril(torch.ones(action_dim, action_dim), -1).unsqueeze(0)) self.diag_mask = ptu.Variable( torch.diag(torch.diag(torch.ones(action_dim, action_dim))).unsqueeze(0)) hidden_init(self.linear1.weight) self.linear1.bias.data.fill_(b_init_value) hidden_init(self.linear2.weight) self.linear2.bias.data.fill_(b_init_value) hidden_init(self.V.weight) self.V.bias.data.fill_(b_init_value) hidden_init(self.L.weight) self.L.bias.data.fill_(b_init_value) hidden_init(self.mu.weight) self.mu.bias.data.fill_(b_init_value)
def get_encoding_and_suff_stats(self, x): output = self(x) means, log_stds = ( output[:, 0:1], output[:, 1:2] ) stds = log_stds.exp() epsilon = ptu.Variable(torch.randn(*means.size())) latents = epsilon * stds + means latents = latents return latents, means, log_stds, stds
def forward(self, flat_obs, actions=None): obs, taus = split_tau(flat_obs) if actions is not None: h = torch.cat((obs, action), dim=1) else: h = obs batch_size = h.size()[0] tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data if actions is not None: h = torch.cat((obs, ptu.Variable(tau_vector), actions), dim=1) else: h = torch.cat(( obs, ptu.Variable(tau_vector), ), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) return -torch.abs(self.last_fc(h))
def train_epoch(self, epoch): self.model.train() losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch(self.X_train, self.y_train, batch_size=self.batch_size) inputs, labels = ptu.Variable(ptu.from_numpy(inputs_np)), ptu.Variable(ptu.from_numpy(labels_np)) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs-labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("train/epoch", epoch) logger.record_tabular("train/loss", np.mean(np.array(losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("train/dim "+str(i)+" loss", np.mean(per_dim_losses[:, i]))
def forward(self, flat_obs, actions=None): obs, taus = split_tau(flat_obs) if actions is not None: h = torch.cat((obs, actions), dim=1) else: h = obs batch_size = taus.size()[0] y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size) if actions is not None: h = torch.cat((obs, ptu.Variable(y_binary), actions), dim=1) else: h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) return -torch.abs(self.last_fc(h))
def test_epoch( self, epoch, ): self.model.eval() val_losses = [] per_dim_losses = np.zeros((self.num_batches, self.y_train.shape[1])) for batch in range(self.num_batches): inputs_np, labels_np = self.random_batch(self.X_test, self.y_test, batch_size=self.batch_size) inputs, labels = ptu.Variable(ptu.from_numpy(inputs_np)), ptu.Variable(ptu.from_numpy(labels_np)) outputs = self.model(inputs) loss = self.criterion(outputs, labels) val_losses.append(loss.data[0]) per_dim_loss = np.mean(np.power(ptu.get_numpy(outputs - labels), 2), axis=0) per_dim_losses[batch] = per_dim_loss logger.record_tabular("test/epoch", epoch) logger.record_tabular("test/loss", np.mean(np.array(val_losses))) for i in range(self.y_train.shape[1]): logger.record_tabular("test/dim "+str(i)+" loss", np.mean(per_dim_losses[:, i])) logger.dump_tabular()
def _get_training_batch(self): keys_to_get = ['observations', 'actions', 'next_observations'] # if self.transfer_version and 'next_observations' not in keys_to_get: # keys_to_get.append('next_observations') if self.few_shot_version: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.max_context_size, num_tasks=self.num_tasks_used_per_update, keys=keys_to_get) mask = ptu.Variable( torch.zeros(self.num_tasks_used_per_update, self.max_context_size, 1)) this_context_sizes = np.random.randint( self.min_context_size, self.max_context_size + 1, size=self.num_tasks_used_per_update) for i, c_size in enumerate(this_context_sizes): mask[i, :c_size, :] = 1.0 else: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.num_context_trajs_for_training, num_tasks=self.num_tasks_used_per_update, keys=keys_to_get) mask = None obs_task_params = np.array( list( map(lambda tid: self.env.task_id_to_obs_task_params(tid), task_identifiers_list))) task_params_size = obs_task_params.shape[-1] # now need to sample points for classification classification_inputs = [] classification_labels = [] for task in obs_task_params: for _ in range(self.classification_batch_size_per_task): good = self.env._sample_color_within_radius( task, self.env.same_color_radius) bad = self.env._sample_color_with_min_dist( task, self.env.same_color_radius) if np.random.uniform() > 0.5: classification_inputs.append(np.concatenate((good, bad))) classification_labels.append([0]) else: classification_inputs.append(np.concatenate((bad, good))) classification_labels.append([1]) classification_inputs = Variable( ptu.from_numpy(np.array(classification_inputs))) classification_labels = Variable( ptu.from_numpy(np.array(classification_labels))) return context_batch, mask, obs_task_params, classification_inputs, classification_labels
def train_network(net, title): train_losses = [] test_losses = [] times = [] optimizer = Adam(net.parameters(), lr=1e-3) criterion = nn.MSELoss() for i in range(N_EPOCHS): for i_batch, sample_batched in enumerate(dataloader): x, y = sample_batched x = ptu.Variable(x) y = ptu.Variable(y) y_hat = net(x) loss = criterion(y_hat, y) optimizer.zero_grad() loss.backward() optimizer.step() y_hat = net(test_x) test_loss = float(criterion(y_hat, test_y)) test_losses.append(test_loss) y_hat = net(train_x) train_loss = float(criterion(y_hat, train_y)) train_losses.append(train_loss) times.append(i) plt.gcf().clear() plt.plot(times, train_losses, '--') plt.plot(times, test_losses, '-') plt.title(title) plt.draw() plt.pause(0.05) print(title) print("\tfinal train loss: {}".format(train_loss)) print("\tfinal test loss: {}".format(test_loss))
def forward(self, flat_obs, actions=None): obs, taus = split_tau(flat_obs) if actions is not None: h = torch.cat((obs, actions), dim=1) else: h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) if actions is not None: h = torch.cat((obs, ptu.Variable(y_binary), actions), dim=1) else: h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) for i, fc in enumerate(self.fcs): h = self.hidden_activation(fc(h)) return -torch.abs(self.last_fc(h))
def forward( self, flat_obs, return_preactivations=False, ): obs, taus = split_tau(flat_obs) batch_size = taus.size()[0] y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward(h, return_preactivations=return_preactivations)
def get_batch(self, training=True): replay_buffer = self.replay_buffer.get_replay_buffer(training) sample_size = min(replay_buffer.num_steps_can_sample(), self.batch_size) batch = replay_buffer.random_batch(sample_size) torch_batch = { k: ptu.Variable(ptu.from_numpy(array).float(), requires_grad=False) for k, array in batch.items() } rewards = torch_batch['rewards'] terminals = torch_batch['terminals'] torch_batch['rewards'] = rewards.unsqueeze(-1) torch_batch['terminals'] = terminals.unsqueeze(-1) return torch_batch
def forward( self, flat_obs, return_preactivations=False, ): obs, taus = split_tau(flat_obs) h = obs batch_size = h.size()[0] tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data h = torch.cat(( obs, ptu.Variable(tau_vector), ), dim=1) return super().forward(h, return_preactivations=return_preactivations)
def train(epoch): for batch_idx, (state, action, q_target) in enumerate(train_loader): q_estim = eval_model(state, action) q_target = ptu.Variable(q_target, requires_grad=False) loss = loss_fnct(q_estim, q_target) optimizer.zero_grad() loss.backward() optimizer.step() if batch_idx % num_batches_per_print == 0: line_logger.print_over( 'Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format( epoch, batch_size * batch_idx, train_size, loss.data[0]))
def _get_training_batch(self, epoch): if self.few_shot_version: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.max_context_size, num_tasks=self.num_tasks_used_per_update, keys=['observations', 'actions', 'next_observations'] # keys=['observations', 'actions'] ) mask = ptu.Variable( torch.zeros(self.num_tasks_used_per_update, self.max_context_size, 1)) this_context_sizes = np.random.randint( self.min_context_size, self.max_context_size + 1, size=self.num_tasks_used_per_update) for i, c_size in enumerate(this_context_sizes): mask[i, :c_size, :] = 1.0 else: context_batch, task_identifiers_list = self.train_context_expert_replay_buffer.sample_trajs( self.num_context_trajs_for_training, num_tasks=self.num_tasks_used_per_update, keys=['observations', 'actions', 'next_observations'] # keys=['observations', 'actions'] ) mask = None # get the test batch for the tasks from policy buffer if epoch == 0: # print('USING ONLY EXPERT DATA') policy_batch, _ = self.train_test_expert_replay_buffer.sample_random_batch( self.policy_optim_batch_size_per_task, task_identifiers_list=task_identifiers_list) else: # print('USING EXPERT AND POLICY DATA') policy_batch, _ = self.replay_buffer.sample_random_batch( self.policy_optim_batch_size_per_task, task_identifiers_list=task_identifiers_list) policy_obs = np.concatenate([d['observations'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_acts = np.concatenate([d['actions'] for d in policy_batch], axis=0) # (N_tasks * batch_size) x Dim policy_batch = dict( observations=policy_obs, actions=policy_acts, ) return context_batch, mask, policy_batch
def forward(self, flat_obs, return_preactivations=False): obs, taus = split_tau(flat_obs) h = obs batch_size = h.size()[0] y_binary = ptu.FloatTensor(batch_size, self.max_tau + 1) y_binary.zero_() t = taus.data.long() t = torch.clamp(t, min=0) y_binary.scatter_(1, t, 1) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward( h, return_preactivations=return_preactivations, )
def simulate_policy(args): ptu.set_gpu_mode(True) model = pickle.load(open(args.file, "rb")) # joblib.load(args.file) model.to(ptu.device) import ipdb; ipdb.set_trace() samples = ptu.Variable(torch.randn(64, model.representation_size)) samples = model.decode(samples).cpu() # for sample in samples: # tensor = sample.data.view(64, model.input_channels, model.imsize, model.imsize) # tensor = tensor.cpu() # img = ptu.get_numpy(tensor) # cv2.imshow('img', img.reshape(3, 84, 84).transpose()) # cv2.waitKey(1) tensor = samples.data.view(64, model.input_channels, model.imsize, model.imsize) tensor = tensor.cpu() grid = make_grid(tensor, nrow=8) ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() im = Image.fromarray(ndarr) im.show()
def forward(self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False): obs, taus = split_tau(obs) h = obs batch_size = h.size()[0] tau_vector = torch.zeros((batch_size, self.tau_vector_len)) + taus.data h = torch.cat(( obs, ptu.Variable(tau_vector), ), dim=1) return super().forward( obs=h, deterministic=deterministic, return_log_prob=return_log_prob, return_entropy=return_entropy, return_log_prob_of_mean=return_log_prob_of_mean, )
def __init__(self, *args, subtraj_length=10, **kwargs): super().__init__(*args, **kwargs) self.subtraj_length = subtraj_length self.gammas = self.discount * torch.ones(self.subtraj_length) discount_factors = torch.cumprod(self.gammas, dim=0) self.discount_factors = ptu.Variable( discount_factors.view(-1, 1), requires_grad=False, ) self.replay_buffer = SplitReplayBuffer( SubtrajReplayBuffer( max_replay_buffer_size=self.replay_buffer_size, env=self.env, subtraj_length=self.subtraj_length, ), SubtrajReplayBuffer( max_replay_buffer_size=self.replay_buffer_size, env=self.env, subtraj_length=self.subtraj_length, ), fraction_paths_in_train=0.8, )
def test(epoch): test_losses = [] for state, action, q_target in test_loader: q_estim = eval_model(state, action) q_target = ptu.Variable(q_target, requires_grad=False) loss = loss_fnct(q_estim, q_target) test_losses.append(loss.data[0]) line_logger.newline() print('Test Epoch: {0}. Loss: {1}'.format(epoch, np.mean(test_losses))) report.add_header("Epoch = {}".format(epoch)) fig = visualize_model(q_function, "True Q Function") img = vu.save_image(fig) report.add_image(img, txt='True Q Function') fig = visualize_model(eval_model_np, "Estimated Q Function") img = vu.save_image(fig) report.add_image(img, txt='Estimated Q Function') report.new_row()
def forward( self, obs, deterministic=False, return_log_prob=False, return_entropy=False, return_log_prob_of_mean=False, ): obs, taus = split_tau(obs) batch_size = taus.size()[0] y_binary = make_binary_tensor(taus, len(self.max_tau), batch_size) h = torch.cat(( obs, ptu.Variable(y_binary), ), dim=1) return super().forward( obs=h, deterministic=deterministic, return_log_prob=return_log_prob, return_entropy=return_entropy, return_log_prob_of_mean=return_log_prob_of_mean, )