示例#1
0
    BUCKETS = 4

    # get_model = get_permute_model
    # get_model = get_basic_modele
    get_model = get_selective_model
    # get_model = get_bowv_model

    CONCAT_SENTENCES = True
    if get_model == get_selective_model or get_model == get_bowv_model:
        CONCAT_SENTENCES = False

    from jtr.preprocess.vocab import NeuralVocab
    from jtr.load.embeddings.embeddings import load_embeddings

    if USE_PRETRAINED_EMBEDDINGS:
        embeddings = load_embeddings('./data/TBD/GloVe/glove.6B.100d.txt',
                                     'glove')
        #embeddings = load_embeddings('./jtr/data/word2vec/GoogleNews-vectors-negative300.bin.gz', 'glove')
        emb = embeddings.get
    else:
        emb = None
    vocab = Vocab(emb=emb)

    print('loading corpus..')
    if DEBUG:
        train_corpus = load_corpus("debug", USE_PERMUTATION_INDEX)
        dev_corpus = load_corpus("debug", USE_PERMUTATION_INDEX)
        test_corpus = load_corpus("debug", USE_PERMUTATION_INDEX)
    else:
        train_corpus = load_corpus("train", USE_PERMUTATION_INDEX)
        dev_corpus = load_corpus("dev", USE_PERMUTATION_INDEX)
        test_corpus = load_corpus("test", USE_PERMUTATION_INDEX)
示例#2
0
    if DEBUG:
        train_data = load("./jtr/data/SNLI/snli_1.0/snli_1.0_train.jsonl",
                          DEBUG_EXAMPLES)
        dev_data = train_data
        test_data = train_data
    else:
        train_data, dev_data, test_data = [load("./jtr/data/SNLI/snli_1.0/snli_1.0_%s.jsonl" % name)\
                                           for name in ["train", "dev", "test"]]

    print(train_data)

    print('loaded train/dev/test data')
    if pretrain:
        if DEBUG:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(
                path.join('jtr', 'data', 'GloVe', emb_file), 'glove')
        else:
            #emb_file = 'GoogleNews-vectors-negative300.bin.gz'
            #embeddings = load_embeddings(path.join('jtr', 'data', 'word2vec', emb_file),'word2vec')
            emb_file = 'glove.840B.300d.zip'
            embeddings = load_embeddings(
                path.join('jtr', 'data', 'GloVe', emb_file), 'glove')
        print('loaded pre-trained embeddings')

    emb = embeddings.get if pretrain else None

    checkpoint()
    print('encode train data')
    train_data, train_vocab, train_target_vocab = pipeline(train_data,
                                                           emb=emb,
                                                           normalize=True)
