Ejemplo n.º 1
0
    def test_invalid_and_dim(self):
        action = ([3], [[14, 3]])
        with self.assertRaises(AssertionError):
            ActionProcesser(dim=5, rect_delta=7).process(*action)

        assert ActionProcesser(dim=15, rect_delta=2).process(*action) == \
               [FunctionCall(function=3, arguments=[[0], [12, 1], [14, 5]])]

        assert ActionProcesser(dim=40, rect_delta=2).process(*action) == \
               [FunctionCall(function=3, arguments=[[0], [12, 1], [16, 5]])]
Ejemplo n.º 2
0
    def process_logits(self, logits, obs, deterministic):
        """
        The SC2 environment requires special logic to mask out unused network outputs.
        :param logits:
        :return:
        """
        available_actions = obs['available_actions']
        actions, log_probs, entropies, head_masks = OrderedDict(), OrderedDict(), OrderedDict(), OrderedDict()
        headnames = logits.keys()

        for headname in headnames:
            actions[headname], log_probs[headname], entropies[headname] = super().process_logits(logits[headname], obs, deterministic)
            if headname == 'func_id':
                head_masks[headname] = torch.ones_like(entropies[headname])
            else:
                head_masks[headname] = torch.zeros_like(entropies[headname])

        function_calls = []
        # iterate over batch dimension
        for i in range(actions['func_id'].shape[0]):
            # force a no op if action is unavailable
            if actions['func_id'][i] not in available_actions[i]:
                function_calls.append(FunctionCall(0, []))
                continue

            # build the masks and the FunctionCall's in the same loop
            args = []
            func_id = actions['func_id'][i]
            required_heads = lookup_headnames_by_id(func_id)
            for headname in required_heads.keys():
                # toggle mask to 1 if the head is required
                head_masks[headname][i] = 1.

                # skip y's
                if '_y' in headname:
                    continue
                # if x, build the argument
                elif '_x' in headname:
                    args.append([actions[headname][i], actions[headname[:-2] + '_y'][i]])
                else:
                    args.append([actions[headname][i]])
            function_calls.append(FunctionCall(func_id, args))

        # apply masks to log_probs and entropies
        for headname in headnames:
            log_probs[headname] = log_probs[headname] * head_masks[headname]
            entropies[headname] = entropies[headname] * head_masks[headname]

        log_probs = torch.stack(tuple(v for v in log_probs.values()), dim=1)
        entropies = torch.stack(tuple(v for v in entropies.values()), dim=1)
        return function_calls, log_probs, entropies
Ejemplo n.º 3
0
    def wrap_actions(self, actions):
        """
        根据action和action_args的矩阵,
        输出可以pysc2执行的Function call实例
        :param actions: action和action_args的矩阵
        :return: pysc2对应的action ID和pysc2可以执行的args
        """
        # 取出action和action的参数
        acts, args = actions[0], actions[1:]

        wrapped_actions = []
        for i, act in enumerate(acts):  #当前顺序i,action的ID act
            act_args = []
            for arg_type in FUNCTIONS[act].args:  #根据action的ID找到参数名称
                act_arg = [DEFAULT_ARGS[arg_type.name]]  #用config定义的默认参数值来初始化
                if arg_type.name in self.config.act_args:
                    act_arg = [args[self.config.arg_idx[arg_type.name]][i]]
                if is_spatial(
                        arg_type.name):  # spatial args, convert to coords
                    act_arg = [
                        act_arg[0] % self.config.sz,
                        act_arg[0] // self.config.sz
                    ]  # (y,x), fix for PySC2
                act_args.append(act_arg)
            wrapped_actions.append(FunctionCall(
                act, act_args))  #pysc2 可以执行的是这个,原本都是数据

        return wrapped_actions
Ejemplo n.º 4
0
    def test_simple(self):
        a = ActionProcesser(
            dim=40,
            rect_delta=5,
        )

        action_ids = [2, 331, 0, 1]
        coords = ((15, 4), (22, 33), (1, 1), (1, 1))

        actions = a.process(action_ids, coords)

        expected = [
            FunctionCall(function=2, arguments=[[0], (4, 15)]),
            FunctionCall(function=331, arguments=[[0], (33, 22)]),
            FunctionCall(function=0, arguments=[]),
            FunctionCall(function=1, arguments=[(1, 1)])
        ]

        assert actions == expected
