def _collect_grads(flat_grad, np_stat):
     if self.grad_clip is not None:
         gradnorm = np.linalg.norm(flat_grad)
         if gradnorm > 1:
             flat_grad /= gradnorm
         logger.logkv_mean('gradnorm', gradnorm)
         logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
     self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
     np.divide(buf, float(total_weight), out=buf)
     if countholder[0] % 100 == 0:
         check_synced(np_stat, self.comm)
     countholder[0] += 1
     return buf
Exemple #2
0
    def step(self, action):
        obs, reward, done, info = self.env.step(action)

        if done:
            logger.logkv_mean('action_mean', action)
            logger.logkv('action', action)

            if 'score' in info:
                score = info['score']
                logger.logkv_mean('score_mean', score)
                logger.logkv('score', score)

            logger.dumpkvs()

        return obs, reward, done, info
Exemple #3
0
def test_mpi_weighted_mean():
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a': (10, 2), 'b': (20, 3)}
        elif comm.rank == 1:
            name2valcount = {'a': (19, 1), 'c': (42, 3)}
        else:
            raise NotImplementedError
        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a': (10 * 2 + 19) / 3.0, 'b': 20, 'c': 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval
Exemple #4
0
def test_mpi_weighted_mean():
    from mpi4py import MPI
    comm = MPI.COMM_WORLD
    with logger.scoped_configure(comm=comm):
        if comm.rank == 0:
            name2valcount = {'a' : (10, 2), 'b' : (20,3)}
        elif comm.rank == 1:
            name2valcount = {'a' : (19, 1), 'c' : (42,3)}
        else:
            raise NotImplementedError

        d = mpi_util.mpi_weighted_mean(comm, name2valcount)
        correctval = {'a' : (10 * 2 + 19) / 3.0, 'b' : 20, 'c' : 42}
        if comm.rank == 0:
            assert d == correctval, '{} != {}'.format(d, correctval)

        for name, (val, count) in name2valcount.items():
            for _ in range(count):
                logger.logkv_mean(name, val)
        d2 = logger.dumpkvs()
        if comm.rank == 0:
            assert d2 == correctval
