コード例 #1
0
def main(args, base_dir):
    """Execute multiple training operations."""
    for i in range(args.n_training):
        # value of the next seed
        seed = args.seed + i

        # create a save directory folder (if it doesn't exist)
        dir_name = os.path.join(
            base_dir,
            '{}/{}'.format(args.env_name, strftime("%Y-%m-%d-%H:%M:%S")))
        ensure_dir(dir_name)

        # get the hyperparameters
        hp = get_hyperparameters(args, FeedForwardPolicy)

        # add the seed for logging purposes
        params_with_extra = hp.copy()
        params_with_extra['seed'] = seed
        params_with_extra['env_name'] = args.env_name
        params_with_extra['policy_name'] = "FeedForwardPolicy"

        # add the hyperparameters to the folder
        with open(os.path.join(dir_name, 'hyperparameters.json'), 'w') as f:
            json.dump(params_with_extra, f, sort_keys=True, indent=4)

        run_exp(env=args.env_name,
                hp=hp,
                steps=args.total_steps,
                dir_name=dir_name,
                evaluate=args.evaluate,
                seed=seed,
                eval_interval=args.eval_interval,
                log_interval=args.log_interval,
                save_interval=args.save_interval)
コード例 #2
0
def main(args, base_dir):
    """Execute multiple training operations."""
    for i in range(args.n_training):
        # value of the next seed
        seed = args.seed + i

        # The time when the current experiment started.
        now = strftime("%Y-%m-%d-%H:%M:%S")

        # Create a save directory folder (if it doesn't exist).
        if args.log_dir is not None:
            dir_name = args.log_dir
        else:
            dir_name = os.path.join(base_dir, '{}/{}'.format(
                args.env_name, now))
        ensure_dir(dir_name)

        # Get the policy class.
        if args.alg == "TD3":
            from hbaselines.fcnet.td3 import FeedForwardPolicy
        elif args.alg == "SAC":
            from hbaselines.fcnet.sac import FeedForwardPolicy
        elif args.alg == "PPO":
            from hbaselines.fcnet.ppo import FeedForwardPolicy
        elif args.alg == "TRPO":
            from hbaselines.fcnet.trpo import FeedForwardPolicy
        else:
            raise ValueError("Unknown algorithm: {}".format(args.alg))

        # Get the hyperparameters.
        hp = get_hyperparameters(args, FeedForwardPolicy)

        # Add the seed for logging purposes.
        params_with_extra = hp.copy()
        params_with_extra['seed'] = seed
        params_with_extra['env_name'] = args.env_name
        params_with_extra['policy_name'] = "FeedForwardPolicy"
        params_with_extra['algorithm'] = args.alg
        params_with_extra['date/time'] = now

        # Add the hyperparameters to the folder.
        with open(os.path.join(dir_name, 'hyperparameters.json'), 'w') as f:
            json.dump(params_with_extra, f, sort_keys=True, indent=4)

        run_exp(
            env=args.env_name,
            policy=FeedForwardPolicy,
            hp=hp,
            dir_name=dir_name,
            evaluate=args.evaluate,
            seed=seed,
            eval_interval=args.eval_interval,
            log_interval=args.log_interval,
            save_interval=args.save_interval,
            initial_exploration_steps=args.initial_exploration_steps,
            ckpt_path=args.ckpt_path,
        )
