示例#1
0
def train(session,
          model,
          length_from=3,
          length_to=8,
          vocab_lower=2,
          vocab_upper=10,
          batch_size=100,
          max_batches=5000,
          batches_in_epoch=1000,
          verbose=True,
          input_keep_prob=1,
          output_keep_prob=1,
          state_keep_prob=1):

    print(input_keep_prob)
    batches = helpers.random_sequences(length_from=length_from,
                                       length_to=length_to,
                                       vocab_lower=vocab_lower,
                                       vocab_upper=vocab_upper,
                                       batch_size=batch_size)
    loss_track = []
    try:
        for batch in range(max_batches + 1):
            batch_data = next(batches)
            fd = model.make_train_inputs_II(batch_data, batch_data,
                                            input_keep_prob, output_keep_prob,
                                            state_keep_prob)
            _, l = session.run([model.train_op, model.loss], fd)
            loss_track.append(l)

            if verbose:
                if batch == 0 or batch % batches_in_epoch == 0:
                    print('batch {}'.format(batch))
                    print('  minibatch loss: {}'.format(
                        session.run(model.loss, fd)))
                    for i, (e_in, dt_pred) in enumerate(
                            zip(
                                fd[model.encoder_inputs],
                                session.run(model.decoder_prediction_train,
                                            fd))):
                        print('  sample {}:'.format(i + 1))
                        print('    enc input           > {}'.format(e_in))
                        print('    dec train predicted > {}'.format(dt_pred))
                        if i >= 2:
                            break
                    print()
    except KeyboardInterrupt:
        print('training interrupted')

    return loss_track
示例#2
0
def train_on_copy_task(session,
                       model,
                       length_from=3,
                       length_to=8,
                       vocab_lower=2,
                       vocab_upper=10,
                       batch_size=100,
                       max_batches=5000,
                       batches_in_epoch=1000,
                       verbose=True):

    batches = helpers.random_sequences(length_from=length_from,
                                       length_to=length_to,
                                       vocab_lower=vocab_lower,
                                       vocab_upper=vocab_upper,
                                       batch_size=batch_size)
    loss_track = []
    try:
        for batch in range(max_batches + 1):
            batch_data = next(batches)
            fd = model.make_train_inputs(batch_data, batch_data)
            _, l = session.run([model.train_op, model.loss], fd)
            loss_track.append(l)

            if verbose:
                if batch == 0 or batch % batches_in_epoch == 0:
                    print('batch {}'.format(batch))
                    print('  minibatch loss: {}'.format(
                        session.run(model.loss, fd)))
                    for i, (e_in, dt_pred) in enumerate(
                            zip(
                                fd[model.encoder_inputs].T,
                                session.run(model.decoder_prediction_train,
                                            fd).T)):
                        print('  sample {}:'.format(i + 1))
                        print('    enc input           > {}'.format(e_in))
                        print('    dec train predicted > {}'.format(dt_pred))
                        if i >= 2:
                            break
                    print()

        print("Doing inference with trained model")
        fd = model.make_inference_inputs([[5, 4, 6, 7], [6, 6]])
        inf_out = session.run(model.decoder_prediction_inference, fd)
        print(inf_out)
    except KeyboardInterrupt:
        print('training interrupted')

    return loss_track
示例#3
0
    def train_on_copy_task(self,
                           length_from=3,
                           length_to=8,
                           vocab_lower=3,
                           vocab_upper=10,
                           batch_size=64,
                           max_batches=5000,
                           batches_in_epoch=1000,
                           verbose=True):
        """ Feed small inputs into the seq2seq graph to ensure it is functioning
            correctly. Only used in the early stages of the project for debugging
        """
        batches = helpers.random_sequences(length_from=length_from,
                                           length_to=length_to,
                                           vocab_lower=vocab_lower,
                                           vocab_upper=vocab_upper,
                                           batch_size=batch_size)
        loss_track = []
        try:
            for batch in range(max_batches + 1):
                batch_data = next(batches)
                fd = self.make_train_inputs(batch_data, batch_data)
                _, l = self.session.run([self.train_op, self.loss], fd)
                loss_track.append(l)

                if verbose:
                    if batch == 0 or batch % batches_in_epoch == 0:
                        print('batch {}'.format(batch))
                        print('  minibatch loss: {}'.format(
                            self.session.run(self.loss, fd)))
                        for i, (e_in, dt_pred) in enumerate(
                                zip(
                                    fd[self.encoder_inputs].T,
                                    self.session.run(
                                        self.decoder_prediction_train, fd).T)):
                            print('  sample {}:'.format(i + 1))
                            print('    enc input           > {}'.format(e_in))
                            print(
                                '    dec train predicted > {}'.format(dt_pred))
                            if i >= 2:
                                break
                        print()
        except KeyboardInterrupt:
            print('training interrupted')

        return loss_track
