示例#1
0
def train():
    data_loader = InputHelper()
    data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file,
                                  FLAGS.data_dir + '/')
    data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.train_file,
                               FLAGS.batch_size, FLAGS.sequence_length)
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes

    test_data_loader = InputHelper()
    test_data_loader.load_dictionary(FLAGS.data_dir + '/dictionary')
    test_data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.test_file,
                                    1000, FLAGS.sequence_length)

    model = BiRNN(FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size,
                  FLAGS.batch_size, FLAGS.sequence_length, FLAGS.n_classes,
                  FLAGS.grad_clip)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        for e in xrange(FLAGS.num_epochs):
            data_loader.reset_batch()
            sess.run(
                tf.assign(model.lr,
                          FLAGS.learning_rate * (FLAGS.decay_rate**e)))
            for b in xrange(data_loader.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.output_keep_prob: FLAGS.dropout_keep_prob
                }
                train_loss, _ = sess.run([model.cost, model.train_op],
                                         feed_dict=feed)
                end = time.time()
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                                .format(e * data_loader.num_batches + b,
                                        FLAGS.num_epochs * data_loader.num_batches,
                                        e, train_loss, end - start))

            test_data_loader.reset_batch()
            for i in xrange(test_data_loader.num_batches):
                test_x, test_y = test_data_loader.next_batch()
                feed = {
                    model.input_data: test_x,
                    model.targets: test_y,
                    model.output_keep_prob: 1.0
                }
                accuracy = sess.run(model.accuracy, feed_dict=feed)
                print 'accuracy:{0}'.format(accuracy)

            checkpoint_path = os.path.join(FLAGS.save_dir, 'model.ckpt')
            saver.save(sess,
                       checkpoint_path,
                       global_step=e * data_loader.num_batches)
            print 'model saved to {}'.format(checkpoint_path)
