예제 #1
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    vocab_list_orig = vocab_list
    binf2phone_np = None
    binf2phone = None
    mapping = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    phone_pred_key = 'sample_ids_phones_binf' if args.use_phones_from_binf else 'sample_ids'
    predict_keys = [phone_pred_key, 'embedding', 'alignment']
    if args.use_phones_from_binf:
        predict_keys.append('logits_binf')
        predict_keys.append('alignment_binf')

    audio, _ = librosa.load(args.waveform, sr=SAMPLE_RATE, mono=True)
    features = [calculate_acoustic_features(args, audio)]

    predictions = model.predict(
        input_fn=lambda: input_fn(features,
                                  args.vocab,
                                  args.norm,
                                  num_channels=features[0].shape[-1],
                                  batch_size=args.batch_size),
        predict_keys=predict_keys)
    predictions = list(predictions)
    for p in predictions:
        beams = p[phone_pred_key].T
        if len(beams.shape) > 1:
            i = beams[0]
        else:
            i = beams
        i = i.tolist() + [utils.EOS_ID]
        i = i[:i.index(utils.EOS_ID)]
        text = to_text(vocab_list, i)
        text = text.split(args.delimiter)
        print(text)
예제 #2
0
def main(args):
    eval_name = str(os.path.basename(args.data).split('.')[0])
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    binf2phone = None
    if hparams.decoder.binary_outputs:
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            run_name=eval_name)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    tf.logging.info('Evaluating on {}'.format(eval_name))
    model.evaluate(lambda: input_fn(args.data,
                                    args.vocab,
                                    args.norm,
                                    num_channels=args.num_channels,
                                    batch_size=args.batch_size,
                                    binf2phone=None),
                   name=eval_name)
예제 #3
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)
    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            transparent_projection=args.use_phones_from_binf)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    audio, _ = librosa.load(args.waveform, sr=SAMPLE_RATE, mono=True)
    features = [calculate_acoustic_features(args, audio)]

    predictions = model.predict(
        input_fn=lambda: input_fn(features, args.vocab, args.norm))
    predictions = list(predictions)
    for p in predictions:
        phone_pred_key = next(k for k in p.keys()
                              if k.startswith('sample_ids'))
        beams = p[phone_pred_key].T
        if len(beams.shape) > 1:
            i = beams[0]
        else:
            i = beams
        i = i.tolist() + [utils.EOS_ID]
        i = i[:i.index(utils.EOS_ID)]
        text = to_text(vocab_list, i)
        text = text.split(args.delimiter)
        for k in p.keys():
            print(f'{k}: {p[k].shape}')
        print(text)
    if args.output_file:
        dump(predictions, args.output_file)
        print(f'Predictions are saved to {args.output_file}')
예제 #4
0
파일: train.py 프로젝트: Hak333m/phones-las
def main(args):
    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    binf2phone = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(
        args, vocab_size, binf_count, utils.SOS_ID, utils.EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels,
        mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features, labels, mode, config, params,
            binf2phone=binf_map)

    model = tf.estimator.Estimator(
        model_fn=model_fn,
        config=config,
        params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda: input_fn(
                args.train, args.vocab, args.norm, num_channels=args.num_channels, batch_size=args.batch_size,
                num_epochs=args.num_epochs, binf2phone=None, num_parallel_calls=args.num_parallel_calls))

        eval_spec = tf.estimator.EvalSpec(
            input_fn=lambda: input_fn(
                args.valid or args.train, args.vocab, args.norm, num_channels=args.num_channels,
                batch_size=args.batch_size, binf2phone=None, num_parallel_calls=args.num_parallel_calls),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        model.train(
            input_fn=lambda: input_fn(
                args.train, args.vocab, args.norm, num_channels=args.num_channels, batch_size=args.batch_size,
                num_epochs=args.num_epochs, binf2phone=None, num_parallel_calls=args.num_parallel_calls))
예제 #5
0
def main(args):
    eval_name = str(os.path.basename(args.data).split('.')[0])
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    vocab_name = args.vocab if not args.t2t_format else os.path.join(
        args.data, 'vocab.txt')
    vocab_list = utils.load_vocab(vocab_name)
    binf2phone_np = None
    binf2phone = None
    if hparams.decoder.binary_outputs:
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            run_name=eval_name)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    tf.logging.info('Evaluating on {}'.format(eval_name))
    if args.t2t_format:
        input_fn = lambda: utils.input_fn_t2t(args.data,
                                              tf.estimator.ModeKeys.EVAL,
                                              hparams,
                                              args.t2t_problem_name,
                                              batch_size=args.batch_size,
                                              features_hparams_override=args.
                                              t2t_features_hparams_override)
    else:
        input_fn = lambda: utils.input_fn(args.data,
                                          args.vocab,
                                          args.norm,
                                          num_channels=args.num_channels,
                                          batch_size=args.batch_size)
    model.evaluate(input_fn, name=eval_name)