示例#4
0
    def train(self, length, vocab, batches, directory):
        help_batch = helpers.random_sequences(length_from=length['from'],
                                              length_to=length['to'],
                                              vocab_lower=vocab['lower'],
                                              vocab_upper=vocab['size'],
                                              batch_size=batches['size'])

        saver = tf.train.Saver(self.seq2seq_vars)
        loss_track = []
        for batch in range(batches['max'] + 1):
            seq_batch = next(help_batch)
            fd = self.make_train_inputs(seq_batch, seq_batch)
            _, loss, state = self.sess.run(
                [self.train_op, self.loss, self.encoder_final_state[0]], fd)
            loss_track.append(loss)
            print('\rBatch {}/{}\tloss: {}\tshape: {}'.format(
                batch, batches['max'], loss, state.shape),
                  end="")

        print('\nLoss {:.4f} after {} examples (batch_size={})'.format(
            loss_track[-1],
            len(loss_track) * batches['size'], batches['size']))
        path = saver.save(self.sess, directory + '/seq2seq.ckpt')
        print("Trained model saved to {}".format(path))
#         encoder_inputs: batch_,
#         decoder_inputs: din_,
#     })
# pred_ = pred_.swapaxes(0,1)
# for i in pred_:
#     print(decode_str(i))
# print('decoder predictions:\n' + str(pred_))

batch_size = 10
# def batch_generator():
#     while True:
#         yield [encode_str(train_mails[1][0]) for x in range(batch_size)]

batches = helpers.random_sequences(length_from=4,
                                   length_to=8,
                                   vocab_lower=50,
                                   vocab_upper=vocab_size,
                                   batch_size=batch_size)


def next_feed():
    batch = next(batches)
    encoder_inputs_, _ = helpers.batch(batch)
    decoder_targets_, _ = helpers.batch([(sequence) for sequence in batch])
    decoder_inputs_, _ = helpers.batch([(sequence) for sequence in batch])
    return {
        encoder_inputs: encoder_inputs_,
        decoder_inputs: decoder_inputs_,
        decoder_targets: decoder_targets_,
    }
示例#6
0
print('decoder inputs:\n' + str(din_))

pred_ = sess.run(decoder_prediction,
                 feed_dict={
                     encoder_inputs: batch_,
                     decoder_inputs: din_,
                 })

# Training on the toy task

batch_size = 100

batches = helpers.random_sequences(length_from=3,
                                   length_to=8,
                                   vocab_lower=2,
                                   vocab_upper=10,
                                   batch_size=batch_size)

print('head of the batch:')
for seq in next(batches)[:10]:
    print(seq)


def next_feed():
    batch = next(batches)
    encoder_inputs_, _ = helpers.batch(batch)
    decoder_targets_, _ = helpers.batch([(sequence) + [EOS]
                                         for sequence in batch])
    decoder_inputs_, = helpers.batch([[EOS] + (sequence)
                                      for sequence in batch])