def train():
    data_loader = InputHelper()
    data_loader.create_dictionary(FLAGS.data_dir + '/' + FLAGS.train_file,
                                  FLAGS.data_dir + '/')
    data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.train_file,
                               FLAGS.batch_size, FLAGS.sequence_length)
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes
    FLAGS.num_batches = data_loader.num_batches

    test_data_loader = InputHelper()
    test_data_loader.load_dictionary(FLAGS.data_dir + '/dictionary')
    test_data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.test_file,
                                    100, FLAGS.sequence_length)

    if FLAGS.pre_trained_vec:
        embeddings = np.load(FLAGS.pre_trained_vec)
        print(embeddings.shape)
        FLAGS.vocab_size = embeddings.shape[0]
        FLAGS.embedding_size = embeddings.shape[1]

    if FLAGS.init_from is not None:
        assert os.path.isdir(FLAGS.init_from), '{} must be a directory'.format(
            FLAGS.init_from)
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

    # Define specified Model
    model = BiRNN(embedding_size=FLAGS.embedding_size,
                  rnn_size=FLAGS.rnn_size,
                  layer_size=FLAGS.layer_size,
                  vocab_size=FLAGS.vocab_size,
                  attn_size=FLAGS.attn_size,
                  sequence_length=FLAGS.sequence_length,
                  n_classes=FLAGS.n_classes,
                  grad_clip=FLAGS.grad_clip,
                  learning_rate=FLAGS.learning_rate)

    # define value for tensorboard
    tf.summary.scalar('train_loss', model.cost)
    tf.summary.scalar('accuracy', model.accuracy)
    merged = tf.summary.merge_all()

    # 调整GPU内存分配方案
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    with tf.Session(config=tf_config) as sess:
        train_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        # using pre trained embeddings
        if FLAGS.pre_trained_vec:
            sess.run(model.embedding.assign(embeddings))
            del embeddings

        # restore model
        if FLAGS.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)

        total_steps = FLAGS.num_epochs * FLAGS.num_batches
        for e in range(FLAGS.num_epochs):
            data_loader.reset_batch()
            for b in range(FLAGS.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.output_keep_prob: FLAGS.dropout_keep_prob
                }
                train_loss, summary, _ = sess.run(
                    [model.cost, merged, model.train_op], feed_dict=feed)
                end = time.time()

                global_step = e * FLAGS.num_batches + b

                print(
                    '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'
                    .format(global_step, total_steps, e, train_loss,
                            end - start))

                if global_step % 20 == 0:
                    train_writer.add_summary(summary,
                                             e * FLAGS.num_batches + b)

                if global_step % FLAGS.save_steps == 0:
                    checkpoint_path = os.path.join(FLAGS.save_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
                    print('model saved to {}'.format(checkpoint_path))

            test_data_loader.reset_batch()
            test_accuracy = []
            for i in range(test_data_loader.num_batches):
                test_x, test_y = test_data_loader.next_batch()
                feed = {
                    model.input_data: test_x,
                    model.targets: test_y,
                    model.output_keep_prob: 1.0
                }
                accuracy = sess.run(model.accuracy, feed_dict=feed)
                test_accuracy.append(accuracy)
            print('test accuracy:{0}'.format(np.average(test_accuracy)))
示例#3
0
def train():
    #train data load
    data_loader = InputHelper(log=log)
    data_loader.load_embedding(FLAGS.embedding_file,FLAGS.embedding_size)
    train_data = data_loader.load_data(FLAGS.data_dir+'/'+FLAGS.train_file, FLAGS.data_dir+'/',FLAGS.interaction_rounds,FLAGS.sequence_length)
    x_batch,y_batch,train_interaction_point, train_word_point = data_loader.generate_batches(train_data,FLAGS.batch_size,FLAGS.interaction_rounds)
    FLAGS.vocab_size = len(data_loader.word2idx)
    FLAGS.n_classes = len(data_loader.label_dictionary)
    print FLAGS.n_classes
    FLAGS.num_batches = data_loader.num_batches
    FLAGS.embeddings = data_loader.embeddings
    # test data load
    test_data_loader = InputHelper(log=log)
    test_data_loader.load_info(embeddings=FLAGS.embeddings,word2idx=data_loader.word2idx,idx2word=data_loader.idx2word,
                                   label_dictionary=data_loader.label_dictionary)
    test_data = test_data_loader.load_data(FLAGS.data_dir + '/' + FLAGS.test_file, FLAGS.data_dir + '/',
                                       FLAGS.interaction_rounds, FLAGS.sequence_length)
    test_x_batch, test_y_batch, test_interaction_point,test_word_point = test_data_loader.generate_batches(test_data, FLAGS.batch_size, FLAGS.interaction_rounds)
    # Define specified Model
    model = BiRNN(embedding_size=FLAGS.embedding_size, rnn_size=FLAGS.rnn_size, layer_size=FLAGS.layer_size,
        vocab_size=FLAGS.vocab_size, attn_size=FLAGS.attn_size, sequence_length=FLAGS.sequence_length,
                n_classes=FLAGS.n_classes, interaction_rounds=FLAGS.interaction_rounds, batch_size=FLAGS.batch_size,
                  embeddings=FLAGS.embeddings,grad_clip=FLAGS.grad_clip, learning_rate=FLAGS.learning_rate)
    # define value for tensorboard
    tf.summary.scalar('train_loss', model.cost)
    tf.summary.scalar('accuracy', model.accuracy)
    merged = tf.summary.merge_all()

    # 调整GPU内存分配方案
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    with tf.Session(config=tf_config) as sess:
        train_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(max_to_keep=1000)
        
        total_steps = FLAGS.num_epochs * FLAGS.num_batches
        for e in xrange(FLAGS.num_epochs):
            data_loader.reset_batch()
            e_avg_loss = []
            t_acc = []
            start = time.time()
            num_tt = []
#            w=open('temp/pre'+str(e)+'.txt','w')
            for b in xrange(FLAGS.num_batches):

                x, y, z,m = data_loader.next_batch(x_batch,y_batch,train_interaction_point,train_word_point)
                feed = {model.input_data:x, model.targets:y, model.output_keep_prob:FLAGS.dropout_keep_prob, model.word_point:m, model.sentence_point:z}
                train_loss,t_accs,yy,yyy, summary,  _ = sess.run([model.cost, model.accuracy,model.y_results,
                model.y_tr, merged, model.train_op], feed_dict=feed)
                e_avg_loss.append(train_loss)
                t_acc.append(t_accs)
                global_step = e * FLAGS.num_batches + b
                if global_step % 20 == 0:
                    train_writer.add_summary(summary, e * FLAGS.num_batches + b)
                num_t = 0
                for i in range(len(yy)):
                    if yy[i] == yyy[i] and yy[i] != 4:
                        num_t+=1
                num_tt.append(num_t*1.0/len(yy))
#                w.write('predict '+str(len(yy))+'\n')
#                for y in yy:
#                    w.write(str(y)+'\t')
#                w.write('\ntrue '+str(len(yyy))+'\n')
#                for ys in yyy:
#                    w.write(str(ys)+'\t')
#                w.write('\n')
#           w.close()


            # model test
            test_data_loader.reset_batch()
            test_accuracy = []
            test_a = []
            for i in xrange(test_data_loader.num_batches):
                test_x, test_y, test_z, test_m = test_data_loader.next_batch(test_x_batch,test_y_batch,test_interaction_point,test_word_point)
                feed = {model.input_data:test_x, model.targets:test_y, model.output_keep_prob:1.0,model.word_point:test_m, model.sentence_point:test_z}
                accuracy,y_p,y_r = sess.run([model.accuracy,model.y_results,model.y_tr],feed_dict=feed)
                test_accuracy.append(accuracy)
                num_test = 0
                for j in range(len(y_p)):
                    if y_p[j] == y_r[j] and y_p[j] != 4:
                        num_test+=1
                test_a.append(num_test*1.0/len(y_p))
            end = time.time()
            num_tt_acc = np.average(num_tt)
            num_test_acc = np.average(test_a)
            avg_loss = np.average(e_avg_loss)
            print('e{},loss = {:.3f}, train_acc = {:.3f}, test_acc = {:.3f}, time/epoch'.format(e,avg_loss,num_tt_acc,num_test_acc,end - start ))
            #print and save
#            avg_loss = np.average(e_avg_loss)
#            t_avg_acc = np.average(t_acc)
#            log.info('epoch {}, train_loss = {:.3f},train_acc = {:.3f} test_accuracy:{:.3f}, time/epoch = {:.3f}'.format(e, avg_loss,t_avg_acc,np.average(test_accuracy), end - start))
            checkpoint_path = os.path.join(FLAGS.save_dir, 'model.ckpt')
            saver.save(sess, checkpoint_path, global_step=e)
def train():
    data_loader = InputHelper()
    # 创建词典
    data_loader.load_file()
    data_loader.create_dictionary_v2(FLAGS.save_dir + '/')
    x_train = data_loader.data_token(data_loader.x_train)
    data_loader.create_batches(x_train, data_loader.y_train, FLAGS.batch_size,
                               FLAGS.sequence_length)
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes
    FLAGS.num_batches = data_loader.num_batches
    test_data_loader = InputHelper()
    test_data_loader.load_dictionary(FLAGS.save_dir + '/dictionary',
                                     data_loader.y_train)

    x_test = data_loader.data_token(data_loader.x_test)
    test_data_loader.create_batches(x_test, data_loader.y_test, 100,
                                    FLAGS.sequence_length)
    embeddings_reshape = None
    if FLAGS.pre_trained_vec_path:
        print('将原始的embedding矩阵重置')
        embeddings = np.load(FLAGS.pre_trained_vec_path +
                             '/word2vec.model.wv.vectors.npy',
                             allow_pickle=True)
        model = Word2Vec.load(FLAGS.pre_trained_vec_path + '/word2vec.model')
        embeddings_reshape = np.zeros(embeddings.shape)
        print('embeddings_shape:', embeddings_reshape.shape)
        dic = data_loader.token_dictionary
        print(len(dic))
        i = 20
        for word in model.wv.index2word:
            tmp = dic[word]
            if tmp < i:
                i = tmp
                print(i)
            embeddings_reshape[tmp] = model.wv[word]
        #print(embeddings_reshape[0])
        #print(embeddings_reshape[dic['e850']])
        """
        embeddings_reshape = tf.get_variable(name="W", shape=embeddings_reshape.shape,
                            initializer=tf.constant_initializer(embeddings_reshape),
                            trainable=False)
        """

        print(embeddings_reshape.shape)
        FLAGS.vocab_size = embeddings_reshape.shape[0]
        FLAGS.embedding_size = embeddings_reshape.shape[1]
    '''
    if FLAGS.init_from is not None:#   fine tune condition
            #断言,传参数前捕获参数异常
            assert os.path.isdir(FLAGS.init_from), '{} must be a directory'.format(FLAGS.init_from)
            ckpt = tf.train.get_checkpoint_state(FLAGS.init_from)
            assert ckpt,'No checkpoint found'
            assert ckpt.model_checkpoint_path,'No model path found in checkpoint'
    '''

    print('create model...')
    # Define specified Model
    model = BiLSTM(embedding_size=FLAGS.embedding_size,
                   rnn_size=FLAGS.rnn_size,
                   vocab_size=FLAGS.vocab_size,
                   sequence_length=FLAGS.sequence_length,
                   n_classes=FLAGS.n_classes,
                   learning_rate=FLAGS.learning_rate,
                   embedding_w=embeddings_reshape)

    # define value for tensorboard
    tf.summary.scalar('train_loss', model.loss)
    tf.summary.scalar('accuracy', model.accuracy)
    merged = tf.summary.merge_all()

    # 调整GPU内存分配方案
    #tf_config = tf.ConfigProto()
    #tf_config.gpu_options.allow_growth = True
    init = tf.global_variables_initializer()
    print('start training...')
    with tf.Session() as sess:  #tf.Session(config=tf_config) as sess:
        train_writer = tf.summary.FileWriter(FLAGS.logs_dir, sess.graph)

        saver = tf.train.Saver(tf.global_variables())

        # using pre trained embeddings
        # if FLAGS.pre_trained_vec_path:
        #     sess.run(model.embedding.assign(embeddings_reshape))#替换为embeddings
        #     del embeddings
        #     del embeddings_reshape

        # restore model
        sess.run(init)
        sess.run(tf.local_variables_initializer())
        total_steps = FLAGS.num_epochs * FLAGS.num_batches

        for e in range(FLAGS.num_epochs):
            data_loader.reset_batch()  # 重新洗牌
            for b in range(FLAGS.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                #print(x.shape,y.shape)
                #print(x[0],y[0])

                feed = {
                    model.input_data: x,
                    model.targets: y,
                    model.output_keep_prob: FLAGS.dropout_keep_prob
                }
                train_loss, summary, _, accuracy = sess.run(
                    [model.loss, merged, model.train_op, model.accuracy],
                    feed_dict=feed)
                end = time.time()

                global_step = e * FLAGS.num_batches + b

                print(
                    '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f},acc = {:.3f}'
                    .format(global_step, total_steps, e, train_loss,
                            end - start, accuracy))

                if global_step % 20 == 0:
                    train_writer.add_summary(summary,
                                             e * FLAGS.num_batches + b)

                if global_step % FLAGS.save_steps == 0:
                    checkpoint_path = os.path.join(FLAGS.save_dir,
                                                   'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
                    print('model saved to {}'.format(checkpoint_path))
            test_data_loader.reset_batch()
            test_accuracy = []
            for i in range(test_data_loader.num_batches):
                test_x, test_y = test_data_loader.next_batch()
                feed = {
                    model.input_data: test_x,
                    model.targets: test_y,
                    model.output_keep_prob: 1.0
                }
                accuracy = sess.run(model.accuracy, feed_dict=feed)
                test_accuracy.append(accuracy)
            print(np.average(test_accuracy))
示例#5
0
文件: train.py 项目: gfmei/python
def train():
    data_loader = InputHelper('data/stop_words.pkl')
    data_loader.create_dictionary(FLAGS.data_dir+'/'+FLAGS.train_file, FLAGS.data_dir+'/')
    data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.train_file, FLAGS.batch_size, FLAGS.sequence_length)
    FLAGS.vocab_size = data_loader.vocab_size
    FLAGS.n_classes = data_loader.n_classes
    FLAGS.num_batches = data_loader.num_batches

    test_data_loader = InputHelper('data/stop_words.pkl')
    test_data_loader.load_dictionary(FLAGS.data_dir + '/dictionary')
    test_data_loader.create_batches(FLAGS.data_dir + '/' + FLAGS.test_file, 100, FLAGS.sequence_length)

    if FLAGS.pre_trained_vec:
        embeddings = np.load(FLAGS.pre_trained_vec)
        print(embeddings.shape)
        FLAGS.vocab_size = embeddings.shape[0]
        FLAGS.embedding_size = embeddings.shape[1]

    if FLAGS.init_from is not None:
        assert os.path.isdir(FLAGS.init_from), '{} must be a directory'.format(FLAGS.init_from)
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

    # Define specified Model
    model = AttentionBiRNN(embedding_size=FLAGS.embedding_size, rnn_size=FLAGS.rnn_size, layer_size=FLAGS.layer_size,
                           vocab_size=FLAGS.vocab_size, attn_size=FLAGS.attn_size,
                           sequence_length=FLAGS.sequence_length,
                           n_classes=FLAGS.n_classes, grad_clip=FLAGS.grad_clip, learning_rate=FLAGS.learning_rate)

    # define value for tensor_board
    tf.summary.scalar('train_loss', model.cost)
    tf.summary.scalar('accuracy', model.accuracy)
    merged = tf.summary.merge_all()

    # 调整GPU内存分配方案
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True

    with tf.Session(config=tf_config) as sess:
        train_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        # using pre trained embeddings
        if FLAGS.pre_trained_vec:
            sess.run(model.embedding.assign(embeddings))
            del embeddings

        # restore model
        if FLAGS.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)

        total_steps = FLAGS.num_epochs * FLAGS.num_batches
        for e in range(FLAGS.num_epochs):
            data_loader.reset_batch()
            for b in range(FLAGS.num_batches):
                start = time.time()
                x, y = data_loader.next_batch()
                feed = {model.input_data: x, model.targets: y, model.output_keep_prob: FLAGS.dropout_keep_prob}
                train_loss, summary, acc, _ = sess.run([model.cost, merged, model.accuracy, model.train_op],
                                                       feed_dict=feed)
                end = time.time()

                global_step = e * FLAGS.num_batches + b

                print('{}/{}(epoch {}), train_loss = {:.3f}, time/batch = {:.3f}, accuracy = {:.3f}'
                      .format(global_step, total_steps, e, train_loss, end - start, acc))

                if global_step % 20 == 0:
                    train_writer.add_summary(summary, e * FLAGS.num_batches + b)

                if global_step % FLAGS.save_steps == 0:
                    checkpoint_path = os.path.join(FLAGS.save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
                    print('model saved to {}'.format(checkpoint_path))

            test_data_loader.reset_batch()
            test_accuracy, test_loss = [], []
            for i in range(test_data_loader.num_batches):
                test_x, test_y = test_data_loader.next_batch()
                feed = {model.input_data: test_x, model.targets: test_y, model.output_keep_prob: 1.0}
                loss, accuracy = sess.run([model.cost, model.accuracy], feed_dict=feed)
                test_accuracy.append(accuracy)
                test_loss.append(loss)
            print('test_loss:{:.5f}, test accuracy:{:.5f}'.format(np.average(test_loss), np.average(test_accuracy)))
示例#6
0
def train():

    train_data_loader = InputHelper(FLAGS.data_dir, FLAGS.train_file,
                                    FLAGS.batch_size, FLAGS.sequence_length)
    FLAGS.num_batches = train_data_loader.num_batches
    FLAGS.vocab_size = len(train_data_loader.vocab_processor.vocabulary_)
    print len(train_data_loader.vocab_processor.vocabulary_)

    if FLAGS.init_from is not None:
        assert os.path.isdir(FLAGS.init_from), '{} must be a directory'.format(
            FLAGS.init_from)
        ckpt = tf.train.get_checkpoint_state(FLAGS.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

    model = SiameseLSTM(FLAGS.rnn_size, FLAGS.layer_size, FLAGS.vocab_size,
                        FLAGS.sequence_length, FLAGS.dropout_keep_prob,
                        FLAGS.grad_clip)

    tf.summary.scalar('train_loss', model.cost)
    merged = tf.summary.merge_all()

    with tf.Session() as sess:
        train_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver(tf.global_variables())

        # restore model
        if FLAGS.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in xrange(FLAGS.num_epochs):
            train_data_loader.reset_batch()
            b = 0
            while not train_data_loader.eos:
                b += 1
                start = time.time()
                x1_batch, x2_batch, y_batch = train_data_loader.next_batch()
                # random exchange x1_batch and x2_batch
                if random.random() > 0.5:
                    feed = {
                        model.input_x1: x1_batch,
                        model.input_x2: x2_batch,
                        model.y_data: y_batch
                    }
                else:
                    feed = {
                        model.input_x1: x2_batch,
                        model.input_x2: x1_batch,
                        model.y_data: y_batch
                    }
                train_loss, summary, _ = sess.run(
                    [model.cost, merged, model.train_op], feed_dict=feed)
                end = time.time()
                print '{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}'.format(
                    e * FLAGS.num_batches + b,
                    FLAGS.num_epochs * FLAGS.num_batches, e, train_loss,
                    end - start)

                if (e * FLAGS.num_batches + b) % 500 == 0:  #500
                    checkpoint_path = os.path.join(FLAGS.save_dir,
                                                   'model.ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * FLAGS.num_batches + b)
                    print 'model saved to {}'.format(checkpoint_path)

                if b % 20 == 0:
                    train_writer.add_summary(summary,
                                             e * FLAGS.num_batches + b)