Esempio n. 1
0
    def create_step_record(self):
        opt = self.opt
        record = DotDic({})
        # s_t: environment state at time t
        record.s_t = None
        # ???
        record.r_t = torch.zeros(opt.bs, opt.game_nagents)
        #
        record.terminal = torch.zeros(opt.bs)

        record.agent_inputs = []

        # Track actions at time t per agent
        record.a_t = torch.zeros(opt.bs, opt.game_nagents, dtype=torch.long)
        # Track actions' arg at time t per agent
        record.a_a_t = torch.zeros(opt.bs, opt.game_nagents, dtype=torch.long)
        if not opt.model_dial:
            record.a_comm_t = torch.zeros(opt.bs,
                                          opt.game_nagents,
                                          dtype=torch.long)

        # Track messages sent at time t per agent
        if opt.comm_enabled:
            comm_dtype = opt.model_dial and torch.float or torch.long
            comm_dtype = torch.float
            record.comm = torch.zeros(opt.bs,
                                      opt.game_nagents,
                                      opt.game_comm_bits,
                                      dtype=comm_dtype)
            if opt.model_dial and opt.model_target:
                record.comm_target = record.comm.clone()

        # Track hidden state per time t per agent
        record.hidden = torch.zeros(opt.game_nagents, 34, 24, 24)
        record.hidden_target = torch.zeros(opt.game_nagents, 34, 24, 24)

        # Track Q(a_t) and Q(a_max_t) per agent
        record.q_a_t = torch.zeros(opt.bs, opt.game_nagents)
        record.q_a_max_t = torch.zeros(opt.bs, opt.game_nagents)
        # Track Q(a_a_t) and Q(a_a_max_t) per agent
        record.q_a_a_t = torch.zeros(opt.bs, opt.game_nagents)
        record.q_a_a_max_t = torch.zeros(opt.bs, opt.game_nagents)

        # Track Q(m_t) and Q(m_max_t) per agent
        if not opt.model_dial:
            record.q_comm_t = torch.zeros(opt.bs, opt.game_nagents)
            record.q_comm_max_t = torch.zeros(opt.bs, opt.game_nagents)

        return record
Esempio n. 2
0
    def __init__(self, opt):
        self.step_count = 0
        self.game_actions = DotDic({'NOTHING': 1, 'TELL': 2})

        self.game_states = DotDic({
            'OUTSIDE': 0,
            'INSIDE': 1,
        })

        self.opt = opt

        self.reward_all_live = 1
        self.reward_all_die = -1

        self.step_count = 0
        self.reward = torch.zeros(self.opt["bs"], self.opt["game_nagents"])
        self.has_been = torch.zeros(self.opt["bs"], self.opt["nsteps"],
                                    self.opt["game_nagents"])
        self.terminal = torch.zeros(self.opt["bs"], dtype=torch.long)
        self.active_agent = torch.zeros(self.opt["bs"],
                                        self.opt["nsteps"],
                                        dtype=torch.long)  # 1-indexed agents

        self.reset()
Esempio n. 3
0
def main(unused_arg):
    opt = DotDic(json.loads(open(FLAGS.config_path, 'r').read()))

    result_path = None
    if FLAGS.results_path:
        result_path = FLAGS.config_path and os.path.join(FLAGS.results_path, Path(FLAGS.config_path).stem) or \
         os.path.join(FLAGS.results_path, 'result-', datetime.datetime.now().isoformat())

    for i in range(FLAGS.ntrials):
        trial_result_path = None
        if result_path:
            trial_result_path = result_path + '_' + str(
                i + FLAGS.start_index) + '.csv'
        trial_opt = copy.deepcopy(opt)
        run_trial(trial_opt,
                  result_path=trial_result_path,
                  verbose=FLAGS.verbose)