def training_model(word_size, max_length, feature_length):
  tf.set_random_seed(9487)
  vocab_size = word_size
  encoder_input_size = feature_length
  decoder_output_size = max_length
  encoder_hidden_units = 20
  decoder_hidden_units = 20
  rnn_size = 64

  encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
  encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
  decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')

  embeddings = tf.Variable(tf.random_uniform([vocab_size, encoder_input_size], -1.0, 1.0), dtype=tf.float32)
  encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)

  encoder_cell = LSTMCell(encoder_hidden_units)

  encoder_outputs, encoder_final_state = (tf.nn.dynamic_rnn(cell=encoder_cell,
                      inputs=encoder_inputs_embedded, 
                      sequence_length=encoder_inputs_length,
                      dtype=tf.float32, time_major=True))


  decoder_cell = LSTMCell(decoder_hidden_units)
  encoder_max_time, batch_size = tf.unstack(tf.shape(encoder_inputs))
  decoder_lengths = encoder_inputs_length + 3

  attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                rnn_size, 
                memory=encoder_outputs,
                memory_sequence_length=encoder_inputs_length)
  
  decoder_cell = tf.contrib.seq2seq.AttentionWrapper(
                cell=decoder_cell, 
                attention_mechanism=attention_mechanism, 
                attention_layer_size=rnn_size, 
                name='Attention_Wrapper')

  #weights
  W = tf.Variable(tf.random_uniform([decoder_hidden_units, vocab_size], -1, 1), dtype=tf.float32)
  #bias
  b = tf.Variable(tf.zeros([vocab_size]), dtype=tf.float32)

  eos_time = tf.ones([batch_size], dtype=tf.int32, name='EOS')
  pad_time = tf.zeros([batch_size], dtype=tf.int32, name='PAD')

  #retrieves rows of the params tensor. The behavior is similar to using indexing with arrays in numpy
  embeddings = tf.Variable(tf.random_uniform([vocab_size, decoder_output_size], -1.0, 1.0), dtype=tf.float32)
  decoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_outputs)

  decoder_outputs, decoder_final_state = (tf.nn.dynamic_rnn(cell=decoder_cell,
                      inputs=decoder_inputs_embedded, 
                      sequence_length= decoder_lengths,
                      dtype=tf.float32, time_major=True))

  decoder_max_steps, decoder_batch_size, decoder_dim = tf.unstack(tf.shape(decoder_outputs))
  decoder_outputs_flat = tf.reshape(decoder_outputs, (-1, decoder_dim))

  decoder_logits_flat = tf.add(tf.matmul(decoder_outputs_flat, W), b)
  decoder_logits = tf.reshape(decoder_logits_flat, (decoder_max_steps, decoder_batch_size, vocab_size))

  decoder_prediction = tf.argmax(decoder_logits, 2)

  stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
    labels=tf.one_hot(decoder_targets, depth=vocab_size, dtype=tf.float32),
    logits=decoder_logits,
  )
  #loss function
  loss = tf.reduce_mean(stepwise_cross_entropy)
  #train it 
  train_op = tf.train.AdamOptimizer().minimize(loss)
  sess.run(tf.global_variables_initializer())

  batch_size = 100

  batches = helpers.random_sequences(length_from=3, length_to=8,
                                   vocab_lower=2, vocab_upper=10,
                                   batch_size=batch_size)

  max_batches = 1972

  for batch in range(max_batches):
    fd = next_feed(batches)
    _, l = sess.run([train_op, loss], fd)
    losses.append(l)

  if batch == 0 or batch % 100 == 0:
    print('batch {}'.format(batch))
    print('  minibatch loss: {}'.format(sess.run(loss, fd)))
    predict_ = sess.run(decoder_prediction, fd)
    for i, (inp, pred) in enumerate(zip(fd[encoder_inputs].T, predict_.T)):
      print('    predicted > {}'.format(pred))
      if i >= 2:
        break
