예제 #1
0
파일: gail.py 프로젝트: shibei00/rltoolkit
    def run_gail(seed: int):
        env = utils.create_environment(environment)
        set_global_seeds(seed)

        policy_fn = make_policy(hidden_layers=(64, 64),
                                hidden_activation=tf.tanh)
        gail = GAIL(env,
                    policy_fn,
                    expert_rollouts,
                    episode_length=2048,
                    lam=0.95,
                    gamma=0.99,
                    ent_coef=0.01,
                    lr=3e-4,
                    cliprange=0.2)
        video_saver = make_video_saver(folder_name=osp.join(
            logger.get_dir(), 'videos'),
                                       prefix='gail',
                                       interval=20)
        gail.train(nminibatches=64,
                   noptepochs=10,
                   log_interval=1,
                   total_timesteps=num_timesteps,
                   callback=video_saver)
def train_joint(args):
    """
    Simultaneously train Initialisation Policy and Baseline

    Arguments
    ---------
    args : dict
        Dictionary containing command line arguments


    """
    # Buffers to store statistics
    reward_init_buffer = list()
    loss_init_buffer = list()

    rewards_local_buffer = list()
    loss_local_buffer = list()

    mean_square_error_per_epoch = list()
    mean_relative_distance_per_epoch = list()

    generator = instance_generator(args.problem)

    # Initialise Initialisation policy and set its optimizer
    init_policy = select_initialization_policy(args)
    init_opt = T.optim.Adam(init_policy.parameters(), lr=args.init_lr_rate)

    # Initialize Baseline if required
    if args.use_baseline:
        baseline_net = Baseline(args.dim_context, args.dim_hidden)
        opt_base = T.optim.Adam(baseline_net.parameters(), lr=1e-4)
        loss_base_fn = T.nn.MSELoss()

    # Initialise local move policy
    local_move_policy = A2CLocalMovePolicy(args.dim_context,
                                           args.dim_problem,
                                           args.window_size,
                                           args.num_of_scenarios_in_state,
                                           gamma=args.gamma,
                                           beta_entropy=args.beta,
                                           num_local_move=args.num_local_move)

    # Train
    for epoch in range(1, args.epochs + 1):
        print("******************************************************")
        print(f"Epoch : {epoch}")
        # Generate instance and environment
        instance = generator.generate_instance()
        context = instance.get_context()

        env = create_environment(args, instance)

        # Learn using REINFORCE
        # If using baseline, update the baseline net
        if args.use_baseline:
            baseline_reward = baseline_net.forward(context)
            reward_init, loss_init, start_state = init_policy.REINFORCE(
                init_opt, env, context, baseline_reward, True)
            update_baseline_model(loss_base_fn, baseline_reward, reward_init,
                                  opt_base)
        # Without using baseline
        else:
            reward_init, loss_init, start_state = init_policy.REINFORCE(
                init_opt, env, context)
        reward_init_buffer.append(reward_init)
        loss_init_buffer.append(loss_init)

        # Learn using A2C
        rewards_local, loss_local = local_move_policy.train(start_state, env)
        rewards_local_buffer.append(rewards_local)
        loss_local_buffer.append(loss_local)

        # Save stats and model
        if epoch % 100 == 0:
            eval_stats = evaluate_model(args,
                                        env,
                                        generator,
                                        init_policy=init_policy,
                                        local_move_policy=local_move_policy)

            mean_square_error_per_epoch.append(eval_stats["mean_square_error"])
            mean_relative_distance_per_epoch.append(
                eval_stats["mean_relative_distance"])

            # Save init policy stats
            save_stats_and_model(args, epoch, reward_init_buffer,
                                 loss_init_buffer, mean_square_error_per_epoch,
                                 mean_relative_distance_per_epoch, init_policy,
                                 INIT)

            # Save local move policy stats
            save_stats_and_model(args, epoch, rewards_local_buffer,
                                 loss_local_buffer,
                                 mean_square_error_per_epoch,
                                 mean_relative_distance_per_epoch,
                                 local_move_policy, LOCAL)
