Exemple #1
0
    def complete(self, text):
        hparams = model.default_hparams()
        with open(os.path.join('models', self.model_name,
                               'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

        if self.length is None:
            self.length = hparams.n_ctx // 2
        elif self.length > hparams.n_ctx:
            raise ValueError("Can't get samples longer than window size: %s" %
                             hparams.n_ctx)

        with tf.Session(graph=tf.Graph()) as sess:
            context = tf.placeholder(tf.int32, [self.batch_size, None])
            output = sample.sample_sequence(hparams=hparams,
                                            length=self.length,
                                            context=context,
                                            batch_size=self.batch_size,
                                            temperature=self.temperature,
                                            top_k=self.top_k)

            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', self.model_name))
            saver.restore(sess, ckpt)
            context_tokens = self.encoder.encode(text)
            out = sess.run(output,
                           feed_dict={context:
                                      [context_tokens]})[:,
                                                         len(context_tokens):]
            generated_text = self.encoder.decode(out[0])
            return generated_text
    def __init__(self):
        if ChatBotModel =="gpt2":
            model_name='355M'
            seed=None
            length=None
            temperature=0.7
            top_k=40
            top_p=0.2
            models_dir= os.path.join(root_file_path,'gpt2/models')
            models_dir = os.path.expanduser(os.path.expandvars(models_dir))

            self.enc = encoder.get_encoder(model_name, models_dir)
            hparams = model.default_hparams()
            with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
                hparams.override_from_dict(json.load(f))

            if length is None:
                length = hparams.n_ctx // 2
            elif length > hparams.n_ctx:
                raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3,allow_growth = True)
            self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
            self.context = tf.placeholder(tf.int32, [1, None])
            np.random.seed(seed)
            tf.set_random_seed(seed)
            self.output = sample.sample_sequence(
                hparams=hparams, length=length,
                context=self.context,
                batch_size=1,
                temperature=temperature, top_k=top_k, top_p=top_p
            )
            saver = tf.train.Saver()
            ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
            saver.restore(self.sess, ckpt)
Exemple #3
0
def interact_model(
    model_name='117M',
    seed=None,
    nsamples=1,
    batch_size=None,
    length=None,
    temperature=1,
    top_k=0,
):
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0
    np.random.seed(seed)
    tf.set_random_seed(seed)

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        output = sample.sample_sequence(hparams=hparams,
                                        length=length,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k)

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        while True:
            raw_text = input("Model prompt >>> ")
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("Model prompt >>> ")
            context_tokens = enc.encode(raw_text)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output,
                               feed_dict={
                                   context:
                                   [context_tokens for _ in range(batch_size)]
                               })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " +
                          "=" * 40)
                    print(text)
            print("=" * 80)
Exemple #4
0
    def preinit_model(self):
        np.random.seed(self.seed)
        tf.set_random_seed(self.seed)
        self.enc = encoder.get_encoder(self.model_name)
        self.hparams = model.default_hparams()
        with open(os.path.join('models', self.model_name, 'hparams.json')) as f:
            self.hparams.override_from_dict(json.load(f))

        if self.length is None:
            self.length = self.hparams.n_ctx // 2
        elif self.length > self.hparams.n_ctx:
            logging.error("Can't get samples longer than window size: %s" % self.hparams.n_ctx)
Exemple #5
0
    def __init__(self,
                 model_name='117M',
                 seed=None,
                 batch_size=6,
                 length=1,
                 temperature=1,
                 top_k=0,
                 top_p=0.0,
                 ckpt='checkpoint/12-8'):
        """
		:model_name=117M : String, which model to use
		:seed=None : Integer seed for random number generators, fix seed to reproduce
		 results
		:batch_size=1 : Number of batches (only affects speed/memory).
		:length=None : Number of tokens in generated text, if None (default), is
		 determined by model hyperparameters
		:temperature=1 : Float value controlling randomness in boltzmann
		 distribution. Lower temperature results in less random completions. As the
		 temperature approaches zero, the model will become deterministic and
		 repetitive. Higher temperature results in more random completions.
		:top_k=0 : Integer value controlling diversity. 1 means only 1 word is
		 considered for each step (token), resulting in deterministic completions,
		 while 40 means 40 words are considered at each step. 0 (default) is a
		 special setting meaning no restrictions. 40 generally is a good value.
		:top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
		 overriding top_k if set to a value > 0. A good setting is 0.9.
		"""

        self.seed = seed
        self.batch_size = batch_size
        self.enc = encoder.get_encoder(model_name)
        self.hparams = model.default_hparams()
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.model_name = model_name
        self.length = length
        self.endoftext = self.enc.encode('<|endoftext|>')

        if ckpt:
            self.ckpt = ckpt
        else:
            self.ckpt = os.path.join('models', self.model_name)

        with open(os.path.join('models', model_name, 'hparams.json')) as f:
            self.hparams.override_from_dict(json.load(f))
        if length is None:
            length = self.hparams.n_ctx // 2
        elif length > self.hparams.n_ctx:
            raise ValueError("Can't get samples longer than window size: %s" %
                             hparams.n_ctx)