示例#8
0
def train():
    print("start training!!")

    restore = True

    args = sys.argv
    args = args[1:]

    for _i in range(int(len(args) / 2)):
        arg_idx = _i * 2
        val_idx = _i * 2 + 1

        arg, value = args[arg_idx], args[val_idx]

        if arg == '-r':
            restore = value

    print(restore)

    vocab_size = cf.vocab_size
    input_embedding_size = cf.input_embedding_size
    encoder_hidden_units = cf.encoder_hidden_units
    batch_size = cf.batch_size

    params = dict()
    params['vocab_size'] = vocab_size
    params['input_embedding_size'] = input_embedding_size
    params['encoder_hidden_units'] = encoder_hidden_units
    params['batch_size'] = batch_size

    model = Model(params)
    saver = tf.train.Saver()

    sentences = wp.get_sentences()

    max_encoder_length = 0
    for b in sentences:
        if len(b) > max_encoder_length:
            max_encoder_length = len(b)
    #batch = [[3,4,5,2,6,7,8,9],[6,7,3,4,5],[2,2,4,5,6]]
    max_decoder_length = max_encoder_length + 3

    encoder_input_list = list()
    encoder_input_length_list = list()
    decoder_target_list = list()
    decoder_length_list = list()

    for i in range(int(len(sentences) / batch_size)):

        batch = sentences[start:end]
        '''
		max_encoder_length = 0
		for b in batch:
			if len(b) > max_encoder_length:
				max_encoder_length = len(b)			
		#batch = [[3,4,5,2,6,7,8,9],[6,7,3,4,5],[2,2,4,5,6]]
		max_decoder_length = max_encoder_length + 3
		'''
        encoder_inputs_, encoder_input_lengths_ = helpers.batch(batch)
        decoder_targets_, _ = helpers.batch(
            #[(sequence) + [EOS] + [PAD] * 2 for sequence in batch]
            [(sequence) + [EOS] + [PAD] *
             (max_decoder_length - len(sequence) - 1) for sequence in batch])

        return {
            model.encoder_inputs:
            encoder_inputs_,
            model.encoder_inputs_length:
            encoder_input_lengths_,
            model.decoder_targets:
            decoder_targets_,
            model.decoder_lengths:
            [max_decoder_length for v in encoder_input_lengths_]
        }

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
        #with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
        sess.run(tf.global_variables_initializer())

        model_ckpt_file = './status/model.ckpt'

        print("&&&&&&&&&&&&&&&&&&&&&&&&&&")
        print("&&&&&&&&&&&&&&&&&&&&&&&&&&")
        print("&&&&&&&&&&&&&&&&&&&&&&&&&&")
        if restore == 'T':
            print("restoring.... ")
            saver.restore(sess, model_ckpt_file)
        else:
            print("not restoring....")

        PAD = 0
        EOS = 1

        batches = helpers.random_sequences(length_from=3,
                                           length_to=8,
                                           vocab_lower=2,
                                           vocab_upper=10,
                                           batch_size=batch_size)

        def next_feed():
            batch = next(batches)
            #batch = [[3,4,5,2,6,7,8,9],[6,7,3,4,5],[2,2,4,5,6]]
            max_decoder_length = max_encoder_length + 3

            encoder_inputs_, encoder_input_lengths_ = helpers.batch(batch)
            decoder_targets_, _ = helpers.batch(
                #[(sequence) + [EOS] + [PAD] * 2 for sequence in batch]
                [(sequence) + [EOS] + [PAD] *
                 (max_decoder_length - len(sequence) - 1)
                 for sequence in batch])

            return {
                model.encoder_inputs:
                encoder_inputs_,
                model.encoder_inputs_length:
                encoder_input_lengths_,
                model.decoder_targets:
                decoder_targets_,
                model.decoder_lengths:
                [max_decoder_length for v in encoder_input_lengths_]
            }

        def next_feed_word(start, end):
            batch = sentences[start:end]
            '''
			max_encoder_length = 0
			for b in batch:
				if len(b) > max_encoder_length:
					max_encoder_length = len(b)			
			#batch = [[3,4,5,2,6,7,8,9],[6,7,3,4,5],[2,2,4,5,6]]
			max_decoder_length = max_encoder_length + 3
			'''
            encoder_inputs_, encoder_input_lengths_ = helpers.batch(batch)
            decoder_targets_, _ = helpers.batch(
                #[(sequence) + [EOS] + [PAD] * 2 for sequence in batch]
                [(sequence) + [EOS] + [PAD] *
                 (max_decoder_length - len(sequence) - 1)
                 for sequence in batch])

            return {
                model.encoder_inputs:
                encoder_inputs_,
                model.encoder_inputs_length:
                encoder_input_lengths_,
                model.decoder_targets:
                decoder_targets_,
                model.decoder_lengths:
                [max_decoder_length for v in encoder_input_lengths_]
            }

        loss_track = []

        max_batches = 30
        batches_in_epoch = 5

        try:
            for e in range(max_batches):
                start_time_out = dt.datetime.now()
                #print(batch)
                print(e, " epoch start...")
                #fd = next_feed()
                for i in range(int(len(sentences) / batch_size)):
                    start_time = dt.datetime.now()

                    start = i * batch_size
                    end = start + batch_size
                    print("get data")
                    fd = next_feed_word(start, end)
                    print("batch processing...")
                    _, l = sess.run([model.train_op, model.loss], fd)
                    #_, l = sess.run([train_op_gd, loss], fd)
                    print("Take", str(
                        (dt.datetime.now() - start_time).seconds),
                          "seconds for ", str(i), " in ",
                          str(len(sentences) / batch_size))

                print("Take", str(
                    (dt.datetime.now() - start_time_out).seconds),
                      "seconds for in epoch. current is ", str(e))
                if e == 0 or e % batches_in_epoch == 0:
                    print('e {}'.format(e))
                    print('  minibatch loss: {}'.format(
                        sess.run(model.loss, fd)))
                    predict_ = sess.run(model.decoder_prediction, fd)
                    for i, (inp, pred) in enumerate(
                            zip(fd[model.encoder_inputs].T, predict_.T)):
                        print('  sample {}:'.format(i + 1))
                        print('    input     > {}'.format(inp))
                        print('    predicted > {}'.format(pred))
                        if i >= 10:
                            break

                    saver.save(sess, model_ckpt_file)
                    print("mode saved to ", model_ckpt_file)

        except KeyboardInterrupt:
            print('training interrupted')