def train_local_move_policy(args):
    """
    Train Local Move Policy only

    Arguments
    ---------
    args : dict
        Dictionary containing command line arguments
    """
    rewards_per_epoch = list()
    loss_per_epoch = list()
    mean_square_error_per_epoch = list()
    mean_relative_distance_per_epoch = list()

    # Instance generator
    generator = instance_generator(args.problem)

    # Initialise local move policy
    local_move_policy = A2CLocalMovePolicy(args.dim_context,
                                           args.dim_problem,
                                           args.window_size,
                                           args.num_of_scenarios_in_state,
                                           gamma=args.gamma,
                                           beta_entropy=args.beta,
                                           num_local_move=args.num_local_move,
                                           lr_a2c=args.lr_a2c)

    # Train
    for epoch in range(1, args.epochs + 1):
        start_time = time.time()
        print("******************************************************")
        print(f"Epoch : {epoch}")

        # Generate instance and environment
        instance = generator.generate_instance()
        context = instance.get_context()
        env = create_environment(args, instance)

        start_state = generate_dummy_start_state(env, args.dim_problem)

        # Take num_local_moves to improves the provided initial solution
        rewards, loss = local_move_policy.train(start_state, env)

        rewards_per_epoch.append(rewards)
        loss_per_epoch.append(loss)

        # Save stats and model
        if epoch % 100 == 0:
            eval_stats = evaluate_model(args,
                                        env,
                                        generator,
                                        local_move_policy=local_move_policy)
            mean_square_error_per_epoch.append(eval_stats["mean_square_error"])
            mean_relative_distance_per_epoch.append(
                eval_stats["mean_relative_distance"])

            save_stats_and_model(args, epoch, rewards_per_epoch,
                                 loss_per_epoch, mean_square_error_per_epoch,
                                 mean_relative_distance_per_epoch,
                                 local_move_policy, LOCAL)

        print(
            f"Took {time.time() - start_time} in epoch {epoch}/{args.epochs}")
def train_init_policy(args):
    """
    Train the Intialisation Policy

    Arguments
    ---------
    args : dict
        Dictionary containing command line arguments
    """
    rewards_per_epoch = list()
    loss_per_epoch = list()
    mean_square_error_per_epoch = list()
    mean_relative_distance_per_epoch = list()

    generator = instance_generator(args.problem)

    # Initialise Initialisation policy and set its optimizer
    init_policy = select_initialization_policy(args)
    init_opt = T.optim.Adam(init_policy.parameters(), lr=args.init_lr_rate)

    # Initialize Baseline if required
    if args.use_baseline:
        baseline_net = Baseline(args.dim_context, args.dim_hidden)
        opt_base = T.optim.Adam(baseline_net.parameters(), lr=1e-4)
        loss_base_fn = T.nn.MSELoss()

    # Train
    for epoch in range(1, args.epochs + 1):
        print("******************************************************")
        print(f"Epoch : {epoch}")
        # Generate instance and environment
        instance = generator.generate_instance()
        context = instance.get_context()

        env = create_environment(args, instance)

        # Learn using REINFORCE
        # If using baseline, update the baseline net
        if args.use_baseline:
            baseline_reward = baseline_net.forward(context)
            reward_, loss_init_, start_state = init_policy.REINFORCE(
                init_opt, env, context, baseline_reward, True)
            update_baseline_model(loss_base_fn, baseline_reward, reward_,
                                  opt_base)
        # Without using baseline
        else:
            reward_, loss_init_, start_state = init_policy.REINFORCE(
                init_opt, env, context)

        rewards_per_epoch.append(reward_.item())
        loss_per_epoch.append(loss_init_.item())

        # Save stats and model
        if epoch % 50 == 0:
            eval_stats = evaluate_model(args,
                                        env,
                                        generator,
                                        init_policy=init_policy)

            mean_square_error_per_epoch.append(eval_stats["mean_square_error"])
            mean_relative_distance_per_epoch.append(
                eval_stats["mean_relative_distance"])

            # Save init policy stats
            save_stats_and_model(args, epoch, rewards_per_epoch,
                                 loss_per_epoch, mean_square_error_per_epoch,
                                 mean_relative_distance_per_epoch, init_policy,
                                 INIT)
예제 #5
0
파일: rl.py 프로젝트: shibei00/rltoolkit
 def env_fn():
     return utils.create_environment(kwargs['environment'])
