예제 #1
0
    def forward(self,
                agent,
                state,
                detach_params=False,
                direct=False,
                **kwargs):
        if detach_params:
            inference_model = copy.deepcopy(self.inference_model)
        else:
            inference_model = self.inference_model
        inf_input = inference_model(state=state)
        if direct and agent.direct_approx_post is not None:
            agent.direct_approx_post.step(inf_input, detach_params)
        else:
            agent.approx_post.step(inf_input, detach_params)

            dist_params = agent.approx_post.get_dist_params()
            params = [param for _, param in dist_params.items()]
            act_opt = self.optimizer(params, lr=self.lr)
            act_opt.zero_grad()

            for _ in range(self.n_inf_iters):
                actions = agent.approx_post.sample(agent.n_action_samples)
                obj = agent.estimate_objective(state, actions)
                obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0)
                self.estimated_objectives.append(obj.detach())
                obj.sum().backward(retain_graph=True)
                act_opt.step()
                act_opt.zero_grad()

            clear_gradients(agent.generative_parameters())
예제 #2
0
    def forward(self, agent, state, **kwargs):
        dist_params = {
            k: v.data.requires_grad_()
            for k, v in agent.approx_post.get_dist_params().items()
        }
        agent.approx_post.reset(batch_size=state.shape[0],
                                dist_params=dist_params)
        # dist_params = agent.approx_post.get_dist_params()
        params = [param for _, param in dist_params.items()]
        act_opt = self.optimizer(params, lr=self.lr)
        act_opt.zero_grad()

        for it in range(self.n_inf_iters):
            # print(' ITERATION: ' + str(it))
            actions = agent.approx_post.sample(agent.n_action_samples)
            obj = agent.estimate_objective(state, actions)
            obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0)
            self.estimated_objectives.append(obj.detach())
            # print(' OBJ: ' + str(-obj.mean().item()))
            obj.sum().backward(retain_graph=True)
            act_opt.step()
            act_opt.zero_grad()
            # clear the sample to force resampling
            agent.approx_post._sample = None

        clear_gradients(agent.generative_parameters())
예제 #3
0
    def forward(self, agent, state, target=False, **kwargs):

        approx_post = agent.approx_post if not target else agent.target_approx_post
        self.dist_params.append(
            {k: v.detach()
             for k, v in approx_post.get_dist_params().items()})

        for it in range(self.n_inf_iters):
            # print(' ITERATION: ' + str(it))
            # sample actions, evaluate objective, backprop to get gradients
            actions = approx_post.sample(agent.n_action_samples)
            obj = agent.estimate_objective(state, actions, target=target)
            obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0)
            self.estimated_objectives.append(obj.detach())
            # print(' OBJ: ' + str(-obj.mean().item()))
            # TODO: should this be multiplied by valid and done?
            obj.sum().backward(retain_graph=True)

            # update the approximate posterior using the iterative inference model
            params, grads = approx_post.params_and_grads()
            if self.encoding_type == 'grads':
                inf_input = self.inference_model(params=params,
                                                 grads=grads,
                                                 state=state)
            else:
                error_dict = agent.q_value_estimator.get_errors()
                errors = [
                    error.detach().view(agent.n_action_samples, -1,
                                        error.shape[-1]).mean(dim=0)
                    for _, error in error_dict.items()
                ]
                for ind, error in enumerate(errors):
                    mean = error.mean(dim=1, keepdim=True)
                    std = error.std(dim=1, keepdim=True)
                    errors[ind] = (error - mean) / (std + 1e-7)
                errors = torch.cat(errors, dim=1)
                inf_input = self.inference_model(params=params,
                                                 errors=errors,
                                                 state=state)
            approx_post.step(inf_input)
            approx_post.retain_grads()
            self.dist_params.append({
                k: v.detach()
                for k, v in approx_post.get_dist_params().items()
            })

        # clear any gradients in the generative parameters
        clear_gradients(agent.generative_parameters())

        if target:
            # clear model gradients if this is the target inference optimizer
            target_params = nn.ParameterList()
            target_params.extend(list(self.inference_model.parameters()))
            target_params.extend(list(agent.target_approx_post.parameters()))
            clear_gradients(target_params)
예제 #4
0
    def __call__(self, agent, state,**kwargs):
        dist_params = agent.approx_post.get_dist_params()
        params = [param for _, param in dist_params.items()]
        act_opt = self.optimizer(params, lr=self.lr)
        act_opt.zero_grad()

        for _ in range(self.n_inf_iters):
            actions = agent.approx_post.sample(agent.n_action_samples)
            obj = agent.estimate_objective(state, actions)
            obj = - obj.view(agent.n_action_samples, -1, 1).mean(dim=0)
            self.estimated_objectives.append(obj.detach())
            obj.sum().backward(retain_graph=True)
            act_opt.step()
            act_opt.zero_grad()
            # clear the sample to force resampling
            agent.approx_post._sample = None

        clear_gradients(agent.generative_parameters())
예제 #5
0
    def forward(self, agent, state, target=False, **kwargs):

        approx_post = agent.approx_post if not target else agent.target_approx_post
        self.dist_params.append(
            {k: v.detach()
             for k, v in approx_post.get_dist_params().items()})

        for _ in range(self.n_inf_iters):
            # sample actions, evaluate objective, backprop to get gradients
            actions = approx_post.sample(agent.n_action_samples)
            obj = agent.estimate_objective(state, actions, target=target)
            obj = -obj.view(agent.n_action_samples, -1, 1).mean(dim=0)
            self.estimated_objectives.append(obj.detach())
            # TODO: should this be multiplied by valid and done?
            obj.sum().backward(retain_graph=True)

            # update the approximate posterior using the iterative inference model
            params, grads = approx_post.params_and_grads()
            inf_input = self.inference_model(params=params,
                                             grads=grads,
                                             state=state)
            approx_post.step(inf_input)
            approx_post.retain_grads()
            self.dist_params.append({
                k: v.detach()
                for k, v in approx_post.get_dist_params().items()
            })

        # clear any gradients in the generative parameters
        clear_gradients(agent.generative_parameters())

        if target:
            # clear model gradients if this is the target inference optimizer
            target_params = nn.ParameterList()
            target_params.extend(list(self.inference_model.parameters()))
            target_params.extend(list(agent.target_approx_post.parameters()))
            clear_gradients(target_params)