Ejemplo n.º 5
0
    def test_rectangle(self):
        dim = 48
        a = ActionProcesser(
            dim=dim,
            rect_delta=7,
        )

        action_ids = [3, 3, 3, 3]
        coords = ((15, 4), (22, 33), (1, 1), (45, 10))

        actions = a.process(action_ids, coords)

        expected = [
            FunctionCall(function=3, arguments=[[0], [8, 0], [22, 11]]),
            FunctionCall(function=3, arguments=[[0], [15, 26], [29, 40]]),
            FunctionCall(function=3, arguments=[[0], [0, 0], [8, 8]]),
            FunctionCall(function=3, arguments=[[0], [38, 3], [47, 17]])
        ]

        assert actions == expected
Ejemplo n.º 6
0
    def run_model(self, model, model_number):
        obs = self.reset()
        print('Env {} reset completed'.format(self.pool_number))

        model.load_weights(self.save_dir + 'model_{}.h5'.format(model_number))
        print('Model {} loaded'.format(model_number))

        games_played = 0
        step = 0
        cumulative_score = 0

        while step < self.game_length:
            observations = obs[0].observation
            screen, minimap = self.translate_observations(observations)

            result = model.predict([np.array([screen]), np.array([minimap])])

            res_dict = self.result_to_dict(result)

            action = self.get_action(res_dict, observations)

            action_args = self.prepare_args(action, res_dict)

            # call action in a new step
            obs = self.step(actions=[FunctionCall(action.id, action_args)])
            step += 1

            if self.state == StepType.FIRST:
                games_played += 1
                current_reward = observations["score_cumulative"][0]
                print('Game {} has ended, score: {}'.format(
                    games_played, current_reward))
                cumulative_score += current_reward

        # TODO: Нужно избегать двойного суммирования score
        games_played += 1
        observations = obs[0].observation
        current_reward = observations["score_cumulative"][0]
        print('Game {} has ended, score: {}'.format(games_played,
                                                    current_reward))
        cumulative_score += current_reward

        print("Model: {}, cumulative_score: {}".format(model_number,
                                                       cumulative_score))

        # TODO: Возможно, лучше будет использовать хитрые фичи multiprocessing, а не костыли с сохранением
        # TODO: Иногда score в одном пуле становится одинаковым. Нужно осознать, почему
        with open(self.save_dir + 'score_{}.txt'.format(model_number),
                  mode='w') as f:
            print(cumulative_score, file=f)
Ejemplo n.º 7
0
 def to_sc2_action(self, action_id, action_args):
     # creates FunctionCall
     chosen_function = FUNCTIONS[action_id]
     f_type = chosen_function.function_type
     function_args = FUNCTION_TYPES[f_type]
     processed_args = []
     for arg in function_args:
         action_arg = action_args[arg]
         if is_spacial[arg]:
             x = action_arg % self.model_config.size
             y = action_arg // self.model_config.size
             action_arg = [x, y]
         processed_args.append(action_arg)
     return FunctionCall(chosen_function.id, processed_args)
