def nodir_test(self): """ Fail if logdir does not exist """ parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) with self.assertRaises(ValueError) as context: parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2']) self.assertIn("Logdir does", str(context.exception))
def overwrite_default_test(self): """ You can overwrite defaults, and it behaves properly """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=float, default=1e-3) parser.add_argument('--b', type=float, default=2.) parser = argload.ArgumentLoader(parser, ['a', 'b']) mkdir('log') parser.parse_args(['--logdir', 'log', '--a', '1e-2', '--b', '3.']) args = parser.parse_args( ['--logdir', 'log', '--a', '1e-3', '--overwrite']) self.assertEqual(args.a, 1e-3) self.assertEqual(args.b, 3.) finally: shutil.rmtree('log')
def dump_no_overwrite_test(self): """ Fail if dump is used without overwrite """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) mkdir('log') parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2']) with self.assertRaises(ValueError) as context: parser.parse_args( ['--logdir', 'log', '--a', '2', '4', '--b', '2', '--dump']) self.assertIn("Dumping is not", str(context.exception)) finally: shutil.rmtree('log')
def failed_overwrite_test(self): """ Trying to overwrite without overwrite flag """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) mkdir('log') parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2']) with self.assertRaises(ValueError) as context: parser.parse_args(['--logdir', 'log', '--a', '4']) self.assertTrue("Overwritting" in str(context.exception)) finally: shutil.rmtree('log')
def overwrite_no_old_test(self): """ Fail if overwrite or dump without first dump """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) mkdir('log') with self.assertRaises(ValueError) as context: parser.parse_args([ '--logdir', 'log', '--a', '2', '4', '--b', '2', '--overwrite' ]) self.assertIn("No old", str(context.exception)) finally: shutil.rmtree('log')
def dump_test(self): """ Dumping modifies stored values """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) mkdir('log') parser.parse_args(['--logdir', 'log', '--a', '3', '4', '--b', '2']) parser.parse_args( ['--logdir', 'log', '--overwrite', '--dump', '--a', '3', '2']) args = parser.parse_args(['--logdir', 'log']) self.assertEqual([3, 2], args.a) finally: shutil.rmtree('log')
def normal_test(self): """ Test normal use case """ try: parser = argparse.ArgumentParser() parser.add_argument('--a', type=int, nargs='*') parser.add_argument('--b', type=int) parser = argload.ArgumentLoader(parser, ['a']) mkdir('log') parser.parse_args(['--logdir', 'log', '--a', '3', '5', '--b', '2']) args = parser.parse_args(['--logdir', 'log']) self.assertTrue(args.a == [3, 5]) self.assertTrue(args.b is None) self.assertTrue(args.logdir == 'log') args = parser.parse_args( ['--logdir', 'log', '--overwrite', '--a', '4']) self.assertTrue(args.a == [4]) self.assertTrue(args.b is None) finally: shutil.rmtree('log')
def setup_args(): parser = argparse.ArgumentParser() parser.add_argument('--eval_gap', type=float, default=.1, help='evaluation is performed every .1/dt epochs.') parser.add_argument('--algo', type=str, default='approximate_advantage', help='algorithm used.') parser.add_argument('--dt', type=float, default=.02, help='temporal discretization.') parser.add_argument('--steps_btw_train', type=int, default=10, help='number of environment steps between two training periods.') parser.add_argument('--env_id', type=str, default='pendulum', help='environment.') parser.add_argument('--noise_type', type=str, default='coherent', choices=['coherent', 'independent'], help='noise type used') parser.add_argument('--batch_size', type=int, default=64, help='training batch size') parser.add_argument('--hidden_size', type=int, default=64, help='number of hidden units per layer.') parser.add_argument('--nb_layers', type=int, default=1, help='number of layers (careful, the "true number of layers" is this number + 1).') parser.add_argument('--gamma', type=float, default=.8, help='discount factor.') parser.add_argument('--n_step', type=int, default=20, help='Number of steps in a2c') parser.add_argument('--nb_true_epochs', type=float, default=50, help='number of true epochs (epochs / dt) to train on.') parser.add_argument('--nb_steps', type=int, default=100, help='number of environment steps in an epoch') parser.add_argument('--sigma', type=float, default=1.5, help='OU noise parameter.') parser.add_argument('--theta', type=float, default=7.5, help='OU stiffness parameter.') parser.add_argument('--c_entropy', type=float, default=1e-4, help='entropy regularization') parser.add_argument('--eps_clamp', type=float, default=0.2, help='Clipping value for PPO, epsilon in the original paper') parser.add_argument('--c_kl', type=float, default=0., help='KL regularization for PPO, beta in the original paper') parser.add_argument('--nb_train_env', type=int, default=32, help='number of parallel environments during training.') parser.add_argument('--nb_eval_env', type=int, default=16, help='number of parallel environments used to evaluate.') parser.add_argument('--memory_size', type=int, default=1000000, help='size of the memory buffer.') parser.add_argument('--learn_per_step', type=int, default=50, help='number of gradient step in one learning step') parser.add_argument('--normalize_state', action='store_true', help='is state normalization used.') parser.add_argument('--lr', type=float, default=.03, help='critic learning rate.') parser.add_argument('--policy_lr', type=float, default=None, help='policy learning rate (for approximate policies).') parser.add_argument('--time_limit', type=float, default=None, help='specify environment time limite (physical time).') parser.add_argument('--redirect_stdout', action='store_true', help='should we redirect stdout to a log file?') parser.add_argument('--weight_decay', type=float, default=0, help='actor weight decay.') parser.add_argument('--alpha', type=float, default=None, help='prioritized replay buffer alpha (untested).') parser.add_argument('--beta', type=float, default=None, help='prioritized replay buffer beta (untested).') parser.add_argument('--tau', type=float, default=.99, help='target network update rate (works for all ' 'algo, do not expect it to work with dau).') parser.add_argument('--epsilon', type=float, default=.1, help='epsilon greedy coeficient') parser.add_argument('--noscale', action='store_true', help='use unscaled ddpg when set') parser.add_argument('--optimizer', type=str, choices=['sgd', 'rmsprop', 'adam'], default='sgd') parser.add_argument('--noreload', action='store_true', help='do not reload previously saved model when set.') parser.add_argument('--snapshot', action='store_true', help='if true, stores snapshot every once in a while') parser = argload.ArgumentLoader(parser, to_reload=[ 'algo', 'dt', 'steps_btw_train', 'env_id', 'noise_type', 'batch_size', 'hidden_size', 'nb_layers', 'gamma', 'n_step', 'nb_true_epochs', 'nb_steps', 'sigma', 'theta', 'c_entropy', 'eps_clamp', 'c_kl', 'nb_train_env', 'nb_eval_env', 'memory_size', 'learn_per_step', 'normalize_state', 'lr', 'time_limit', 'policy_lr', 'alpha', 'beta', 'weight_decay', 'optimizer', 'tau', 'eval_gap', 'noscale', 'epsilon', 'snapshot' ]) args = parser.parse_args() # args translation if args.algo == 'ddpg': args.algo = 'approximate_value' elif args.algo == 'dqn': args.algo = 'discrete_value' elif args.algo == 'ddau': args.algo = 'discrete_advantage' elif args.algo == 'cdau': args.algo = 'approximate_advantage' return args