def main():
    support_alts = {'none', 'single', 'multiple'}
    question_alts = answer_alts = {'single', 'multiple'}
    candidate_alts = {'open', 'per-instance', 'fixed'}

    train_default = 'tests/test_data/SNLI/train.json'
    dev_default = 'tests/test_data/SNLI/dev.json'
    test_default = 'tests/test_data/SNLI/test.json'

    parser = argparse.ArgumentParser(
        description='Train and Evaluate a Machine Reader',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        '--debug',
        action='store_true',
        help=
        "Run in debug mode, in which case the training file is also used for testing"
    )

    parser.add_argument(
        '--debug_examples',
        default=10,
        type=int,
        help="If in debug mode, how many examples should be used (default 2000)"
    )
    parser.add_argument('--train',
                        default=train_default,
                        type=str,
                        help="jtr training file")
    parser.add_argument('--dev',
                        default=dev_default,
                        type=str,
                        help="jtr dev file")
    parser.add_argument('--test',
                        default=test_default,
                        type=str,
                        help="jtr test file")
    parser.add_argument(
        '--supports',
        default='single',
        choices=sorted(support_alts),
        help=
        "None, single (default) or multiple supporting statements per instance; multiple_flat reads multiple instances creates a separate instance for every support"
    )
    parser.add_argument(
        '--questions',
        default='single',
        choices=sorted(question_alts),
        help="None, single (default), or multiple questions per instance")
    parser.add_argument(
        '--candidates',
        default='fixed',
        choices=sorted(candidate_alts),
        help="Open, per-instance, or fixed (default) candidates")
    parser.add_argument('--answers',
                        default='single',
                        choices=sorted(answer_alts),
                        help="Single or multiple")
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help="Batch size for training data, default 128")
    parser.add_argument('--dev_batch_size',
                        default=128,
                        type=int,
                        help="Batch size for development data, default 128")
    parser.add_argument(
        '--repr_dim_input',
        default=128,
        type=int,
        help=("Size of the input representation (embeddings),",
              "default 128 (embeddings cut off or extended if not",
              "matched with pretrained embeddings)"))
    parser.add_argument('--repr_dim',
                        default=128,
                        type=int,
                        help="Size of the hidden representations, default 128")

    parser.add_argument(
        '--pretrain',
        action='store_true',
        help=
        "Use pretrained embeddings, by default the initialisation is random")
    parser.add_argument(
        '--with_char_embeddings',
        action='store_true',
        help="Use also character based embeddings in readers which support it."
    )
    parser.add_argument('--vocab_from_embeddings',
                        action='store_true',
                        help="Use fixed vocab of pretrained embeddings")
    parser.add_argument(
        '--train_pretrain',
        action='store_true',
        help=
        "Continue training pretrained embeddings together with model parameters"
    )
    parser.add_argument(
        '--normalize_pretrain',
        action='store_true',
        help=
        "Normalize pretrained embeddings, default True (randomly initialized embeddings have expected unit norm too)"
    )

    parser.add_argument('--embedding_format',
                        default='word2vec',
                        choices=["glove", "word2vec"],
                        help="format of embeddings to be loaded")
    parser.add_argument(
        '--embedding_file',
        default='jtr/data/SG_GoogleNews/GoogleNews-vectors-negative300.bin.gz',
        type=str,
        help="format of embeddings to be loaded")

    parser.add_argument('--vocab_maxsize', default=sys.maxsize, type=int)
    parser.add_argument('--vocab_minfreq', default=2, type=int)
    parser.add_argument(
        '--vocab_sep',
        default=True,
        type=bool,
        help=
        'Should there be separate vocabularies for questions, supports, candidates and answers. This needs to be set to True for candidate-based methods.'
    )
    parser.add_argument('--model',
                        default='snli_reader',
                        choices=sorted(readers.readers.keys()),
                        help="Reading model to use")
    parser.add_argument('--learning_rate',
                        default=0.001,
                        type=float,
                        help="Learning rate, default 0.001")
    parser.add_argument('--learning_rate_decay',
                        default=0.5,
                        type=float,
                        help="Learning rate decay, default 0.5")
    parser.add_argument('--l2',
                        default=0.0,
                        type=float,
                        help="L2 regularization weight, default 0.0")
    parser.add_argument(
        '--clip_value',
        default=0.0,
        type=float,
        help=
        "Gradients clipped between [-clip_value, clip_value] (default 0.0; no clipping)"
    )
    parser.add_argument(
        '--dropout',
        default=0.0,
        type=float,
        help="Probability for dropout on output (set to 0.0 for no dropout)")
    parser.add_argument('--epochs',
                        default=5,
                        type=int,
                        help="Number of epochs to train for, default 5")
    parser.add_argument('--checkpoint',
                        default=None,
                        type=int,
                        help="Number of batches before evaluation on devset.")

    parser.add_argument(
        '--negsamples',
        default=0,
        type=int,
        help="Number of negative samples, default 0 (= use full candidate list)"
    )
    parser.add_argument('--tensorboard_folder',
                        default=None,
                        help='Folder for tensorboard logs')
    parser.add_argument('--write_metrics_to',
                        default=None,
                        type=str,
                        help='Filename to log the metrics of the EvalHooks')
    parser.add_argument(
        '--prune',
        default='False',
        help='If the vocabulary should be pruned to the most frequent words.')
    parser.add_argument('--model_dir',
                        default='/tmp/jtreader',
                        type=str,
                        help="Directory to write reader to.")
    parser.add_argument('--out_pred',
                        default='out_pred.txt',
                        type=str,
                        help="File to write predictions to.")
    parser.add_argument('--log_interval',
                        default=100,
                        type=int,
                        help="interval for logging eta, training loss, etc.")
    parser.add_argument('--device',
                        default='/cpu:0',
                        type=str,
                        help='device setting for tensorflow')
    parser.add_argument('--lowercase',
                        action='store_true',
                        help='lowercase texts.')
    parser.add_argument('--seed', default=325, type=int, help="Seed for rngs.")
    parser.add_argument('--answer_size',
                        default=3,
                        type=int,
                        help=("How many answer does the output have. Used for "
                              "classification."))
    parser.add_argument(
        '--max_support_length',
        default=-1,
        type=int,
        help=
        "How large the support should be. Can be used for cutting or filtering QA examples."
    )

    args = parser.parse_args()

    # make everything deterministic
    random.seed(args.seed)
    tf.set_random_seed(args.seed)

    clip_value = None
    if args.clip_value != 0.0:
        clip_value = -abs(args.clip_value), abs(args.clip_value)

    logger.info('configuration:')
    for arg in vars(args):
        logger.info('\t{} : {}'.format(str(arg), str(getattr(args, arg))))

    # Get information about available CPUs and GPUs:
    # to set specific device, add CUDA_VISIBLE_DEVICES environment variable, e.g.
    # $ CUDA_VISIBLE_DEVICES=0 ./jtr_script.py

    logger.info('available devices:')
    for device in device_lib.list_local_devices():
        logger.info('device info: ' + str(device).replace("\n", " "))

    if args.debug:
        train_data = load_labelled_data(args.train, args.debug_examples,
                                        **vars(args))

        logger.info(
            'loaded {} samples as debug train/dev/test dataset '.format(
                args.debug_examples))

        dev_data = train_data
        test_data = train_data
        if args.pretrain:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(path.join('data', 'GloVe', emb_file),
                                         'glove')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))
            args.repr_dim_input = embeddings.lookup.shape[1]
        else:
            embeddings = Embeddings(None, None)
    else:
        train_data, dev_data = [
            load_labelled_data(name, **vars(args))
            for name in [args.train, args.dev]
        ]
        test_data = load_labelled_data(args.test, **
                                       vars(args)) if args.test else None
        logger.info('loaded train/dev/test data')
        if args.pretrain:
            embeddings = load_embeddings(args.embedding_file,
                                         args.embedding_format)
            logger.info('loaded pre-trained embeddings ({})'.format(
                args.embedding_file))
            args.repr_dim_input = embeddings.lookup.shape[1]
        else:
            embeddings = Embeddings(None, None)

    emb = embeddings

    vocab = Vocab(emb=emb, init_from_embeddings=args.vocab_from_embeddings)

    with tf.device(args.device):
        # build JTReader
        checkpoint()
        reader = readers.readers[args.model](vocab, vars(args))
        checkpoint()

        learning_rate = tf.get_variable("learning_rate",
                                        initializer=args.learning_rate,
                                        dtype=tf.float32,
                                        trainable=False)
        lr_decay_op = learning_rate.assign(args.learning_rate_decay *
                                           learning_rate)
        optim = tf.train.AdamOptimizer(learning_rate)

        if args.tensorboard_folder is not None:
            if os.path.exists(args.tensorboard_folder):
                shutil.rmtree(args.tensorboard_folder)
            sw = tf.summary.FileWriter(args.tensorboard_folder)
        else:
            sw = None

        # Hooks
        iter_interval = 1 if args.debug else args.log_interval
        hooks = [
            LossHook(reader, iter_interval, summary_writer=sw),
            ExamplesPerSecHook(reader, args.batch_size, iter_interval, sw),
            ETAHook(reader, iter_interval,
                    math.ceil(len(train_data) / args.batch_size), args.epochs,
                    args.checkpoint, sw)
        ]

        preferred_metric, best_metric = readers.eval_hooks[
            args.model].preferred_metric_and_best_score()

        def side_effect(metrics, prev_metric):
            """Returns: a state (in this case a metric) that is used as input for the next call"""
            m = metrics[preferred_metric]
            if prev_metric is not None and m < prev_metric:
                reader.sess.run(lr_decay_op)
                logger.info("Decayed learning rate to: %.5f" %
                            reader.sess.run(learning_rate))
            elif m > best_metric[0]:
                best_metric[0] = m
                if prev_metric is None:  # store whole model only at beginning of training
                    reader.store(args.model_dir)
                else:
                    reader.model_module.store(
                        reader.sess,
                        os.path.join(args.model_dir, "model_module"))
                logger.info("Saving model to: %s" % args.model_dir)
            return m

        # this is the standard hook for the model
        hooks.append(readers.eval_hooks[args.model](
            reader,
            dev_data,
            summary_writer=sw,
            side_effect=side_effect,
            iter_interval=args.checkpoint,
            epoch_interval=(1 if args.checkpoint is None else None),
            write_metrics_to=args.write_metrics_to))

        # Train
        reader.train(optim,
                     training_set=train_data,
                     max_epochs=args.epochs,
                     hooks=hooks,
                     l2=args.l2,
                     clip=clip_value,
                     clip_op=tf.clip_by_value,
                     device=args.device)

        # Test final model
        if test_data is not None:
            test_eval_hook = readers.eval_hooks[args.model](
                reader,
                test_data,
                summary_writer=sw,
                epoch_interval=1,
                write_metrics_to=args.write_metrics_to)

            reader.load(args.model_dir)
            test_eval_hook.at_test_time(1)