예제 #6
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    vocab_list_orig = vocab_list
    binf2phone_np = None
    binf2phone = None
    mapping = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    phone_pred_key = 'sample_ids_phones_binf' if args.use_phones_from_binf else 'sample_ids'
    predict_keys = [phone_pred_key, 'embedding', 'alignment']
    if args.use_phones_from_binf:
        predict_keys.append('logits_binf')
        predict_keys.append('alignment_binf')
    predictions = model.predict(
        input_fn=lambda: utils.input_fn(args.data,
                                        args.vocab,
                                        args.norm,
                                        num_channels=args.num_channels,
                                        batch_size=args.batch_size,
                                        take=args.take,
                                        is_infer=True),
        predict_keys=predict_keys)

    if args.calc_frame_binf_accuracy:
        with open(args.mapping_for_frame_accuracy, 'r') as fid:
            mapping_lines = fid.read().strip().split()
        mapping_targets = dict()
        for line in mapping_lines:
            phones = line.split('\t')
            if len(phones) < 3:
                mapping_targets[phones[0]] = None
            else:
                mapping_targets[phones[0]] = phones[-1]

    predictions = list(predictions)
    if args.plain_targets:
        targets = []
        for line in open(args.plain_targets, 'r'):
            delim = ','
            if '\t' in line:
                delim = '\t'
            cells = line.split(delim)
            sound, lang, phrase = cells[:3]
            if args.convert_targets_to_ipa:
                if len(cells) == 4:
                    phrase = cells[-1].split(',')
                else:
                    phrase = get_ipa(phrase, lang)
            else:
                phrase = phrase.split()
                phrase = [x.strip().lower() for x in phrase]
            if args.calc_frame_binf_accuracy:
                markup, binfs, nsamples = get_timit_binf_markup(
                    sound, binf2phone, mapping_targets)
                targets.append((phrase, markup, binfs, nsamples))
            else:
                targets.append(phrase)
        save_to = os.path.join(args.model_dir, 'infer_targets.txt')
        with open(save_to, 'w') as f:
            f.write('\n'.join(args.delimiter.join(t) for t in targets))
        err = 0
        tot = 0
        optimistic_err = 0
        if args.calc_frame_binf_accuracy:
            frames_count = 0
            correct_frames_count = np.zeros((len(binf2phone.index)))
        for p, target in tqdm(zip(predictions, targets)):
            first_text = []
            min_err = 100000
            beams = p[phone_pred_key].T
            if len(beams.shape) > 1:
                for bi, i in enumerate(p['sample_ids'].T):
                    i = i.tolist() + [utils.EOS_ID]
                    i = i[:i.index(utils.EOS_ID)]
                    text = to_text(vocab_list, i)
                    text = text.split(args.delimiter)
                    min_err = min([min_err, edist(text, t)])
                    if bi == 0:
                        first_text = text.copy()
                i = first_text
            else:
                i = beams.tolist() + [utils.EOS_ID]
                i = i[:i.index(utils.EOS_ID)]
                text = to_text(vocab_list, i)
                first_text = text.split(args.delimiter)

            t = target[0] if args.calc_frame_binf_accuracy else target
            if mapping is not None:
                target_ids = np.array([vocab_list_orig.index(p) for p in t])
                target_ids = np.array(mapping)[target_ids]
                t = [vocab_list[i] for i in target_ids]
            err += edist(first_text, t)
            optimistic_err += min_err
            tot += len(t)

            if args.calc_frame_binf_accuracy:
                attention = p['alignment'][:len(first_text), :]
                binf_preds = None
                if args.use_phones_from_binf:
                    logits = p['logits_binf'][:-1, :]
                    binf_preds = np.round(1 / (1 + np.exp(-logits)))
                    attention = p['alignment_binf'][:binf_preds.shape[0], :]

                markup, binfs, nsamples = target[1:]
                markup = np.minimum(markup, nsamples / SAMPLE_RATE)
                markup_frames = librosa.time_to_frames(
                    markup, SAMPLE_RATE,
                    args.encoder_frame_step * SAMPLE_RATE / 1000, WIN_LEN)
                markup_frames[markup_frames < 0] = 0
                markup_frames_binf = get_binf_markup_frames(
                    markup_frames, binfs)

                if not args.use_markup_segments:
                    alignment, _ = attention_to_segments(attention)
                    pred_frames_binf = segs_phones_to_frame_binf(
                        alignment, first_text, binf2phone, binf_preds)
                else:
                    binfs_pred = segments_to_attention(attention,
                                                       markup_frames,
                                                       first_text, binf2phone,
                                                       binf_preds)
                    pred_frames_binf = get_binf_markup_frames(
                        markup_frames, binfs_pred)

                if pred_frames_binf.shape[0] != markup_frames_binf.shape[0]:
                    print(
                        'Warining: sound {} prediction frames {} target frames {}'
                        .format(t, pred_frames_binf.shape[0],
                                markup_frames_binf.shape[0]))
                    nframes_fixed = min(pred_frames_binf.shape[0],
                                        markup_frames_binf.shape[0])
                    pred_frames_binf = pred_frames_binf[:nframes_fixed, :]
                    markup_frames_binf = markup_frames_binf[:nframes_fixed, :]
                correct_frames_count += count_correct(pred_frames_binf,
                                                      markup_frames_binf)
                frames_count += markup_frames_binf.shape[0]

            # Compare binary feature vectors

        print(f'PER: {100 * err / tot:2.2f}%')
        print(f'Optimistic PER: {100 * optimistic_err / tot:2.2f}%')

        if args.calc_frame_binf_accuracy:
            df = pd.DataFrame({'correct': correct_frames_count / frames_count},
                              index=binf2phone.index)
            print(df)

    if args.beam_width > 0:
        predictions = [{
            'transcription':
            to_text(vocab_list, y[phone_pred_key][:, 0])
        } for y in predictions]
    else:
        predictions = [{
            'transcription': to_text(vocab_list, y[phone_pred_key])
        } for y in predictions]

    save_to = os.path.join(args.model_dir, 'infer.txt')
    with open(save_to, 'w') as f:
        f.write('\n'.join(p['transcription'] for p in predictions))

    save_to = os.path.join(args.model_dir, 'infer.dmp')
    dump(predictions, save_to)
