Ejemplo n.º 1
0
def build(game, representation, model_path, n_agents, make_gif, gif_name, **kwargs):

    env_name = '{}-{}-v0'.format(game, representation)  

    if game == "binary":
        kwargs['cropped_size'] = 28
    elif game == "zelda":
        kwargs['cropped_size'] = 22
    elif game == "sokoban":
        kwargs['cropped_size'] = 10
    kwargs['render'] = True

    crop_size = kwargs.get('cropped_size',28)

    temp_env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size, n_agents,**kwargs)

    if kwargs['restrict_map']:  
        map_restrictions = [{'x': (0,(temp_env.pcgrl_env._prob._width - 1)//2 - 1),
                            'y': (0,temp_env.pcgrl_env._prob._height - 1)},
                            {'x':((temp_env.pcgrl_env._prob._width - 1)//2,(temp_env.pcgrl_env._prob._width - 1)),
                            'y':(0,temp_env.pcgrl_env._prob._height - 1)}]
        kwargs['map_restrictions'] = map_restrictions

    env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size, n_agents,**kwargs)

    n_actions = env.action_space.n
    obs = env.reset()
    # print(obs.shape)


    models, optimizers, _,_ = load_models(device, model_path, n_agents,obs.shape[-1],crop_size,n_actions)

    obs = reshape_obs(obs)
    # print(obs.shape)

    frames = []

    done = False
    while not done:
        if make_gif:
            frames.append(env.render(mode='rgb_array'))
        env.render()
        actions = []
        for i in range(n_agents):
            pi, _ = models[i](obs_to_torch(obs[i]))
            a = pi.sample()
            actions.append(a)
            # print(actions) 
        obs, rewards, dones, info, active_agent = env.step(actions)
        obs = reshape_obs(obs)
        if True in dones:
            done = True
    print(info)

    if make_gif:
        frames[0].save(gif_name,save_all=True,append_images = frames[1:])

    time.sleep(10)
Ejemplo n.º 2
0
 def _thunk():
     if representation == 'wide':
         env = wrappers.ActionMapImagePCGRLWrapper(env_name, **kwargs)
     else:
         crop_size = kwargs.get('cropped_size', 28)
         env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size,
                                                 **kwargs)
     # RenderMonitor must come last
     if render or log_dir is not None and len(log_dir) > 0:
         env = RenderMonitor(env, rank, log_dir, **kwargs)
     return env
Ejemplo n.º 3
0
    def _thunk():
        if representation == 'wide':
            ca_action = kwargs.get('ca_action', False)
            if ca_action:
                raise Exception
#               env = wrappers.CAactionWrapper(env_name, **kwargs)
            else:
                env = wrappers.ActionMapImagePCGRLWrapper(env_name, **kwargs)

        else:
            crop_size = kwargs.get('cropped_size', 28)
            env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size,
                                                    **kwargs)


#       if evo_compare:
#           # FIXME: THIS DOES NOT WORK

#           # Give a little wiggle room from targets, to allow for some diversity
#           if "binary" in env_name:
#               path_trg = env.unwrapped._prob.static_trgs['path-length']
#               env.unwrapped._prob.static_trgs.update({'path-length': (path_trg - 20, path_trg)})
#           elif "zelda" in env_name:
#               path_trg = env.unwrapped._prob.static_trgs['path-length']
#               env.unwrapped._prob.static_trgs.update({'path-length': (path_trg - 40, path_trg)})
#           elif "sokoban" in env_name:
#               sol_trg = env.unwrapped._prob.static_trgs['sol-length']
#               env.unwrapped._prob.static_trgs.update({'sol-length': (sol_trg - 10, sol_trg)})
#           elif "smb" in env_name:
#               pass
#           else:
#               raise NotImplementedError
        env.configure(**kwargs)
        if max_step is not None:
            env = wrappers.MaxStep(env, max_step)
        if log_dir is not None and kwargs.get('add_bootstrap', False):
            env = wrappers.EliteBootStrapping(
                env, os.path.join(log_dir, "bootstrap{}/".format(rank)))
        env = conditional_wrappers.ParamRew(
            env, cond_metrics=kwargs.pop('cond_metrics'), **kwargs)
        if not evaluate:
            if not ALP_GMM:
                env = conditional_wrappers.UniformNoiseyTargets(env, **kwargs)
            elif conditional:
                env = conditional_wrappers.ALPGMMTeacher(env, **kwargs)
            # it not conditional, the ParamRew wrapper should just be fixed at default static targets
        if render or log_dir is not None and len(log_dir) > 0:
            # RenderMonitor must come last
            env = RenderMonitor(env, rank, log_dir, **kwargs)

        return env