예제 #6
0
def main(logdir, checkpoint, human_render, num_rollouts, max_episode_length,
         save_videos, save_rollouts, save_separate_rollouts):
    if not osp.exists(osp.join(logdir, 'run.json')):
        raise FileNotFoundError("Could not find run.json.")

    configuration = json.load(open(osp.join(logdir, 'run.json'), 'r'))
    if configuration["settings"]["method"] not in ["trpo", "ppo"]:
        raise NotImplementedError(
            "Playback for %s has not been implemented yet." %
            configuration["method"])

    env = utils.create_environment(configuration["settings"]["environment"])

    # build policy network
    # TODO this needs to be more general
    from baselines.ppo1 import mlp_policy
    tf.Session().__enter__()
    pi = mlp_policy.MlpPolicy(
        name="pi",
        ob_space=env.observation_space,
        ac_space=env.action_space,
        hid_size=configuration["settings"].get('pi_hid_size', 150),
        num_hid_layers=configuration["settings"].get('pi_num_hid_layers', 3))

    # find latest policy checkpoint
    saver = tf.train.Saver()
    if checkpoint is None:
        files = glob.glob(osp.join(logdir, 'checkpoints') + '/*.index')
        files = [(int(re.findall(".*?_(\d+)\.", f)[0]), f) for f in files]
        files = sorted(files, key=operator.itemgetter(0))
        checkpoint = files[-1][1]
    elif not osp.isabs(checkpoint):
        if not osp.exists(osp.join(logdir, 'checkpoints')):
            raise FileNotFoundError("Could not find checkpoints folder")
        else:
            checkpoint = osp.join(logdir, 'checkpoints', checkpoint)
    if checkpoint.endswith(".index"):
        checkpoint = checkpoint[:-len(".index")]
    print("Loading checkpoint %s." % checkpoint)
    saver.restore(tf.get_default_session(), checkpoint)

    # generate rollouts
    rollouts = []
    for i_rollout in tqdm(range(num_rollouts), "Computing rollouts"):
        observation = env.reset()
        rollout = {"observation": [], "reward": [], "action": []}
        video = []
        for i_episode in range(max_episode_length):
            action, _ = pi.act(stochastic=False, ob=observation)
            observation, reward, done, _ = env.step(action)
            if human_render:
                env.render(mode='human')
            if save_videos is not None:
                video.append(env.render(mode='rgb_array'))
            if save_rollouts is not None:
                rollout["observation"].append(observation)
                rollout["reward"].append(reward)
                rollout["action"].append(action)
            if done:
                break

        if save_videos is not None:
            imageio.mimsave(osp.join(save_videos,
                                     'rollout_%i.mp4' % i_rollout),
                            video,
                            fps=env.metadata.get('video.frames_per_second',
                                                 50))
        if save_rollouts is not None and save_separate_rollouts:
            pkl.dump(
                rollout,
                open(osp.join(save_rollouts, 'rollout_%i.pkl' % i_rollout),
                     "wb"))
        else:
            rollouts.append(rollout)

    if save_rollouts is not None and not save_separate_rollouts:
        pkl.dump(rollouts, open(osp.join(save_rollouts, 'rollouts.pkl'), "wb"))
import numpy as np
import math
from utils import create_environment, plot_metrics
from dqn_agent import DQNAgent

if __name__ == '__main__':

    #### Execution variables ####
    max_score = -math.inf  # Intial max_score depends on game.(Pong --> - inf | Breakout --> 0))
    train = False
    load_model = False
    create_capture = True
    num_games = 400  # Training-episodes
    ####

    env = create_environment('PongNoFrameskip-v4')
    agent_dqn = DQNAgent(gamma=0.99,
                         epsilon=1.0,
                         learning_rate=0.0001,
                         num_actions=env.action_space.n,
                         input_dim=(env.observation_space.shape),
                         memory_size=50000,
                         eps_min=0.1,
                         eps_dec=1e-5,
                         batch_size=32,
                         update=1000,
                         chkpt_dir='models/',
                         algorithm='DQN',
                         env_name='PongNoFrameskip-v4')

    if load_model: