示例#1
0
 def __init__(self,
              output_dir=None,
              output_fname='progress.txt',
              exp_name=None):
     if proc_id() == 0:
         self.output_dir = output_dir or "/tmp/experiments/%i" % self.output_dir
         if osp.exists(self.output_dir):
             print(
                 "Warning: Log dir %s already exists! Storing info there anyway."
                 % self.output_dir)
         else:
             os.makedirs(self.output_dir)
         self.output_file = open(osp.join(self.output_dir, output_fname),
                                 'w')
         atexit.register(self.output_file.close)
         print(
             colorize("Logging data to %s" % self.output_file.name,
                      'green',
                      bold=True))
     else:
         self.output_dir = None
         self.output_file = None
     self.first_row = True
     self.log_headers = []
     self.log_current_row = {}
     self.exp_name = exp_name
示例#2
0
 def save_state(self, state_dict, itr=None):
     if proc_id() == 0:
         fname = 'vars.pkl' if itr is None else 'vars%d.pkl' % itr
         try:
             joblib.dump(state_dict, osp.join(self.output_dir, fname))
         except:
             self.log("Warning: could not pickle state_dict", color='red')
         # place holder for tensorflow
         if hasattr(self, 'pytorch_saver_elements'):
             self._pytorch_simple_save(itr)
示例#3
0
 def _pytorch_simple_save(self, itr=None):
     if proc_id() == 0:
         assert hasattr(self, 'pytorch_saver_elements'), \
             "First have to setup saving with self.setup_pytorch_saver"
         fpath = 'pyt_save'
         fpath = osp.join(self.output_dir, fpath)
         fname = 'model' + ('%d' % itr if itr is not None else '') + '.pt'
         fname = osp.join(fpath, fname)
         os.makedirs(fpath, exist_ok=True)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             torch.save(self.pytorch_saver_elements, fname)
示例#4
0
 def save_config(self, config):
     config_json = convert_json(config)
     if self.exp_name is not None:
         config_json['exp_name'] = self.exp_name
     if proc_id() == 0:
         output = json.dumps(config_json,
                             separators=(',', ':\t'),
                             indent=4,
                             sort_keys=True)
         print(colorize('Saving config:\n', color='cyan', bold=True))
         print(output)
         with open(osp.join(self.output_dir, "config.json"), 'w') as out:
             out.write(output)
示例#5
0
 def dump_tabular(self):
     if proc_id() == 0:
         vals = []
         key_lens = [len(key) for key in self.log_headers]
         max_key_len = max(15, max(key_lens))
         keystr = '%' + '%d' % max_key_len
         fmt = "| " + keystr + "s | %15s |"
         n_slashes = 22 + max_key_len
         print("-" * n_slashes)
         for key in self.log_headers:
             val = self.log_current_row.get(key, "")
             valstr = "%8.3g" % val if hasattr(val, '__float__') else val
             print(fmt % (key, valstr))
             vals.append(val)
         print("-" * n_slashes, flush=True)
         if self.output_file is not None:
             if self.first_row:
                 self.output_file.write("\t".join(self.log_headers) + "\n")
             self.output_file.write("\t".join(map(str, vals)) + "\n")
             self.output_file.flush()
     self.log_current_row.clear()
     self.first_row = False
