Beispiel #1
0
def train():

    parser = argparse.ArgumentParser()
    parser.add_argument('--policy', type=str, default='td3')
    parser.add_argument('--policy_type', type=str, default='mlp')
    parser.add_argument('--env', type=str, default='HalfCheetah-v3')
    parser.add_argument('--hidden_sizes', nargs='+', type=int)
    # parser.add_argument('--layers_len', type=int, default=2)
    # parser.add_argument('--hidden_size', type=int, default=256)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--start_timesteps', type=int, default=int(25e3))
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--epochs', type=int, default=250)
    parser.add_argument('--steps_per_epoch', type=int, default=5000)
    parser.add_argument('--max_episode_len', type=int, default=1000)
    parser.add_argument('--buffer_size', type=int, default=int(1e6))
    parser.add_argument('--save_freq', type=int, default=10)
    parser.add_argument('--pi_lr', type=float, default=1e-3)
    parser.add_argument('--q_lr', type=float, default=1e-3)
    parser.add_argument('--expl_noise', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=100)
    parser.add_argument('--tau', type=float, default=0.005)
    parser.add_argument('--policy_noise', type=float, default=0.2)
    parser.add_argument('--policy_freq', type=int, default=2)
    parser.add_argument('--noise_clip', type=float, default=0.5)
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--cpu', type=int, default=1)
    parser.add_argument('--datestamp', action='store_true')
    parser.add_argument('--obs_norm', action='store_true')
    parser.add_argument('--obs_clip', type=float, default=5.0)
    args = parser.parse_args()

    if args.cpu > 1:
        args.policy = 'mp_' + args.policy

    file_name = f'{args.policy}_{args.env}_{args.seed}'
    print('-----' * 8)
    print(f'Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}')
    print('-----' * 8)

    if not os.path.exists(DEFAULT_MODEL_DIR):
        os.makedirs(DEFAULT_MODEL_DIR)

    # Setup logger and save config
    logger_kwargs = setup_logger_kwargs(f'{args.policy}_{args.env}',
                                        args.seed,
                                        datestamp=args.datestamp)
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(args)

    # Init Envirorment
    env = make_envs(args.env, args.cpu, args.seed)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Set Seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    kwargs = {
        'env': env,
        'actor_critic': MLPDetActorCritic,
        # 'ac_kwargs': dict(hidden_sizes=[args.hidden_size]*args.layers_len),
        # 'ac_kwargs': dict(hidden_sizes=[400, 300]),
        'ac_kwargs': dict(hidden_sizes=args.hidden_sizes,
                          pso_kwargs=pso_kwargs),
        'gamma': args.gamma,
        'batch_size': args.batch_size,
        'tau': args.tau,
        'expl_noise': args.expl_noise,
        'policy_noise': args.policy_noise,
        'policy_freq': args.policy_freq,
        'pi_lr': args.pi_lr,
        'q_lr': args.q_lr,
        'noise_clip': args.noise_clip,
        'device': args.device,
        'logger': logger
    }

    policy = TD3(**kwargs) if args.policy_type == 'mlp' else TD3(**kwargs)

    # Count variables
    if args.policy_type == 'mlp':
        var_counts = tuple(
            core.count_vars(module)
            for module in [policy.ac.pi, policy.ac.q1, policy.ac.q2])
        logger.log(
            '\nNumber of parameters: \t pi: %d, \t q1: %d,  \t q2: %d\n' %
            var_counts)
    else:
        var_counts = core.count_vars(policy.ac)
        logger.log('\nNumber of parameters: \t  pi_q1_q2: %d\n' % var_counts)

    # Set up model saving
    logger.setup_pytorch_saver(policy.ac.state_dict())

    buf = core.ReplayBuffer(obs_dim,
                            act_dim,
                            size=args.buffer_size,
                            device=args.device)

    # Prepare for interaction with environment
    total_steps = args.steps_per_epoch * args.epochs
    start_time = time.time()
    obs, done = env.reset(), [False for _ in range(args.cpu)]
    if args.obs_norm:
        ObsNormal = core.ObsNormalize(
            obs_dim, args.cpu, args.obs_clip)  # Normalize the observation
        obs = ObsNormal.normalize_all(obs)
    episode_rew = np.zeros(args.cpu, dtype=np.float32)
    episode_len = np.zeros(args.cpu)

    # Main loop: collect experience in env and update/log each epoch
    for t in range(0, total_steps, args.cpu):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Aferwards,
        # use the learned policy (with some noise, via act_noise)
        if t < args.start_timesteps:
            act = env.action_space.sample()[None]
        else:
            act = policy.select_action(obs)

        # Step the env
        if args.cpu == 1:
            next_obs, ret, done, info = env.step(act.squeeze(axis=0))
        else:
            next_obs, ret, done, info = env.step(act)

        if args.obs_norm:
            next_obs = ObsNormal.normalize_all(next_obs)
        episode_rew += rew
        episode_len += 1

        # Ignore the done "done" signal if it comes from hitting the time
        # horizon (that is , when it's an artificial terinal signal
        # that isn't based on the agent's state
        for idx, d in enumerate(done):
            done[
                idx] = False if episode_len[idx] == args.max_episode_len else d

        # Store experience to repaly buffer
        buf.add(obs, act, rew, next_obs, done)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        obs = next_obs

        for idx in range(args.cpu):
            if done[idx] or (episode_len[idx] == args.max_episode_len):
                logger.store(EpRet=episode_rew[idx], EpLen=episode_len[idx])
                obs, episode_rew[idx], episode_len[idx] = env.reset(idx), 0, 0

        # Update handling
        # Update every timesteps or Update 50 times each 50 timesteps
        if t > args.start_timesteps:
            batch = buf.sample(args.batch_size)
            policy.update(batch)

        # End of epoch handling
        if t >= args.start_timesteps and (t + 1) % args.steps_per_epoch == 0:
            epoch = (t + 1) // args.steps_per_epoch

            # Test the performance of the deterministic version of the agent
            policy.eval_policy(args.env, args.seed)

            # Log info about epoch
            logger.log_tabular('Exp', file_name)
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEplen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t + 1)
            logger.log_tabular('Q1Vals', with_min_and_max=True)
            logger.log_tabular('Q2Vals', with_min_and_max=True)
            logger.log_tabular('LossPi', average_only=True)
            logger.log_tabular('LossQ', average_only=True)
            logger.log_tabular('Time', (time.time() - start_time) / 60)
            if args.obs_norm:
                logger.log_tabular('obs_mean', ObsNormal.mean.mean())
                logger.log_tabular('obs_std', np.sqrt(ObsNormal.var).mean())
            logger.dump_tabular()

            # Save model
            if (epoch % args.save_freq == 0) or (epoch == args.epochs):
                logger.save_state(
                    dict(obs_normal=ObsNormal if args.obs_norm else None),
                    None)
