def _init(self, args): params = args.params assert isinstance( params["num_action"], int), "num_action has to be a number. action = " + str( params["num_action"]) self.params = params self.net = MiniRTSNet(args) last_num_channel = self.net.num_channels[-1] if self.params.get("model_no_spatial", False): self.num_unit = params["num_unit_type"] linear_in_dim = last_num_channel else: linear_in_dim = last_num_channel * 25 self.linear_policy = nn.Linear(linear_in_dim, params["num_action"]) self.linear_value = nn.Linear(linear_in_dim, 1) self.relu = nn.LeakyReLU(0.1) self.Wt = nn.Linear(linear_in_dim + params["num_action"], linear_in_dim) self.Wt2 = nn.Linear(linear_in_dim, linear_in_dim) self.Wt3 = nn.Linear(linear_in_dim, linear_in_dim) self.softmax = nn.Softmax()
def _init(self, args): params = args.params assert isinstance( params["num_action"], int), "num_action has to be a number. action = " + str( params["num_action"]) self.params = params self.net = MiniRTSNet(args, output1d=False) self.na = params["num_action"] self.num_unit = params["num_unit_type"] self.num_planes = params["num_planes"] self.num_cmd_type = params["num_cmd_type"] self.mapx = params["map_x"] self.mapy = params["map_y"] out_dim = self.num_planes * self.mapx * self.mapy # After trunk, we can predict the unit action and value. # Four dimensions. unit loc, target loc, self.unit_locs = nn.Conv2d(self.num_planes, 1, 3, padding=1) self.target_locs = nn.Conv2d(self.num_planes, 1, 3, padding=1) self.cmd_types = nn.Linear(out_dim, self.num_cmd_type) self.build_types = nn.Linear(out_dim, self.num_unit) self.value = nn.Linear(out_dim, 1) self.relu = nn.LeakyReLU(0.1) self.softmax = nn.Softmax()
def get_define_args(): return MiniRTSNet.get_define_args() + [ ("ratio_skip_observation", 0.0), ("concat", dict(action="store_true")), ("enable_transition_model", dict(action="store_true")), ("gating", dict(action="store_true")), ]
def _init(self, args): params = args.params assert isinstance(params["num_action"], int), "num_action has to be a number. action = " + str(params["num_action"]) self.params = params self.net = MiniRTSNet(args) if self.params.get("model_no_spatial", False): self.num_unit = params["num_unit_type"] linear_in_dim = (params["num_unit_type"] + 7) else: linear_in_dim = (params["num_unit_type"] + 7) * 25 self.linear_policy = nn.Linear(linear_in_dim, params["num_action"]) self.linear_value = nn.Linear(linear_in_dim, 1) self.softmax = nn.Softmax()
def _init(self, args): params = args.params assert isinstance( params["num_action"], int), "num_action has to be a number. action = " + str( params["num_action"]) self.params = params self.net = MiniRTSNet(args) last_num_channel = self.net.num_channels[-1] if self.params.get("model_no_spatial", False): self.num_unit = params["num_unit_type"] linear_in_dim = last_num_channel else: linear_in_dim = last_num_channel * 25 self.linear_value = nn.Linear(linear_in_dim, params["num_action"])
def _init(self, args): params = args.params assert isinstance(params["num_action"], int), "num_action has to be a number. action = " + str(params["num_action"]) self.params = params self.net = MiniRTSNet(args) if self.params.get("model_no_spatial", False): self.num_unit = params["num_unit_type"] linear_in_dim = (params["num_unit_type"] + 7) else: linear_in_dim = (params["num_unit_type"] + 7) * 25 self.na = params["num_action"] self.linear_policy = nn.Linear(linear_in_dim, self.na) self.linear_value = nn.Linear(linear_in_dim, 1) self.relu = nn.LeakyReLU(0.1) if self.args.concat: self.Wt = nn.Linear(linear_in_dim * 2, linear_in_dim) else: self.Wt = nn.Linear(linear_in_dim, linear_in_dim) self.sigmoid = nn.Sigmoid() # Adaptive gating if getattr(args, "gating", False): self.gate_1 = nn.Linear(linear_in_dim, linear_in_dim) self.gate_2 = nn.Linear(linear_in_dim, 1) self.transition1 = nn.Linear(linear_in_dim + self.na, linear_in_dim) self.transition2 = nn.Linear(linear_in_dim, linear_in_dim) self.num_hidden_dim = linear_in_dim self.softmax = nn.Softmax()
def get_define_args(): return MiniRTSNet.get_define_args()