def run_environment(h_size, middle_size, lstm_layers, learning_starts,
                    learning_freq, target_update_freq, lr, gamma, batch_size,
                    replay_buffer_size, epsilon_decay_steps, final_epsilon,
                    root_dir, num):
    log_dir = os.path.join(root_dir, "{:03}".format(num))
    os.makedirs(log_dir)
    agent = DQN_Agent(h_size,
                      middle_size,
                      lstm_layers,
                      learning_starts,
                      learning_freq,
                      target_update_freq,
                      lr,
                      gamma,
                      batch_size,
                      replay_buffer_size,
                      epsilon_decay_steps,
                      final_epsilon,
                      verbose=True,
                      log_dir=log_dir)

    Reward_func = getattr(reward, training['reward_func'])
    agent.train(Reward_func, training['reward_settings'], training['episodes'],
                training['targets'], training['reg_inits'])
    agent.save("best", best=True)
    performance = agent.global_performance()
    best_performance, best_episode = agent.best_performance()
    return performance + (1 / (1 + best_episode))
示例#2
0
def main():
    args = parse_arguments()
    agent = DQN_Agent(args, memory_size=args.memory_size, burn_in=args.burn_in)

    if args.train == 1:
        if not os.path.exists(args.folder_prefix):
            os.makedirs(args.folder_prefix)

        sys.stdout = Logger(args.folder_prefix + args.logfile)
        print_user_flags(args)

        PolicyModel = args.folder_prefix + 'PolicyModel/'
        TargetModel = args.folder_prefix + 'TargetModel/'
        RewardsCSV = args.folder_prefix + 'RewardsCSV/'

        if not os.path.exists(PolicyModel):
            os.makedirs(PolicyModel)
        elif args.reset_dir:
            shutil.rmtree(PolicyModel, ignore_errors=True)
            os.makedirs(PolicyModel)
        if not os.path.exists(TargetModel):
            os.makedirs(TargetModel)
        elif args.reset_dir:
            shutil.rmtree(TargetModel, ignore_errors=True)
            os.makedirs(TargetModel)
        if not os.path.exists(RewardsCSV):
            os.makedirs(RewardsCSV)
        elif args.reset_dir:
            shutil.rmtree(RewardsCSV, ignore_errors=True)
            os.makedirs(RewardsCSV)

        agent.train()
    else:
        agent.test(test_epi=args.test_epi,
                   model_file=args.weight_file,
                   lookahead=agent.greedy_policy)

    agent.agent_close()