def __init__(self, envs, args):
        self.value_loss_coefficient = args.value_loss_weight
        self.entropy_coefficient = args.entropy_weight
        self.learning_rate = args.lr
        self.envs = envs

        self.map = args.map
        self.env_num = args.envs
        self.save = args.save_eposides
        self.save_dir = args.save_dir
        self.processor = Preprocessor(self.envs.observation_spec()[0],
                                      self.map, args.process_screen)
        self.sum_score = 0
        self.n_steps = 8
        self.gamma = 0.999
        self.sum_episode = 0
        self.total_updates = -1

        if args.process_screen:
            self.net = CNN(348, 1985).cuda()
        else:
            self.net = CNN().cuda()
        self.optimizer = optim.Adam(self.net.parameters(),
                                    self.learning_rate,
                                    weight_decay=0.01)
Beispiel #2
0
 def __init__(self, envs):
     self.value_loss_coefficient = 0.5
     self.entropy_coefficient = 0.05
     self.learning_rate = 1e-4
     self.envs = envs
     self.processor = Preprocessor(self.envs.observation_spec()[0])
     self.sum_score = 0
     self.last_score = 0
     self.n_steps = 8
     self.gamma = 0.99
     self.sum_episode = 0
     self.total_updates = -1
     self.net = CNN().cuda()
     self.optimizer = optim.Adam(self.net.parameters(),
                                 self.learning_rate,
                                 weight_decay=0.01)
Beispiel #3
0
 def __init__(self, envs):
     self.value_loss_coefficient = 0.5
     self.entropy_coefficient = 0.05
     self.learning_rate = 1e-4
     self.envs = envs
     self.env_num=8
     self.processor = Preprocessor(self.envs.observation_spec()[0])
     self.sum_score = 0
     self.n_steps = 512
     self.gamma = 0.999
     self.clip=0.27
     self.sum_episode = 0
     self.total_updates = -1
     self.net = CNN().cuda()
     self.old_net = copy.deepcopy(self.net)
     self.old_net.cuda()
     self.epoch=4
     self.batch_size=8
     self.optimizer = optim.Adam(
         self.net.parameters(), self.learning_rate, weight_decay=0.01)