Beispiel #2
0
def ppo():

    parser = argparse.ArgumentParser()
    parser.add_argument('--policy', type=str, default='ppo')
    parser.add_argument('--env', type=str, default='HalfCheetah-v3')
    parser.add_argument('--hidden', type=int, default=64)
    parser.add_argument('--layers_len', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--clip_ratio', type=float, default=0.2)
    parser.add_argument('--pi_lr', type=float, default=3e-4)
    parser.add_argument('--vf_lr', type=float, default=1e-3)
    parser.add_argument('--train_pi_iters', type=int, default=80)
    parser.add_argument('--train_v_iters', type=int, default=80)
    parser.add_argument('--lam', type=float, default=0.97)
    parser.add_argument('--target_kl', type=float, default=0.01)
    parser.add_argument('--device', type=str, default='cuda:3')
    parser.add_argument('--datestamp', action='store_true')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=1)
    parser.add_argument('--step_per_epoch', type=int, default=4000)
    parser.add_argument('--max_episode_len', type=int, default=1000)
    parser.add_argument('--epochs', type=int, default=250)
    parser.add_argument('--max_timesteps', type=int, default=1e6)
    parser.add_argument('--save_freq', type=int, default=10)
    parser.add_argument('--obs_norm', action='store_true')
    parser.add_argument('--obs_clip', type=float, default=5.0)
    parser.add_argument('--use_clipped_value_loss', action='store_true')
    parser.add_argument('--clip_val_param', type=float, default=80.0)
    args = parser.parse_args()

    file_name = f'{args.policy}_{args.env}_{args.seed}'
    print('-----' * 8)
    print(f'Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}')
    print('-----' * 8)

    if not os.path.exists(DEFAULT_MODEL_DIR):
        os.makedirs(DEFAULT_MODEL_DIR)

    # Set up logger and save configuration
    logger_kwargs = setup_logger_kwargs(f'{args.policy}_{args.env}',
                                        args.seed,
                                        datestamp=args.datestamp)
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(args)

    # Init Envirorment
    env = gym.make(args.env)
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Set seeds
    env.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    env.action_space.seed(args.seed)

    kwargs = {
        "env": env,
        "actor_critic": core.MLPActorCritic,
        "ac_kwargs": dict(hidden_sizes=[args.hidden] * args.layers_len),
        "gamma": args.gamma,
        "clip_ratio": args.clip_ratio,
        "pi_lr": args.pi_lr,
        "vf_lr": args.vf_lr,
        "train_pi_iters": args.train_pi_iters,
        "train_v_iters": args.train_v_iters,
        "lam": args.lam,
        "target_kl": args.target_kl,
        "use_clipped_value_loss": args.use_clipped_value_loss,
        "clip_val_param": args.clip_val_param,
        "device": args.device,
        "logger": logger
    }

    policy = PPO(**kwargs)

    # Count variables
    var_counts = tuple(
        core.count_vars(module) for module in [policy.ac.pi, policy.ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up model saving
    logger.setup_pytorch_saver(policy.ac.state_dict())

    local_steps_per_epoch = int(args.step_per_epoch)
    buf = core.PPOBuffer(  #Param
        obs_dim, act_dim, local_steps_per_epoch, args.gamma, args.lam)
    # Prepare for interaction with environment
    start_time = time.time()
    obs, done = env.reset(), False
    if args.obs_norm:
        ObsNormal = core.ObsNormalize(
            obs.shape, args.obs_clip)  # Normalize the observation
        obs = ObsNormal.normalize(obs)
    episode_ret = 0.
    episode_len = 0.

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(args.epochs):
        for t in range(local_steps_per_epoch):
            act, val, logp = policy.step(obs)

            next_obs, ret, done, _ = env.step(act)
            if args.obs_norm:
                next_obs = ObsNormal.normalize(next_obs)
            episode_ret += ret
            episode_len += 1

            # save and log
            buf.add(obs, act, ret, val, logp)
            logger.store(VVals=val)

            # Update obs (critical!)
            obs = next_obs

            timeout = episode_len == args.max_episode_len
            terminal = done or timeout
            epoch_ended = t == local_steps_per_epoch - 1
            if epoch_ended or terminal:
                if epoch_ended and not terminal:
                    print(
                        f'Warning: Trajectory cut off by epoch at {episode_len} steps',
                        flush=True)
                # if trajectory didn't reach terminal state, bootstrap value target
                if timeout or epoch_ended:
                    _, val, _ = policy.step(obs)
                else:
                    val = 0
                buf.finish_path(val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=episode_ret, EpLen=episode_len)
                obs, episode_ret, episode_len = env.reset(), 0, 0
                # if args.obs_norm: obs = ObsNormal.normalize(obs)

        policy.update(buf)

        # Log info about epoch
        logger.log_tabular('Exp', file_name)
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts',
                           (epoch + 1) * args.step_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', int((time.time() - start_time) / 60))
        if args.obs_norm:
            logger.log_tabular('obs_mean', ObsNormal.mean.mean())
            logger.log_tabular('obs_std', np.sqrt(ObsNormal.var).mean())
        logger.dump_tabular()

        # Save model
        if (epoch % args.save_freq == 0) or (epoch == args.epochs - 1):
            torch.save(policy.ac.state_dict(),
                       f'{DEFAULT_MODEL_DIR}/{file_name}.pth')
            logger.save_state({'env': env}, None)
Beispiel #3
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument('--policy', type=str, default='ppo')
    parser.add_argument('--policy_type', type=str, default='mlp')
    parser.add_argument('--env', type=str, default='HalfCheetah-v3')
    parser.add_argument('--hidden', type=int, default=64)
    parser.add_argument('--layers_len', type=int, default=2)
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--clip_ratio', type=float, default=0.2)
    parser.add_argument('--pi_lr', type=float, default=3e-4)
    parser.add_argument('--vf_lr', type=float, default=1e-3)
    parser.add_argument('--train_pi_iters', type=int, default=80)
    parser.add_argument('--train_v_iters', type=int, default=80)
    parser.add_argument('--lam', type=float, default=0.97)
    parser.add_argument('--target_kl', type=float, default=0.01)
    parser.add_argument('--device', type=str, default='cuda:3')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--cpu', type=int, default=4)
    parser.add_argument('--datestamp', action='store_true')
    parser.add_argument('--steps_per_epoch', type=int, default=4000)
    parser.add_argument('--max_episode_len', type=int, default=1000)
    parser.add_argument('--epochs', type=int, default=250)
    parser.add_argument('--max_timesteps', type=int, default=1e6)
    parser.add_argument('--save_freq', type=int, default=10)
    parser.add_argument('--obs_norm', action='store_true')
    parser.add_argument('--obs_clip', type=float, default=5.0)
    parser.add_argument('--use_clipped_value_loss', action='store_true')
    parser.add_argument('--clip_val_param', type=float, default=80.0)
    args = parser.parse_args()

    if args.cpu > 1: args.policy = 'mp_' + args.policy

    file_name = f'{args.policy}_{args.env}_{args.seed}'
    print('-----' * 8)
    print(f'Policy: {args.policy}, Env: {args.env}, Seed: {args.seed}')
    print('-----' * 8)

    if not os.path.exists(DEFAULT_MODEL_DIR):
        os.makedirs(DEFAULT_MODEL_DIR)

    # Set up logger and save configuration
    logger_kwargs = setup_logger_kwargs(f'{args.policy}_{args.env}',
                                        args.seed,
                                        datestamp=args.datestamp)
    logger = EpochLogger(**logger_kwargs)
    logger.save_config(args)

    # Init Envirorment
    env = make_envs(
        args.env, args.cpu,
        args.seed)  # SingleEnv Wrapper and env.seed env.action_space seed
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape if isinstance(env.action_space,
                                                   gym.spaces.Box) else (1, )

    # Set seeds
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    kwargs = {
        "env":
        env,
        "actor_critic":
        core.CNNActorCritic
        if args.policy_type == 'cnn' else core.MLPActorCritic,
        "ac_kwargs":
        dict(hidden_sizes=[args.hidden] * args.layers_len),
        "gamma":
        args.gamma,
        "clip_ratio":
        args.clip_ratio,
        "pi_lr":
        args.pi_lr,
        "vf_lr":
        args.vf_lr,
        "train_pi_iters":
        args.train_pi_iters,
        "train_v_iters":
        args.train_v_iters,
        "lam":
        args.lam,
        "target_kl":
        args.target_kl,
        "use_clipped_value_loss":
        args.use_clipped_value_loss,
        "clip_val_param":
        args.clip_val_param,
        "device":
        args.device,
        "logger":
        logger
    }

    policy = PPO(**kwargs)
    # policy.ac.share_memory()

    # Count variables
    var_counts = tuple(
        core.count_vars(module) for module in [policy.ac.pi, policy.ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up model saving
    logger.setup_pytorch_saver(policy.ac.state_dict())

    local_steps_per_epoch = int(args.steps_per_epoch / args.cpu)
    buf = core.PPO_mp_Buffer(obs_dim, act_dim, local_steps_per_epoch,
                             args.gamma, args.lam, args.cpu, args.device)
    # Prepare for interaction with environment
    start_time = time.time()
    obs, done = env.reset(), [False for _ in range(args.cpu)]
    if args.obs_norm:
        ObsNormal = core.ObsNormalize(
            obs_dim, args.cpu, args.obs_clip)  # Normalize the observation
        obs = ObsNormal.normalize_all(obs)
    episode_ret = np.zeros(args.cpu, dtype=np.float32)
    episode_len = np.zeros(args.cpu)

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(args.epochs):
        for t in range(local_steps_per_epoch):
            act, val, logp = policy.step(obs)

            next_obs, ret, done, info = env.step(act)
            if args.obs_norm:
                next_obs = ObsNormal.normalize_all(next_obs)
            episode_ret += ret
            episode_len += 1

            # save and log
            buf.add(obs, act, ret, val, logp)
            logger.store(VVals=val)

            # Update obs (critical!)
            obs = next_obs
            # In multiprocess env when a episode is terminal it will automatic reset(This has been removed because hard to reset())
            # the next_obs is the obs after reset,the real obs that cause terminal is stored in info['terminal_observation'] | updata: automatic reset has been removed

            timeout = episode_len == args.max_episode_len
            terminal = done + timeout
            epoch_ended = t == local_steps_per_epoch - 1

            # 感觉写的太臃肿了,暂时没想到好的写法

            for idx in range(args.cpu):
                if epoch_ended or terminal[idx]:
                    if epoch_ended and not terminal[idx]:
                        print(
                            f'Warning: Trajectory {idx} cut off by epoch at {episode_len[idx]} steps',
                            flush=True)
                    # if trajectory didn't reach terminal state, bootstrap value target
                    if timeout[idx] or epoch_ended:
                        _, val, _ = policy.step(obs[idx][None])
                    else:
                        val = 0
                    buf.finish_path(val, idx)
                    if terminal[idx]:
                        # only save EpRet / EpLen if trajectory finished
                        logger.store(EpRet=episode_ret[idx],
                                     EpLen=episode_len[idx])
                    obs[idx], episode_ret[idx], episode_len[idx] = env.reset(
                        idx), 0, 0
                    # if args.obs_norm: obs = ObsNormal.normalize(obs)

        # During Experiment, I find that reset state without Normalize will perform better

        policy.update(buf)

        # Log info about epoch
        logger.log_tabular('Exp', file_name)
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts',
                           (epoch + 1) * args.steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', int((time.time() - start_time) // 60))
        if args.obs_norm:
            logger.log_tabular('obs_mean', ObsNormal.mean.mean())
            logger.log_tabular('obs_std', np.sqrt(ObsNormal.var).mean())
        logger.dump_tabular()

        # Save model
        if (epoch % args.save_freq == 0) or (epoch == args.epochs - 1):
            torch.save(policy.ac.state_dict(),
                       f'{DEFAULT_MODEL_DIR}/{file_name}.pth')
            logger.save_state(
                dict(obs_normal=ObsNormal if args.obs_norm else None), None)
Beispiel #4
0
    # Count variables
    var_counts = tuple(
        core.count_vars(module) for module in [policy.ac.pi, policy.ac.v])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Set up model saving
    logger.setup_pytorch_saver(policy.ac.state_dict())

    local_steps_per_epoch = int(args.step_per_epoch / args.cpu)
    buf = core.PPO_mp_Buffer(obs_dim, act_dim, local_steps_per_epoch,
                             args.gamma, args.lam, args.cpu)
    # Prepare for interaction with environment
    start_time = time.time()
    obs, done = env.reset(), [False for _ in range(args.cpu)]
    if args.obs_norm:
        ObsNormal = core.ObsNormalize(
            obs_dim, args.cpu, args.obs_clip)  # Normalize the observation
        obs = ObsNormal.normalize_all(obs)
    episode_ret = np.zeros(args.cpu, dtype=np.float32)
    episode_len = np.zeros(args.cpu)

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(args.epochs):
        for t in range(local_steps_per_epoch):
            act, val, logp = policy.step(obs)

            next_obs, ret, done, info = env.step(act)
            if args.obs_norm:
                next_obs = ObsNormal.normalize_all(next_obs)
            episode_ret += ret
            episode_len += 1