예제 #1
0
    def _init(self, args):
        params = args.params
        assert isinstance(
            params["num_action"],
            int), "num_action has to be a number. action = " + str(
                params["num_action"])
        self.params = params
        self.net = MiniRTSNet(args)
        last_num_channel = self.net.num_channels[-1]

        if self.params.get("model_no_spatial", False):
            self.num_unit = params["num_unit_type"]
            linear_in_dim = last_num_channel
        else:
            linear_in_dim = last_num_channel * 25

        self.linear_policy = nn.Linear(linear_in_dim, params["num_action"])
        self.linear_value = nn.Linear(linear_in_dim, 1)

        self.relu = nn.LeakyReLU(0.1)

        self.Wt = nn.Linear(linear_in_dim + params["num_action"],
                            linear_in_dim)
        self.Wt2 = nn.Linear(linear_in_dim, linear_in_dim)
        self.Wt3 = nn.Linear(linear_in_dim, linear_in_dim)

        self.softmax = nn.Softmax()
예제 #2
0
    def _init(self, args):
        params = args.params
        assert isinstance(
            params["num_action"],
            int), "num_action has to be a number. action = " + str(
                params["num_action"])
        self.params = params
        self.net = MiniRTSNet(args, output1d=False)

        self.na = params["num_action"]
        self.num_unit = params["num_unit_type"]
        self.num_planes = params["num_planes"]
        self.num_cmd_type = params["num_cmd_type"]
        self.mapx = params["map_x"]
        self.mapy = params["map_y"]

        out_dim = self.num_planes * self.mapx * self.mapy

        # After trunk, we can predict the unit action and value.
        # Four dimensions. unit loc, target loc,
        self.unit_locs = nn.Conv2d(self.num_planes, 1, 3, padding=1)
        self.target_locs = nn.Conv2d(self.num_planes, 1, 3, padding=1)
        self.cmd_types = nn.Linear(out_dim, self.num_cmd_type)
        self.build_types = nn.Linear(out_dim, self.num_unit)
        self.value = nn.Linear(out_dim, 1)

        self.relu = nn.LeakyReLU(0.1)
        self.softmax = nn.Softmax()
예제 #3
0
파일: model_lstm.py 프로젝트: jma127/oldELF
 def get_define_args():
     return MiniRTSNet.get_define_args() + [
         ("ratio_skip_observation", 0.0),
         ("concat", dict(action="store_true")),
         ("enable_transition_model", dict(action="store_true")),
         ("gating", dict(action="store_true")),
     ]
예제 #4
0
파일: model_lstm.py 프로젝트: GenjiWu/ELF
 def get_define_args():
     return MiniRTSNet.get_define_args() + [
         ("ratio_skip_observation", 0.0),
         ("concat", dict(action="store_true")),
         ("enable_transition_model", dict(action="store_true")),
         ("gating", dict(action="store_true")),
     ]
예제 #5
0
    def _init(self, args):
        params = args.params
        assert isinstance(params["num_action"], int), "num_action has to be a number. action = " + str(params["num_action"])
        self.params = params
        self.net = MiniRTSNet(args)

        if self.params.get("model_no_spatial", False):
            self.num_unit = params["num_unit_type"]
            linear_in_dim = (params["num_unit_type"] + 7)
        else:
            linear_in_dim = (params["num_unit_type"] + 7) * 25

        self.linear_policy = nn.Linear(linear_in_dim, params["num_action"])
        self.linear_value = nn.Linear(linear_in_dim, 1)
        self.softmax = nn.Softmax()
예제 #6
0
    def _init(self, args):
        params = args.params
        assert isinstance(
            params["num_action"],
            int), "num_action has to be a number. action = " + str(
                params["num_action"])

        self.params = params
        self.net = MiniRTSNet(args)
        last_num_channel = self.net.num_channels[-1]

        if self.params.get("model_no_spatial", False):
            self.num_unit = params["num_unit_type"]
            linear_in_dim = last_num_channel
        else:
            linear_in_dim = last_num_channel * 25

        self.linear_value = nn.Linear(linear_in_dim, params["num_action"])
예제 #7
0
파일: model_lstm.py 프로젝트: jma127/oldELF
    def _init(self, args):
        params = args.params
        assert isinstance(params["num_action"], int), "num_action has to be a number. action = " + str(params["num_action"])
        self.params = params
        self.net = MiniRTSNet(args)

        if self.params.get("model_no_spatial", False):
            self.num_unit = params["num_unit_type"]
            linear_in_dim = (params["num_unit_type"] + 7)
        else:
            linear_in_dim = (params["num_unit_type"] + 7) * 25

        self.na = params["num_action"]

        self.linear_policy = nn.Linear(linear_in_dim, self.na)
        self.linear_value = nn.Linear(linear_in_dim, 1)

        self.relu = nn.LeakyReLU(0.1)

        if self.args.concat:
            self.Wt = nn.Linear(linear_in_dim * 2, linear_in_dim)
        else:
            self.Wt = nn.Linear(linear_in_dim, linear_in_dim)

        self.sigmoid = nn.Sigmoid()

        # Adaptive gating
        if getattr(args, "gating", False):
            self.gate_1 = nn.Linear(linear_in_dim, linear_in_dim)
            self.gate_2 = nn.Linear(linear_in_dim, 1)

        self.transition1 = nn.Linear(linear_in_dim + self.na, linear_in_dim)
        self.transition2 = nn.Linear(linear_in_dim, linear_in_dim)
        self.num_hidden_dim = linear_in_dim

        self.softmax = nn.Softmax()
예제 #8
0
 def get_define_args():
     return MiniRTSNet.get_define_args()
예제 #9
0
파일: model.py 프로젝트: GenjiWu/ELF
 def get_define_args():
     return MiniRTSNet.get_define_args()