コード例 #1
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def nodir_test(self):
        """ Fail if logdir does not exist """
        parser = argparse.ArgumentParser()
        parser.add_argument('--a', type=int, nargs='*')
        parser.add_argument('--b', type=int)
        parser = argload.ArgumentLoader(parser, ['a'])
        with self.assertRaises(ValueError) as context:
            parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2'])

        self.assertIn("Logdir does", str(context.exception))
コード例 #2
0
ファイル: test_argload.py プロジェクト: ctallec/argload
 def overwrite_default_test(self):
     """ You can overwrite defaults, and it behaves properly """
     try:
         parser = argparse.ArgumentParser()
         parser.add_argument('--a', type=float, default=1e-3)
         parser.add_argument('--b', type=float, default=2.)
         parser = argload.ArgumentLoader(parser, ['a', 'b'])
         mkdir('log')
         parser.parse_args(['--logdir', 'log', '--a', '1e-2', '--b', '3.'])
         args = parser.parse_args(
             ['--logdir', 'log', '--a', '1e-3', '--overwrite'])
         self.assertEqual(args.a, 1e-3)
         self.assertEqual(args.b, 3.)
     finally:
         shutil.rmtree('log')
コード例 #3
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def dump_no_overwrite_test(self):
        """ Fail if dump is used without overwrite """
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('--a', type=int, nargs='*')
            parser.add_argument('--b', type=int)
            parser = argload.ArgumentLoader(parser, ['a'])
            mkdir('log')
            parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2'])
            with self.assertRaises(ValueError) as context:
                parser.parse_args(
                    ['--logdir', 'log', '--a', '2', '4', '--b', '2', '--dump'])

            self.assertIn("Dumping is not", str(context.exception))
        finally:
            shutil.rmtree('log')
コード例 #4
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def failed_overwrite_test(self):
        """ Trying to overwrite without overwrite flag """
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('--a', type=int, nargs='*')
            parser.add_argument('--b', type=int)
            parser = argload.ArgumentLoader(parser, ['a'])
            mkdir('log')
            parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2'])

            with self.assertRaises(ValueError) as context:
                parser.parse_args(['--logdir', 'log', '--a', '4'])

            self.assertTrue("Overwritting" in str(context.exception))
        finally:
            shutil.rmtree('log')
コード例 #5
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def overwrite_no_old_test(self):
        """ Fail if overwrite or dump without first dump """
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('--a', type=int, nargs='*')
            parser.add_argument('--b', type=int)
            parser = argload.ArgumentLoader(parser, ['a'])
            mkdir('log')
            with self.assertRaises(ValueError) as context:
                parser.parse_args([
                    '--logdir', 'log', '--a', '2', '4', '--b', '2',
                    '--overwrite'
                ])

            self.assertIn("No old", str(context.exception))
        finally:
            shutil.rmtree('log')
コード例 #6
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def dump_test(self):
        """ Dumping modifies stored values """
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('--a', type=int, nargs='*')
            parser.add_argument('--b', type=int)
            parser = argload.ArgumentLoader(parser, ['a'])
            mkdir('log')
            parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2'])

            parser.parse_args(
                ['--logdir', 'log', '--overwrite', '--dump', '--a', '3', '2'])
            args = parser.parse_args(['--logdir', 'log'])

            self.assertEqual([3, 2], args.a)
        finally:
            shutil.rmtree('log')
コード例 #7
0
ファイル: test_argload.py プロジェクト: ctallec/argload
    def normal_test(self):
        """ Test normal use case """
        try:
            parser = argparse.ArgumentParser()
            parser.add_argument('--a', type=int, nargs='*')
            parser.add_argument('--b', type=int)
            parser = argload.ArgumentLoader(parser, ['a'])
            mkdir('log')
            parser.parse_args(['--logdir', 'log', '--a', '3', '5', '--b', '2'])

            args = parser.parse_args(['--logdir', 'log'])

            self.assertTrue(args.a == [3, 5])
            self.assertTrue(args.b is None)
            self.assertTrue(args.logdir == 'log')

            args = parser.parse_args(
                ['--logdir', 'log', '--overwrite', '--a', '4'])

            self.assertTrue(args.a == [4])
            self.assertTrue(args.b is None)
        finally:
            shutil.rmtree('log')