Exemple #6
0
def generate(hints, model_name='345M', seed=None,
             nsamples=10, batch_size=1, length=None,
             temperature=1, top_k=0, top_p=1, models_dir='models'):
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))

    batch_size = batch_size or 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    results = defaultdict(set)
    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        for hint in hints:
            print("[%s]begin to generate for: %s" % (datetime.utcnow(), hint))
            context_tokens = enc.encode(hint)
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                })[:, len(context_tokens):]
                for out_data in out:
                    text = enc.decode(out_data)
                    text = postprocess(hint, text.strip())
                    results[hint].add(text)

            print("[%s]finished generating for: %s" % (datetime.utcnow(), hint))

    return results
def sample_model(
    model_name='117M',
    seed=None,
    nsamples=0,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
):
    np.random.seed(seed)
    tf.set_random_seed(seed)

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        output = sample.sample_sequence(
            hparams=hparams,
            length=length,
            start_token=enc.encoder['<|endoftext|>'],
            batch_size=batch_size,
            temperature=temperature,
            top_k=top_k)[:, 1:]

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        generated = 0
        while nsamples == 0 or generated < nsamples:
            out = sess.run(output)
            for i in range(batch_size):
                generated += batch_size
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=1,
               sample_length=1023,
               sample_num=50,
               sample_every=100,
               run_name='dnd_biographies08',
               restore_from='latest',
               mode="test",
               max_iterations=50000,
               loss_threshold=0.8,
               save_every=1000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=1.0,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        opt = tf.train.AdamOptimizer(1e-4).minimize(loss, var_list=train_vars)

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, dataset)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num or sample_num == 0:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
                    print(text)
            # print(''.join(all_text))
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            if mode == "train":
                while True and counter <= max_iterations:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    batch = [
                        data_sampler.sample(1024) for _ in range(batch_size)
                    ]

                    _, lv = sess.run((opt, loss), feed_dict={context: batch})

                    avg_loss = (avg_loss[0] * 0.99 + lv,
                                avg_loss[1] * 0.99 + 1.0)

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1]))

                    counter += 1
                    if counter > 100:
                        if (avg_loss[0] / avg_loss[1]) < loss_threshold:
                            counter = max_iterations + 1
            else:
                generate_samples()
        except KeyboardInterrupt:
            print('interrupted')
            save()
Exemple #9
0
def interact_model(
    model_name='124M',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=0,
    top_p=1,
    models_dir='models',
):
    """
    Interactively run the model
    :model_name=124M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
     :models_dir : path to parent folder containing model subfolders
     (i.e. contains the <model_name> folder)
    """
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        while True:
            raw_text = input("Model prompt >>> ")
            while not raw_text:
                print('Prompt should not be empty!')
                raw_text = input("Model prompt >>> ")
            context_tokens = enc.encode(raw_text)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                    print(text)
            print("=" * 80)
