예제 #1
0
def main(args):
    eval_name = str(os.path.basename(args.data).split('.')[0])
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    binf2phone = None
    if hparams.decoder.binary_outputs:
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            run_name=eval_name)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    tf.logging.info('Evaluating on {}'.format(eval_name))
    model.evaluate(lambda: input_fn(args.data,
                                    args.vocab,
                                    args.norm,
                                    num_channels=args.num_channels,
                                    batch_size=args.batch_size,
                                    binf2phone=None),
                   name=eval_name)
예제 #2
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    vocab_list_orig = vocab_list
    binf2phone_np = None
    binf2phone = None
    mapping = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    phone_pred_key = 'sample_ids_phones_binf' if args.use_phones_from_binf else 'sample_ids'
    predict_keys = [phone_pred_key, 'embedding', 'alignment']
    if args.use_phones_from_binf:
        predict_keys.append('logits_binf')
        predict_keys.append('alignment_binf')

    audio, _ = librosa.load(args.waveform, sr=SAMPLE_RATE, mono=True)
    features = [calculate_acoustic_features(args, audio)]

    predictions = model.predict(
        input_fn=lambda: input_fn(features,
                                  args.vocab,
                                  args.norm,
                                  num_channels=features[0].shape[-1],
                                  batch_size=args.batch_size),
        predict_keys=predict_keys)
    predictions = list(predictions)
    for p in predictions:
        beams = p[phone_pred_key].T
        if len(beams.shape) > 1:
            i = beams[0]
        else:
            i = beams
        i = i.tolist() + [utils.EOS_ID]
        i = i[:i.index(utils.EOS_ID)]
        text = to_text(vocab_list, i)
        text = text.split(args.delimiter)
        print(text)
예제 #3
0
def main(args):

    vocab_list = np.array(utils.load_vocab(args.vocab))

    vocab_size = len(vocab_list)

    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(
        args, vocab_size, utils.SOS_ID, utils.EOS_ID)

    hparams.decoder.set_hparam('beam_width', args.beam_width)

    model = tf.estimator.Estimator(
        model_fn=las_model_fn,
        config=config,
        params=hparams)

    predictions = model.predict(
        input_fn=lambda: input_fn(
            args.data, args.vocab, num_channels=args.num_channels, batch_size=args.batch_size, num_epochs=1),
        predict_keys='sample_ids')

    if args.beam_width > 0:
        predictions = [vocab_list[y['sample_ids'][:, 0]].tolist() + [utils.EOS]
                       for y in predictions]
    else:
        predictions = [vocab_list[y['sample_ids']].tolist() + [utils.EOS]
                       for y in predictions]

    predictions = [' '.join(y[:y.index(utils.EOS)]) for y in predictions]

    with open(args.save, 'w') as f:
        f.write('\n'.join(predictions))
예제 #4
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)
    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            transparent_projection=args.use_phones_from_binf)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    audio, _ = librosa.load(args.waveform, sr=SAMPLE_RATE, mono=True)
    features = [calculate_acoustic_features(args, audio)]

    predictions = model.predict(
        input_fn=lambda: input_fn(features, args.vocab, args.norm))
    predictions = list(predictions)
    for p in predictions:
        phone_pred_key = next(k for k in p.keys()
                              if k.startswith('sample_ids'))
        beams = p[phone_pred_key].T
        if len(beams.shape) > 1:
            i = beams[0]
        else:
            i = beams
        i = i.tolist() + [utils.EOS_ID]
        i = i[:i.index(utils.EOS_ID)]
        text = to_text(vocab_list, i)
        text = text.split(args.delimiter)
        for k in p.keys():
            print(f'{k}: {p[k].shape}')
        print(text)
    if args.output_file:
        dump(predictions, args.output_file)
        print(f'Predictions are saved to {args.output_file}')