Esempio n. 4
0
    def create_step_record(self):
        opt = self.opt
        record = DotDic({})
        record.s_t = None
        record.r_t = torch.zeros(opt["bs"], opt["game_nagents"])
        record.terminal = torch.zeros(opt["bs"])

        record.agent_inputs = []

        # Track actions at time t per agent
        record.a_t = torch.zeros(opt["bs"],
                                 opt["game_nagents"],
                                 dtype=torch.long)
        if not opt["model_dial"]:
            record.a_comm_t = torch.zeros(opt["bs"],
                                          opt["game_nagents"],
                                          dtype=torch.long)

        # Track messages sent at time t per agent
        if opt["comm_enabled"]:
            comm_dtype = torch.float
            record.comm = torch.zeros(opt["bs"],
                                      opt["game_nagents"],
                                      opt["game_comm_bits"],
                                      dtype=comm_dtype)
            if opt["model_dial"] and opt["model_target"]:
                record.comm_target = record.comm.clone()

        # Track hidden state per time t per agent
        record.hidden = torch.zeros(opt["game_nagents"],
                                    opt["model_rnn_layers"], opt["bs"],
                                    opt["model_rnn_size"])
        record.hidden_target = torch.zeros(opt["game_nagents"],
                                           opt["model_rnn_layers"], opt["bs"],
                                           opt["model_rnn_size"])

        # Track Q(a_t) and Q(a_max_t) per agent
        record.q_a_t = torch.zeros(opt["bs"], opt["game_nagents"])
        record.q_a_max_t = torch.zeros(opt["bs"], opt["game_nagents"])

        # Track Q(m_t) and Q(m_max_t) per agent
        if not opt["model_dial"]:
            record.q_comm_t = torch.zeros(opt["bs"], opt["game_nagents"])
            record.q_comm_max_t = torch.zeros(opt["bs"], opt["game_nagents"])

        return record
    def __init__(self, opt, size):
        self.opt = opt
        self.game_actions = DotDic({
            'NOTHING': 0,
            'UP': 1,
            'DOWN': 2,
            'LEFT': 3,
            'RIGHT': 4
        })
        if self.opt.game_action_space != len(self.game_actions):
            raise ValueError(
                "Config action space doesn't match game's ({} != {}).".format(
                    self.opt.game_action_space, len(self.game_actions)))

        self.H = size[0]
        self.W = size[1]
        self.goal_reward = 10
        self.reset()
Esempio n. 6
0
    def __init__(self, opt):
        self.opt = opt
        if opt.bs is not 1:
            raise NotImplementedError()

        # Set game defaults
        opt_game_default = DotDic({
            'render': True,
            'feature_screen_size': 48,
            'feature_minimap_size': 48,
            'rgb_screen_size': None,
            'rgb_minimap_size': None,
            'action_space':"RAW",
            'use_feature_units':True,
            'use_raw_units':True,
            'disable_fog':True,
            'max_agent_step':0,
            'game_steps_per_episode':None,
            'max_episodes':0,
            'step_mul':4,
            'agent':'pysc2.agents.random_agent.RandomAgent',
            'agent_name':None,
            'agent_race':'random',
            'agent2':'Bot',
            'agent2_name':None,
            'agent2_race':'random',
            'difficulty':'very_easy',
            'bot_build':'random',
            'save_replay':False,
            'map':'1',
            'battle_net_map':False
        })
        for k in opt_game_default:
            if k not in self.opt:
                self.opt[k] = opt_game_default[k]

        self.env = make_env(self.opt)

        self.reset()
Esempio n. 7
0
from envs.grid_game_flat import GridGame
from utils.dotdic import DotDic
import json
import torch

opt = DotDic(json.loads(open('config/grid_3_dial.json', 'r').read()))
opt.bs = 3
opt.game_nagents = 4
opt.game_action_space_total = 6


g = GridGame(opt, (4, 4))
g.show(vid=False)
u = torch.zeros((opt.bs, opt.game_nagents)) + 4
g.get_reward(u)
g.show(vid=False)
# print(g.get_action_range(None, None))