Exemplo n.º 1
0
def main(argv):
    env = gym.make(FLAGS.env)
    env = env.unwrapped
    if FLAGS.env == 'Pendulum-v0':
        env = WrappedPendulumEnv(env)
        FLAGS.step = 200

    message = OrderedDict({
        "Env": FLAGS.env,
        "Agent": FLAGS.agent,
        "Episode": FLAGS.n_episode,
        "Max_Step": FLAGS.step,
        "batch_size": FLAGS.batch_size,
        "Optimizer": FLAGS.opt,
        "learning_rate": FLAGS.lr,
        "Priority": FLAGS.priority,
        "multi_step": FLAGS.multi_step,
        "Categorical": FLAGS.category,
        "n_warmup": FLAGS.n_warmup,
        "model_update": FLAGS.model_update,
        "init_model": FLAGS.init_model
    })

    out_dim = set_output_dim(FLAGS, env.action_space.shape[0])

    agent = eval(FLAGS.agent)(model=set_model(outdim=out_dim),
                              n_actions=env.action_space.shape[0],
                              n_features=env.observation_space.shape[0],
                              learning_rate=FLAGS.lr,
                              batch_size=FLAGS.batch_size,
                              e_greedy=0.9,
                              replace_target_iter=1,
                              e_greedy_decrement=0.01,
                              optimizer=FLAGS.opt,
                              is_categorical=FLAGS.category,
                              max_action=env.action_space.high[0],
                              gpu=find_gpu())

    trainer = Trainer(agent=agent,
                      env=env,
                      n_episode=FLAGS.n_episode,
                      max_step=FLAGS.step,
                      replay_size=FLAGS.batch_size,
                      data_size=10**6,
                      n_warmup=FLAGS.n_warmup,
                      priority=FLAGS.priority,
                      multi_step=FLAGS.multi_step,
                      render=FLAGS.render,
                      test_episode=2,
                      test_interval=50,
                      test_frame=FLAGS.rec,
                      test_render=FLAGS.test_render,
                      metrics=message,
                      init_model_dir=FLAGS.init_model)

    trainer.train()
Exemplo n.º 2
0
def main(argv):
    env = make_atari(FLAGS.env)
    env = wrap_deepmind(env, frame_stack=True)

    if FLAGS.agent == 'Rainbow':
        FLAGS.network = 'Dueling_Net'
        FLAGS.multi_step = 3
        FLAGS.category = True
        FLAGS.noise = True

    message = OrderedDict({
        "Env": env,
        "Agent": FLAGS.agent,
        "Network": FLAGS.network,
        "Episode": FLAGS.n_episode,
        "Max_Step": FLAGS.step,
        "Categorical": FLAGS.category,
        "init_model": FLAGS.model
    })

    out_dim = set_output_dim(FLAGS, env.action_space.n)

    agent = eval(FLAGS.agent)(model=set_model(outdim=out_dim),
                              n_actions=env.action_space.n,
                              n_features=env.observation_space.shape,
                              learning_rate=0,
                              e_greedy=0,
                              reward_decay=0,
                              replace_target_iter=0,
                              e_greedy_increment=0,
                              optimizer=None,
                              network=FLAGS.network,
                              trainable=False,
                              is_categorical=FLAGS.category,
                              is_noise=FLAGS.noise,
                              gpu=find_gpu())

    if FLAGS.agent == 'PolicyGradient':
        trainer = PolicyTrainer(agent=agent,
                                env=env,
                                n_episode=FLAGS.n_episode,
                                max_step=FLAGS.step,
                                replay_size=0,
                                data_size=0,
                                n_warmup=0,
                                priority=None,
                                multi_step=0,
                                render=FLAGS.render,
                                test_episode=5,
                                test_interval=0,
                                test_frame=FLAGS.rec,
                                test_render=FLAGS.test_render,
                                metrics=message,
                                init_model_dir=FLAGS.model)

    elif FLAGS.agent == 'A3C' or FLAGS.agent == 'Ape_X':
        trainer = DistributedTrainer(agent=agent,
                                     n_workers=0,
                                     env=env,
                                     n_episode=FLAGS.n_episode,
                                     max_step=FLAGS.step,
                                     replay_size=0,
                                     data_size=0,
                                     n_warmup=0,
                                     priority=None,
                                     multi_step=0,
                                     render=False,
                                     test_episode=5,
                                     test_interval=0,
                                     test_frame=FLAGS.rec,
                                     test_render=FLAGS.test_render,
                                     metrics=message,
                                     init_model_dir=FLAGS.model)

    else:
        trainer = Trainer(agent=agent,
                          env=env,
                          n_episode=FLAGS.n_episode,
                          max_step=FLAGS.step,
                          replay_size=0,
                          data_size=0,
                          n_warmup=0,
                          priority=None,
                          multi_step=0,
                          render=FLAGS.render,
                          test_episode=5,
                          test_interval=0,
                          test_frame=FLAGS.rec,
                          test_render=FLAGS.test_render,
                          metrics=message,
                          init_model_dir=FLAGS.model)

    trainer.test()

    return