예제 #5
0
파일: train.py 프로젝트: Hak333m/phones-las
def main(args):
    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    binf2phone = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(
        args, vocab_size, binf_count, utils.SOS_ID, utils.EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels,
        mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features, labels, mode, config, params,
            binf2phone=binf_map)

    model = tf.estimator.Estimator(
        model_fn=model_fn,
        config=config,
        params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda: input_fn(
                args.train, args.vocab, args.norm, num_channels=args.num_channels, batch_size=args.batch_size,
                num_epochs=args.num_epochs, binf2phone=None, num_parallel_calls=args.num_parallel_calls))

        eval_spec = tf.estimator.EvalSpec(
            input_fn=lambda: input_fn(
                args.valid or args.train, args.vocab, args.norm, num_channels=args.num_channels,
                batch_size=args.batch_size, binf2phone=None, num_parallel_calls=args.num_parallel_calls),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        model.train(
            input_fn=lambda: input_fn(
                args.train, args.vocab, args.norm, num_channels=args.num_channels, batch_size=args.batch_size,
                num_epochs=args.num_epochs, binf2phone=None, num_parallel_calls=args.num_parallel_calls))
예제 #6
0
def main(args):
    eval_name = str(os.path.basename(args.data).split('.')[0])
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    vocab_name = args.vocab if not args.t2t_format else os.path.join(
        args.data, 'vocab.txt')
    vocab_list = utils.load_vocab(vocab_name)
    binf2phone_np = None
    binf2phone = None
    if hparams.decoder.binary_outputs:
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np,
                            run_name=eval_name)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    tf.logging.info('Evaluating on {}'.format(eval_name))
    if args.t2t_format:
        input_fn = lambda: utils.input_fn_t2t(args.data,
                                              tf.estimator.ModeKeys.EVAL,
                                              hparams,
                                              args.t2t_problem_name,
                                              batch_size=args.batch_size,
                                              features_hparams_override=args.
                                              t2t_features_hparams_override)
    else:
        input_fn = lambda: utils.input_fn(args.data,
                                          args.vocab,
                                          args.norm,
                                          num_channels=args.num_channels,
                                          batch_size=args.batch_size)
    model.evaluate(input_fn, name=eval_name)
예제 #7
0
def infer():
    tf.logging.set_verbosity(tf.logging.INFO)
    args = parse_args_infer()
    args.model_dir = MODEL_DIR
    args.beam_width = 32
    args.data = TEST_TF
    args.save = INFER_RESULT
    args.vocab = VOCAB_TABLE
    vocab_list = np.array(utils.load_vocab(args.vocab))
    vocab_size = len(vocab_list)

    conf = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args, vocab_size, utils.SOS_ID,
                                   utils.EOS_ID)
    hparams.decoder.set_hparam('beam_width', args.beam_width)

    model = tf.estimator.Estimator(model_fn=las_model_fn,
                                   config=conf,
                                   params=hparams)
    predictions = model.predict(input_fn=lambda: input_fn_infer(
        args.data,
        args.vocab,
        num_channels=args.num_channels,
        batch_size=args.batch_size,
        num_epochs=1,
    ),
                                predict_keys='sample_ids')
    if args.beam_width > 0:
        predictions = [
            vocab_list[y['sample_ids'][:, 0]].tolist() + [utils.EOS]
            for y in predictions
        ]
    else:
        predictions = [
            vocab_list[y['sample_ids']].tolist() + [utils.EOS]
            for y in predictions
        ]

    predictions = [' '.join(y[:y.index(utils.EOS)]) for y in predictions]
    with open(args.save, 'w') as f:
        f.write('\n'.join(predictions))