Ejemplo n.º 4
0
    def _thunk():
        if representation == "wide":
            env = wrappers.ActionMapImagePCGRLWrapper(env_name, **kwargs)
        else:
            crop_size = kwargs.get("cropped_size", 28)
            env = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size,
                                                    **kwargs)

        if max_step is not None:
            env = wrappers.MaxStep(env, max_step)

        if log_dir is not None and kwargs.get("add_bootstrap", False):
            env = wrappers.EliteBootStrapping(
                env, os.path.join(log_dir, "bootstrap{}/".format(rank)))
        # RenderMonitor must come last

        if render or log_dir is not None and len(log_dir) > 0:
            env = RenderMonitor(env, rank, log_dir, **kwargs)

        return env
Ejemplo n.º 5
0
def worker_process(remote: multiprocessing.connection.Connection,
                   env_name: str, crop_size: int, kwargs: Dict):

    game = wrappers.CroppedImagePCGRLWrapper(env_name, crop_size, **kwargs)

    while True:
        cmd, data = remote.recv()
        if cmd == "step":
            # print('stepping')
            temp = game.step(data)
            # print(temp)
            remote.send(temp)
        elif cmd == "reset":
            # print('resetting')
            temp = game.reset()
            # print(temp)
            remote.send(temp)
        elif cmd == "close":
            remote.close()
            break
        else:
            raise NotImplementedError
Ejemplo n.º 6
0
    def __init__(self):

        self.load_model = False

        self.gamma = 0.99
        self.lamda = 0.95

        self.updates = 10000
        self.update_start = 0
        self.save_period = 50

        self.epochs = 4

        self.n_workers = 8

        self.n_agents = 2
        self.active_agent = 0

        self.worker_steps = 128

        self.n_mini_batch = 4

        self.batch_size = self.n_workers * self.worker_steps

        self.mini_batch_size = self.batch_size // self.n_mini_batch

        self.logging = True

        assert (self.batch_size % self.n_mini_batch == 0)

        game = 'binary'
        representation = 'turtle'

        kwargs = {
            'change_percentage': 0.4,
            'verbose': True,
            'negative_switch': False,
            'render': False,
            'restrict_map': False
        }

        self.negative_switch = kwargs['negative_switch']
        if self.negative_switch:
            self.updates = int(self.updates * 1.5)

        self.env_name = '{}-{}-v0'.format(game, representation)

        if game == "binary":
            kwargs['cropped_size'] = 28
        elif game == "zelda":
            kwargs['cropped_size'] = 22
        elif game == "sokoban":
            kwargs['cropped_size'] = 10

        self.save_path = 'models/{}/{}/{}{}'.format(
            game, representation,
            'negative_switch_' if kwargs['negative_switch'] else '',
            'map_restricted_' if kwargs['restrict_map'] else '')

        if self.logging:

            if self.load_model:
                self.logfile = open(
                    'logs/self_play_{}_{}_{}{}log.txt'.format(
                        game, representation, 'negative_switch_'
                        if kwargs['negative_switch'] else '',
                        'map_restricted_' if kwargs['restrict_map'] else ''),
                    'a+')
                self.logfile.read()

            else:
                self.logfile = open(
                    'logs/self_play_{}_{}_{}{}log.txt'.format(
                        game, representation, 'negative_switch_'
                        if kwargs['negative_switch'] else '',
                        'map_restricted_' if kwargs['restrict_map'] else ''),
                    'w')

        self.crop_size = kwargs.get('cropped_size', 28)

        temp_env = wrappers.CroppedImagePCGRLWrapper(self.env_name,
                                                     self.crop_size,
                                                     self.n_agents, **kwargs)

        if kwargs['restrict_map']:
            map_restrictions = [{
                'x': (0, (temp_env.pcgrl_env._prob._width - 1) // 2 - 1),
                'y': (0, temp_env.pcgrl_env._prob._height - 1)
            }, {
                'x': ((temp_env.pcgrl_env._prob._width - 1) // 2,
                      (temp_env.pcgrl_env._prob._width - 1)),
                'y': (0, temp_env.pcgrl_env._prob._height - 1)
            }]
            kwargs['map_restrictions'] = map_restrictions

        kwargs['step_length'] = [10, 1]

        n_actions = temp_env.action_space.n

        temp = []

        self.workers = [
            Worker(self.env_name, self.crop_size, self.n_agents, **kwargs)
            for i in range(self.n_workers)
        ]

        for worker in self.workers:
            worker.child.send(("reset", None))
        for i, worker in enumerate(self.workers):
            temp.append(worker.child.recv())
            #self.obs[:,i] = temp

        self.obs = np.zeros((self.n_agents, self.n_workers, self.crop_size,
                             self.crop_size, temp[0].shape[3]),
                            dtype=np.uint8)

        for i in range(len(temp)):
            self.obs[:, i] = temp[i]

        if self.load_model:  #device, path, n_models, in_channels, map_size, out_length
            self.model, optimizer, epoch, self.update_start = load_model(
                device, self.save_path, self.obs.shape[-1], self.crop_size,
                n_actions)
            self.updates = self.updates - self.update_start

            self.trainer = SelfTrainer(self.model, optimizer=optimizer)

        else:
            self.model = Model(self.obs.shape[-1], self.crop_size, n_actions)
            self.model.to(device)

            self.trainer = SelfTrainer(self.model)