Exemplo n.º 1
0
def decode():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])

    # Load inverse dictionary used in decoding
    target_inverse_dict = data_utils.load_inverse_dict(
        config['target_vocabulary'])

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            print('Decoding {}..'.format(FLAGS.decode_input))
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]

            for idx, source_seq in enumerate(test_set.next()):
                print('Source', source_seq)
                source, source_len = prepare_batch(source_seq)
                print('Source', source, 'Source Len', source_len)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess,
                                              encoder_inputs=source,
                                              encoder_inputs_length=source_len)
                print(predicted_ids)
                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    for seq in predicted_ids:
                        f.write(
                            str(
                                data_utils.seq2words(
                                    seq[:, k], target_inverse_dict)) + '\n')
                    if not FLAGS.write_n_best:
                        break
                print('{}th line decoded'.format(idx *
                                                 FLAGS.decode_batch_size))

            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
Exemplo n.º 2
0
def decode(config):

    model, config = load_model(config)
    
    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            source_dict=config['src_vocab'],
                            batch_size=config['batch_size'],
                            n_words_source=config['num_enc_symbols'],
                            maxlen=None)
    target_inv_dict = load_inv_dict(config['tgt_vocab'])

    if use_cuda:
        print 'Using gpu..'
        model = model.cuda()

    try:
        print 'Decoding starts..'
        fout = fopen(config['decode_output'], 'w')
        for idx, source_seq in enumerate(test_set):
            source, source_len = prepare_batch(source_seq)

            preds_prev = torch.zeros(config['batch_size'], config['max_decode_step']).long()
            preds_prev[:,0] += data_utils.start_token
            preds = torch.zeros(config['batch_size'], config['max_decode_step']).long()

            if use_cuda:
                source = Variable(source.cuda())
                source_len = Variable(source_len.cuda())
                preds_prev = Variable(preds_prev.cuda())
                preds = preds.cuda()
            else:
                source = Variable(source)
                source_len = Variable(source_len)
                preds_prev = Variable(preds_prev)

            states, memories = model.encode(source, source_len)
            
            for t in xrange(config['max_decode_step']):
                # logits: [batch_size x max_decode_step, tgt_vocab_size]
                _, logits = model.decode(preds_prev, None, memories, keep_len=True)
                # outputs: [batch_size, max_decode_step]
                outputs = torch.max(logits, dim=1)[1].view(config['batch_size'], -1)
                preds[:,t] = outputs[:,t].data
                if t < config['max_decode_step'] - 1:
                    preds_prev[:,t+1] = outputs[:,t]

            for i in xrange(len(preds)):
                fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
                fout.flush()

            print '  {}th line decoded'.format(idx * config['batch_size'])
        print 'Decoding terminated'

    except IOError:
        pass
    finally:
        fout.close()
Exemplo n.º 3
0
def decode(config):
    model, config = load_model(config)
    # Load source data to decode
    test_set = TextIterator(
        source=config['decode_input'],
        source_dict=config['src_vocab'],
        batch_size=config['batch_size'],
        maxlen=None,
        n_words_source=config['num_enc_symbols'],
        shuffle_each_epoch=False,
        sort_by_length=False,
    )
    target_inv_dict = load_inv_dict(config['tgt_vocab'])

    lines = 0
    max_decode_step = config['max_decode_step']
    print 'Decoding starts..'
    with fopen(config['decode_output'], 'w') as fout:
        for idx, source_seq in enumerate(test_set):
            source, source_len = prepare_batch(source_seq)

            preds_prev = torch.zeros(len(source), max_decode_step).long()
            preds_prev[:, 0] += data_utils.start_token
            preds = torch.zeros(len(source), max_decode_step).long()

            if use_cuda:
                source = Variable(source.cuda())
                source_len = Variable(source_len.cuda())
                preds_prev = Variable(preds_prev.cuda())
                preds = preds.cuda()
            else:
                source = Variable(source)
                source_len = Variable(source_len)
                preds_prev = Variable(preds_prev)

            states, memories = model.encode(source, source_len)

            for t in xrange(max_decode_step):
                # logits: [batch_size x max_decode_step, tgt_vocab_size]
                _, logits = model.decode(preds_prev[:, :t + 1], states,
                                         memories)
                # outputs: [batch_size, max_decode_step]
                outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
                preds[:, t] = outputs[:, t].data
                if t < max_decode_step - 1:
                    preds_prev[:, t + 1] = outputs[:, t]
            for i in xrange(len(preds)):
                fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
                fout.flush()

            lines += source.size(0)
            print '  {}th line decoded'.format(lines)
        print 'Decoding terminated'