예제 #8
0
def train():
    tf.logging.set_verbosity(tf.logging.INFO)
    args = parse_args()
    args.train = TRAIN_TF
    args.valid = TEST_TF
    args.vocab = VOCAB_TABLE
    args.model_dir = MODEL_DIR
    vocab_list = utils.load_vocab(args.vocab)
    vocab_size = len(vocab_list)

    conf = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args, vocab_size, utils.SOS_ID,
                                   utils.EOS_ID)

    model = tf.estimator.Estimator(model_fn=las_model_fn,
                                   config=conf,
                                   params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda: input_fn(args.train,
                                      args.vocab,
                                      num_channels=args.num_channels,
                                      batch_size=args.batch_size,
                                      num_epochs=args.num_epochs))
        eval_spec = tf.estimator.EvalSpec(
            input_fn=lambda: input_fn(args.valid or args.train,
                                      args.vocab,
                                      num_channels=args.num_channels,
                                      batch_size=args.batch_size),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        model.train(input_fn=lambda: input_fn(args.train,
                                              args.vocab,
                                              num_channels=args.num_channels,
                                              batch_size=args.batch_size,
                                              num_epochs=args.num_epochs))
예제 #9
0
def main(args):

    vocab_list = utils.load_vocab(args.vocab)

    vocab_size = len(vocab_list)

    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args, vocab_size, utils.SOS_ID,
                                   utils.EOS_ID)

    model = tf.estimator.Estimator(model_fn=las_model_fn,
                                   config=config,
                                   params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(
            input_fn=lambda: input_fn(args.train,
                                      args.vocab,
                                      num_channels=args.num_channels,
                                      batch_size=args.batch_size,
                                      num_epochs=args.num_epochs))

        eval_spec = tf.estimator.EvalSpec(
            input_fn=lambda: input_fn(args.valid or args.train,
                                      args.vocab,
                                      num_channels=args.num_channels,
                                      batch_size=args.batch_size),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        model.train(input_fn=lambda: input_fn(args.train,
                                              args.vocab,
                                              num_channels=args.num_channels,
                                              batch_size=args.batch_size,
                                              num_epochs=args.num_epochs))
예제 #10
0
    return serving_input_receiver_fn


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--model_dir',
                        type=str,
                        required=True,
                        help='path to model')
    parser.add_argument('--num_channels',
                        type=int,
                        required=True,
                        help='number of input channels')
    parser.add_argument('--export_dir',
                        type=str,
                        required=True,
                        help='path where to save exported model')
    args = parser.parse_args()

    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args,
                                   sos_id=utils.SOS_ID,
                                   eos_id=utils.EOS_ID)

    model = tf.estimator.Estimator(model_fn=export_las_model_fn,
                                   config=config,
                                   params=hparams)
    model.export_saved_model(args.export_dir,
                             serving_input_receiver_fn=serving_input_factory(
                                 args.num_channels))
예제 #11
0
    #print('分词前:',inputs[:10])
    #print('分词前:',outputs[:10])
    inputs = cn_segment(inputs)
    outputs = en_segment(outputs)
    #print('分词后:',inputs[:10])
    #print('分词后:',outputs[:10])
    # print(outputs)

encoder_vocab, decoder_vocab = make_vocab(inputs, outputs)
print('\n-----------vocab have made-----------')

encoder_inputs, decoder_inputs, decoder_targets = data_format(
    inputs, outputs, encoder_vocab, decoder_vocab)

arg = create_hparams()
arg.input_vocab_size = len(encoder_vocab)
arg.label_vocab_size = len(decoder_vocab)
arg.epochs = epoch
arg.batch_size = batch_size

g = Graph(arg)