def sample_model(
    #model_name="117M",
    model_name="345M",
    seed=None,
    nsamples=5,
    batch_size=1,
    length=200,
    temperature=1,
    top_k=40,
):
    """
    Run the sample_model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to
     reproduce results
    :nsamples=0 : Number of samples to return, if 0, continues to
     generate samples indefinately.
    :batch_size=1 : Number of batches (only affects speed/memory).
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    """
    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()

    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    #print(hparams)
    if length is None:
        length = hparams.n_ctx
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        np.random.seed(seed)
        tf.set_random_seed(seed)

        output = sample.sample_sequence(
            hparams=hparams,
            length=length,
            start_token=enc.encoder['<|endoftext|>'],
            batch_size=batch_size,
            temperature=temperature,
            top_k=top_k)[:, 1:]

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
        saver.restore(sess, ckpt)

        generated = 0
        while nsamples == 0 or generated < nsamples:
            out = sess.run(output)
            #print(out)
            for i in range(batch_size):
                generated += batch_size
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
Exemple #11
0
def conditional_model(
    model_name='1558M',
    seed=None,
    nsamples=1,
    batch_size=1,
    length=None,
    temperature=1,
    top_k=40,
    top_p=0,
    models_dir='models',
    sentences=None,
    ):
    """
    Run the model on multilple sentences and return a dict.
    :model_name : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=40 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    :top_p=0.0 : Float value controlling diversity. Implements nucleus sampling,
     overriding top_k if set to a value > 0. A good setting is 0.9.
    :sentences : List of strings or string. Model returns an answer or a continuation
     to that string. If list of strings the model return a dictionary of sentences and their
     respective model replies.
    """
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0
    
    if sentences == None:
        raise ValueError('Sentences cannot be None')

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('../models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join('../models', model_name))
        saver.restore(sess, ckpt)
        listy = []
        n = 0
        
        if isinstance(sentences, list):
            for i in sentences:
                context_tokens = enc.encode(i)
                for _ in range(nsamples // batch_size):
                    out = sess.run(output, feed_dict={
                        context: [context_tokens for _ in range(batch_size)]
                    })[:, len(context_tokens):]
                text = i + enc.decode(out[0])
                listy.append(text)
                n += 1
                print(n)
            return dict(zip(sentences,listy))
        else:
            context_tokens = enc.encode(sentences)
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                })[:, len(context_tokens):]
            text = sentences + enc.decode(out[0])
            
            return {sentences: text}
Exemple #12
0
def main():

    # args = parser.parse_args()
    args = Opts()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        # hparams.override_from_dict(json.load(f))
        hparams.override_from_dict(json.loads(f.read()))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
