示例#1
0
def test_fastqa():
    tf.reset_default_graph()

    data = load_jack('tests/test_data/squad/snippet_jtr.json')

    # fast qa must be initialized with existing embeddings, so we create some
    embeddings = load_embeddings('./tests/test_data/glove.840B.300d_top256.txt', 'glove')

    # we need a vocabulary (with embeddings for our fastqa_reader, but this is not always necessary)
    vocab = Vocab(emb=embeddings, init_from_embeddings=True)

    # ... and a config
    config = {
        "batch_size": 1,
        "repr_dim": 10,
        "repr_dim_input": embeddings.lookup.shape[1],
        "with_char_embeddings": True
    }

    # create/setup reader
    shared_resources = SharedResources(vocab, config)

    input_module = XQAInputModule(shared_resources)
    model_module = FastQAModule(shared_resources)
    output_module = XQAOutputModule()

    reader = TFReader(shared_resources, input_module, model_module, output_module)
    reader.setup_from_data(data, is_training=True)

    loss = reader.model_module.tensors[Ports.loss]
    optimizer = tf.train.AdagradOptimizer(learning_rate=0.01)
    min_op = optimizer.minimize(loss)

    session = model_module.tf_session
    session.run(tf.global_variables_initializer())

    for epoch in range(0, 10):
        for batch in reader.input_module.batch_generator(data, 1, False):
            feed_dict = reader.model_module.convert_to_feed_dict(batch)
            loss_value, _ = session.run((loss, min_op), feed_dict=feed_dict)
            print(loss_value)
示例#2
0
文件: jack-train.py 项目: jg8610/jack
def main(config,
         loader,
         debug,
         debug_examples,
         embedding_file,
         embedding_format,
         experiments_db,
         reader,
         train,
         num_train_examples,
         dev,
         num_dev_examples,
         test,
         vocab_from_embeddings):
    logger.info("TRAINING")

    if 'JACK_TEMP' not in os.environ:
        jack_temp = os.path.join(tempfile.gettempdir(), 'jack', str(uuid.uuid4()))
        os.environ['JACK_TEMP'] = jack_temp
        logger.info("JACK_TEMP not set, setting it to %s. Might be used for caching." % jack_temp)
    else:
        jack_temp = os.environ['JACK_TEMP']
    if not os.path.exists(jack_temp):
        os.makedirs(jack_temp)

    if experiments_db is not None:
        ex.observers.append(SqlObserver.create('sqlite:///%s' % experiments_db))

    if debug:
        train_data = loaders[loader](train, debug_examples)

        logger.info('loaded {} samples as debug train/dev/test dataset '.format(debug_examples))

        dev_data = train_data
        test_data = train_data

        if embedding_file is not None and embedding_format is not None:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(path.join('data', 'GloVe', emb_file), 'glove')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))
            ex.current_run.config["repr_dim_input"] = 50
        else:
            embeddings = Embeddings(None, None)
    else:
        train_data = loaders[loader](train, num_train_examples)
        dev_data = loaders[loader](dev, num_dev_examples)
        test_data = loaders[loader](test) if test else None

        logger.info('loaded train/dev/test data')
        if embedding_file is not None and embedding_format is not None:
            embeddings = load_embeddings(embedding_file, embedding_format)
            logger.info('loaded pre-trained embeddings ({})'.format(embedding_file))
            ex.current_run.config["repr_dim_input"] = embeddings.lookup[0].shape[0]
        else:
            embeddings = None
            if ex.current_run.config["vocab_from_embeddings"]:
                raise RuntimeError("If you want to create vocab from embeddings, embeddings have to be provided")

    vocab = Vocab(emb=embeddings, init_from_embeddings=vocab_from_embeddings)

    # build JTReader
    checkpoint()
    parsed_config = ex.current_run.config
    ex.run('print_config', config_updates=parsed_config)

    # name defaults to name of the model
    if 'name' not in parsed_config or parsed_config['name'] is None:
        parsed_config['name'] = reader

    shared_resources = SharedResources(vocab, parsed_config)
    jtreader = readers.readers[reader](shared_resources)

    checkpoint()

    try:
        jtrain(jtreader, train_data, test_data, dev_data, parsed_config, debug=debug)
    finally:  # clean up temporary dir
        if os.path.exists(jack_temp):
            shutil.rmtree(jack_temp)