Exemplo n.º 3
0
parser.add_argument('--last_layer',
                    default=200,
                    type=int,
                    help='The size of the last layer, 0-11')
parser.add_argument('--dataset',
                    default='tiny',
                    type=str,
                    help='The dataset being used')
parser.add_argument('--pickle_path',
                    default='saved_pickle',
                    type=str,
                    help='Directory to save all of the policies with pickle.')

args = parser.parse_args()

best_gpu = find_gpu()
cuda_device = 'cuda:' + str(best_gpu)

device = cuda_device if torch.cuda.is_available() else 'cpu'
best_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
ll = args.last_layer

networks_names = ('DenseNet121', 'DPN92', 'GoogLeNet', 'MobileNet',
                  'MobileNetV2', 'PreActResNet18', 'ResNet18',
                  'ResNeXt29_2x64d', 'SENet18', 'ShuffleNetG2', 'VGG',
                  'ShuffleNetV2')

root = args.data

print('==> Building model..')
Exemplo n.º 4
0
def main(argv):
    #env = make_atari(FLAGS.env)
    #env = wrap_deepmind(gym.make(FLAGS.env), frame_stack=True)
    env = wrap_dqn(gym.make(FLAGS.env))

    if FLAGS.agent == 'Rainbow':
        FLAGS.network = 'Dueling_Net'
        FLAGS.priority = True
        FLAGS.multi_step = 3
        FLAGS.category = True
        FLAGS.noise = True
        FLAGS.opt = 'Adam'
        FLAGS.lr = 0.0000625

    message = OrderedDict({
        "Env": FLAGS.env,
        "Agent": FLAGS.agent,
        "Network": FLAGS.network,
        "Episode": FLAGS.n_episode,
        "Max_Step": FLAGS.step,
        "batch_size": FLAGS.batch_size,
        "Optimizer": FLAGS.opt,
        "learning_rate": FLAGS.lr,
        "Priority": FLAGS.priority,
        "multi_step": FLAGS.multi_step,
        "Categorical": FLAGS.category,
        "Noisy": FLAGS.noise,
        "n_warmup": FLAGS.n_warmup,
        "model_update": FLAGS.model_update,
        "init_model": FLAGS.init_model
    })

    out_dim = set_output_dim(FLAGS, env.action_space.n)

    agent = eval(FLAGS.agent)(model=set_model(outdim=out_dim),
                              n_actions=env.action_space.n,
                              n_features=env.observation_space.shape,
                              learning_rate=FLAGS.lr,
                              e_greedy=0.0 if FLAGS.noise else 0.1,
                              reward_decay=0.99,
                              replace_target_iter=FLAGS.model_update,
                              optimizer=FLAGS.opt,
                              network=FLAGS.network,
                              is_categorical=FLAGS.category,
                              is_noise=FLAGS.noise,
                              gpu=find_gpu())

    if FLAGS.agent == 'PolicyGradient':
        trainer = PolicyTrainer(agent=agent,
                                env=env,
                                n_episode=FLAGS.n_episode,
                                max_step=FLAGS.step,
                                replay_size=FLAGS.batch_size,
                                data_size=256,
                                n_warmup=FLAGS.n_warmup,
                                priority=FLAGS.priority,
                                multi_step=0,
                                render=FLAGS.render,
                                test_episode=2,
                                test_interval=50,
                                test_frame=FLAGS.rec,
                                test_render=FLAGS.test_render,
                                metrics=message,
                                init_model_dir=FLAGS.init_model)

    elif FLAGS.agent == 'A3C' or FLAGS.agent == 'Ape_X':
        trainer = DistributedTrainer(agent=agent,
                                     n_workers=FLAGS.n_workers,
                                     env=env,
                                     n_episode=FLAGS.n_episode,
                                     max_step=FLAGS.step,
                                     replay_size=FLAGS.batch_size,
                                     data_size=500,
                                     n_warmup=FLAGS.n_warmup,
                                     priority=FLAGS.priority,
                                     multi_step=0,
                                     render=FLAGS.render,
                                     test_episode=2,
                                     test_interval=50,
                                     test_frame=FLAGS.rec,
                                     test_render=FLAGS.test_render,
                                     metrics=message,
                                     init_model_dir=FLAGS.init_model)

    else:
        trainer = Trainer(agent=agent,
                          env=env,
                          n_episode=FLAGS.n_episode,
                          max_step=FLAGS.step,
                          replay_size=FLAGS.batch_size,
                          data_size=10**6,
                          n_warmup=FLAGS.n_warmup,
                          priority=FLAGS.priority,
                          multi_step=FLAGS.multi_step,
                          render=FLAGS.render,
                          test_episode=2,
                          test_interval=5000,
                          test_frame=FLAGS.rec,
                          test_render=FLAGS.test_render,
                          metrics=message,
                          init_model_dir=FLAGS.init_model)

    trainer.train()