def __init__( self, backbone_cfg: ConfigDict, head_cfg: ConfigDict, action_embedder_cfg: ConfigDict, shared_backbone: nn.Module = None, ): nn.Module.__init__(self) if shared_backbone is not None: self.backbone = shared_backbone head_cfg.configs.input_size = self.calculate_fc_input_size( head_cfg.configs.state_size) elif not backbone_cfg: self.backbone = identity head_cfg.configs.input_size = head_cfg.configs.state_size[0] else: self.backbone = build_backbone(backbone_cfg) head_cfg.configs.input_size = self.calculate_fc_input_size( head_cfg.configs.state_size) self.action_embedder = None if action_embedder_cfg: action_embedder_cfg.configs.input_size = head_cfg.configs.action_size self.action_embedder = build_head(action_embedder_cfg) head_cfg.configs.input_size += action_embedder_cfg.configs.output_size else: head_cfg.configs.input_size += head_cfg.configs.action_size self.head = build_head(head_cfg)
def _init_network(self): """Initialize networks and optimizers.""" # create actor if self.backbone_cfg.shared_actor_critic: shared_backbone = build_backbone( self.backbone_cfg.shared_actor_critic) self.actor = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.actor, shared_backbone, ) self.critic = Brain( self.backbone_cfg.shared_actor_critic, self.head_cfg.critic, shared_backbone, ) self.actor = self.actor.to(self.device) self.critic = self.critic.to(self.device) else: self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device) self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(self.device) self.discriminator = Discriminator( self.backbone_cfg.discriminator, self.head_cfg.discriminator, self.head_cfg.aciton_embedder, ).to(self.device) # create optimizer self.actor_optim = optim.Adam( self.actor.parameters(), lr=self.optim_cfg.lr_actor, weight_decay=self.optim_cfg.weight_decay, ) self.critic_optim = optim.Adam( self.critic.parameters(), lr=self.optim_cfg.lr_critic, weight_decay=self.optim_cfg.weight_decay, ) self.discriminator_optim = optim.Adam( self.discriminator.parameters(), lr=self.optim_cfg.lr_discriminator, weight_decay=self.optim_cfg.weight_decay, ) # load model parameters if self.load_from is not None: self.load_params(self.load_from)
def __init__( self, backbone_cfg: ConfigDict, head_cfg: ConfigDict, ): """Initialize.""" nn.Module.__init__(self) if not backbone_cfg: self.backbone = identity head_cfg.configs.input_size = head_cfg.configs.state_size[0] else: self.backbone = build_backbone(backbone_cfg) head_cfg.configs.input_size = self.calculate_fc_input_size( head_cfg.configs.state_size ) self.head = build_head(head_cfg)