示例#9
0
def main(_):
    vocab_size = 10
    input_embedding_size = 20

    encoder_hidden_units = 20
    decoder_hidden_units = 20

    with tf.name_scope('encoder_inputs'):
        encoder_inputs = tf.placeholder(shape=(None, None),
                                        dtype=tf.int32,
                                        name='encoder_inputs')
    with tf.name_scope('decoder_targets'):
        decoder_targets = tf.placeholder(shape=(None, None),
                                         dtype=tf.int32,
                                         name='decoder_targets')

    with tf.name_scope('decoder_iputs'):
        decoder_inputs = tf.placeholder(shape=(None, None),
                                        dtype=tf.int32,
                                        name='decoder_inputs')

    with tf.name_scope('embeddings'):
        embeddings = tf.Variable(tf.random_uniform(
            [vocab_size, input_embedding_size], -0.1, 1.0),
                                 dtype=tf.float32)

    with tf.name_scope('encoder_inputs_embedded'):
        encoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, encoder_inputs)
    with tf.name_scope('decoder_inputs_embedded'):
        decoder_inputs_embedded = tf.nn.embedding_lookup(
            embeddings, decoder_inputs)

    with tf.name_scope('encoder_cell'):
        encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)
    with tf.name_scope('encoder_dynamic'):
        encoder_outputs, encoder_final_state = tf.nn.dynamic_rnn(
            encoder_cell,
            encoder_inputs_embedded,
            dtype=tf.float32,
            time_major=True,
        )
    del encoder_outputs

    encoder_final_state

    with tf.name_scope('decoder_cell'):
        decoder_cell = tf.contrib.rnn.LSTMCell(decoder_hidden_units)
    with tf.name_scope('decoder_dynamic'):
        decoder_outputs, decoder_final_state = tf.nn.dynamic_rnn(
            decoder_cell,
            decoder_inputs_embedded,
            initial_state=encoder_final_state,
            dtype=tf.float32,
            time_major=True,
            scope="plain_decoder",
        )

    with tf.name_scope('decoder_logits'):
        decoder_logits = tf.contrib.layers.linear(decoder_outputs, vocab_size)
    with tf.name_scope('decoder_prediction'):
        decoder_prediction = tf.argmax(decoder_logits, 2)
    decoder_logits

    with tf.name_scope('stepwise_cross_entropy'):
        stepwise_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
            labels=tf.one_hot(decoder_targets,
                              depth=vocab_size,
                              dtype=tf.float32),
            logits=decoder_logits,
        )

    with tf.name_scope('loss'):
        loss = tf.reduce_mean(stepwise_cross_entropy)
    with tf.name_scope('train_op'):
        train_op = tf.train.AdamOptimizer().minimize(loss)

        sess.run(tf.global_variables_initializer())

    sv = tf.train.Supervisor(logdir=FLAGS.save_path)

    with sv.managed_session() as session:
        batch_ = [[6], [3, 4], [9, 8, 7]]
        batch_, batch_length_ = helpers.batch(batch_)
        #print('batch_encoded:\n' + str(batch_))

        din_, dlen_ = helpers.batch(np.ones(shape=(3, 1), dtype=np.int32),
                                    max_sequence_length=4)

        #print('decoder inputs:\n' + str(din_))

        pred_ = sess.run(decoder_prediction,
                         feed_dict={
                             encoder_inputs: batch_,
                             decoder_inputs: din_,
                         })

        if FLAGS.save_path:
            sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)

        #print('decoder predictions:\n' + str(pred_))

    batch_size = 100
    batches = helpers.random_sequences(length_from=3,
                                       length_to=8,
                                       vocab_lower=2,
                                       vocab_upper=10,
                                       batch_size=batch_size)

    print('head of the batch:')

    #   for seq in next(batches)[:10]:
    #       print(seq)

    def next_feed():
        batch = next(batches)
        #       print('nex_feed batch EOS:{}'.format(EOS))
        for seq in batch:
            print(seq)
        encoder_inputs_, _ = helpers.batch(batch)
        #       print('encode_input {}'.format(encoder_inputs_))
        decoder_targets_, _ = helpers.batch([((sequence) + [EOS])
                                             for sequence in batch])
        #       print('decoder_targets_{}'.format(decoder_targets_))
        decoder_inputs_, _ = helpers.batch([([EOS] + (sequence))
                                            for sequence in batch])
        #       print('decode_input {}'.format(decoder_inputs_))
        return {
            encoder_inputs: encoder_inputs_,
            decoder_inputs: decoder_inputs_,
            decoder_targets: decoder_targets_,
        }

    loss_track = []

    max_batches = 3001
    batches_in_epoch = 1000

    try:
        for batch in range(max_batches):
            fd = next_feed()
            _, l = sess.run([train_op, loss], fd)
            loss_track.append(l)

            if batch == 0 or batch % batches_in_epoch == 0:
                #               print('batch{}'.format(batch))
                #               print('minibatch loss: {}'.format(sess.run(loss, fd)))
                predict_ = sess.run(decoder_prediction, fd)
                for i, (inp, pred) in enumerate(
                        zip(fd[encoder_inputs].T, predict_.T)):
                    print(' sample{}:'.format(i + 1))
                    print('  input   > {}'.format(inp))
                    print('  predicted > {}'.format(pred))
                    if i >= 2:
                        break
    except KeyboardInterrupt:
        print('training interrupted')
    plt.plot(loss_track)
    #   plt.show()
    print('loss {:.4f} after {} examples (batch_size={})'.format(
        loss_track[-1],
        len(loss_track) * batch_size, batch_size))
