Example #1
0
 def forward(self, state, deterministic):
     with torch.no_grad():
         if isinstance(state, np.ndarray):
             state = from_np(state, 'cpu')
         action, dist_params = self.policy.select_action(
             state, deterministic)
         return CentralizedOutput(
             action=LiteralActionTransformation(action),
             dist_params=dist_params)
Example #2
0
 def forward(self, state, deterministic):
     with torch.no_grad():
         if isinstance(state, np.ndarray):
             state = from_np(state, 'cpu')
         action, dist_params = self.policy.select_action(
             state, deterministic)
         if self.policy.discrete:
             action = self.transformations[action]
         else:
             # must be a leaf policy
             action = LiteralActionTransformation(action)
         return CentralizedOutput(action=action, dist_params=dist_params)
Example #3
0
    def create_transformation_builder(cls, state_dim, action_dim, args):
        if args.hrl:
            transform_policy = cls.centralized_policy_switch(
                state_dim=state_dim, action_dim=action_dim, args=args)
            transform_valuefn = cls.value_switch(state_dim=state_dim,
                                                 args=args)

            if args.shared_vfn:
                shared_transform_valuefn = transform_valuefn()
                transformation_builder = lambda id_num: SubpolicyTransformation(
                    id_num=id_num,
                    networks=dict(
                        policy=transform_policy(),
                        valuefn=shared_transform_valuefn,
                    ),
                    transformations=OrderedDict([(
                        i, LiteralActionTransformation(id_num=i))
                                                 for i in range(action_dim)]),
                    replay_buffer=PathMemory(max_replay_buffer_size=args.
                                             max_buffer_size),
                    args=args)
            else:
                transformation_builder = lambda id_num: SubpolicyTransformation(
                    id_num=id_num,
                    networks=dict(
                        policy=transform_policy(),
                        valuefn=transform_valuefn(),
                    ),
                    transformations=OrderedDict([(
                        i, LiteralActionTransformation(id_num=i))
                                                 for i in range(action_dim)]),
                    replay_buffer=PathMemory(max_replay_buffer_size=args.
                                             max_buffer_size),
                    args=args)
        else:
            transformation_builder = lambda id_num: LiteralActionTransformation(
                id_num)
        return transformation_builder