コード例 #3
0
ファイル: test_utils.py プロジェクト: le-horizon/h-baselines
    def test_parse_options(self):
        # Test the default case.
        args = parse_options("", "", args=["AntMaze"])
        expected_args = {
            'env_name':
            'AntMaze',
            'alg':
            'TD3',
            'evaluate':
            False,
            'n_training':
            1,
            'total_steps':
            1000000,
            'seed':
            1,
            'log_interval':
            2000,
            'eval_interval':
            50000,
            'save_interval':
            50000,
            'initial_exploration_steps':
            10000,
            'nb_train_steps':
            1,
            'nb_rollout_steps':
            1,
            'nb_eval_episodes':
            50,
            'reward_scale':
            1,
            'render':
            False,
            'render_eval':
            False,
            'verbose':
            2,
            'actor_update_freq':
            2,
            'meta_update_freq':
            10,
            'noise':
            TD3_PARAMS['noise'],
            'target_policy_noise':
            TD3_PARAMS['target_policy_noise'],
            'target_noise_clip':
            TD3_PARAMS['target_noise_clip'],
            'target_entropy':
            SAC_PARAMS['target_entropy'],
            'buffer_size':
            FEEDFORWARD_PARAMS['buffer_size'],
            'batch_size':
            FEEDFORWARD_PARAMS['batch_size'],
            'actor_lr':
            FEEDFORWARD_PARAMS['actor_lr'],
            'critic_lr':
            FEEDFORWARD_PARAMS['critic_lr'],
            'tau':
            FEEDFORWARD_PARAMS['tau'],
            'gamma':
            FEEDFORWARD_PARAMS['gamma'],
            'layer_norm':
            False,
            'use_huber':
            False,
            'num_levels':
            GOAL_CONDITIONED_PARAMS['num_levels'],
            'meta_period':
            GOAL_CONDITIONED_PARAMS['meta_period'],
            'intrinsic_reward_scale':
            GOAL_CONDITIONED_PARAMS['intrinsic_reward_scale'],
            'relative_goals':
            False,
            'off_policy_corrections':
            False,
            'hindsight':
            False,
            'subgoal_testing_rate':
            GOAL_CONDITIONED_PARAMS['subgoal_testing_rate'],
            'use_fingerprints':
            False,
            'centralized_value_functions':
            False,
            'connected_gradients':
            False,
            'cg_weights':
            GOAL_CONDITIONED_PARAMS['cg_weights'],
            'shared':
            False,
            'maddpg':
            False,
        }
        self.assertDictEqual(vars(args), expected_args)

        # Test custom cases.
        args = parse_options("",
                             "",
                             args=[
                                 "AntMaze",
                                 '--evaluate',
                                 '--n_training',
                                 '1',
                                 '--total_steps',
                                 '2',
                                 '--seed',
                                 '3',
                                 '--log_interval',
                                 '4',
                                 '--eval_interval',
                                 '5',
                                 '--save_interval',
                                 '6',
                                 '--nb_train_steps',
                                 '7',
                                 '--nb_rollout_steps',
                                 '8',
                                 '--nb_eval_episodes',
                                 '9',
                                 '--reward_scale',
                                 '10',
                                 '--render',
                                 '--render_eval',
                                 '--verbose',
                                 '11',
                                 '--actor_update_freq',
                                 '12',
                                 '--meta_update_freq',
                                 '13',
                                 '--buffer_size',
                                 '14',
                                 '--batch_size',
                                 '15',
                                 '--actor_lr',
                                 '16',
                                 '--critic_lr',
                                 '17',
                                 '--tau',
                                 '18',
                                 '--gamma',
                                 '19',
                                 '--noise',
                                 '20',
                                 '--target_policy_noise',
                                 '21',
                                 '--target_noise_clip',
                                 '22',
                                 '--layer_norm',
                                 '--use_huber',
                                 '--num_levels',
                                 '23',
                                 '--meta_period',
                                 '24',
                                 '--intrinsic_reward_scale',
                                 '25',
                                 '--relative_goals',
                                 '--off_policy_corrections',
                                 '--hindsight',
                                 '--subgoal_testing_rate',
                                 '26',
                                 '--use_fingerprints',
                                 '--centralized_value_functions',
                                 '--connected_gradients',
                                 '--cg_weights',
                                 '27',
                                 '--shared',
                                 '--maddpg',
                             ])
        hp = get_hyperparameters(args, GoalConditionedPolicy)
        expected_hp = {
            'nb_train_steps': 7,
            'nb_rollout_steps': 8,
            'nb_eval_episodes': 9,
            'reward_scale': 10.0,
            'render': True,
            'render_eval': True,
            'verbose': 11,
            'actor_update_freq': 12,
            'meta_update_freq': 13,
            '_init_setup_model': True,
            'policy_kwargs': {
                'buffer_size': 14,
                'batch_size': 15,
                'actor_lr': 16.0,
                'critic_lr': 17.0,
                'tau': 18.0,
                'gamma': 19.0,
                'noise': 20.0,
                'target_policy_noise': 21.0,
                'target_noise_clip': 22.0,
                'layer_norm': True,
                'use_huber': True,
                'num_levels': 23,
                'meta_period': 24,
                'intrinsic_reward_scale': 25.0,
                'relative_goals': True,
                'off_policy_corrections': True,
                'hindsight': True,
                'subgoal_testing_rate': 26.0,
                'use_fingerprints': True,
                'centralized_value_functions': True,
                'connected_gradients': True,
                'cg_weights': 27.0,
            }
        }
        self.assertDictEqual(hp, expected_hp)
        self.assertEqual(args.log_interval, 4)
        self.assertEqual(args.eval_interval, 5)

        hp = get_hyperparameters(args, MultiFeedForwardPolicy)
        expected_hp = {
            'nb_train_steps': 7,
            'nb_rollout_steps': 8,
            'nb_eval_episodes': 9,
            'actor_update_freq': 12,
            'meta_update_freq': 13,
            'reward_scale': 10.0,
            'render': True,
            'render_eval': True,
            'verbose': 11,
            '_init_setup_model': True,
            'policy_kwargs': {
                'buffer_size': 14,
                'batch_size': 15,
                'actor_lr': 16.0,
                'critic_lr': 17.0,
                'tau': 18.0,
                'gamma': 19.0,
                'layer_norm': True,
                'use_huber': True,
                'noise': 20.0,
                'target_policy_noise': 21.0,
                'target_noise_clip': 22.0,
                'shared': True,
                'maddpg': True,
            }
        }
        self.assertDictEqual(hp, expected_hp)
        self.assertEqual(args.log_interval, 4)
        self.assertEqual(args.eval_interval, 5)