示例#6
0
def ppo(env_fn,
        actor_critic=sc2_nets.SC2AtariNetActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=10000,
        epochs=1000000,
        gamma=0.99,
        clip_ratio=0.2,
        lr=3e-4,
        vf_coeff=0.5,
        ent_coeff=0.01,
        train_iters=10,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.03,
        batch_size=64,
        logger_kwargs=dict(),
        save_freq=100,
        device=torch.device("cpu")):
    setup_pytorch_for_mpi()

    print("device - ", device)

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    # Later change this ---- This depends on the environment and the network structure
    obs_space = env.observation_gym_space
    obs_dim = obs_space['feature_screen'].shape
    act_dim = env.action_gym_space.nvec.shape
    # ----------------------
    print("obs_dim, act_dim = ", obs_dim, act_dim)

    action_spec, action_mask = env.action_set.get_action_spec_and_action_mask()
    ac = actor_critic(env.observation_gym_space,
                      action_spec=action_spec,
                      action_mask=action_mask,
                      device=device,
                      **ac_kwargs)

    sync_params(ac)

    var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam,
                    device)

    def compute_loss_pi(data, start, end):
        obs, act, adv, logp_old = data['obs'][start:end], data['act'][start:end], \
                                  data['adv'][start:end], data['logp'][start:end]

        pis, logp = ac.pi(obs, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        approx_kl = (logp_old - logp).mean().item()
        ent = 0
        for pi in pis:
            if isinstance(pi, tuple):
                ent += pi[0].entropy().mean() + pi[1].entropy().mean()
            else:
                ent += pi.entropy().mean()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clipfrac = torch.as_tensor(clipped, dtype=torch.float32,
                                   device=device).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent.item(), cf=clipfrac)

        return loss_pi, ent, pi_info

    def compute_loss_v(data, start, end):
        obs, ret = data['obs'][start:end], data['ret'][start:end]
        return ((ac.v(obs) - ret)**2).mean()

    optimizer = Adam(ac.parameters(), lr=lr)
    logger.setup_pytorch_saver(ac)

    def update():
        data = buf.get()

        pi_l_old, ent_old, pi_info_old = compute_loss_pi(
            data, 0, len(data['obs']))
        pi_l_old = pi_l_old.item()
        ent_old = ent_old.item()
        v_l_old = compute_loss_v(data, 0, len(data['obs'])).item()

        # Change this part to combined batch training

        for i in range(train_iters):
            # do mini batch training instead of full batch
            print("training pi and v ...")
            sys.stdout.flush()
            data_length = len(data['obs'])
            for j in range(data_length // batch_size):
                print("doing batch {}".format(j))
                sys.stdout.flush()
                start = batch_size * j
                end = batch_size * (j + 1)
                if end > data_length:
                    end = data_length
                optimizer.zero_grad()
                loss_pi, entropy, pi_info = compute_loss_pi(data, start, end)
                loss_v = compute_loss_v(data, start, end)
                #kl = mpi_avg(pi_info['kl'])
                #if kl > 1.5 * target_kl:
                #    logger.log('Early stopping at step %d due to reaching max kl.' % i)
                #    break
                (loss_pi + vf_coeff * loss_v - ent_coeff * entropy).backward()
                mpi_avg_grads(ac)
                optimizer.step()

        logger.store(StopIter=i)
        kl, ent, cf = pi_info['kl'], pi_info_old['ent'], pi_info['cf']
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(loss_pi.item() - pi_l_old),
                     DeltaLossV=(loss_v.item() - v_l_old))

    start_time = time.time()
    o, ep_ret, ep_len = env.reset(), 0, 0

    for epoch in range(epochs):
        for t in range(local_steps_per_epoch):
            o = o['feature_screen']
            a, v, logp = ac.step(
                torch.as_tensor(o,
                                dtype=torch.float32).to(device).unsqueeze(0))

            #print("a v logp -- ", a, v, logp)
            #print(o.shape)

            #for i in range(84):
            #    for j in range(84):
            #        print("{:3.1f} ".format((o[3, i, j])), end='')
            #    print()

            print(".", end='')
            sys.stdout.flush()

            next_o, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1

            #print(a, r, v, logp)

            buf.store(o, a, r, v, logp)
            logger.store(VVals=v, incre=True)

            o = next_o

            timeout = ep_len == max_ep_len
            terminal = d or timeout
            epoch_ended = t == local_steps_per_epoch - 1

            if terminal or epoch_ended:
                print("episode ended {}".format(t))
                if epoch_ended and not terminal:
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len,
                          flush=True)
                if timeout or epoch_ended:
                    o = o['feature_screen']
                    _, v, _ = ac.step(
                        torch.as_tensor(
                            o, dtype=torch.float32).to(device).unsqueeze(0))
                else:
                    v = 0
                buf.finish_path(v)
                if terminal:
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, ep_ret, ep_len = env.reset(), 0, 0

        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, epoch)

        print("update started....")
        sys.stdout.flush()
        update()
        print("update ended....")
        sys.stdout.flush()

        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
示例#7
0
 def log(self, msg, color='green'):
     if proc_id() == 0:
         print(colorize(msg, color, bold=True))