示例#10
0
def train(id_, inv_, x_):
    agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
    cond_ops = ['=', '>', '<', 'OP']
    syms = [
        'SELECT', 'WHERE', 'AND', 'COL', 'TABLE', 'CAPTION', 'PAGE', 'SECTION',
        'OP', 'COND', 'QUESTION', 'AGG', 'AGGOPS', 'CONDOPS'
    ]

    PAD = 0
    EOS = 1

    vocab_size = len(id_)  # 17
    input_embedding_size = 20

    encoder_hidden_units = 20
    decoder_hidden_units = encoder_hidden_units  #*2

    encoder_inputs = tf.placeholder(shape=(None, None),
                                    dtype=tf.int32,
                                    name='encoder_inputs')
    encoder_inputs_length = tf.placeholder(shape=(None, ),
                                           dtype=tf.int32,
                                           name='encoder_inputs_length')

    decoder_targets = tf.placeholder(shape=(None, None),
                                     dtype=tf.int32,
                                     name='decoder_targets')

    embeddings = tf.Variable(tf.random_uniform(
        [vocab_size, input_embedding_size], -1.0, 1.0),
                             dtype=tf.float32)

    encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings,
                                                     encoder_inputs)

    encoder_cell = tf.contrib.rnn.LSTMCell(encoder_hidden_units)

    ################################## attention model #############################
    loss_tracks = dict()

    def do_train(session, model):
        return train_on_copy_task(session,
                                  model,
                                  length_from=3,
                                  length_to=8,
                                  vocab_lower=2,
                                  vocab_upper=10,
                                  batch_size=100,
                                  max_batches=5000,
                                  batches_in_epoch=1000,
                                  verbose=False)

    def make_model(**kwa):
        args = dict(cell_class=LSTMCell,
                    num_units_encoder=10,
                    vocab_size=10,
                    embedding_size=10,
                    attention=False,
                    bidirectional=False,
                    debug=False)

        args.update(kwa)
        cell_class = args.pop('cell_class')
        num_units_encoder = args.pop('num_units_encoder')
        num_units_decoder = num_units_encoder

        if args['bidirectional']:
            num_units_decoder *= 2

        args['encoder_cell'] = cell_class(num_units_encoder)
        args['decoder_cell'] = cell_class(num_units_decoder)
        return Seq2SeqModel(**args)

    tf.reset_default_graph()
    tf.set_random_seed(1)
    with tf.Session() as session:
        model = make_model(bidirectional=False, attention=True)
        session.run(tf.global_variables_initializer())
        loss_tracks['forward encoder, with attention'] = do_train(
            session, model)

    #################################### train #########################################
    batch_size = 4
    batches = helpers.random_sequences(length_from=3,
                                       length_to=8,
                                       vocab_lower=2,
                                       vocab_upper=10,
                                       batch_size=batch_size)
    sql_batches = [i[1] for i in x_[1:20]]

    print('head of the batch:')
    for seq in sql_batches[1:5]:
        print([id_[s] for s in seq])

    def next_feed():
        batch = sql_batches
        encoder_inputs_, encoder_input_lengths_ = helpers.batch(batch)
        decoder_targets_, _ = helpers.batch([(sequence) + [EOS] + [PAD] * 2
                                             for sequence in batch])
        return {
            encoder_inputs: encoder_inputs_,
            encoder_inputs_length: encoder_input_lengths_,
            decoder_targets: decoder_targets_,
        }

    loss_track = []
    max_batches = 200
    batches_in_epoch = 20
    try:
        for batch in range(max_batches):
            fd = next_feed()
            _, l = sess.run([train_op, loss], fd)
            loss_track.append(l)

            if batch == 0 or batch % batches_in_epoch == 0:
                print('batch {}'.format(batch))
                print('  minibatch loss: {}'.format(sess.run(loss, fd)))
                predict_ = sess.run(decoder_prediction, fd)
                for i, (inp, pred) in enumerate(
                        zip(fd[encoder_inputs].T, predict_.T)):
                    print('  sample {}:'.format(i + 1))
                    print('    input     > {}'.format([id_[q] for q in inp]))
                    print('    predicted > {}'.format([id_[q] for q in pred]))
                    if i >= 2:
                        break
                print()
    except KeyboardInterrupt:
        print('training interrupted')
    plt.plot(loss_track)
    plt.savefig('loss track')
    print('loss {:.4f} after {} examples (batch_size={})'.format(
        loss_track[-1],
        len(loss_track) * batch_size, batch_size))