Beispiel #4
0
class PPO():

    def __init__(self, envs):
        self.value_loss_coefficient = 0.5
        self.entropy_coefficient = 0.05
        self.learning_rate = 1e-4
        self.envs = envs
        self.env_num=8
        self.processor = Preprocessor(self.envs.observation_spec()[0])
        self.sum_score = 0
        self.n_steps = 512
        self.gamma = 0.999
        self.clip=0.27
        self.sum_episode = 0
        self.total_updates = -1
        self.net = CNN().cuda()
        self.old_net = copy.deepcopy(self.net)
        self.old_net.cuda()
        self.epoch=4
        self.batch_size=8
        self.optimizer = optim.Adam(
            self.net.parameters(), self.learning_rate, weight_decay=0.01)

    def reset(self):
        self.obs_start = self.envs.reset()
        self.last_obs = self.processor.preprocess_obs(self.obs_start)

    def grad_step(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        policy, value = self.net(screen, minimap, flat)
        return policy, value

    def step(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        with torch.no_grad():
            policy, value = self.net(screen, minimap, flat)
        return policy, value
    def old_step(self,observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        with torch.no_grad():
            policy, value = self.old_net(screen, minimap, flat)
        return policy, value
    def select_actions(self, policy, last_obs):
        available_actions = last_obs['available_actions']

        def sample(prob):
            actions = Categorical(prob).sample()
            return actions
        function_pi, args_pi = policy
        available_actions = torch.FloatTensor(available_actions)
        function_pi = available_actions*function_pi.cpu()
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        try:
            function_sample = sample(function_pi)
        except:
            return 0
        args_sample = dict()
        for type, pi in args_pi.items():
            if type.name == 'queued':
                args_sample[type] = torch.zeros((self.env_num,),dtype=int)
            else:
                args_sample[type] = sample(pi).cpu()
        return function_sample, args_sample

    def mask_unused_action(self, actions):
        fn_id, arg_ids = actions
        for n in range(fn_id.shape[0]):
            a_0 = fn_id[n]
            unused_types = set(ACTION_TYPES) - \
                set(FUNCTIONS._func_list[a_0].args)
            for arg_type in unused_types:
                arg_ids[arg_type][n] = -1
        return (fn_id, arg_ids)

    def functioncall_action(self, actions, size):
        height, width = size
        fn_id, arg_ids = actions
        fn_id = fn_id.numpy().tolist()
        actions_list = []
        for n in range(len(fn_id)):
            a_0 = fn_id[n]
            a_l = []
            for arg_type in FUNCTIONS._func_list[a_0].args:
                arg_id = arg_ids[arg_type][n].detach(
                ).numpy().squeeze().tolist()
                if is_spatial_action[arg_type]:
                    arg = [arg_id % width, arg_id // height]
                else:
                    arg = [arg_id]
                a_l.append(arg)
            action = FunctionCall(a_0, a_l)

            actions_list.append(action)
        return actions_list

    def get_value(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        with torch.no_grad():
            _, value = self.net(screen, minimap, flat)
        return value

    def train(self):
        obs_raw = self.obs_start
        shape = (self.n_steps, self.envs.n_envs)
        sample_values = np.zeros(shape, dtype=np.float32)
        sample_obersavation = []
        sample_rewards = np.zeros(shape, dtype=np.float32)
        sample_actions = []
        sample_dones = np.zeros(shape, dtype=np.float32)
        scores = []
        last_obs = self.last_obs
        for step in range(self.n_steps):
            policy, value = self.step(last_obs)

            actions = self.select_actions(policy, last_obs)
            if actions == 0:
                self.sum_episode = 7
                self.sum_score = 0
                return
            actions = self.mask_unused_action(actions)

            size = last_obs['screen'].shape[2:4]
            sample_values[step, :] = value.cpu()
            sample_obersavation.append(last_obs)
            sample_actions.append(actions)
            pysc2_action = self.functioncall_action(actions, size)

            '''fn_id, args_id = actions
            if fn_id[0].cpu().numpy().squeeze() in obs_raw[0].observation['available_actions']:
                print('1,True')
            else: print('1.False'),printoobs_info(obs_raw[0])
            if fn_id[1].cpu().numpy().squeeze() in obs_raw[1].observation['available_actions']:
                print('2,True')
            else: print('2.False'),printoobs_info(obs_raw[1])
            print(last_obs['available_actions'][0][fn_id[0]], last_obs['available_actions'][1][fn_id[1]],fn_id)'''
            obs_raw = self.envs.step(pysc2_action)
            # print("0:",pysc2_action[0].function)
            # print("1:",pysc2_action[1].function)

            last_obs = self.processor.preprocess_obs(obs_raw)
            sample_rewards[step, :] = [
                i.reward for i in obs_raw]
            sample_dones[step, :] = [i.last() for i in obs_raw]

            for i in obs_raw:
                if i.last():
                    score = i.observation['score_cumulative'][0]
                    self.sum_score += score
                    self.sum_episode += 1
                    print("episode %d: score = %f" % (self.sum_episode, score))
                    # if self.sum_episode % 10 == 0:
                    #     torch.save(self.net.state_dict(), './save/episode' +
                    #                str(self.sum_episode)+'_score'+str(score)+'.pkl')

        self.last_obs = last_obs
        next_value = self.get_value(last_obs).cpu()

        returns = np.zeros(
            [sample_rewards.shape[0]+1, sample_rewards.shape[1]])
        returns[-1, :] = next_value
        for i in reversed(range(sample_rewards.shape[0])):
            next_rewards = self.gamma*returns[i+1, :]*(1-sample_dones[i, :])
            returns[i, :] = sample_rewards[i, :]+next_rewards
        returns = returns[:-1, :]
        advantages = returns-sample_values
        self.old_net.load_state_dict(self.net.state_dict())
        actions = stack_and_flatten_actions(sample_actions)
        observation = flatten_first_dims_dict(
            stack_ndarray_dicts(sample_obersavation))
        returns = flatten_first_dims(returns)
        advantages = flatten_first_dims(advantages)
        self.learn(observation, actions, returns, advantages)

    def learn(self, observation, actions, returns, advantages):
        temp=np.arange(returns.shape[0])
        minibatch=returns.shape[0]//self.batch_size
        screen=observation['screen']
        flat=observation['flat']
        minimap=observation['minimap']
        a_actions=observation['available_actions']
        args_id=actions[1]
        for _ in range(self.epoch):
            np.random.shuffle(temp)
            for i in range(0,returns.shape[0],minibatch):
                j=i+minibatch
                shuffle=temp[i:j]
                batch_screen=screen[shuffle]
                batch_minimap=minimap[shuffle]
                batch_flat=flat[shuffle]
                batch_a_actions=a_actions[shuffle]
                batch_observation={'screen': batch_screen,
                                    'minimap': batch_minimap,
                                    'flat': batch_flat,
                                    'available_actions': batch_a_actions}
                batch_advantages=advantages[shuffle]
                batch_fn_id=actions[0][shuffle]

                batch_args_id={k:v[shuffle] for k, v in args_id.items()}
                batch_actions=(batch_fn_id,batch_args_id)
                batch_returns=returns[shuffle]

                batch_advantages = torch.FloatTensor(batch_advantages).cuda()
                batch_returns = torch.FloatTensor(batch_returns).cuda()
                batch_advantages = (batch_advantages - batch_advantages.mean()) / (batch_advantages.std() + 1e-8)

                policy, batch_value = self.grad_step(batch_observation)
                log_probs = compute_policy_log_probs(
                    batch_observation['available_actions'], policy, batch_actions).squeeze()

                old_policy, _ =self.old_step(batch_observation)
                old_log_probs=compute_policy_log_probs(
                    batch_observation['available_actions'], old_policy, batch_actions).squeeze().detach()
                ratio=torch.exp(log_probs-old_log_probs)
                temp1=ratio*batch_advantages
                temp2=torch.clamp(ratio, 1 - self.clip, 1 + self.clip) * batch_advantages

                policy_loss = -torch.min(temp1, temp2).mean()

                value_loss = (batch_returns-batch_value).pow(2).mean()
                entropy_loss = compute_policy_entropy(
                    batch_observation['available_actions'], policy, batch_actions)
                loss = policy_loss+value_loss*self.value_loss_coefficient +\
                    entropy_loss*self.entropy_coefficient
                # loss=loss.requires_grad_()
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
                self.optimizer.step()
class A2C():
    def __init__(self, envs, args):
        self.value_loss_coefficient = args.value_loss_weight
        self.entropy_coefficient = args.entropy_weight
        self.learning_rate = args.lr
        self.envs = envs

        self.map = args.map
        self.env_num = args.envs
        self.save = args.save_eposides
        self.save_dir = args.save_dir
        self.processor = Preprocessor(self.envs.observation_spec()[0],
                                      self.map, args.process_screen)
        self.sum_score = 0
        self.n_steps = 8
        self.gamma = 0.999
        self.sum_episode = 0
        self.total_updates = -1

        if args.process_screen:
            self.net = CNN(348, 1985).cuda()
        else:
            self.net = CNN().cuda()
        self.optimizer = optim.Adam(self.net.parameters(),
                                    self.learning_rate,
                                    weight_decay=0.01)

    def reset(self):
        self.obs_start = self.envs.reset()
        self.last_obs = self.processor.preprocess_obs(self.obs_start)

    def grad_step(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        policy, value = self.net(screen, minimap, flat)
        return policy, value

    def step(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        with torch.no_grad():
            policy, value = self.net(screen, minimap, flat)
        return policy, value

    def select_actions(self, policy, last_obs):
        available_actions = last_obs['available_actions']

        def sample(prob):
            actions = Categorical(prob).sample()
            return actions

        function_pi, args_pi = policy
        available_actions = torch.FloatTensor(available_actions)
        function_pi = available_actions * function_pi.cpu()
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        try:
            function_sample = sample(function_pi)
        except:
            return 0
        args_sample = dict()
        for type, pi in args_pi.items():
            if type.name == 'queued':
                args_sample[type] = torch.zeros((self.env_num, ), dtype=int)
            else:
                args_sample[type] = sample(pi).cpu()
        return function_sample, args_sample

    def determined_actions(self, policy, last_obs):
        available_actions = last_obs['available_actions']

        def sample(prob):
            actions = torch.argmax(prob, dim=1)
            return actions

        function_pi, args_pi = policy
        available_actions = torch.FloatTensor(available_actions)
        function_pi = available_actions * function_pi.cpu()
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        try:
            function_sample = sample(function_pi)
        except:
            return 0
        args_sample = dict()
        for type, pi in args_pi.items():
            if type.name == 'queued':
                args_sample[type] = torch.zeros((self.env_num, ), dtype=int)
            else:
                args_sample[type] = sample(pi).cpu()
        return function_sample, args_sample

    def mask_unused_action(self, actions):
        fn_id, arg_ids = actions
        for n in range(fn_id.shape[0]):
            a_0 = fn_id[n]
            unused_types = set(ACTION_TYPES) - \
                set(FUNCTIONS._func_list[a_0].args)
            for arg_type in unused_types:
                arg_ids[arg_type][n] = -1
        return (fn_id, arg_ids)

    def functioncall_action(self, actions, size):
        height, width = size
        fn_id, arg_ids = actions
        fn_id = fn_id.numpy().tolist()
        actions_list = []
        for n in range(len(fn_id)):
            a_0 = fn_id[n]
            a_l = []
            for arg_type in FUNCTIONS._func_list[a_0].args:
                arg_id = arg_ids[arg_type][n].detach().numpy().squeeze(
                ).tolist()
                if is_spatial_action[arg_type]:
                    arg = [arg_id % width, arg_id // height]
                else:
                    arg = [arg_id]
                a_l.append(arg)
            action = FunctionCall(a_0, a_l)

            actions_list.append(action)
        return actions_list

    def get_value(self, observation):
        screen = torch.FloatTensor(observation['screen']).cuda()
        minimap = torch.FloatTensor(observation['minimap']).cuda()
        flat = torch.FloatTensor(observation['flat']).cuda()
        with torch.no_grad():
            _, value = self.net(screen, minimap, flat)
        return value

    def train(self):
        obs_raw = self.obs_start
        shape = (self.n_steps, self.envs.n_envs)
        sample_values = np.zeros(shape, dtype=np.float32)
        sample_obersavation = []
        sample_rewards = np.zeros(shape, dtype=np.float32)
        sample_actions = []
        sample_dones = np.zeros(shape, dtype=np.float32)
        scores = []
        last_obs = self.last_obs
        for step in range(self.n_steps):
            policy, value = self.step(last_obs)

            actions = self.select_actions(policy, last_obs)
            if actions == 0:
                self.sum_episode = 7
                self.sum_score = 0
                return
            actions = self.mask_unused_action(actions)

            size = last_obs['screen'].shape[2:4]
            sample_values[step, :] = value.cpu()
            sample_obersavation.append(last_obs)
            sample_actions.append(actions)
            pysc2_action = self.functioncall_action(actions, size)

            obs_raw = self.envs.step(pysc2_action)
            # print("0:",pysc2_action[0].function)
            # print("1:",pysc2_action[1].function)

            last_obs = self.processor.preprocess_obs(obs_raw)
            sample_rewards[step, :] = [
                1 if i.reward else -0.1 for i in obs_raw
            ]
            sample_dones[step, :] = [i.last() for i in obs_raw]

            for i in obs_raw:
                if i.last():
                    score = i.observation['score_cumulative'][0]
                    self.sum_score += score
                    self.sum_episode += 1
                    print("episode %d: score = %f" % (self.sum_episode, score))
                    if self.sum_episode % self.save == 0:
                        torch.save(
                            self.net.state_dict(),
                            self.save_dir + '/' + str(self.sum_episode) +
                            '_score' + str(score) + '.pkl')

        self.last_obs = last_obs
        next_value = self.get_value(last_obs).cpu()

        returns = np.zeros(
            [sample_rewards.shape[0] + 1, sample_rewards.shape[1]])
        returns[-1, :] = next_value
        for i in reversed(range(sample_rewards.shape[0])):
            next_rewards = self.gamma * returns[i + 1, :] * (
                1 - sample_dones[i, :])
            returns[i, :] = sample_rewards[i, :] + next_rewards
        returns = returns[:-1, :]
        advantages = returns - sample_values
        actions = stack_and_flatten_actions(sample_actions)
        observation = flatten_first_dims_dict(
            stack_ndarray_dicts(sample_obersavation))
        returns = flatten_first_dims(returns)
        advantages = flatten_first_dims(advantages)
        self.learn(observation, actions, returns, advantages)

    def learn(self, observation, actions, returns, advantages):
        advantages = torch.FloatTensor(advantages).cuda()
        returns = torch.FloatTensor(returns).cuda()
        policy, value = self.grad_step(observation)
        log_probs = compute_policy_log_probs(observation['available_actions'],
                                             policy, actions).squeeze()

        policy_loss = -(log_probs * advantages).mean()
        value_loss = (returns - value).pow(2).mean()
        entropy_loss = compute_policy_entropy(observation['available_actions'],
                                              policy, actions)
        loss = policy_loss+value_loss*self.value_loss_coefficient +\
            entropy_loss*self.entropy_coefficient
        #print(loss)
        # loss=loss.requires_grad_()
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.0)
        self.optimizer.step()
Beispiel #6
0
def RuleBase(net):
    map_name = 'CollectMineralShards'
    total_episodes = 100
    total_updates = -1
    sum_score = 0
    n_steps = 8
    learning_rate = 1e-4
    optimizer = optim.Adam(net.parameters(), learning_rate, weight_decay=0.01)
    env = make_sc2env(
        map_name=map_name,
        battle_net_map=False,
        players=[sc2_env.Agent(sc2_env.Race.terran)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_screen=32,
            feature_minimap=32,
            rgb_screen=None,
            rgb_minimap=None,
            action_space=None,
            use_feature_units=False,
            use_raw_units=False),
        step_mul=8,
        game_steps_per_episode=None,
        disable_fog=False,
        visualize=True)

    processor = Preprocessor(env.observation_spec()[0])
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    agent = CollectMineralShards()
    episodes = 0
    agent.reset()
    timesteps = env.reset()
    while True:
        fn_ids = []
        args_ids = []
        observations = []
        for step in range(n_steps):
            a_0, a_1 = agent.step(timesteps[0])
            obs = processor.preprocess_obs(timesteps)
            observations.append(obs)
            actions = FunctionCall(a_0, a_1)
            fn_id = torch.LongTensor([a_0]).cuda()
            args_id = {}
            if a_0 == 7:
                for type in ACTION_TYPES:
                    if type.name == 'select_add':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 331:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':

                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            action = (fn_id, args_id)
            fn_ids.append(fn_id)
            args_ids.append(args_id)
            timesteps = env.step([actions])
            if timesteps[0].last():
                i = timesteps[0]
                score = i.observation['score_cumulative'][0]
                sum_score += score
                episodes += 1
                if episodes % 50 == 0:
                    torch.save(net.state_dict(),
                               './save/episode2' + str(episodes) + str('.pkl'))
                print("episode %d: score = %f" % (episodes, score))

        observations = flatten_first_dims_dict(
            stack_ndarray_dicts(observations))

        train_fn_ids = torch.cat(fn_ids)
        train_arg_ids = {}

        for k in args_ids[0].keys():
            temp = []
            temp = [d[k] for d in args_ids]

            train_arg_ids[k] = torch.cat(temp, dim=0)

        screen = torch.FloatTensor(observations['screen']).cuda()
        minimap = torch.FloatTensor(observations['minimap']).cuda()
        flat = torch.FloatTensor(observations['flat']).cuda()
        policy, _ = net(screen, minimap, flat)

        fn_pi, args_pi = policy
        available_actions = torch.FloatTensor(
            observations['available_actions']).cuda()
        function_pi = available_actions * fn_pi
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        Loss = nn.CrossEntropyLoss(reduction='none')
        loss = Loss(function_pi, train_fn_ids)

        for type in train_arg_ids.keys():
            id = train_arg_ids[type]
            pi = args_pi[type]
            arg_loss_list = []
            for i, p in zip(id, pi):
                if i == -1:
                    temp = torch.zeros((1)).cuda()
                else:
                    a = torch.LongTensor([i]).cuda()
                    b = torch.unsqueeze(p, dim=0).cuda()
                    temp = Loss(b, a)
                arg_loss_list.append(temp)

            arg_loss = torch.cat(arg_loss_list)
            loss += arg_loss
        loss = loss.mean()
        print(loss)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        if episodes >= total_episodes:
            break
    torch.save(net.state_dict(), './save/episode1' + str('.pkl'))
Beispiel #7
0
def RuleBase6(net, map, process):
    map_name = 'CollectMineralsAndGas'
    value_coef = 0.01
    total_episodes = 20
    total_updates = -1
    sum_score = 0
    n_steps = 8
    learning_rate = 1e-5
    optimizer = optim.Adam(net.parameters(), learning_rate, weight_decay=0.01)
    env = make_sc2env(
        map_name=map_name,
        battle_net_map=False,
        players=[sc2_env.Agent(sc2_env.Race.terran)],
        agent_interface_format=sc2_env.parse_agent_interface_format(
            feature_screen=32,
            feature_minimap=32,
            rgb_screen=None,
            rgb_minimap=None,
            action_space=None,
            use_feature_units=True,
            use_raw_units=False),
        step_mul=8,
        game_steps_per_episode=None,
        disable_fog=False,
        visualize=True)

    processor = Preprocessor(env.observation_spec()[0], map, process)
    observation_spec = env.observation_spec()
    action_spec = env.action_spec()
    agent = CollectMineralsAndGas()
    agent.setup(observation_spec[0], action_spec[0])
    episodes = 0
    agent.reset()
    timesteps = env.reset()
    while True:
        fn_ids = []
        args_ids = []
        observations = []
        rewards = []
        dones = []
        for step in range(n_steps):
            a_0, a_1 = agent.step(timesteps[0])
            obs = processor.preprocess_obs(timesteps)
            observations.append(obs)
            actions = FunctionCall(a_0, a_1)
            fn_id = torch.LongTensor([a_0]).cuda()
            args_id = {}
            if a_0 == 2:
                for type in ACTION_TYPES:
                    if type.name == 'select_point_act':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':
                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 91 or a_0 == 44 or a_0 == 264:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    elif type.name == 'screen':

                        args_id[type] = torch.LongTensor(
                            [a_1[1][1] * 32 + a_1[1][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 490:
                for type in ACTION_TYPES:
                    if type.name == 'queued':
                        args_id[type] = torch.LongTensor([a_1[0][0]]).cuda()
                    else:
                        args_id[type] = torch.LongTensor([-1]).cuda()
            elif a_0 == 0:
                for type in ACTION_TYPES:
                    args_id[type] = torch.LongTensor([-1]).cuda()
            action = (fn_id, args_id)
            fn_ids.append(fn_id)
            args_ids.append(args_id)
            timesteps = env.step([actions])
            rewards.append(torch.FloatTensor([timesteps[0].reward]).cuda())
            dones.append(torch.IntTensor([timesteps[0].last()]).cuda())

            if timesteps[0].last():
                i = timesteps[0]
                score = i.observation['score_cumulative'][0]
                sum_score += score
                episodes += 1
                if episodes % 1 == 0:
                    torch.save(net.state_dict(),
                               './save/game6_' + str(episodes) + str('.pkl'))
                print("episode %d: score = %f" % (episodes, score))
            # obs = processor.preprocess_obs(timesteps)
            # observations.append(obs)
        rewards = torch.cat(rewards)
        dones = torch.cat(dones)
        with torch.no_grad():
            obs = processor.preprocess_obs(timesteps)
            screen = torch.FloatTensor(obs['screen']).cuda()
            minimap = torch.FloatTensor(obs['minimap']).cuda()
            flat = torch.FloatTensor(obs['flat']).cuda()
            _, next_value = net(screen, minimap, flat)

        observations = flatten_first_dims_dict(
            stack_ndarray_dicts(observations))

        train_fn_ids = torch.cat(fn_ids)
        train_arg_ids = {}

        for k in args_ids[0].keys():
            temp = []
            temp = [d[k] for d in args_ids]

            train_arg_ids[k] = torch.cat(temp, dim=0)

        screen = torch.FloatTensor(observations['screen']).cuda()
        minimap = torch.FloatTensor(observations['minimap']).cuda()
        flat = torch.FloatTensor(observations['flat']).cuda()
        policy, value = net(screen, minimap, flat)

        returns = torch.zeros((rewards.shape[0] + 1, ), dtype=float)
        returns[-1] = next_value
        for i in reversed(range(rewards.shape[0])):
            next_rewards = 0.999 * returns[i + 1] * (1 - dones[i])
            returns[i] = rewards[i] + next_rewards
        returns = returns[:-1].cuda()

        fn_pi, args_pi = policy
        available_actions = torch.FloatTensor(
            observations['available_actions']).cuda()
        function_pi = available_actions * fn_pi
        function_pi /= torch.sum(function_pi, dim=1, keepdim=True)
        Loss = nn.CrossEntropyLoss(reduction='none')
        function_pi = torch.clamp(function_pi, 1e-4, 1 - (1e-4))
        policy_loss = Loss(function_pi, train_fn_ids)

        for type in train_arg_ids.keys():
            id = train_arg_ids[type]
            pi = args_pi[type]
            arg_loss_list = []
            for i, p in zip(id, pi):
                if i == -1:
                    temp = torch.zeros((1)).cuda()
                else:
                    a = torch.LongTensor([i]).cuda()
                    b = torch.unsqueeze(p, dim=0).cuda()
                    b = torch.clamp(b, 1e-4, 1 - (1e-4))
                    temp = Loss(b, a)
                arg_loss_list.append(temp)

            arg_loss = torch.cat(arg_loss_list)
            policy_loss += arg_loss
        policy_loss = policy_loss.mean()
        value_loss = (returns - value).pow(2).mean()
        print(policy_loss, value_loss)
        loss = policy_loss + value_coef * value_loss
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
        optimizer.step()

        if episodes >= total_episodes:
            break
    torch.save(net.state_dict(), './save/game6_final' + str('.pkl'))