コード例 #4
0
ファイル: train.py プロジェクト: danieljeswin/RLProject
def train_h_baselines(env_name, args, multiagent):
    """Train policies using SAC and TD3 with h-baselines."""
    from hbaselines.algorithms import OffPolicyRLAlgorithm
    from hbaselines.utils.train import parse_options, get_hyperparameters

    # Get the command-line arguments that are relevant here
    args = parse_options(description="", example_usage="", args=args)

    # the base directory that the logged data will be stored in
    base_dir = "training_data"

    for i in range(args.n_training):
        # value of the next seed
        seed = args.seed + i

        # The time when the current experiment started.
        now = strftime("%Y-%m-%d-%H:%M:%S")

        # Create a save directory folder (if it doesn't exist).
        dir_name = os.path.join(base_dir, '{}/{}'.format(args.env_name, now))
        ensure_dir(dir_name)

        # Get the policy class.
        if args.alg == "TD3":
            if multiagent:
                from hbaselines.multi_fcnet.td3 import MultiFeedForwardPolicy
                policy = MultiFeedForwardPolicy
            else:
                from hbaselines.fcnet.td3 import FeedForwardPolicy
                policy = FeedForwardPolicy
        elif args.alg == "SAC":
            if multiagent:
                from hbaselines.multi_fcnet.sac import MultiFeedForwardPolicy
                policy = MultiFeedForwardPolicy
            else:
                from hbaselines.fcnet.sac import FeedForwardPolicy
                policy = FeedForwardPolicy
        else:
            raise ValueError("Unknown algorithm: {}".format(args.alg))

        # Get the hyperparameters.
        hp = get_hyperparameters(args, policy)

        # Add the seed for logging purposes.
        params_with_extra = hp.copy()
        params_with_extra['seed'] = seed
        params_with_extra['env_name'] = args.env_name
        params_with_extra['policy_name'] = policy.__name__
        params_with_extra['algorithm'] = args.alg
        params_with_extra['date/time'] = now

        # Add the hyperparameters to the folder.
        with open(os.path.join(dir_name, 'hyperparameters.json'), 'w') as f:
            json.dump(params_with_extra, f, sort_keys=True, indent=4)

        # Create the algorithm object.
        alg = OffPolicyRLAlgorithm(
            policy=policy,
            env="flow:{}".format(env_name),
            eval_env="flow:{}".format(env_name) if args.evaluate else None,
            **hp)

        # Perform training.
        alg.learn(
            total_steps=args.total_steps,
            log_dir=dir_name,
            log_interval=args.log_interval,
            eval_interval=args.eval_interval,
            save_interval=args.save_interval,
            initial_exploration_steps=args.initial_exploration_steps,
            seed=seed,
        )