Exemple #13
0
def interact_model(
    raw_poem=None,
    model_name='100_fal',
    seed=150,
    nsamples=1,
    batch_size=1,
    length=140,
    temperature=1,
    top_k=0,
    top_p=0.0,
):
    """
    Interactively run the model
    :model_name=117M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=40 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
    """
    input_list = [raw_poem]

    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join(basedir, 'models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(
            hparams=hparams, length=length,
            context=context,
            batch_size=batch_size,
            temperature=temperature, top_k=top_k, top_p=top_p
        )

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(basedir, 'models', model_name))
        saver.restore(sess, ckpt)

        out_text = []

        if raw_poem in input_list:
            context_tokens = enc.encode(raw_poem)
            generated = 0
            for _ in range(nsamples // batch_size):
                out = sess.run(output, feed_dict={
                    context: [context_tokens for _ in range(batch_size)]
                })[:, len(context_tokens):]
                for i in range(batch_size):
                    generated += 1
                    text = enc.decode(out[i])
                    out_text.append(text)

        text = get_poem_chunks(out_text)

        return textFilter(text)
Exemple #14
0
def train(sess,
          data,
          labels,
          steps,
          run_name,
          batch_size=1,
          n_heads=None,
          n_layers=None,
          learning_rate=0.0001,
          print_each=1,
          save_every=1000,
          accumulate=5,
          use_class_entropy=False,
          model_path="checkpoint/"):

    model_path = os.path.join(model_path, run_name)

    if not os.path.exists(model_path):
        os.mkdir(model_path)

    new_run = 'counter' not in os.listdir(model_path)

    hparams = model.default_hparams()
    #Set HyperParams
    if n_layers: hparams.n_layer = n_layers
    if n_heads: hparams.n_head = n_heads
    if os.path.exists(model_path + "/hparams.json"):
        with open(os.path.join(model_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

    #Spectrogram dimensions
    d_shape = np.shape(data)
    print(d_shape)
    hparams.n_timestep = d_shape[1]
    hparams.n_freq = d_shape[2]
    hparams.n_cat = len(labels[0])

    #Create TF graph
    inp_specs = tf.placeholder(
        tf.float32, [batch_size, hparams.n_timestep, hparams.n_freq])
    logits = model.model(hparams, inp_specs, reuse=tf.AUTO_REUSE)
    #Loss tensor = Softmax cross entropy
    label_exp = tf.placeholder(tf.int8, [batch_size, hparams.n_cat])
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_exp,
                                                   logits=logits['logits']))

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    print("Using {} Parameter Network".format(str(len(all_vars))))

    lr = tf.placeholder(tf.float32)
    if accumulate > 1:
        #Train step using AdamOtimizer with Accumulating gradients
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=lr), var_list=all_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
    else:
        opt = tf.train.AdamOptimizer(learning_rate=lr)
        opt_grads = tf.gradients(loss, all_vars)
        opt_grads = list(zip(opt_grads, all_vars))
        opt_apply = opt.apply_gradients(opt_grads)

    #Create saveable graph and checkpoint + counter
    saver = tf.train.Saver(var_list=all_vars)
    sess.run(tf.global_variables_initializer())
    if new_run:
        saver.save(sess, model_path + "/{}.ckpt".format(run_name))
    ckpt = tf.train.latest_checkpoint(model_path)
    print('Restoring checkpoint', ckpt)
    saver.restore(sess, ckpt)

    #Training SetUp
    #Get counter
    counter = 1
    counter_path = os.path.join(model_path, 'counter')
    if os.path.exists(counter_path):
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        print('Saving',
              os.path.join(model_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(model_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def next_batch(num, data, lab):
        '''
        Return a total of `num` random samples and labels.
        '''
        idx = np.arange(0, len(data))
        np.random.shuffle(idx)
        idx = idx[:num]
        data_shuffle = [data[i] for i in idx]
        labels_shuffle = [lab[i] for i in idx]
        return np.asarray(data_shuffle), np.asarray(labels_shuffle)

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    def class_entropy(y):
        y = np.sum(y, 0)
        e = sum([(i / sum(y)) * np.log(i / sum(y)) if i > 0 else 0 for i in y])

        return np.abs(1 - (-np.log(1 / len(y)) + e))

    try:
        while counter < (counter_base + steps):
            if (counter - 1) % save_every == 0 and counter > 1:
                save()

            # Get batch of specified size
            x, lab = next_batch(batch_size, data, labels)
            lrate = learning_rate * class_entropy(
                lab) if use_class_entropy else learning_rate

            if accumulate > 1:
                sess.run(opt_reset)
                #Run Gradient accumulation steps
                for _ in range(accumulate):
                    sess.run(opt_compute,
                             feed_dict={
                                 inp_specs: x,
                                 label_exp: lab
                             })
            else:
                _, v_loss = sess.run((opt_apply, loss),
                                     feed_dict={
                                         inp_specs: x,
                                         label_exp: lab,
                                         lr: lrate,
                                         "model/drop:0": 1.0
                                     })

            avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0)
            print(
                '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f} lrate={lrate}'
                .format(counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        lrate=str(lrate)))
            if counter % print_each == 0:
                sample = next_batch(batch_size, data, labels)
                out = sess.run(logits,
                               feed_dict={
                                   inp_specs: sample[0],
                                   "model/drop:0": 1.0
                               })
                acc = sum(
                    np.argmax(np.asarray(out['logits']), axis=1) == np.argmax(
                        sample[1], axis=1)) / batch_size
                print("[Summary Step] Accuracy {}% for {} distribution".format(
                    str(acc * 100), str(np.sum(sample[1], 0))))
                print("Class Entropy: {}".format(str(class_entropy(
                    sample[1]))))
            counter += 1
        save()

    except KeyboardInterrupt:
        print('interrupted')
        save()
Exemple #15
0
    def __init__(
        self,
        model_name='345M',
        seed=None,
        nsamples=1,
        batch_size=1,
        length=None,
        temperature=1,
        top_k=0,
        raw_text="",
    ):
        """
        Interactively run the model
        :model_name=117M : String, which model to use
        :seed=None : Integer seed for random number generators, fix seed to reproduce
         results
        :nsamples=1 : Number of samples to return total
        :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
        :length=None : Number of tokens in generated text, if None (default), is
         determined by model hyperparameters
        :temperature=1 : Float value controlling randomness in boltzmann
         distribution. Lower temperature results in less random completions. As the
         temperature approaches zero, the model will become deterministic and
         repetitive. Higher temperature results in more random completions.
        :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
         considered for each step (token), resulting in deterministic completions,
         while 40 means 40 words are considered at each step. 0 (default) is a
         special setting meaning no restrictions. 40 generally is a good value.
        """
        if batch_size is None:
            batch_size = 1
        assert nsamples % batch_size == 0

        self.nsamples = nsamples
        self.batch_size = batch_size

        self.enc = encoder.get_encoder(model_name)
        hparams = model.default_hparams()
        with open(os.path.join('models', model_name, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

        if length is None:
            length = hparams.n_ctx // 2
        elif length > hparams.n_ctx:
            raise ValueError("Can't get samples longer than window size: %s" %
                             hparams.n_ctx)

        self.sess = tf.Session(graph=tf.Graph())
        self.sess.__enter__()

        self.context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        self.output = sample.sample_sequence(hparams=hparams,
                                             length=length,
                                             context=self.context,
                                             batch_size=batch_size,
                                             temperature=temperature,
                                             top_k=top_k)

        saver = tf.train.Saver()
        self.ckpt = tf.train.latest_checkpoint(
            os.path.join('models', model_name))
        saver.restore(self.sess, self.ckpt)
def nointeract_model(model_name, seed, nsamples, batch_size, length,
                     temperature, top_k, top_p, models_dir, inputGPT2, nrepeat,
                     filePath, nconcepts, nphases, ntest):
    """
    Non Interactively run the model
    :model_name=124M : String, which model to use
    :seed=None : Integer seed for random number generators, fix seed to reproduce
     results
    :nsamples=1 : Number of samples to return total
    :batch_size=1 : Number of batches (only affects speed/memory).  Must divide nsamples.
    :length=None : Number of tokens in generated text, if None (default), is
     determined by model hyperparameters
    :temperature=1 : Float value controlling randomness in boltzmann
     distribution. Lower temperature results in less random completions. As the
     temperature approaches zero, the model will become deterministic and
     repetitive. Higher temperature results in more random completions.
    :top_k=0 : Integer value controlling diversity. 1 means only 1 word is
     considered for each step (token), resulting in deterministic completions,
     while 40 means 40 words are considered at each step. 0 (default) is a
     special setting meaning no restrictions. 40 generally is a good value.
     :models_dir : path to parent folder containing model subfolders
     (i.e. contains the <model_name> folder)
    """
    models_dir = os.path.expanduser(os.path.expandvars(models_dir))
    if batch_size is None:
        batch_size = 1
    assert nsamples % batch_size == 0

    enc = encoder.get_encoder(model_name, models_dir)
    hparams = model.default_hparams()
    with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if length is None:
        length = hparams.n_ctx // 2
    elif length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    with tf.Session(graph=tf.Graph()) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = sample.sample_sequence(hparams=hparams,
                                        length=length,
                                        context=context,
                                        batch_size=batch_size,
                                        temperature=temperature,
                                        top_k=top_k,
                                        top_p=top_p)

        saver = tf.train.Saver()
        ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
        saver.restore(sess, ckpt)

        f = open(filePath, 'w')

        for c in range(1, nconcepts + 1):
            for p in range(1, nphases + 1):
                for t in range(1, ntest + 1):
                    pos = (c - 1) * nphases * ntest + (t - 1)
                    for r in range(1, nrepeat + 1):
                        raw_text = inputGPT2[pos]
                        context_tokens = enc.encode(raw_text)
                        for _ in range(nsamples // batch_size):
                            out = sess.run(output,
                                           feed_dict={
                                               context: [
                                                   context_tokens
                                                   for _ in range(batch_size)
                                               ]
                                           })[:, len(context_tokens):]
                            for i in range(batch_size):
                                text = enc.decode(out[i])
                                f.write("Concept " + str(c) + " Phase " +
                                        str(p) + " Test " + str(t) +
                                        " Repetition " + str(r) +
                                        " // Input= " + raw_text +
                                        " // Output= " + text)
Exemple #17
0
    def generate(self,
                 sess,
                 return_as_list=False,
                 truncate=None,
                 destination_path=None,
                 sample_delim='=' * 20 + '\n',
                 prefix=None,
                 seed=None,
                 batch_size=1,
                 nsamples=1,
                 length=1023,
                 temperature=0.7,
                 top_k=0,
                 run_name='run1',
                 include_prefix=True):
        """
        Generates text from a model loaded into memory.
        Adapted from https://github.com/openai/gpt-2/blob/master/src/interactive_conditional_samples.py
        """

        if batch_size is None:
            batch_size = 1
        assert nsamples % batch_size == 0

        if nsamples == 1:
            sample_delim = ''

        if prefix:
            context = tf.placeholder(tf.int32, [batch_size, None])

        CHECKPOINT_DIR = 'checkpoint'

        checkpoint_path = os.path.join(model_path, CHECKPOINT_DIR, run_name)

        enc = encoder.get_encoder(checkpoint_path)
        hparams = model.default_hparams()
        with open(os.path.join(checkpoint_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

        np.random.seed(seed)
        tf.set_random_seed(seed)

        output = model.sample_sequence(
            hparams=hparams,
            length=length,
            start_token=enc.encoder['<|endoftext|>'] if not prefix else None,
            context=context if prefix else None,
            batch_size=batch_size,
            temperature=temperature,
            top_k=top_k)[:, 1:]

        if destination_path:
            f = open(destination_path, 'w')
        if prefix:
            context_tokens = enc.encode(prefix)
        generated = 0
        gen_texts = []
        while generated < nsamples:
            if not prefix:
                out = sess.run(output)
            else:
                out = sess.run(
                    output, feed_dict={context: batch_size * [context_tokens]})
            for i in range(batch_size):
                generated += 1
                gen_text = enc.decode(out[i])
                if prefix:
                    gen_text = enc.decode([context_tokens[0]]) + gen_text
                if truncate:
                    truncate_esc = re.escape(truncate)
                    if prefix and not include_prefix:
                        prefix_esc = re.escape(prefix)
                        pattern = '(?:{})(.*?)(?:{})'.format(
                            prefix_esc, truncate_esc)
                    else:
                        pattern = '(.*?)(?:{})'.format(truncate_esc)

                    trunc_text = re.search(pattern, gen_text, re.S)
                    if trunc_text:
                        gen_text = trunc_text.group(1)
                if destination_path:
                    f.write("{}\n{}".format(gen_text, sample_delim))
                if not return_as_list and not destination_path:
                    print("{}\n{}".format(gen_text, sample_delim))
                gen_texts.append(gen_text)

        if destination_path:
            f.close()

        if return_as_list:
            return gen_texts
Exemple #18
0
def predict(sess,
            data,
            run_name,
            batch_size,
            num_categories,
            category_names,
            model_path="checkpoint/"):

    model_path = os.path.join(model_path, run_name)

    # Load Hyperparams from model
    hparams = model.default_hparams()
    if os.path.exists(model_path + "/hparams.json"):
        with open(os.path.join(model_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

    d_shape = np.shape(data)
    print("Precicting for data: " + str(d_shape))
    hparams.n_timestep = d_shape[1]
    hparams.n_freq = d_shape[2]
    hparams.n_cat = num_categories

    # Create TF graph
    inp_specs = tf.placeholder(
        tf.float32, [batch_size, hparams.n_timestep, hparams.n_freq])
    prediction = model.model(hparams, inp_specs)

    # Get Model vars
    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    saver = tf.train.Saver(var_list=all_vars)
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.latest_checkpoint(model_path)
    saver.restore(sess, ckpt)

    predictions = np.zeros((len(data), num_categories))
    num_batches = int(np.ceil(len(data) / batch_size))

    for i in tqdm(range(num_batches)):
        c = batch_size

        if i * batch_size + c > len(data):
            add = (i * batch_size + c) - len(data)
            pred = sess.run(prediction,
                            feed_dict={
                                inp_specs:
                                np.concatenate((data[i * batch_size:],
                                                np.zeros(
                                                    (add, hparams.n_timestep,
                                                     hparams.n_freq)))),
                                "model/drop:0":
                                1.0
                            })['logits']
            predictions[i * batch_size:] = pred[:-add]
        else:

            predictions[i*batch_size: i*batch_size+c] =\
                sess.run(prediction, feed_dict={inp_specs: data[i*batch_size: i*batch_size+batch_size],
                                                "model/drop:0": 1.0})['logits']

    cats = np.argmax(predictions, axis=1)

    return {
        "raw": predictions,
        "category": cats,
        "predictName": ["N", "S", "V", "F", "Q"],
        "names": category_names
    }