コード例 #8
0
ファイル: parse.py プロジェクト: zh0123210/continuous-rl
def setup_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--eval_gap', type=float, default=.1,
                        help='evaluation is performed every .1/dt epochs.')
    parser.add_argument('--algo', type=str, default='approximate_advantage',
                        help='algorithm used.')
    parser.add_argument('--dt', type=float, default=.02,
                        help='temporal discretization.')
    parser.add_argument('--steps_btw_train', type=int, default=10,
                        help='number of environment steps between two training periods.')
    parser.add_argument('--env_id', type=str, default='pendulum',
                        help='environment.')
    parser.add_argument('--noise_type', type=str, default='coherent', choices=['coherent', 'independent'],
                        help='noise type used')
    parser.add_argument('--batch_size', type=int, default=64,
                        help='training batch size')
    parser.add_argument('--hidden_size', type=int, default=64,
                        help='number of hidden units per layer.')
    parser.add_argument('--nb_layers', type=int, default=1,
                        help='number of layers (careful, the "true number of layers" is this number + 1).')
    parser.add_argument('--gamma', type=float, default=.8,
                        help='discount factor.')
    parser.add_argument('--n_step', type=int, default=20,
                        help='Number of steps in a2c')
    parser.add_argument('--nb_true_epochs', type=float, default=50,
                        help='number of true epochs (epochs / dt) to train on.')
    parser.add_argument('--nb_steps', type=int, default=100,
                        help='number of environment steps in an epoch')
    parser.add_argument('--sigma', type=float, default=1.5,
                        help='OU noise parameter.')
    parser.add_argument('--theta', type=float, default=7.5,
                        help='OU stiffness parameter.')
    parser.add_argument('--c_entropy', type=float, default=1e-4,
                        help='entropy regularization')
    parser.add_argument('--eps_clamp', type=float, default=0.2,
                        help='Clipping value for PPO, epsilon in the original paper')
    parser.add_argument('--c_kl', type=float, default=0.,
                        help='KL regularization for PPO, beta in the original paper')
    parser.add_argument('--nb_train_env', type=int, default=32,
                        help='number of parallel environments during training.')
    parser.add_argument('--nb_eval_env', type=int, default=16,
                        help='number of parallel environments used to evaluate.')
    parser.add_argument('--memory_size', type=int, default=1000000,
                        help='size of the memory buffer.')
    parser.add_argument('--learn_per_step', type=int, default=50,
                        help='number of gradient step in one learning step')
    parser.add_argument('--normalize_state', action='store_true',
                        help='is state normalization used.')
    parser.add_argument('--lr', type=float, default=.03,
                        help='critic learning rate.')
    parser.add_argument('--policy_lr', type=float, default=None,
                        help='policy learning rate (for approximate policies).')
    parser.add_argument('--time_limit', type=float, default=None,
                        help='specify environment time limite (physical time).')
    parser.add_argument('--redirect_stdout', action='store_true',
                        help='should we redirect stdout to a log file?')
    parser.add_argument('--weight_decay', type=float, default=0,
                        help='actor weight decay.')
    parser.add_argument('--alpha', type=float, default=None,
                        help='prioritized replay buffer alpha (untested).')
    parser.add_argument('--beta', type=float, default=None,
                        help='prioritized replay buffer beta (untested).')
    parser.add_argument('--tau', type=float, default=.99,
                        help='target network update rate (works for all '
                        'algo, do not expect it to work with dau).')
    parser.add_argument('--epsilon', type=float, default=.1,
                        help='epsilon greedy coeficient')
    parser.add_argument('--noscale', action='store_true',
                        help='use unscaled ddpg when set')
    parser.add_argument('--optimizer', type=str, choices=['sgd', 'rmsprop', 'adam'],
                        default='sgd')
    parser.add_argument('--noreload', action='store_true',
                        help='do not reload previously saved model when set.')
    parser.add_argument('--snapshot', action='store_true',
                        help='if true, stores snapshot every once in a while')
    parser = argload.ArgumentLoader(parser, to_reload=[
        'algo', 'dt', 'steps_btw_train', 'env_id', 'noise_type',
        'batch_size', 'hidden_size', 'nb_layers', 'gamma', 'n_step', 'nb_true_epochs', 'nb_steps',
        'sigma', 'theta', 'c_entropy', 'eps_clamp', 'c_kl', 'nb_train_env', 'nb_eval_env', 'memory_size',
        'learn_per_step', 'normalize_state', 'lr', 'time_limit',
        'policy_lr', 'alpha', 'beta', 'weight_decay', 'optimizer',
        'tau', 'eval_gap', 'noscale', 'epsilon', 'snapshot'
    ])
    args = parser.parse_args()

    # args translation
    if args.algo == 'ddpg':
        args.algo = 'approximate_value'
    elif args.algo == 'dqn':
        args.algo = 'discrete_value'
    elif args.algo == 'ddau':
        args.algo = 'discrete_advantage'
    elif args.algo == 'cdau':
        args.algo = 'approximate_advantage'

    return args