def main():
    t0 = time()

    # (1) Defined JTR models
    # Please add new models to models.__models__ when they work
    reader_models = {
        model_name: models.get_function(model_name)
        for model_name in models.__models__
    }

    support_alts = {'none', 'single', 'multiple'}
    question_alts = answer_alts = {'single', 'multiple'}
    candidate_alts = {'open', 'per-instance', 'fixed'}

    train_default = dev_default = test_default = '../tests/test_data/sentihood/overfit.json'

    # (2) Parse the input arguments
    parser = argparse.ArgumentParser(
        description='Train and Evaluate a Machine Reader')

    parser.add_argument(
        '--debug',
        action='store_true',
        help=
        "Run in debug mode, in which case the training file is also used for testing"
    )
    parser.add_argument(
        '--debug_examples',
        default=10,
        type=int,
        help="If in debug mode, how many examples should be used (default 2000)"
    )
    parser.add_argument('--train',
                        default=train_default,
                        type=argparse.FileType('r'),
                        help="jtr training file")
    parser.add_argument('--dev',
                        default=dev_default,
                        type=argparse.FileType('r'),
                        help="jtr dev file")
    parser.add_argument('--test',
                        default=test_default,
                        type=argparse.FileType('r'),
                        help="jtr test file")
    parser.add_argument(
        '--supports',
        default='single',
        choices=sorted(support_alts),
        help=
        "None, single (default) or multiple supporting statements per instance; "
        "multiple_flat reads multiple instances creates a separate instance for every support"
    )
    parser.add_argument(
        '--questions',
        default='single',
        choices=sorted(question_alts),
        help="None, single (default), or multiple questions per instance")
    parser.add_argument(
        '--candidates',
        default='fixed',
        choices=sorted(candidate_alts),
        help="Open, per-instance, or fixed (default) candidates")
    parser.add_argument('--answers',
                        default='single',
                        choices=sorted(answer_alts),
                        help="Single or multiple")
    parser.add_argument('--batch_size',
                        default=128,
                        type=int,
                        help="Batch size for training data, default 128")
    parser.add_argument('--dev_batch_size',
                        default=128,
                        type=int,
                        help="Batch size for development data, default 128")
    parser.add_argument(
        '--repr_dim_input',
        default=300,
        type=int,
        help="Size of the input representation (embeddings),"
        "default 100 (embeddings cut off or extended if not matched with pretrained embeddings)"
    )
    parser.add_argument(
        '--repr_dim_input_trf',
        default=100,
        type=int,
        help=
        "Size of the input embeddings after reducing with fully_connected layer (default 100)"
    )
    parser.add_argument('--repr_dim_output',
                        default=100,
                        type=int,
                        help="Size of the output representation, default 100")

    parser.add_argument(
        '--pretrain',
        action='store_true',
        help=
        "Use pretrained embeddings, by default the initialisation is random")
    parser.add_argument(
        '--train_pretrain',
        action='store_true',
        help=
        "Continue training pretrained embeddings together with model parameters"
    )
    parser.add_argument(
        '--normalize_pretrain',
        action='store_true',
        help="Normalize pretrained embeddings, default False "
        "(randomly initialized embeddings have expected unit norm too)")

    parser.add_argument('--vocab_maxsize', default=sys.maxsize, type=int)
    parser.add_argument('--vocab_minfreq', default=2, type=int)
    parser.add_argument(
        '--vocab_sep',
        default=True,
        type=bool,
        help='Should there be separate vocabularies for questions and supports, '
        'vs. candidates and answers. This needs to be set to True for candidate-based methods.'
    )
    parser.add_argument('--model',
                        default='bicond_singlesupport_reader',
                        choices=sorted(reader_models.keys()),
                        help="Reading model to use")
    parser.add_argument('--learning_rate',
                        default=0.001,
                        type=float,
                        help="Learning rate, default 0.001")
    parser.add_argument('--l2',
                        default=0.0,
                        type=float,
                        help="L2 regularization weight, default 0.0")
    parser.add_argument(
        '--clip_value',
        default=None,
        type=float,
        help=
        "Gradients clipped between [-clip_value, clip_value] (default: no clipping)"
    )
    parser.add_argument(
        '--drop_keep_prob',
        default=1.0,
        type=float,
        help=
        "Keep probability for dropout on output (set to 1.0 for no dropout)")
    parser.add_argument('--epochs',
                        default=5,
                        type=int,
                        help="Number of epochs to train for, default 5")

    parser.add_argument('--tokenize',
                        dest='tokenize',
                        action='store_true',
                        help="Tokenize question and support")
    parser.add_argument('--no-tokenize',
                        dest='tokenize',
                        action='store_false',
                        help="Tokenize question and support")
    parser.set_defaults(tokenize=True)
    parser.add_argument('--lowercase',
                        dest='lowercase',
                        action='store_true',
                        help="Lowercase data")

    parser.add_argument(
        '--negsamples',
        default=0,
        type=int,
        help="Number of negative samples, default 0 (= use full candidate list)"
    )
    parser.add_argument('--tensorboard_folder',
                        default='./.tb/',
                        help='Folder for tensorboard logs')
    parser.add_argument('--write_metrics_to',
                        default=None,
                        type=str,
                        help='Filename to log the metrics of the EvalHooks')
    parser.add_argument(
        '--prune',
        default='False',
        help='If the vocabulary should be pruned to the most frequent words.')
    parser.add_argument('--seed', default=1337, type=int, help='random seed')
    parser.add_argument('--logfile', default=None, type=str, help='log file')

    args = parser.parse_args()

    clip_range = (-abs(args.clip_value),
                  abs(args.clip_value)) if args.clip_value else None

    if args.logfile:
        fh = logging.FileHandler(args.logfile)
        fh.setLevel(logging.INFO)
        fh.setFormatter(
            logging.Formatter('%(levelname)s:%(name)s:\t%(message)s'))
        logger.addHandler(fh)

    logger.info('Configuration:')
    for arg in vars(args):
        logger.info('\t{} : {}'.format(str(arg), str(getattr(args, arg))))

    # set random seed
    tf.set_random_seed(args.seed)
    DefaultRandomState(args.seed)

    # Get information about available CPUs and GPUs:
    # to set specific device, add CUDA_VISIBLE_DEVICES environment variable, e.g.
    # $ CUDA_VISIBLE_DEVICES=0 ./jtr_script.py

    logger.info('available devices:')
    for l in device_lib.list_local_devices():
        logger.info('device info: ' + str(l).replace("\n", " "))

    # (3) Read the train, dev, and test data (with optionally loading pre-trained embeddings
    embeddings = None
    train_data, dev_data, test_data = None, None, None

    if args.debug:
        train_data = jtr_load(args.train, args.debug_examples, **vars(args))
        dev_data, test_data = train_data, train_data

        logger.info(
            'Loaded {} samples as debug train/dev/test dataset '.format(
                args.debug_examples))

        if args.pretrain:
            emb_file = 'glove.6B.50d.txt'
            embeddings = load_embeddings(
                path.join('jtr', 'data', 'GloVe', emb_file), 'glove')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))
    else:
        if args.train:
            train_data = jtr_load(args.train, **vars(args))

        if args.dev:
            dev_data = jtr_load(args.dev, **vars(args))

        if args.test:
            test_data = jtr_load(args.test, **vars(args))

        logger.info('loaded train/dev/test data')
        if args.pretrain:
            emb_file = 'GoogleNews-vectors-negative300.bin.gz'
            embeddings = load_embeddings(
                path.join('jtr', 'data', 'SG_GoogleNews', emb_file),
                'word2vec')
            logger.info('loaded pre-trained embeddings ({})'.format(emb_file))

    emb = embeddings.get if args.pretrain else None

    checkpoint()

    #  (4) Preprocesses the data (tokenize, normalize, add
    #  start and end of sentence tags) via the JTR pipeline method

    if args.vocab_minfreq != 0 and args.vocab_maxsize != 0:
        logger.info('build vocab based on train data')
        _, train_vocab, train_answer_vocab, train_candidate_vocab = pipeline(
            train_data,
            normalize=True,
            sepvocab=args.vocab_sep,
            tokenization=args.tokenize,
            lowercase=args.lowercase,
            emb=emb)
        if args.prune == 'True':
            train_vocab = train_vocab.prune(args.vocab_minfreq,
                                            args.vocab_maxsize)

        logger.info('encode train data')
        train_data, _, _, _ = pipeline(train_data,
                                       train_vocab,
                                       train_answer_vocab,
                                       train_candidate_vocab,
                                       normalize=True,
                                       freeze=True,
                                       sepvocab=args.vocab_sep,
                                       tokenization=args.tokenize,
                                       lowercase=args.lowercase,
                                       negsamples=args.negsamples)
    else:
        train_data, train_vocab, train_answer_vocab, train_candidate_vocab = pipeline(
            train_data,
            emb=emb,
            normalize=True,
            tokenization=args.tokenize,
            lowercase=args.lowercase,
            negsamples=args.negsamples,
            sepvocab=args.vocab_sep)

    N_oov = train_vocab.count_oov()
    N_pre = train_vocab.count_pretrained()
    logger.info(
        'In Training data vocabulary: {} pre-trained, {} out-of-vocab.'.format(
            N_pre, N_oov))

    vocab_size = len(train_vocab)
    answer_size = len(train_answer_vocab)

    # this is a bit of a hack since args are supposed to be user-defined,
    # but it's cleaner that way with passing on args to reader models
    parser.add_argument('--vocab_size', default=vocab_size, type=int)
    parser.add_argument('--answer_size', default=answer_size, type=int)
    args = parser.parse_args()

    checkpoint()
    logger.info('encode dev data')
    dev_data, _, _, _ = pipeline(dev_data,
                                 train_vocab,
                                 train_answer_vocab,
                                 train_candidate_vocab,
                                 freeze=True,
                                 tokenization=args.tokenize,
                                 lowercase=args.lowercase,
                                 sepvocab=args.vocab_sep)
    checkpoint()
    logger.info('encode test data')
    test_data, _, _, _ = pipeline(test_data,
                                  train_vocab,
                                  train_answer_vocab,
                                  train_candidate_vocab,
                                  freeze=True,
                                  tokenization=args.tokenize,
                                  lowercase=args.lowercase,
                                  sepvocab=args.vocab_sep)
    checkpoint()

    # (5) Create NeuralVocab
    logger.info('build NeuralVocab')
    nvocab = NeuralVocab(train_vocab,
                         input_size=args.repr_dim_input,
                         reduced_input_size=args.repr_dim_input_trf,
                         use_pretrained=args.pretrain,
                         train_pretrained=args.train_pretrain,
                         unit_normalize=args.normalize_pretrain)

    with tf.variable_scope("candvocab"):
        candvocab = NeuralVocab(train_candidate_vocab,
                                input_size=args.repr_dim_input,
                                reduced_input_size=args.repr_dim_input_trf,
                                use_pretrained=args.pretrain,
                                train_pretrained=args.train_pretrain,
                                unit_normalize=args.normalize_pretrain)
    checkpoint()

    # (6) Create TensorFlow placeholders and initialize model
    logger.info('create placeholders')
    placeholders = create_placeholders(train_data)
    logger.info('build model {}'.format(args.model))

    # add dropout on the model level
    # todo: more general solution
    options_train = vars(args)
    with tf.name_scope("Train"):
        with tf.variable_scope("Model", reuse=None):
            (logits_train, loss_train,
             predict_train) = reader_models[args.model](placeholders,
                                                        nvocab,
                                                        candvocab=candvocab,
                                                        **options_train)

    options_valid = {k: v for k, v in options_train.items()}
    options_valid["drop_keep_prob"] = 1.0
    with tf.name_scope("Valid_Test"):
        with tf.variable_scope("Model", reuse=True):
            (logits_valid, loss_valid,
             predict_valid) = reader_models[args.model](placeholders,
                                                        nvocab,
                                                        candvocab=candvocab,
                                                        **options_valid)

    # (7) Batch the data via jtr.batch.get_feed_dicts
    if args.supports != "none":
        # composite buckets; first over question, then over support
        bucket_order = ('question', 'support')
        # will result in 16 composite buckets, evenly spaced over questions and supports
        bucket_structure = (1, 1)  # (4, 4)
    else:
        # question buckets
        bucket_order = ('question', )
        # 4 buckets, evenly spaced over questions
        bucket_structure = (1, )  # (4,)

    train_feed_dicts = get_feed_dicts(train_data,
                                      placeholders,
                                      args.batch_size,
                                      bucket_order=bucket_order,
                                      bucket_structure=bucket_structure,
                                      exact_epoch=False)
    dev_feed_dicts = get_feed_dicts(dev_data,
                                    placeholders,
                                    args.dev_batch_size,
                                    exact_epoch=True)

    test_feed_dicts = get_feed_dicts(test_data,
                                     placeholders,
                                     args.dev_batch_size,
                                     exact_epoch=True)

    optim = tf.train.AdamOptimizer(args.learning_rate)

    sw = tf.summary.FileWriter(args.tensorboard_folder)

    answname = "targets" if "cands" in args.model else "answers"

    # (8) Add hooks
    hooks = [
        # report_loss
        LossHook(1, args.batch_size, summary_writer=sw),
        ExamplesPerSecHook(100, args.batch_size, summary_writer=sw),

        # evaluate on train data after each epoch
        EvalHook(train_feed_dicts,
                 logits_valid,
                 predict_valid,
                 placeholders[answname],
                 at_every_epoch=1,
                 metrics=['Acc', 'macroF1'],
                 print_details=False,
                 write_metrics_to=args.write_metrics_to,
                 info="training",
                 summary_writer=sw),

        # evaluate on dev data after each epoch
        EvalHook(dev_feed_dicts,
                 logits_valid,
                 predict_valid,
                 placeholders[answname],
                 at_every_epoch=1,
                 metrics=['Acc', 'macroF1'],
                 print_details=False,
                 write_metrics_to=args.write_metrics_to,
                 info="development",
                 summary_writer=sw),

        # evaluate on test data after training
        EvalHook(test_feed_dicts,
                 logits_valid,
                 predict_valid,
                 placeholders[answname],
                 at_every_epoch=args.epochs,
                 metrics=['Acc', 'macroP', 'macroR', 'macroF1'],
                 print_details=False,
                 write_metrics_to=args.write_metrics_to,
                 info="test")
    ]

    # (9) Train the model
    train(loss_train,
          optim,
          train_feed_dicts,
          max_epochs=args.epochs,
          l2=args.l2,
          clip=clip_range,
          hooks=hooks)
    logger.info('finished in {0:.3g}'.format((time() - t0) / 3600.))