def train_on_copy_task(session, model,
                       length_from=3, length_to=8,
                       vocab_lower=2, vocab_upper=10,
                       batch_size=100,
                       max_batches=5000,
                       batches_in_epoch=1000,
                       verbose=True):

    batches = helpers.random_sequences(
        length_from=length_from, 
        length_to=length_to,
        vocab_lower=vocab_lower, 
        vocab_upper=vocab_upper,
        batch_size=batch_size)
    print batches
    loss_track = []
    try:
        for batch in range(max_batches+1):
            batch_data = next(batches)
#            print len(batch_data)
            fd = model.make_train_inputs(batch_data, batch_data)
#            t = session.run([model.decoder_logits_train], fd)
#            
#            print np.shape(t)
#           
#            PAD_SLICE = session.run([model.PAD_SLICE], fd)
#            print np.shape(PAD_SLICE)
#            
#            logits,targets,weight = session.run([model.logits,model.targets,model.loss_weights], fd)
#            print np.shape(logits),np.shape(targets),np.shape(weight)
#            print logits[0]
#            print targets[0]
#            print weight[0]
            
#            
#            decoder_outputs_train1 = session.run([model.decoder_outputs_train.rnn_output], fd)
#            decoder_outputs_train2 = session.run([model.decoder_outputs_train.sample_id], fd)
#            print decoder_outputs_train1
#            print np.shape(decoder_outputs_train2)
#            exit(0)

#
#            loss_weights = session.run([model.loss_weights], fd)
#            print 'loss_weights:',np.shape(loss_weights)
          
            _, l = session.run([model.train_op, model.loss], fd)
            loss_track.append(l)
#            print l
#            exit(0)
            
            if verbose:
                if batch == 0 or batch % batches_in_epoch == 0:
                    print('batch {}'.format(batch))
                    print('  minibatch loss: {}'.format(session.run(model.loss, fd)))
                    for i, (e_in, dt_pred) in enumerate(zip(
                            fd[model.encoder_inputs].T,
                            session.run(model.decoder_prediction_train, fd).T
                        )):
                        print('  sample {}:'.format(i + 1))
                        print('    enc input           > {}'.format(e_in))
                        print('    dec train predicted > {}'.format(dt_pred))
                        if i >= 2:
                            break
                    print()
    except KeyboardInterrupt:
        print('training interrupted')

    return loss_track