Exemple #5
0
def sac(env,
        policy,
        q_func=None,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        gamma=0.99,
        replay_size=REPLAY_SIZE_DEFAULT,
        polyak=0.995,
        start_steps=10000,
        epoch=5000,
        mb_size=100,
        lr=1e-3,
        alpha=0.2,
        target_entropy=None,
        reward_scale=1.0,
        updates_per_step=1.0,
        max_ep_length=1000,
        **saver_kwargs):

    # Set and save experiment hyperparameters
    q_func = q_func or ContinuousQFunction.from_policy(policy)
    val_fn = val_fn or ValueFunction.from_policy(policy)
    save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # Initialize environments, models and replay buffer
    vec_env = VecEnvMaker(env)()
    test_env = VecEnvMaker(env)(train=False)
    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    pi_class, pi_args = policy.pop("class"), policy
    qf_class, qf_args = q_func.pop("class"), q_func
    vf_class, vf_args = val_fn.pop("class"), val_fn
    policy = pi_class(vec_env, **pi_args)
    q1func = qf_class(vec_env, **qf_args)
    q2func = qf_class(vec_env, **qf_args)
    val_fn = vf_class(vec_env, **vf_args)
    replay = ReplayBuffer(replay_size, ob_space, ac_space)
    if target_entropy is not None:
        log_alpha = torch.nn.Parameter(torch.zeros([]))
        if target_entropy == "auto":
            target_entropy = -np.prod(ac_space.shape)

    # Initialize optimizers and target networks
    loss_fn = torch.nn.MSELoss()
    pi_optim = torch.optim.Adam(policy.parameters(), lr=lr)
    qf_optim = torch.optim.Adam(list(q1func.parameters()) +
                                list(q2func.parameters()),
                                lr=lr)
    vf_optim = torch.optim.Adam(val_fn.parameters(), lr=lr)
    vf_targ = vf_class(vec_env, **vf_args)
    vf_targ.load_state_dict(val_fn.state_dict())
    if target_entropy is not None:
        al_optim = torch.optim.Adam([log_alpha], lr=lr)

    # Save initial state
    state = dict(
        alg=dict(samples=0),
        policy=policy.state_dict(),
        q1func=q1func.state_dict(),
        q2func=q2func.state_dict(),
        val_fn=val_fn.state_dict(),
        pi_optim=pi_optim.state_dict(),
        qf_optim=qf_optim.state_dict(),
        vf_optim=vf_optim.state_dict(),
        vf_targ=vf_targ.state_dict(),
    )
    if target_entropy is not None:
        state["log_alpha"] = log_alpha
        state["al_optim"] = al_optim.state_dict()
    saver.save_state(index=0, state=state)

    # Setup and run policy tests
    ob, don = test_env.reset(), False

    @torch.no_grad()
    def test_policy():
        nonlocal ob, don
        policy.eval()
        for _ in range(10):
            while not don:
                act = policy.actions(torch.from_numpy(ob))
                ob, _, don, _ = test_env.step(act.numpy())
            don = False
        policy.train()
        log_reward_statistics(test_env, num_last_eps=10, prefix="Test")

    test_policy()
    logger.logkv("Epoch", 0)
    logger.logkv("TotalNSamples", 0)
    logger.dumpkvs()

    # Set action sampling strategies
    def rand_uniform_actions(_):
        return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)])

    @torch.no_grad()
    def stoch_policy_actions(obs):
        return policy.actions(torch.from_numpy(obs)).numpy()

    # Algorithm main loop
    obs1, ep_length = vec_env.reset(), 0
    for samples in trange(1, total_steps + 1, desc="Training", unit="step"):
        if samples <= start_steps:
            actions = rand_uniform_actions
        else:
            actions = stoch_policy_actions

        acts = actions(obs1)
        obs2, rews, dones, _ = vec_env.step(acts)
        ep_length += 1
        dones[0] = False if ep_length == max_ep_length else dones[0]

        as_tensors = map(torch.from_numpy,
                         (obs1, acts, rews, obs2, dones.astype("f")))
        for ob1, act, rew, ob2, done in zip(*as_tensors):
            replay.store(ob1, act, rew, ob2, done)
        obs1 = obs2

        if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size:
            for _ in range(int(ep_length * updates_per_step)):
                ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size)
                dist = policy(ob_1)
                pi_a = dist.rsample()
                logp = dist.log_prob(pi_a)
                if target_entropy is not None:
                    al_optim.zero_grad()
                    alpha_loss = torch.mean(
                        log_alpha * (logp.detach() + target_entropy)).neg()
                    alpha_loss.backward()
                    al_optim.step()
                    logger.logkv_mean("AlphaLoss", alpha_loss.item())
                    alpha = log_alpha.exp().item()

                with torch.no_grad():
                    y_qf = reward_scale * rew_ + gamma * (
                        1 - done_) * vf_targ(ob_2)
                    y_vf = (torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a)) -
                            alpha * logp)

                qf_optim.zero_grad()
                q1_val = q1func(ob_1, act_)
                q2_val = q2func(ob_1, act_)
                q1_loss = loss_fn(q1_val, y_qf).div(2)
                q2_loss = loss_fn(q2_val, y_qf).div(2)
                q1_loss.add(q2_loss).backward()
                qf_optim.step()

                vf_optim.zero_grad()
                vf_val = val_fn(ob_1)
                vf_loss = loss_fn(vf_val, y_vf).div(2)
                vf_loss.backward()
                vf_optim.step()

                pi_optim.zero_grad()
                qpi_val = q1func(ob_1, pi_a)
                # qpi_val = torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a))
                pi_loss = qpi_val.sub(logp, alpha=alpha).mean().neg()
                pi_loss.backward()
                pi_optim.step()

                update_polyak(val_fn, vf_targ, polyak)

                logger.logkv_mean("Entropy", logp.mean().neg().item())
                logger.logkv_mean("Q1Val", q1_val.mean().item())
                logger.logkv_mean("Q2Val", q2_val.mean().item())
                logger.logkv_mean("VFVal", vf_val.mean().item())
                logger.logkv_mean("QPiVal", qpi_val.mean().item())
                logger.logkv_mean("Q1Loss", q1_loss.item())
                logger.logkv_mean("Q2Loss", q2_loss.item())
                logger.logkv_mean("VFLoss", vf_loss.item())
                logger.logkv_mean("PiLoss", pi_loss.item())
                logger.logkv_mean("Alpha", alpha)

            ep_length = 0

        if samples % epoch == 0:
            test_policy()
            logger.logkv("Epoch", samples // epoch)
            logger.logkv("TotalNSamples", samples)
            log_reward_statistics(vec_env)
            logger.dumpkvs()

            state = dict(
                alg=dict(samples=samples),
                policy=policy.state_dict(),
                q1func=q1func.state_dict(),
                q2func=q2func.state_dict(),
                val_fn=val_fn.state_dict(),
                pi_optim=pi_optim.state_dict(),
                qf_optim=qf_optim.state_dict(),
                vf_optim=vf_optim.state_dict(),
                vf_targ=vf_targ.state_dict(),
            )
            if target_entropy is not None:
                state["log_alpha"] = log_alpha
                state["al_optim"] = al_optim.state_dict()
            saver.save_state(index=samples // epoch, state=state)
Exemple #6
0
def td3(env,
        policy,
        q_func=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        gamma=0.99,
        replay_size=REPLAY_SIZE_DEFAULT,
        polyak=0.995,
        start_steps=10000,
        epoch=5000,
        pi_lr=1e-3,
        qf_lr=1e-3,
        mb_size=100,
        act_noise=0.1,
        max_ep_length=1000,
        target_noise=0.2,
        noise_clip=0.5,
        policy_delay=2,
        updates_per_step=1.0,
        **saver_kwargs):
    # Set and save experiment hyperparameters
    q_func = q_func or ContinuousQFunction.from_policy(policy)
    save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # Initialize environments, models and replay buffer
    vec_env = VecEnvMaker(env)()
    test_env = VecEnvMaker(env)(train=False)
    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    pi_class, pi_args = policy.pop("class"), policy
    qf_class, qf_args = q_func.pop("class"), q_func
    policy = pi_class(vec_env, **pi_args)
    q1func = qf_class(vec_env, **qf_args)
    q2func = qf_class(vec_env, **qf_args)
    replay = ReplayBuffer(replay_size, ob_space, ac_space)

    # Initialize optimizers and target networks
    loss_fn = torch.nn.MSELoss()
    pi_optim = torch.optim.Adam(policy.parameters(), lr=pi_lr)
    qf_optim = torch.optim.Adam(chain(q1func.parameters(),
                                      q2func.parameters()),
                                lr=qf_lr)
    pi_targ = pi_class(vec_env, **pi_args)
    q1_targ = qf_class(vec_env, **qf_args)
    q2_targ = qf_class(vec_env, **qf_args)
    pi_targ.load_state_dict(policy.state_dict())
    q1_targ.load_state_dict(q1func.state_dict())
    q2_targ.load_state_dict(q2func.state_dict())

    # Save initial state
    saver.save_state(
        index=0,
        state=dict(
            alg=dict(samples=0),
            policy=policy.state_dict(),
            q1func=q1func.state_dict(),
            q2func=q2func.state_dict(),
            pi_optim=pi_optim.state_dict(),
            qf_optim=qf_optim.state_dict(),
            pi_targ=pi_targ.state_dict(),
            q1_targ=q1_targ.state_dict(),
            q2_targ=q2_targ.state_dict(),
        ),
    )

    # Setup and run policy tests
    ob, don = test_env.reset(), False

    @torch.no_grad()
    def test_policy():
        nonlocal ob, don
        for _ in range(10):
            while not don:
                act = policy.actions(torch.from_numpy(ob))
                ob, _, don, _ = test_env.step(act.numpy())
            don = False
        log_reward_statistics(test_env, num_last_eps=10, prefix="Test")

    test_policy()
    logger.logkv("Epoch", 0)
    logger.logkv("TotalNSamples", 0)
    logger.dumpkvs()

    # Set action sampling strategies
    def rand_uniform_actions(_):
        return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)])

    act_low, act_high = map(torch.Tensor, (ac_space.low, ac_space.high))

    @torch.no_grad()
    def noisy_policy_actions(obs):
        acts = policy.actions(torch.from_numpy(obs))
        acts += act_noise * torch.randn_like(acts)
        return np.clip(acts.numpy(), ac_space.low, ac_space.high)

    # Algorithm main loop
    obs1, ep_length, critic_updates = vec_env.reset(), 0, 0
    for samples in trange(1, total_steps + 1, desc="Training", unit="step"):
        if samples <= start_steps:
            actions = rand_uniform_actions
        else:
            actions = noisy_policy_actions

        acts = actions(obs1)
        obs2, rews, dones, _ = vec_env.step(acts)
        ep_length += 1
        dones[0] = False if ep_length == max_ep_length else dones[0]

        as_tensors = map(torch.from_numpy,
                         (obs1, acts, rews, obs2, dones.astype("f")))
        for ob1, act, rew, ob2, done in zip(*as_tensors):
            replay.store(ob1, act, rew, ob2, done)
        obs1 = obs2

        if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size:
            for _ in range(int(ep_length * updates_per_step)):
                ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size)
                with torch.no_grad():
                    atarg = pi_targ(ob_2)
                    atarg += torch.clamp(
                        target_noise * torch.randn_like(atarg), -noise_clip,
                        noise_clip)
                    atarg = torch.max(torch.min(atarg, act_high), act_low)
                    targs = rew_ + gamma * (1 - done_) * torch.min(
                        q1_targ(ob_2, atarg), q2_targ(ob_2, atarg))

                qf_optim.zero_grad()
                q1_val = q1func(ob_1, act_)
                q2_val = q2func(ob_1, act_)
                q1_loss = loss_fn(q1_val, targs).div(2)
                q2_loss = loss_fn(q2_val, targs).div(2)
                q1_loss.add(q2_loss).backward()
                qf_optim.step()

                critic_updates += 1
                if critic_updates % policy_delay == 0:
                    pi_optim.zero_grad()
                    qpi_val = q1func(ob_1, policy(ob_1))
                    pi_loss = qpi_val.mean().neg()
                    pi_loss.backward()
                    pi_optim.step()

                    update_polyak(policy, pi_targ, polyak)
                    update_polyak(q1func, q1_targ, polyak)
                    update_polyak(q2func, q2_targ, polyak)

                    logger.logkv_mean("QPiVal", qpi_val.mean().item())
                    logger.logkv_mean("PiLoss", pi_loss.item())

                logger.logkv_mean("Q1Val", q1_val.mean().item())
                logger.logkv_mean("Q2Val", q2_val.mean().item())
                logger.logkv_mean("Q1Loss", q1_loss.item())
                logger.logkv_mean("Q2Loss", q2_loss.item())

            ep_length = 0

        if samples % epoch == 0:
            test_policy()
            logger.logkv("Epoch", samples // epoch)
            logger.logkv("TotalNSamples", samples)
            log_reward_statistics(vec_env)
            logger.dumpkvs()

            saver.save_state(
                index=samples // epoch,
                state=dict(
                    alg=dict(samples=samples),
                    policy=policy.state_dict(),
                    q1func=q1func.state_dict(),
                    q2func=q2func.state_dict(),
                    pi_optim=pi_optim.state_dict(),
                    qf_optim=qf_optim.state_dict(),
                    pi_targ=pi_targ.state_dict(),
                    q1_targ=q1_targ.state_dict(),
                    q2_targ=q2_targ.state_dict(),
                ),
            )