예제 #7
0
    parser.add_argument('--n_jobs', help='Number of parallel jobs.', type=int, default=4)
    parser.add_argument('--targets', help='Determines targets type.', type=str,
                        choices=['words', 'phones', 'binary_features', 'chars'], default='words')
    parser.add_argument('--binf_map', help='Path to CSV with phonemes to binary features map',
                        type=str, default='misc/binf_map.csv')
    parser.add_argument('--remove_diacritics', help='Remove diacritics from IPA targets',
                        action='store_true')
    parser.add_argument('--split_diphthongs', help='Remove diacritics from IPA targets',
                        action='store_true')
    parser.add_argument('--start', help='Index of example to start from', type=int, default=0)
    parser.add_argument('--count', help='Maximal phrases count, -1 for all phrases', type=int, default=-1)
    parser.add_argument('--delimiter', help='CSV delimiter', type=str, default=',')
    args = parser.parse_args()

    if args.targets in ('phones', 'binary_features'):
        binf2phone = load_binf2phone(args.binf_map)
    if args.feature_type == 'lyon' or args.backend == 'speechpy':
        print('Forcing n_jobs = 1 for selected configuration.')
        args.n_jobs = 1
    print('Processing audio dataset from file {}.'.format(args.input_file))
    window = int(SAMPLE_RATE * args.window / 1000.0)
    step = int(SAMPLE_RATE * args.step / 1000.0)
    lines = open(args.input_file, 'r').readlines()

    count = len(lines) - args.start
    if args.count > 0 and args.count < len(lines):
        count == args.count
    lines = lines[args.start:count+args.start]

    par_handle = tqdm(unit='sound')
    with tf.io.TFRecordWriter(args.output_file) as writer:
예제 #8
0
파일: train.py 프로젝트: madved/phones-las
def main(args):
    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    if args.tpu_name:
        iterations_per_loop = 100
        tpu_cluster_resolver = None
        if args.tpu_name != 'fake':
            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
                args.tpu_name)
        config = tf.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=args.model_dir,
            save_checkpoints_steps=max(600, iterations_per_loop),
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=iterations_per_loop,
                per_host_input_for_training=tf.estimator.tpu.
                InputPipelineConfig.PER_HOST_V2))
    else:
        config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args, vocab_size, binf_count, utils.SOS_ID,
                                   utils.EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels, mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf_map)

    if args.tpu_name:
        model = tf.estimator.tpu.TPUEstimator(model_fn=model_fn,
                                              config=config,
                                              params=hparams,
                                              eval_on_tpu=False,
                                              train_batch_size=args.batch_size,
                                              use_tpu=args.tpu_name != 'fake')
    else:
        model = tf.estimator.Estimator(model_fn=model_fn,
                                       config=config,
                                       params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(input_fn=lambda params: input_fn(
            args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            num_epochs=args.num_epochs,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                                            max_steps=args.num_epochs * 1000 *
                                            args.batch_size)

        eval_spec = tf.estimator.EvalSpec(input_fn=lambda params: input_fn(
            args.valid or args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                                          start_delay_secs=60,
                                          throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        tf.logging.warning('Training without evaluation!')
        model.train(input_fn=lambda params: input_fn(
            args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            num_epochs=args.num_epochs,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                    steps=args.num_epochs * 1000 * args.batch_size)
예제 #9
0
def main(args):
    train_dir = os.path.dirname(args.train)
    vocab_name = os.path.join(train_dir, 'vocab.txt')
    norm_name = os.path.join(train_dir, 'norm.dmp')
    vocab_list = utils.load_vocab(vocab_name)
    binf2phone_np = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, vocab_name)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    if args.tpu_name:
        iterations_per_loop = 100
        tpu_cluster_resolver = None
        if args.tpu_name != 'fake':
            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu_name)
        config = tf.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=args.model_dir,
            save_checkpoints_steps=max(args.tpu_checkpoints_interval, iterations_per_loop),
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=iterations_per_loop,
                per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2))
    else:
        config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(
        args, vocab_size, binf_count, utils.SOS_ID if not args.t2t_format else PAD_ID,
        utils.EOS_ID if not args.t2t_format else EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels,
        mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features, labels, mode, config, params,
            binf2phone=binf_map)

    if args.tpu_name:
        model = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn, config=config, params=hparams, eval_on_tpu=False,
            train_batch_size=args.batch_size, use_tpu=args.tpu_name != 'fake'
        )
    else:
        model = tf.estimator.Estimator(
            model_fn=model_fn,
            config=config,
            params=hparams)
    def create_input_fn(mode):
        if args.t2t_format:
            return lambda params: utils.input_fn_t2t(args.train, mode, hparams,
                args.t2t_problem_name,
                batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
                num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
                num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
                max_frames=args.max_frames, max_symbols=args.max_symbols,
                features_hparams_override=args.t2t_features_hparams_override)
        else:
            return lambda params: utils.input_fn(
                args.valid if mode == tf.estimator.ModeKeys.EVAL and args.valid else args.train,
                vocab_name, norm_name,
                num_channels=args.num_channels if args.num_channels is not None else hparams.get_hparam('num_channels'),
                batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
                num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
                num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
                max_frames=args.max_frames, max_symbols=args.max_symbols)


    if args.valid or args.t2t_format:
        train_spec = tf.estimator.TrainSpec(
            input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
            max_steps=args.num_epochs * 1000 * args.batch_size
        )

        eval_spec = tf.estimator.EvalSpec(
            input_fn=create_input_fn(tf.estimator.ModeKeys.EVAL),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        tf.logging.warning('Training without evaluation!')
        model.train(
            input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
            steps=args.num_epochs * 1000 * args.batch_size
        )