saver = tf.train.Saver()
with tf.Session() as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    add_num = 0
    if os.path.exists('model_self/checkpoint'):
        print('loading  model_self...')
        latest = tf.train.latest_checkpoint(
예제 #12
0
def main(args):
    config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args)

    hparams.decoder.set_hparam('beam_width', args.beam_width)

    vocab_list = utils.load_vocab(args.vocab)
    vocab_list_orig = vocab_list
    binf2phone_np = None
    binf2phone = None
    mapping = None
    if hparams.decoder.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            hparams.del_hparam('mapping')
            hparams.add_hparam('mapping', mapping)

        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf2phone_np = binf2phone.values

    def model_fn(features, labels, mode, config, params):
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf2phone_np)

    model = tf.estimator.Estimator(model_fn=model_fn,
                                   config=config,
                                   params=hparams)

    phone_pred_key = 'sample_ids_phones_binf' if args.use_phones_from_binf else 'sample_ids'
    predict_keys = [phone_pred_key, 'embedding', 'alignment']
    if args.use_phones_from_binf:
        predict_keys.append('logits_binf')
        predict_keys.append('alignment_binf')
    predictions = model.predict(
        input_fn=lambda: utils.input_fn(args.data,
                                        args.vocab,
                                        args.norm,
                                        num_channels=args.num_channels,
                                        batch_size=args.batch_size,
                                        take=args.take,
                                        is_infer=True),
        predict_keys=predict_keys)

    if args.calc_frame_binf_accuracy:
        with open(args.mapping_for_frame_accuracy, 'r') as fid:
            mapping_lines = fid.read().strip().split()
        mapping_targets = dict()
        for line in mapping_lines:
            phones = line.split('\t')
            if len(phones) < 3:
                mapping_targets[phones[0]] = None
            else:
                mapping_targets[phones[0]] = phones[-1]

    predictions = list(predictions)
    if args.plain_targets:
        targets = []
        for line in open(args.plain_targets, 'r'):
            delim = ','
            if '\t' in line:
                delim = '\t'
            cells = line.split(delim)
            sound, lang, phrase = cells[:3]
            if args.convert_targets_to_ipa:
                if len(cells) == 4:
                    phrase = cells[-1].split(',')
                else:
                    phrase = get_ipa(phrase, lang)
            else:
                phrase = phrase.split()
                phrase = [x.strip().lower() for x in phrase]
            if args.calc_frame_binf_accuracy:
                markup, binfs, nsamples = get_timit_binf_markup(
                    sound, binf2phone, mapping_targets)
                targets.append((phrase, markup, binfs, nsamples))
            else:
                targets.append(phrase)
        save_to = os.path.join(args.model_dir, 'infer_targets.txt')
        with open(save_to, 'w') as f:
            f.write('\n'.join(args.delimiter.join(t) for t in targets))
        err = 0
        tot = 0
        optimistic_err = 0
        if args.calc_frame_binf_accuracy:
            frames_count = 0
            correct_frames_count = np.zeros((len(binf2phone.index)))
        for p, target in tqdm(zip(predictions, targets)):
            first_text = []
            min_err = 100000
            beams = p[phone_pred_key].T
            if len(beams.shape) > 1:
                for bi, i in enumerate(p['sample_ids'].T):
                    i = i.tolist() + [utils.EOS_ID]
                    i = i[:i.index(utils.EOS_ID)]
                    text = to_text(vocab_list, i)
                    text = text.split(args.delimiter)
                    min_err = min([min_err, edist(text, t)])
                    if bi == 0:
                        first_text = text.copy()
                i = first_text
            else:
                i = beams.tolist() + [utils.EOS_ID]
                i = i[:i.index(utils.EOS_ID)]
                text = to_text(vocab_list, i)
                first_text = text.split(args.delimiter)

            t = target[0] if args.calc_frame_binf_accuracy else target
            if mapping is not None:
                target_ids = np.array([vocab_list_orig.index(p) for p in t])
                target_ids = np.array(mapping)[target_ids]
                t = [vocab_list[i] for i in target_ids]
            err += edist(first_text, t)
            optimistic_err += min_err
            tot += len(t)

            if args.calc_frame_binf_accuracy:
                attention = p['alignment'][:len(first_text), :]
                binf_preds = None
                if args.use_phones_from_binf:
                    logits = p['logits_binf'][:-1, :]
                    binf_preds = np.round(1 / (1 + np.exp(-logits)))
                    attention = p['alignment_binf'][:binf_preds.shape[0], :]

                markup, binfs, nsamples = target[1:]
                markup = np.minimum(markup, nsamples / SAMPLE_RATE)
                markup_frames = librosa.time_to_frames(
                    markup, SAMPLE_RATE,
                    args.encoder_frame_step * SAMPLE_RATE / 1000, WIN_LEN)
                markup_frames[markup_frames < 0] = 0
                markup_frames_binf = get_binf_markup_frames(
                    markup_frames, binfs)

                if not args.use_markup_segments:
                    alignment, _ = attention_to_segments(attention)
                    pred_frames_binf = segs_phones_to_frame_binf(
                        alignment, first_text, binf2phone, binf_preds)
                else:
                    binfs_pred = segments_to_attention(attention,
                                                       markup_frames,
                                                       first_text, binf2phone,
                                                       binf_preds)
                    pred_frames_binf = get_binf_markup_frames(
                        markup_frames, binfs_pred)

                if pred_frames_binf.shape[0] != markup_frames_binf.shape[0]:
                    print(
                        'Warining: sound {} prediction frames {} target frames {}'
                        .format(t, pred_frames_binf.shape[0],
                                markup_frames_binf.shape[0]))
                    nframes_fixed = min(pred_frames_binf.shape[0],
                                        markup_frames_binf.shape[0])
                    pred_frames_binf = pred_frames_binf[:nframes_fixed, :]
                    markup_frames_binf = markup_frames_binf[:nframes_fixed, :]
                correct_frames_count += count_correct(pred_frames_binf,
                                                      markup_frames_binf)
                frames_count += markup_frames_binf.shape[0]

            # Compare binary feature vectors

        print(f'PER: {100 * err / tot:2.2f}%')
        print(f'Optimistic PER: {100 * optimistic_err / tot:2.2f}%')

        if args.calc_frame_binf_accuracy:
            df = pd.DataFrame({'correct': correct_frames_count / frames_count},
                              index=binf2phone.index)
            print(df)

    if args.beam_width > 0:
        predictions = [{
            'transcription':
            to_text(vocab_list, y[phone_pred_key][:, 0])
        } for y in predictions]
    else:
        predictions = [{
            'transcription': to_text(vocab_list, y[phone_pred_key])
        } for y in predictions]

    save_to = os.path.join(args.model_dir, 'infer.txt')
    with open(save_to, 'w') as f:
        f.write('\n'.join(p['transcription'] for p in predictions))

    save_to = os.path.join(args.model_dir, 'infer.dmp')
    dump(predictions, save_to)
