Exemplo n.º 1
0
def executor(agrs):
    # Keras Backend 설정
    if args.backend == 'tensorflow':
        os.environ['KERAS_BACKEND'] = 'tensorflow'
    elif args.backend == 'plaidml':
        os.environ['KERAS_BACKEND'] = 'plaidml.keras.backend'

    # 출력 경로 설정
    output_path = os.path.join(
        settings.BASE_DIR, 'output/{}_{}_{}'.format(args.output_name,
                                                    args.rl_method, args.net))
    if not os.path.isdir(output_path):
        os.makedirs(output_path)

    # 파라미터 기록
    with open(os.path.join(output_path, 'params.json'), 'w') as f:
        f.write(json.dumps(vars(args)))

    # 로그 기록 설정
    file_handler = logging.FileHandler(filename=os.path.join(
        output_path, "{}.log".format(args.output_name)),
                                       encoding='utf-8')
    stream_handler = logging.StreamHandler(sys.stdout)
    file_handler.setLevel(logging.DEBUG)
    stream_handler.setLevel(logging.INFO)
    logging.basicConfig(format="%(message)s",
                        handlers=[file_handler, stream_handler],
                        level=logging.DEBUG)

    # 로그, Keras Backend 설정을 먼저하고 RLTrader 모듈들을 이후에 임포트해야 함
    from agent import Agent
    from learners import DQNLearner

    # 모델 경로 준비
    value_network_path = ''
    policy_network_path = ''
    if args.value_network_name is not None:
        value_network_path = os.path.join(
            settings.BASE_DIR, 'models/{}.h5'.format(args.value_network_name))
    else:
        value_network_path = os.path.join(
            output_path, '{}_{}_value_{}.h5'.format(args.rl_method, args.net,
                                                    args.output_name))
    if args.policy_network_name is not None:
        policy_network_path = os.path.join(
            settings.BASE_DIR, 'models/{}.h5'.format(args.policy_network_name))
    else:
        policy_network_path = os.path.join(
            output_path, '{}_{}_policy_{}.h5'.format(args.rl_method, args.net,
                                                     args.output_name))

    common_params = {}
    list_stock_code = []
    list_chart_data = []
    list_training_data = []
    list_min_trading_unit = []
    list_max_trading_unit = []

    stock_code = args.stock_code

    get_data.get_data(stock_code, args.start_date, args.end_date, ver=args.ver)

    # 차트 데이터, 학습 데이터 준비
    chart_data, training_data = data_manager.load_data(os.path.join(
        settings.BASE_DIR, 'data/{}/{}.csv'.format(args.ver, stock_code)),
                                                       args.start_date,
                                                       args.end_date,
                                                       ver=args.ver)

    # 최소/최대 투자 단위 설정
    min_trading_unit = max(int(10000 / chart_data.iloc[-1]['close']), 1)
    max_trading_unit = max(int(100000 / chart_data.iloc[-1]['close']), 1)

    # 공통 파라미터 설정
    common_params = {
        'rl_method': args.rl_method,
        'delayed_reward_threshold': args.delayed_reward_threshold,
        'net': args.net,
        'num_steps': args.num_steps,
        'lr': args.lr,
        'output_path': output_path,
        'reuse_models': args.reuse_models
    }

    # 강화학습 시작
    learner = None

    common_params.update({
        'stock_code': stock_code,
        'chart_data': chart_data,
        'training_data': training_data,
        'min_trading_unit': min_trading_unit,
        'max_trading_unit': max_trading_unit
    })

    learner = DQNLearner(
        **{
            **common_params, 'value_network_path': value_network_path
        })

    if learner is not None:
        pvs = learner.run(balance=args.balance,
                          num_epoches=args.num_epoches,
                          discount_factor=args.discount_factor,
                          start_epsilon=args.start_epsilon,
                          learning=args.learning)
        learner.save_models()

    return chart_data, pvs
Exemplo n.º 2
0
            'reuse_models': args.reuse_models
        }

        # 강화학습 시작
        learner = None
        if args.rl_method != 'a3c':
            common_params.update({
                'stock_code': stock_code,
                'chart_data': chart_data,
                'training_data': training_data,
                'min_trading_unit': min_trading_unit,
                'max_trading_unit': max_trading_unit
            })
            if args.rl_method == 'dqn':
                learner = DQNLearner(**{
                    **common_params, 'value_network_path':
                    value_network_path
                })
            elif args.rl_method == 'pg':
                learner = PolicyGradientLearner(**{
                    **common_params, 'policy_network_path':
                    policy_network_path
                })
            elif args.rl_method == 'ac':
                learner = ActorCriticLearner(
                    **{
                        **common_params, 'value_network_path':
                        value_network_path,
                        'policy_network_path': policy_network_path
                    })
            elif args.rl_method == 'a2c':
                learner = A2CLearner(