示例#3
0
def main(batch_size, clip_value, config, loader, debug, debug_examples, dev,
         embedding_file, embedding_format, experiments_db, epochs, l2,
         optimizer, learning_rate, learning_rate_decay, log_interval,
         validation_interval, model, model_dir, seed, tensorboard_folder, test,
         train, vocab_from_embeddings, write_metrics_to):
    logger.info("TRAINING")

    if experiments_db is not None:
        ex.observers.append(SqlObserver.create('sqlite:///%s' %
                                               experiments_db))

    if debug:
        train_data = loaders[loader](train, debug_examples)

        logger.info(
            'loaded {} samples as debug train/dev/test dataset '.format(
                debug_examples))

        dev_data = train_data
        test_data = train_data

        if embedding_file is not None and embedding_format is not None:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(path.join('data', 'GloVe', emb_file),
                                         'glove')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))
            ex.current_run.config["repr_dim_input"] = 50
        else:
            embeddings = Embeddings(None, None)
    else:
        train_data = loaders[loader](train)
        dev_data = loaders[loader](dev)
        test_data = loaders[loader](test) if test else None

        logger.info('loaded train/dev/test data')
        if embedding_file is not None and embedding_format is not None:
            embeddings = load_embeddings(embedding_file, embedding_format)
            logger.info(
                'loaded pre-trained embeddings ({})'.format(embedding_file))
            ex.current_run.config["repr_dim_input"] = embeddings.lookup[
                0].shape[0]
        else:
            embeddings = None
            if ex.current_run.config["vocab_from_embeddings"]:
                raise RuntimeError(
                    "If you want to create vocab from embeddings, embeddings have to be provided"
                )

    vocab = Vocab(emb=embeddings, init_from_embeddings=vocab_from_embeddings)

    # build JTReader
    checkpoint()
    parsed_config = ex.current_run.config
    ex.run('print_config', config_updates=parsed_config)

    # name defaults to name of the model
    if 'name' not in parsed_config or parsed_config['name'] is None:
        parsed_config['name'] = model

    shared_resources = SharedResources(vocab, parsed_config)
    reader = readers.readers[model](shared_resources)

    checkpoint()

    configuration = {
        'seed': seed,
        'clip_value': clip_value,
        'batch_size': batch_size,
        'epochs': epochs,
        'l2': l2,
        'optimizer': optimizer,
        'learning_rate': learning_rate,
        'learning_rate_decay': learning_rate_decay,
        'log_interval': log_interval,
        'validation_interval': validation_interval,
        'tensorboard_folder': tensorboard_folder,
        'model': model,
        'model_dir': model_dir,
        'write_metrics_to': write_metrics_to
    }

    jtrain(reader, train_data, test_data, dev_data, configuration, debug=debug)
示例#4
0
def run(loader, debug, debug_examples, embedding_file, embedding_format,
        repr_dim_task_embedding, reader, train, num_train_examples, dev,
        num_dev_examples, test, vocab_from_embeddings, **kwargs):
    logger.info("TRAINING")

    # build JTReader
    parsed_config = ex.current_run.config
    ex.run('print_config', config_updates=parsed_config)

    if 'JACK_TEMP' not in os.environ:
        jack_temp = os.path.join(tempfile.gettempdir(), 'jack',
                                 str(uuid.uuid4()))
        os.environ['JACK_TEMP'] = jack_temp
        logger.info(
            "JACK_TEMP not set, setting it to %s. Might be used for caching." %
            jack_temp)
    else:
        jack_temp = os.environ['JACK_TEMP']
    if not os.path.exists(jack_temp):
        os.makedirs(jack_temp)

    if debug:
        train_data = loaders[loader](train, debug_examples)

        logger.info(
            'loaded {} samples as debug train/dev/test dataset '.format(
                debug_examples))

        dev_data = train_data
        test_data = train_data

        if embedding_file is not None and embedding_format is not None:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(path.join('data', 'GloVe', emb_file),
                                         'glove')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))
        else:
            embeddings = None
    else:
        train_data = loaders[loader](train, num_train_examples)
        dev_data = loaders[loader](dev, num_dev_examples)
        test_data = loaders[loader](test) if test else None

        logger.info('loaded train/dev/test data')
        if embedding_file is not None and embedding_format is not None:
            embeddings = load_embeddings(embedding_file, embedding_format)
            logger.info(
                'loaded pre-trained embeddings ({})'.format(embedding_file))
        else:
            embeddings = None
            if vocab_from_embeddings:
                raise ValueError(
                    "If you want to create vocab from embeddings, embeddings have to be provided"
                )

    vocab = Vocab(vocab=embeddings.vocabulary if vocab_from_embeddings
                  and embeddings is not None else None)

    if repr_dim_task_embedding < 1 and embeddings is None:
        raise ValueError(
            "Either provide pre-trained embeddings or set repr_dim_task_embedding > 0."
        )

    # name defaults to name of the model
    if 'name' not in parsed_config or parsed_config['name'] is None:
        parsed_config['name'] = reader

    shared_resources = SharedResources(vocab, parsed_config, embeddings)
    jtreader = readers.readers[reader](shared_resources)

    try:
        jtrain(jtreader,
               train_data,
               test_data,
               dev_data,
               parsed_config,
               debug=debug)
    finally:  # clean up temporary dir
        if os.path.exists(jack_temp):
            shutil.rmtree(jack_temp)