예제 #13
0
파일: train.py 프로젝트: madved/phones-las
def main(args):
    vocab_list = utils.load_vocab(args.vocab)
    binf2phone_np = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, args.vocab)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    if args.tpu_name:
        iterations_per_loop = 100
        tpu_cluster_resolver = None
        if args.tpu_name != 'fake':
            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
                args.tpu_name)
        config = tf.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=args.model_dir,
            save_checkpoints_steps=max(600, iterations_per_loop),
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=iterations_per_loop,
                per_host_input_for_training=tf.estimator.tpu.
                InputPipelineConfig.PER_HOST_V2))
    else:
        config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(args, vocab_size, binf_count, utils.SOS_ID,
                                   utils.EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels, mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features,
                            labels,
                            mode,
                            config,
                            params,
                            binf2phone=binf_map)

    if args.tpu_name:
        model = tf.estimator.tpu.TPUEstimator(model_fn=model_fn,
                                              config=config,
                                              params=hparams,
                                              eval_on_tpu=False,
                                              train_batch_size=args.batch_size,
                                              use_tpu=args.tpu_name != 'fake')
    else:
        model = tf.estimator.Estimator(model_fn=model_fn,
                                       config=config,
                                       params=hparams)

    if args.valid:
        train_spec = tf.estimator.TrainSpec(input_fn=lambda params: input_fn(
            args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            num_epochs=args.num_epochs,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                                            max_steps=args.num_epochs * 1000 *
                                            args.batch_size)

        eval_spec = tf.estimator.EvalSpec(input_fn=lambda params: input_fn(
            args.valid or args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                                          start_delay_secs=60,
                                          throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        tf.logging.warning('Training without evaluation!')
        model.train(input_fn=lambda params: input_fn(
            args.train,
            args.vocab,
            args.norm,
            num_channels=args.num_channels,
            batch_size=params.batch_size
            if 'batch_size' in params else args.batch_size,
            num_epochs=args.num_epochs,
            binf2phone=None,
            num_parallel_calls=args.num_parallel_calls,
            max_frames=args.max_frames,
            max_symbols=args.max_symbols),
                    steps=args.num_epochs * 1000 * args.batch_size)
def main(argv):
    """Runs supervised wavefunction optimization.

  This pipeline optimizes wavefunction by matching amplitudes of a target state.

  """
    del argv  # Not used.

    supervisor_path = os.path.join(FLAGS.supervisor_dir, 'hparams.pbtxt')
    supervisor_hparams = utils.load_hparams(supervisor_path)

    hparams = utils.create_hparams()
    hparams.set_hparam('num_sites', supervisor_hparams.num_sites)
    hparams.set_hparam('checkpoint_dir', FLAGS.checkpoint_dir)
    hparams.set_hparam('supervisor_dir', FLAGS.supervisor_dir)
    hparams.set_hparam('basis_file_path', FLAGS.basis_file_path)
    hparams.set_hparam('num_epochs', FLAGS.num_epochs)
    hparams.set_hparam('wavefunction_type', FLAGS.wavefunction_type)
    hparams.parse(FLAGS.hparams)
    hparams_path = os.path.join(hparams.checkpoint_dir, 'hparams.pbtxt')

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    if os.path.exists(hparams_path) and not FLAGS.override:
        print('Hparams file already exists')
        exit()

    with tf.gfile.GFile(hparams_path, 'w') as file:
        file.write(str(hparams.to_proto()))

    target_wavefunction = wavefunctions.build_wavefunction(supervisor_hparams)
    wavefunction = wavefunctions.build_wavefunction(hparams)

    wavefunction_optimizer = training.SUPERVISED_OPTIMIZERS[FLAGS.optimizer]()

    shared_resources = {}

    graph_building_args = {
        'wavefunction': wavefunction,
        'target_wavefunction': target_wavefunction,
        'hparams': hparams,
        'shared_resources': shared_resources
    }

    train_ops = wavefunction_optimizer.build_opt_ops(**graph_building_args)

    session = tf.Session()
    init = tf.global_variables_initializer()
    init_l = tf.local_variables_initializer()
    session.run([init, init_l])

    target_saver = tf.train.Saver(
        target_wavefunction.get_trainable_variables())
    supervisor_checkpoint = tf.train.latest_checkpoint(FLAGS.supervisor_dir)
    target_saver.restore(session, supervisor_checkpoint)
    checkpoint_saver = tf.train.Saver(wavefunction.get_trainable_variables(),
                                      max_to_keep=5)

    if FLAGS.resume_training:
        latest_checkpoint = tf.train.latest_checkpoint(hparams.checkpoint_dir)
        checkpoint_saver.restore(session, latest_checkpoint)

    for epoch_number in range(FLAGS.num_epochs):
        wavefunction_optimizer.run_optimization_epoch(train_ops, session,
                                                      hparams, epoch_number)
        if epoch_number % FLAGS.checkpoint_frequency == 0:
            checkpoint_name = 'model_after_{}_epochs'.format(epoch_number)
            save_path = os.path.join(hparams.checkpoint_dir, checkpoint_name)
            checkpoint_saver.save(session, save_path)

    if FLAGS.generate_vectors:
        vector_generator = evaluation.VectorWavefunctionEvaluator()
        eval_ops = vector_generator.build_eval_ops(wavefunction, None, hparams,
                                                   shared_resources)
        vector_generator.run_evaluation(eval_ops, session, hparams,
                                        FLAGS.num_epochs)
예제 #15
0
def main(args):
    train_dir = os.path.dirname(args.train)
    vocab_name = os.path.join(train_dir, 'vocab.txt')
    norm_name = os.path.join(train_dir, 'norm.dmp')
    vocab_list = utils.load_vocab(vocab_name)
    binf2phone_np = None
    mapping = None
    vocab_size = len(vocab_list)
    binf_count = None
    if args.binary_outputs:
        if args.mapping is not None:
            vocab_list, mapping = utils.get_mapping(args.mapping, vocab_name)
            args.mapping = None
        binf2phone = utils.load_binf2phone(args.binf_map, vocab_list)
        binf_count = len(binf2phone.index)
        if args.output_ipa:
            binf2phone_np = binf2phone.values

    if args.tpu_name:
        iterations_per_loop = 100
        tpu_cluster_resolver = None
        if args.tpu_name != 'fake':
            tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(args.tpu_name)
        config = tf.estimator.tpu.RunConfig(
            cluster=tpu_cluster_resolver,
            model_dir=args.model_dir,
            save_checkpoints_steps=max(args.tpu_checkpoints_interval, iterations_per_loop),
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=iterations_per_loop,
                per_host_input_for_training=tf.estimator.tpu.InputPipelineConfig.PER_HOST_V2))
    else:
        config = tf.estimator.RunConfig(model_dir=args.model_dir)
    hparams = utils.create_hparams(
        args, vocab_size, binf_count, utils.SOS_ID if not args.t2t_format else PAD_ID,
        utils.EOS_ID if not args.t2t_format else EOS_ID)
    if mapping is not None:
        hparams.del_hparam('mapping')
        hparams.add_hparam('mapping', mapping)

    def model_fn(features, labels,
        mode, config, params):
        binf_map = binf2phone_np
        return las_model_fn(features, labels, mode, config, params,
            binf2phone=binf_map)

    if args.tpu_name:
        model = tf.estimator.tpu.TPUEstimator(
            model_fn=model_fn, config=config, params=hparams, eval_on_tpu=False,
            train_batch_size=args.batch_size, use_tpu=args.tpu_name != 'fake'
        )
    else:
        model = tf.estimator.Estimator(
            model_fn=model_fn,
            config=config,
            params=hparams)
    def create_input_fn(mode):
        if args.t2t_format:
            return lambda params: utils.input_fn_t2t(args.train, mode, hparams,
                args.t2t_problem_name,
                batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
                num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
                num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
                max_frames=args.max_frames, max_symbols=args.max_symbols,
                features_hparams_override=args.t2t_features_hparams_override)
        else:
            return lambda params: utils.input_fn(
                args.valid if mode == tf.estimator.ModeKeys.EVAL and args.valid else args.train,
                vocab_name, norm_name,
                num_channels=args.num_channels if args.num_channels is not None else hparams.get_hparam('num_channels'),
                batch_size=params.batch_size if 'batch_size' in params else args.batch_size,
                num_epochs=args.num_epochs if mode == tf.estimator.ModeKeys.TRAIN else 1,
                num_parallel_calls=64 if args.tpu_name and args.tpu_name != 'fake' else args.num_parallel_calls,
                max_frames=args.max_frames, max_symbols=args.max_symbols)


    if args.valid or args.t2t_format:
        train_spec = tf.estimator.TrainSpec(
            input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
            max_steps=args.num_epochs * 1000 * args.batch_size
        )

        eval_spec = tf.estimator.EvalSpec(
            input_fn=create_input_fn(tf.estimator.ModeKeys.EVAL),
            start_delay_secs=60,
            throttle_secs=args.eval_secs)

        tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
    else:
        tf.logging.warning('Training without evaluation!')
        model.train(
            input_fn=create_input_fn(tf.estimator.ModeKeys.TRAIN),
            steps=args.num_epochs * 1000 * args.batch_size
        )
예제 #16
0
def main(argv):
    """Runs wavefunction optimization.

  This pipeline optimizes wavefunction specified in flags on a Marshal sign
  included Heisenberg model. Bonds should be specified in the file J.txt in
  checkpoint directory, otherwise will default to 1D PBC system. For other
  tunable parameters see flags description.
  """
    del argv  # Not used.
    n_sites = FLAGS.num_sites
    hparams = utils.create_hparams()
    hparams.set_hparam('checkpoint_dir', FLAGS.checkpoint_dir)
    hparams.set_hparam('basis_file_path', FLAGS.basis_file_path)
    hparams.set_hparam('num_sites', FLAGS.num_sites)
    hparams.set_hparam('num_epochs', FLAGS.num_epochs)
    hparams.set_hparam('wavefunction_type', FLAGS.wavefunction_type)
    hparams.set_hparam('wavefunction_optimizer_type', FLAGS.optimizer)
    hparams.parse(FLAGS.hparams)
    hparams_path = os.path.join(hparams.checkpoint_dir, 'hparams.pbtxt')

    if not os.path.exists(FLAGS.checkpoint_dir):
        os.makedirs(FLAGS.checkpoint_dir)

    if os.path.exists(hparams_path) and not FLAGS.override:
        print('Hparams file already exists')
        exit()

    with tf.gfile.GFile(hparams_path, 'w') as file:
        file.write(str(hparams.to_proto()))

    bonds_file_path = os.path.join(FLAGS.checkpoint_dir, 'J.txt')
    heisenberg_jx = FLAGS.heisenberg_jx
    if os.path.exists(bonds_file_path):
        heisenberg_data = np.genfromtxt(bonds_file_path, dtype=int)
        heisenberg_bonds = [[bond[0], bond[1]] for bond in heisenberg_data]
    else:
        heisenberg_bonds = [(i, (i + 1) % n_sites) for i in range(0, n_sites)]

    wavefunction = wavefunctions.build_wavefunction(hparams)
    hamiltonian = operators.HeisenbergHamiltonian(heisenberg_bonds,
                                                  heisenberg_jx, 1.)

    wavefunction_optimizer = training.GROUND_STATE_OPTIMIZERS[
        FLAGS.optimizer]()

    # TODO(dkochkov) change the pipeline to avoid adding elements to dictionary
    shared_resources = {}

    graph_building_args = {
        'wavefunction': wavefunction,
        'hamiltonian': hamiltonian,
        'hparams': hparams,
        'shared_resources': shared_resources
    }

    train_ops = wavefunction_optimizer.build_opt_ops(**graph_building_args)

    session = tf.Session()
    init = tf.global_variables_initializer()
    init_l = tf.local_variables_initializer()
    session.run([init, init_l])

    checkpoint_saver = tf.train.Saver(wavefunction.get_trainable_variables(),
                                      max_to_keep=5)

    if FLAGS.resume_training:
        latest_checkpoint = tf.train.latest_checkpoint(hparams.checkpoint_dir)
        checkpoint_saver.restore(session, latest_checkpoint)

    # TODO(kochkov92) use custom output file.
    training_metrics_file = os.path.join(hparams.checkpoint_dir, 'metrics.txt')
    for epoch_number in range(FLAGS.num_epochs):
        checkpoint_name = 'model_prior_{}_epochs'.format(epoch_number)
        save_path = os.path.join(hparams.checkpoint_dir, checkpoint_name)
        checkpoint_saver.save(session, save_path)

        metrics_record = wavefunction_optimizer.run_optimization_epoch(
            train_ops, session, hparams)

        metrics_file_output = open(training_metrics_file, 'a')
        metrics_file_output.write('{}\n'.format(metrics_record))
        metrics_file_output.close()

    if FLAGS.generate_vectors:
        vector_generator = evaluation.VectorWavefunctionEvaluator()
        eval_ops = vector_generator.build_eval_ops(wavefunction, None, hparams,
                                                   shared_resources)
        vector_generator.run_evaluation(eval_ops, session, hparams,
                                        FLAGS.num_epochs)