Пример #1
0
    def toShakespeare(self):
        """Given a line of text, return that text in the indicated style.
        
        Args:
          modern_text: (string) The input.
          
        Returns:
          string: The translated text, if generated.
        """

        args = load_arguments()
        vocab = Vocabulary(self.vocab_path, args.embedding, args.dim_emb)

        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True

        with tf.Session(config=config) as sess:
            model = Model(args, vocab)
            model.saver.restore(sess, args.model)

            if args.beam > 1:
                decoder = beam_search.Decoder(sess, args, vocab, model)
            else:
                decoder = greedy_decoding.Decoder(sess, args, vocab, model)

                batch = get_batch([self.modern_text], [1], vocab.word2id)
                ori, tsf = decoder.rewrite(batch)

                out = ' '.join(w for w in tsf[0])

        return out
Пример #2
0
def transform_text(text):
    tf.compat.v1.disable_eager_execution()
    args = load_arguments()
    ah = vars(args)
    ah['vocab'] = '../model/yelp.vocab'
    ah['model'] = '../model/model'
    ah['load_model'] = True
    ah['beam'] = 8
    ah['batch_size'] = 1
    inp = [text]

    vocab = Vocabulary(args.vocab, args.embedding, args.dim_emb)
    print('vocabulary size:', vocab.size)

    config = tf.compat.v1.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.compat.v1.Session(config=config) as sess:
        model = create_model(sess, args, vocab)
        decoder = beam_search.Decoder(sess, args, vocab, model)
        '''test_losses = transfer(model, decoder, sess, args, vocab,
                               test0, test1, args.output)'''

        batches, order0, order1 = get_batches(inp, inp, vocab.word2id,
                                              args.batch_size)

        data0_tsf, data1_tsf = [], []
        losses = Accumulator(len(batches), ['loss', 'rec', 'adv', 'd0', 'd1'])

        # rec, tsf = decoder.rewrite(inp)

        # print(rec)
        # print(tsf)
        for batch in batches:
            rec, tsf = decoder.rewrite(batch)
            half = batch['size'] // 2
            print("rec:")
            print(rec)
            print("tsf:")
            print(tsf)
            data0_tsf += tsf[:half]
            data1_tsf += tsf[half:]
        n0, n1 = len(inp), len(inp)
        data0_tsf = reorder(order0, data0_tsf)[:n0]
        data1_tsf = reorder(order1, data1_tsf)[:n1]
        print(data0_tsf)
        print(data1_tsf)
Пример #3
0
        if args.load_model:
            print('Loading model from', args.model)
            ckpt = tf.train.get_checkpoint_state(args.model)
            if ckpt and ckpt.model_checkpoint_path:
                try:
                    print("Trying to restore from a checkpoint...")
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print("Model is restored from checkpoint {}".format(
                        ckpt.model_checkpoint_path))
                except Exception as e:
                    print("Cannot restore from checkpoint due to {}".format(e))
                    pass

        # set type of decoding (is this after the very last layer?)
        if args.beam > 1:
            decoder = beam_search.Decoder(sess, args, vocab, model)
        else:
            decoder = greedy_decoding.Decoder(sess, args, vocab, model)

        if args.train:
            batches, _, _ = get_batches(train0,
                                        train1,
                                        vocab.word2id,
                                        args.batch_size,
                                        noisy=True,
                                        unparallel=False,
                                        max_seq_len=args.max_seq_length)
            random.shuffle(batches)

            start_time = time.time()
            step = 0