def forward(self, agent, state, detach_params=False, direct=False, **kwargs): if detach_params: inference_model = copy.deepcopy(self.inference_model) else: inference_model = self.inference_model inf_input = inference_model(state=state) if direct and agent.direct_approx_post is not None: agent.direct_approx_post.step(inf_input, detach_params) else: agent.approx_post.step(inf_input, detach_params) dist_params = agent.approx_post.get_dist_params() params = [param for _, param in dist_params.items()] act_opt = self.optimizer(params, lr=self.lr) act_opt.zero_grad() for _ in range(self.n_inf_iters): actions = agent.approx_post.sample(agent.n_action_samples) obj = agent.estimate_objective(state, actions) obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0) self.estimated_objectives.append(obj.detach()) obj.sum().backward(retain_graph=True) act_opt.step() act_opt.zero_grad() clear_gradients(agent.generative_parameters())
def forward(self, agent, state, **kwargs): dist_params = { k: v.data.requires_grad_() for k, v in agent.approx_post.get_dist_params().items() } agent.approx_post.reset(batch_size=state.shape[0], dist_params=dist_params) # dist_params = agent.approx_post.get_dist_params() params = [param for _, param in dist_params.items()] act_opt = self.optimizer(params, lr=self.lr) act_opt.zero_grad() for it in range(self.n_inf_iters): # print(' ITERATION: ' + str(it)) actions = agent.approx_post.sample(agent.n_action_samples) obj = agent.estimate_objective(state, actions) obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0) self.estimated_objectives.append(obj.detach()) # print(' OBJ: ' + str(-obj.mean().item())) obj.sum().backward(retain_graph=True) act_opt.step() act_opt.zero_grad() # clear the sample to force resampling agent.approx_post._sample = None clear_gradients(agent.generative_parameters())
def forward(self, agent, state, target=False, **kwargs): approx_post = agent.approx_post if not target else agent.target_approx_post self.dist_params.append( {k: v.detach() for k, v in approx_post.get_dist_params().items()}) for it in range(self.n_inf_iters): # print(' ITERATION: ' + str(it)) # sample actions, evaluate objective, backprop to get gradients actions = approx_post.sample(agent.n_action_samples) obj = agent.estimate_objective(state, actions, target=target) obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0) self.estimated_objectives.append(obj.detach()) # print(' OBJ: ' + str(-obj.mean().item())) # TODO: should this be multiplied by valid and done? obj.sum().backward(retain_graph=True) # update the approximate posterior using the iterative inference model params, grads = approx_post.params_and_grads() if self.encoding_type == 'grads': inf_input = self.inference_model(params=params, grads=grads, state=state) else: error_dict = agent.q_value_estimator.get_errors() errors = [ error.detach().view(agent.n_action_samples, -1, error.shape[-1]).mean(dim=0) for _, error in error_dict.items() ] for ind, error in enumerate(errors): mean = error.mean(dim=1, keepdim=True) std = error.std(dim=1, keepdim=True) errors[ind] = (error - mean) / (std + 1e-7) errors = torch.cat(errors, dim=1) inf_input = self.inference_model(params=params, errors=errors, state=state) approx_post.step(inf_input) approx_post.retain_grads() self.dist_params.append({ k: v.detach() for k, v in approx_post.get_dist_params().items() }) # clear any gradients in the generative parameters clear_gradients(agent.generative_parameters()) if target: # clear model gradients if this is the target inference optimizer target_params = nn.ParameterList() target_params.extend(list(self.inference_model.parameters())) target_params.extend(list(agent.target_approx_post.parameters())) clear_gradients(target_params)
def __call__(self, agent, state,**kwargs): dist_params = agent.approx_post.get_dist_params() params = [param for _, param in dist_params.items()] act_opt = self.optimizer(params, lr=self.lr) act_opt.zero_grad() for _ in range(self.n_inf_iters): actions = agent.approx_post.sample(agent.n_action_samples) obj = agent.estimate_objective(state, actions) obj = - obj.view(agent.n_action_samples, -1, 1).mean(dim=0) self.estimated_objectives.append(obj.detach()) obj.sum().backward(retain_graph=True) act_opt.step() act_opt.zero_grad() # clear the sample to force resampling agent.approx_post._sample = None clear_gradients(agent.generative_parameters())
def forward(self, agent, state, target=False, **kwargs): approx_post = agent.approx_post if not target else agent.target_approx_post self.dist_params.append( {k: v.detach() for k, v in approx_post.get_dist_params().items()}) for _ in range(self.n_inf_iters): # sample actions, evaluate objective, backprop to get gradients actions = approx_post.sample(agent.n_action_samples) obj = agent.estimate_objective(state, actions, target=target) obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0) self.estimated_objectives.append(obj.detach()) # TODO: should this be multiplied by valid and done? obj.sum().backward(retain_graph=True) # update the approximate posterior using the iterative inference model params, grads = approx_post.params_and_grads() inf_input = self.inference_model(params=params, grads=grads, state=state) approx_post.step(inf_input) approx_post.retain_grads() self.dist_params.append({ k: v.detach() for k, v in approx_post.get_dist_params().items() }) # clear any gradients in the generative parameters clear_gradients(agent.generative_parameters()) if target: # clear model gradients if this is the target inference optimizer target_params = nn.ParameterList() target_params.extend(list(self.inference_model.parameters())) target_params.extend(list(agent.target_approx_post.parameters())) clear_gradients(target_params)