Ejemplo n.º 8
0
    def wrap_actions(self, actions):
        acts, args = actions[0], actions[1:]

        wrapped_actions = []
        for i, act in enumerate(acts): # 对于batch中的每个act函数
            act_args = []
            for arg_type in FUNCTIONS[act].args: # 对于该动作函数的每个参数
                act_arg = [DEFAULT_ARGS[arg_type.name]] # 初始化
                if arg_type.name in self.config.act_args:
                    act_arg = [args[self.config.arg_idx[arg_type.name]][i]] # 等于bacth中对应参数的第i项(第i个样本)
                if is_spatial(arg_type.name):  # spatial args, convert to coords
                    act_arg = [act_arg[0] % self.config.sz, act_arg[0] // self.config.sz]  # (y,x), fix for PySC2
                act_args.append(act_arg)
            wrapped_actions.append(FunctionCall(act, act_args))

        return wrapped_actions
def main():
    map_dict = dict()

    map_dict['CollectMineralShards'] = CollectMineralShards
    map_dict['DefeatRoaches'] = DefeatRoaches
    map_dict['DefeatZerglingsAndBanelings'] = DefeatZerglingsAndBanelings
    map_dict['CollectMineralsAndGas'] = CollectMineralsAndGas
    map_dict['BuildMarines'] = BuildMarines
    agent = map_dict[args.map]()
    env = make_sc2env(
        map_name=args.map,
        battle_net_map=False,
        players=[sc2_env.Agent(sc2_env.Race.terran)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_screen=32,
            feature_minimap=32,
            rgb_screen=None,
            rgb_minimap=None,
            action_space=None,
            use_feature_units=True,
            use_raw_units=False),
        step_mul=8,
        game_steps_per_episode=None,
        disable_fog=False,
        visualize=True)
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    if args.map != 'CollectMineralShards' or args.map != 'DefeatRoaches':
        agent.setup(observation_spec[0], action_spec[0])
    agent.reset()

    timesteps = env.reset()
    episodes = 0
    sum_score = 0
    while True:

        a_0, a_1 = agent.step(timesteps[0])

        actions = FunctionCall(a_0, a_1)
        timesteps = env.step([actions])
        if timesteps[0].last():
            i = timesteps[0]
            score = i.observation['score_cumulative'][0]
            sum_score += score
            episodes += 1

            print("episode %d: score = %f" % (episodes, score))
Ejemplo n.º 10
0
    def _actions_to_sc2(self, actions):
        def convert_arg(value, spec):
            if len(spec.sizes) == 2:
                value = np.unravel_index(value, spec.sizes)
                value = np.flip(value)
                return list(value)
            else:
                return [value]

        function = self._func_ids[actions['function_id']]
        args = [
            convert_arg(actions[arg.name].item(),
                        self.spec.action_spec[arg.name])
            for arg in FUNCTIONS[function].args
        ]

        return FunctionCall(function, args)
Ejemplo n.º 11
0
    def _wrap_action(self, action):
        func_id = action['func_id'].item()
        required_heads = self._func_id_to_headnames[func_id]
        args = []

        for headname in required_heads.keys():
            if '_y' in headname:
                continue
            elif '_x' in headname:
                args.append([
                    action[headname].item(),
                    action[headname[:-2] + '_y'].item()
                ])
            else:
                args.append([action[headname].item()])

        return [FunctionCall(func_id, args)]
Ejemplo n.º 12
0
def actions_to_pysc2(actions, size):
    """Convert agent action representation to FunctionCall representation."""
    height, width = size
    fn_id, arg_ids = actions
    actions_list = []
    for n in range(fn_id.shape[0]):
        a_0 = fn_id[n]
        a_l = []
        for arg_type in FUNCTIONS._func_list[a_0].args:
            arg_id = arg_ids[arg_type][n]
            if is_spatial_action[arg_type]:
                arg = [arg_id % width, arg_id // height]
            else:
                arg = [arg_id]
            a_l.append(arg)
        action = FunctionCall(a_0, a_l)
        actions_list.append(action)
    return actions_list
Ejemplo n.º 13
0
def actions_to_pysc2(fn_id, arg_ids, size):
    height, width = size
    actions_list = []

    a_0 = int(fn_id)
    a_l = []
    for arg_type in FUNCTIONS._func_list[a_0].args:
        arg_id = int(arg_ids[arg_type])
        if is_spatial_action[arg_type]:
            arg = [arg_id % width, arg_id // height]
        else:
            arg = [arg_id]

        a_l.append(arg)

    action = FunctionCall(a_0, a_l)
    actions_list.append(action)

    return actions_list
Ejemplo n.º 14
0
 def reverse(self, actions):
     action_indexes = self._config._action_indexes
     action_arg_table = self._config._action_args_index_table
     size = self._config._size
     function_calls = []
     for action in zip(*actions):
         act_id = action_indexes[action[0]]
         act_args = []
         for arg_type in FUNCTIONS[act_id].args:
             arg_index = action_arg_table[arg_type.name]
             arg_value = action[1:][arg_index]
             if arg_type.name in SPATIAL_ARG_TYPES:
                 act_args.append(
                     [arg_value % size,
                      arg_value // size])
             else:
                 act_args.append([arg_value])
         function_call = FunctionCall(act_id, act_args)
         function_calls.append(function_call)
     return function_calls
Ejemplo n.º 15
0
    def functioncall_action(self, actions, size):
        height, width = size
        fn_id, arg_ids = actions
        fn_id = fn_id.numpy().tolist()
        actions_list = []
        for n in range(len(fn_id)):
            a_0 = fn_id[n]
            a_l = []
            for arg_type in FUNCTIONS._func_list[a_0].args:
                arg_id = arg_ids[arg_type][n].detach(
                ).numpy().squeeze().tolist()
                if is_spatial_action[arg_type]:
                    arg = [arg_id % width, arg_id // height]
                else:
                    arg = [arg_id]
                a_l.append(arg)
            action = FunctionCall(a_0, a_l)

            actions_list.append(action)
        return actions_list
Ejemplo n.º 16
0
 def localactions_to_pysc2(self, actions, extractor):
     """Convert agent action representation to FunctionCall representation."""
     height, width = (24, 24)
     fn_id, world_pt = actions
     actions_list = []
     reward_addon = 0
     a_0 = fn_id.item()
     a_0 = localaction_table[a_0]
     # print(extractor.raw_unit.x,extractor.raw_unit.y,world_pt.item() // width,world_pt.item() % width)
     x = extractor.raw_unit.x - (world_pt.item() // width - 11)
     y = extractor.raw_unit.y - (world_pt.item() % height - 11)
     if x < 0:
         x = 0
     if y < 0:
         y = 0
     arg = [x, y]
     a_0, args, r = extractor.get_action(
         a_0, [[0], [extractor.unit_tag], arg])
     reward_addon += r
     action = FunctionCall(a_0, args)
     # actions_list.append(action)
     return reward_addon, action
Ejemplo n.º 17
0
def action_to_pysc2(agent_action):

    [[base_action, args, spatial_args]] = agent_action

    base_action_func = FUNCTIONS._func_list[base_action]
    arg_types = base_action_func.args
    arg_ids = [arg_types[i].id for i in range(len(arg_types))]

    arg_inputs = []
    spatial_arg_inputs = []
    for i in range(len(arg_ids)):
        id = arg_ids[i]
        if is_spatial_arg(id):
            spatial_arg_inputs.append(spatial_args[id])
        elif TYPES[id].values is not None:
            arg_inputs.append([TYPES[id].values(args[id - 3])])
        else:
            arg_inputs.append([args[id - 3]])

    function = FunctionCall(base_action_func.id,
                            arg_inputs + spatial_arg_inputs)
    return function
Ejemplo n.º 18
0
 def wrap_actions(self, actions):
     pol_mask = [torch.zeros((self.envs.num_envs)).float().to(self.device) \
                 for _ in range(1 + len(ARG_TYPES))]
     pol_mask[0].fill_(1.0)
     acts, args = actions[0], actions[1:]
     wrapped_actions = []
     for i, act in enumerate(acts):
         fn = TERRAN_FUNCTIONS[act]
         act_args = []
         for arg_type in fn.args:
             act_arg = [args[self.config.arg_idx[arg_type.name]][i]]
             pol_mask[self.config.arg_idx[arg_type.name] + 1][i] = 1.
             if arg_type.name == 'queued':
                 act_arg = [False]
             if is_spatial(
                     arg_type.name):  # spatial args, convert to coords
                 act_arg = [
                     act_arg[0] % self.config.sz,
                     act_arg[0] // self.config.sz
                 ]
             act_args.append(act_arg)
         wrapped_actions.append(FunctionCall(fn.id, act_args))
     return wrapped_actions, pol_mask
Ejemplo n.º 19
0
    def space_to_function(self, space) -> FunctionCall:
        # space is in the action_space defined in 'make_action_space()'.

        # first extract the action id
        action_id = space[0]
        # then determine the required arguments for this action
        chosen_args = []
        f = actions.FUNCTIONS[self._function_names[action_id]]
        for arg in f.args:
            option_space, option_type = self._option_space_and_type(arg.id)
            if option_type == 'point':
                chosen_args.append(space[1])
            elif option_type == 'enum':
                chosen_args.append(self._get_fixed_arg_value(arg.id))
            elif option_type == 'scalar':
                logging.warning('SCALAR NOT YET SUPPORTED!')
                chosen_args.append(space[1][0])
            else:
                # Note: screen2 would land here
                raise ValueError('unsupported argtype passed')

        # finally return the FunctionCall that can be executed in PySC2
        return FunctionCall(self._function_ids[action_id], chosen_args)
Ejemplo n.º 20
0
def RuleBase(net):
    map_name = 'CollectMineralShards'
    total_episodes = 100
    total_updates = -1
    sum_score = 0
    n_steps = 8
    learning_rate = 1e-4
    optimizer = optim.Adam(net.parameters(), learning_rate, weight_decay=0.01)
    env = make_sc2env(
        map_name=map_name,
        battle_net_map=False,
        players=[sc2_env.Agent(sc2_env.Race.terran)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_screen=32,
            feature_minimap=32,
            rgb_screen=None,
            rgb_minimap=None,
            action_space=None,
            use_feature_units=False,
            use_raw_units=False),
        step_mul=8,
        game_steps_per_episode=None,
        disable_fog=False,
        visualize=True)

    processor = Preprocessor(env.observation_spec()[0])
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    agent = CollectMineralShards()
    episodes = 0
    agent.reset()
    timesteps = env.reset()
    while True:
        fn_ids = []
        args_ids = []
        observations = []
        for step in range(n_steps):
            a_0, a_1 = agent.step(timesteps[0])
            obs = processor.preprocess_obs(timesteps)
            observations.append(obs)
            actions = FunctionCall(a_0, a_1)
            fn_id = torch.LongTensor([a_0]).cuda()
            args_id = {}
            if a_0 == 7:
                for type in ACTION_TYPES:
                    if type.name == 'select_add':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 331:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':

                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            action = (fn_id, args_id)
            fn_ids.append(fn_id)
            args_ids.append(args_id)
            timesteps = env.step([actions])
            if timesteps[0].last():
                i = timesteps[0]
                score = i.observation['score_cumulative'][0]
                sum_score += score
                episodes += 1
                if episodes % 50 == 0:
                    torch.save(net.state_dict(),
                               './save/episode2' + str(episodes) + str('.pkl'))
                print("episode %d: score = %f" % (episodes, score))

        observations = flatten_first_dims_dict(
            stack_ndarray_dicts(observations))

        train_fn_ids = torch.cat(fn_ids)
        train_arg_ids = {}

        for k in args_ids[0].keys():
            temp = []
            temp = [d[k] for d in args_ids]

            train_arg_ids[k] = torch.cat(temp, dim=0)

        screen = torch.FloatTensor(observations['screen']).cuda()
        minimap = torch.FloatTensor(observations['minimap']).cuda()
        flat = torch.FloatTensor(observations['flat']).cuda()
        policy, _ = net(screen, minimap, flat)

        fn_pi, args_pi = policy
        available_actions = torch.FloatTensor(
            observations['available_actions']).cuda()
        function_pi = available_actions * fn_pi
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        Loss = nn.CrossEntropyLoss(reduction='none')
        loss = Loss(function_pi, train_fn_ids)

        for type in train_arg_ids.keys():
            id = train_arg_ids[type]
            pi = args_pi[type]
            arg_loss_list = []
            for i, p in zip(id, pi):
                if i == -1:
                    temp = torch.zeros((1)).cuda()
                else:
                    a = torch.LongTensor([i]).cuda()
                    b = torch.unsqueeze(p, dim=0).cuda()
                    temp = Loss(b, a)
                arg_loss_list.append(temp)

            arg_loss = torch.cat(arg_loss_list)
            loss += arg_loss
        loss = loss.mean()
        print(loss)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        if episodes >= total_episodes:
            break
    torch.save(net.state_dict(), './save/episode1' + str('.pkl'))
Ejemplo n.º 21
0
def RuleBase6(net, map, process):
    map_name = 'CollectMineralsAndGas'
    value_coef = 0.01
    total_episodes = 20
    total_updates = -1
    sum_score = 0
    n_steps = 8
    learning_rate = 1e-5
    optimizer = optim.Adam(net.parameters(), learning_rate, weight_decay=0.01)
    env = make_sc2env(
        map_name=map_name,
        battle_net_map=False,
        players=[sc2_env.Agent(sc2_env.Race.terran)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_screen=32,
            feature_minimap=32,
            rgb_screen=None,
            rgb_minimap=None,
            action_space=None,
            use_feature_units=True,
            use_raw_units=False),
        step_mul=8,
        game_steps_per_episode=None,
        disable_fog=False,
        visualize=True)

    processor = Preprocessor(env.observation_spec()[0], map, process)
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    agent = CollectMineralsAndGas()
    agent.setup(observation_spec[0], action_spec[0])
    episodes = 0
    agent.reset()
    timesteps = env.reset()
    while True:
        fn_ids = []
        args_ids = []
        observations = []
        rewards = []
        dones = []
        for step in range(n_steps):
            a_0, a_1 = agent.step(timesteps[0])
            obs = processor.preprocess_obs(timesteps)
            observations.append(obs)
            actions = FunctionCall(a_0, a_1)
            fn_id = torch.LongTensor([a_0]).cuda()
            args_id = {}
            if a_0 == 2:
                for type in ACTION_TYPES:
                    if type.name == 'select_point_act':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':
                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 91 or a_0 == 44 or a_0 == 264:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':

                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 490:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 0:
                for type in ACTION_TYPES:
                    args_id[type] = torch.LongTensor([-1]).cuda()
            action = (fn_id, args_id)
            fn_ids.append(fn_id)
            args_ids.append(args_id)
            timesteps = env.step([actions])
            rewards.append(torch.FloatTensor([timesteps[0].reward]).cuda())
            dones.append(torch.IntTensor([timesteps[0].last()]).cuda())

            if timesteps[0].last():
                i = timesteps[0]
                score = i.observation['score_cumulative'][0]
                sum_score += score
                episodes += 1
                if episodes % 1 == 0:
                    torch.save(net.state_dict(),
                               './save/game6_' + str(episodes) + str('.pkl'))
                print("episode %d: score = %f" % (episodes, score))
            # obs = processor.preprocess_obs(timesteps)
            # observations.append(obs)
        rewards = torch.cat(rewards)
        dones = torch.cat(dones)
        with torch.no_grad():
            obs = processor.preprocess_obs(timesteps)
            screen = torch.FloatTensor(obs['screen']).cuda()
            minimap = torch.FloatTensor(obs['minimap']).cuda()
            flat = torch.FloatTensor(obs['flat']).cuda()
            _, next_value = net(screen, minimap, flat)

        observations = flatten_first_dims_dict(
            stack_ndarray_dicts(observations))

        train_fn_ids = torch.cat(fn_ids)
        train_arg_ids = {}

        for k in args_ids[0].keys():
            temp = []
            temp = [d[k] for d in args_ids]

            train_arg_ids[k] = torch.cat(temp, dim=0)

        screen = torch.FloatTensor(observations['screen']).cuda()
        minimap = torch.FloatTensor(observations['minimap']).cuda()
        flat = torch.FloatTensor(observations['flat']).cuda()
        policy, value = net(screen, minimap, flat)

        returns = torch.zeros((rewards.shape[0] + 1, ), dtype=float)
        returns[-1] = next_value
        for i in reversed(range(rewards.shape[0])):
            next_rewards = 0.999 * returns[i + 1] * (1 - dones[i])
            returns[i] = rewards[i] + next_rewards
        returns = returns[:-1].cuda()

        fn_pi, args_pi = policy
        available_actions = torch.FloatTensor(
            observations['available_actions']).cuda()
        function_pi = available_actions * fn_pi
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        Loss = nn.CrossEntropyLoss(reduction='none')
        function_pi = torch.clamp(function_pi, 1e-4, 1 - (1e-4))
        policy_loss = Loss(function_pi, train_fn_ids)

        for type in train_arg_ids.keys():
            id = train_arg_ids[type]
            pi = args_pi[type]
            arg_loss_list = []
            for i, p in zip(id, pi):
                if i == -1:
                    temp = torch.zeros((1)).cuda()
                else:
                    a = torch.LongTensor([i]).cuda()
                    b = torch.unsqueeze(p, dim=0).cuda()
                    b = torch.clamp(b, 1e-4, 1 - (1e-4))
                    temp = Loss(b, a)
                arg_loss_list.append(temp)

            arg_loss = torch.cat(arg_loss_list)
            policy_loss += arg_loss
        policy_loss = policy_loss.mean()
        value_loss = (returns - value).pow(2).mean()
        print(policy_loss, value_loss)
        loss = policy_loss + value_coef * value_loss
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        if episodes >= total_episodes:
            break
    torch.save(net.state_dict(), './save/game6_final' + str('.pkl'))