예제 #1
0
    def setUpClass(cls):

        try:
            cls.config = Config('config.toml')
            engine = create_engine(cls.config.db_uri_test)
            SessionMaker = sessionmaker(bind=engine)
            cls.session = SessionMaker()

            create_tables(engine)

            cls.session.query(Close).delete()
            cls.session.query(Price).delete()
            cls.session.query(PriceSeq).delete()
            cls.session.commit()

            dir_resources = Path(cls.config.dir_resources)
            dir_prices = dir_resources / Path('pseudo-data') / Path('prices')
            missing_rics = ['.TEST']
            logger = create_logger(Path('test.log'),
                                   is_debug=False,
                                   is_temporary=True)

            # insert database
            insert_prices(cls.session, dir_prices, missing_rics, dir_resources,
                          logger)

        except:  # noqa: E722
            raise unittest.SkipTest('Cannot establish connection')
예제 #2
0
def main() -> None:

    args = parse_args()

    predictor = Predictor(Config(args.dest_config), torch.device(args.device),
                          Path(args.output))

    sentence = predictor.predict(args.time, args.ric)

    print('"' + '", "'.join(sentence) + '"')
예제 #3
0
def main() -> None:

    args = parse_args()

    if not args.is_debug:
        warnings.simplefilter(action='ignore', category=FutureWarning)

    config = Config(args.dest_config)

    device = torch.device(args.device)

    now = datetime.today().strftime('reporter-%Y-%m-%d-%H-%M-%S')
    dest_dir = config.dir_output / Path(now) \
        if args.output_subdir is None \
        else config.dir_output / Path(args.output_subdir)

    dest_log = dest_dir / Path('reporter.log')

    logger = create_logger(dest_log, is_debug=args.is_debug)
    config.write_log(logger)

    message = 'start main (is_debug: {}, device: {})'.format(args.is_debug, args.device)
    logger.info(message)

    # === Alignment ===
    has_all_alignments = \
        reduce(lambda x, y: x and y,
               [(config.dir_output / Path('alignment-{}.json'.format(phase.value))).exists()
                for phase in list(Phase)])

    if not has_all_alignments:

        engine = create_engine(config.db_uri)
        SessionMaker = sessionmaker(bind=engine)
        pg_session = SessionMaker()
        create_tables(engine)

        prepare_resources(config, pg_session, logger)
        for phase in list(Phase):
            config.dir_output.mkdir(parents=True, exist_ok=True)
            dest_alignments = config.dir_output / Path('alignment-{}.json'.format(phase.value))
            alignments = load_alignments_from_db(pg_session, phase, logger)
            with dest_alignments.open(mode='w') as f:
                writer = jsonlines.Writer(f)
                writer.write_all(alignments)
        pg_session.close()

    # === Dataset ===
    (vocab, train, valid, test) = create_dataset(config, device)

    vocab_size = len(vocab)
    dest_vocab = dest_dir / Path('reporter.vocab')
    with dest_vocab.open(mode='wb') as f:
        torch.save(vocab, f)
    seqtypes = []
    attn = setup_attention(config, seqtypes)
    encoder = Encoder(config, device)
    decoder = Decoder(config, vocab_size, attn, device)
    model = EncoderDecoder(encoder, decoder, device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    criterion = torch.nn.NLLLoss(reduction='elementwise_mean',
                                 ignore_index=vocab.stoi[SpecialToken.Padding.value])

    # === Train ===
    dest_model = dest_dir / Path('reporter.model')
    prev_valid_bleu = 0.0
    max_bleu = 0.0
    best_epoch = 0
    early_stop_counter = 0
    for epoch in range(config.n_epochs):
        logger.info('start epoch {}'.format(epoch))
        train_result = run(train,
                           vocab,
                           model,
                           optimizer,
                           criterion,
                           Phase.Train,
                           logger)
        train_bleu = calc_bleu(train_result.gold_sents, train_result.pred_sents)
        valid_result = run(valid,
                           vocab,
                           model,
                           optimizer,
                           criterion,
                           Phase.Valid,
                           logger)
        valid_bleu = calc_bleu(valid_result.gold_sents, valid_result.pred_sents)

        s = ' | '.join(['epoch: {0:4d}'.format(epoch),
                        'training loss: {:.2f}'.format(train_result.loss),
                        'training BLEU: {:.4f}'.format(train_bleu),
                        'validation loss: {:.2f}'.format(valid_result.loss),
                        'validation BLEU: {:.4f}'.format(valid_bleu)])
        logger.info(s)

        if max_bleu < valid_bleu:
            torch.save(model.state_dict(), str(dest_model))
            max_bleu = valid_bleu
            best_epoch = epoch

        early_stop_counter = early_stop_counter + 1 \
            if prev_valid_bleu > valid_bleu \
            else 0
        if early_stop_counter == config.patience:
            logger.info('EARLY STOPPING')
            break
        prev_valid_bleu = valid_bleu

    # === Test ===
    with dest_model.open(mode='rb') as f:
        model.load_state_dict(torch.load(f))
    test_result = run(test,
                      vocab,
                      model,
                      optimizer,
                      criterion,
                      Phase.Test,
                      logger)
    test_bleu = calc_bleu(test_result.gold_sents, test_result.pred_sents)

    s = ' | '.join(['epoch: {:04d}'.format(best_epoch),
                    'Test Loss: {:.2f}'.format(test_result.loss),
                    'Test BLEU: {:.10f}'.format(test_bleu)])
    logger.info(s)

    export_results_to_csv(dest_dir, test_result)
예제 #4
0
def config():
    return Config('config.toml')
예제 #5
0
from sqlalchemy import func

from reporter.database.misc import in_jst, in_utc
from reporter.database.model import GenerationResult, Headline, HumanEvaluation
from reporter.database.read import fetch_date_range, fetch_max_t_of_prev_trading_day, fetch_rics
from reporter.predict import Predictor
from reporter.util.config import Config
from reporter.util.constant import JST, NIKKEI_DATETIME_FORMAT, UTC, Code
from reporter.webapp.chart import (fetch_all_closes_fast,
                                   fetch_all_points_fast, fetch_close,
                                   fetch_points)
from reporter.webapp.human_evaluation import populate_for_human_evaluation
from reporter.webapp.search import construct_constraint_query
from reporter.webapp.table import Table, create_ric_tables, load_ric_to_ric_info

config = Config('config.toml')
app = flask.Flask(__name__)
app.config['TESTING'] = True
app.config['SQLALCHEMY_DATABASE_URI'] = config.db_uri
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
app.jinja_env.add_extension('pypugjs.ext.jinja.PyPugJSExtension')

dir_scss = Path('reporter/webapp/static/scss').resolve()
dir_css = Path('reporter/webapp/static/css').resolve()
sass.compile(dirname=(str(dir_scss), str(dir_css)), output_style='expanded')

db = SQLAlchemy(app)

ric_to_ric_info = load_ric_to_ric_info()
populate_for_human_evaluation(db.session, config.result)
demo_initial_date = config.demo_initial_date