def reload(args):

    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = pkl.load(f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'rb') as f:
        vocab = pkl.load(f)
        idx2word = {v: k for (k, v) in vocab.items()}
        idx2word[0] = '\n'
        idx2word[1] = '<UNK>'
        with open("%s/rnn.vcb" % (model_dir), "w") as vcb_f:
            for (k, v) in vocab.items():
                vcb_f.write(k.encode("utf8", "ignore") + "\t" + str(v) + "\n")


    g1 = tf.Graph()
    with tf.device(args.device), g1.as_default():
        with tf.device(args.device):
            m = RnnAutoencoder(rnn_type=saved_args.rnn_type,
                           batch_size=saved_args.batch_size,
                           dim_emb=saved_args.dim_emb,
                           num_units=saved_args.num_units,
                           vocab_size=saved_args.vocab_size,
                           seq_len=saved_args.seq_len,
                           grad_clip=saved_args.grad_clip,
                           learning_rate=saved_args.learning_rate,
                           infer=True)
    
            inputs = m.get_inputs(infer=True)
            m.build_model(inputs, infer=True)
    
        #sym_x, sym_lx, sym_y = inputs[0], inputs[1], inputs[2]
        
        with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                              log_device_placement=False)) as sess:
            tf.initialize_all_variables().run()
            saver = tf.train.Saver(tf.all_variables())
            ckpt = tf.train.get_checkpoint_state(args.save_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)


            save_vars = {}
            for v in tf.trainable_variables():
                save_vars[v.value().name] = sess.run(v)
            g2 = tf.Graph()
            with g2.as_default():
                consts = {}
                for k in save_vars.keys():
                    consts[k] = tf.constant(save_vars[k])
                tf.import_graph_def(g1.as_graph_def(), input_map={name:consts[name] for name in consts.keys()})
                tf.train.write_graph(g2.as_graph_def(), model_dir, 'rnn.pb' , False)
                tf.train.write_graph(g2.as_graph_def(), model_dir, 'rnn.txt')
def train(args):
    data_reader = DoubanReader(args.data_dir)
    args.vocab_size = data_reader.vocab_size
    print(args.vocab_size)
    maxi = 0
    for i in data_reader.train_data:
        for j in i:
            for k in j:
                maxi = np.max([maxi, k])
    print(maxi)

    if args.init_from is not None:
        assert os.path.isdir(
            args.init_from), "%s must be a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, 'config.pkl')
        ), 'config.pkl file does not exist in path %s' % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, 'vocab.pkl')
        ), 'vocab.pkl file does not exist in path %s' % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = pkl.load(f)
        need_to_be_same = [
            'rnn_type',
            'num_units',
            'dim_emb',
        ]
        for checkme in need_to_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], 'Command line argument and saved model disagree on %s' % checkme
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pkl.dump(args, f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'wb') as f:
        pkl.dump(data_reader.vocab, f)
    g1 = tf.Graph()
    with tf.device(args.device), g1.as_default():
        m = RnnAutoencoder(
            rnn_type=args.rnn_type,
            batch_size=args.batch_size,
            dim_emb=args.dim_emb,
            num_units=args.num_units,
            vocab_size=args.vocab_size,
            seq_len=args.seq_len,
            grad_clip=args.grad_clip,
            learning_rate=args.learning_rate,
            infer=False,
        )
        inputs = m.get_inputs(infer=False)
        cost = m.build_model(inputs, infer=False)

        global_step = tf.Variable(0, name="global_step", trainable=False)

        tvars = tf.trainable_variables()
        grads = tf.gradients(cost, tvars)
        if args.grad_clip:
            grads, _ = tf.clip_by_global_norm(grads, args.grad_clip)
        optimizer = tf.train.AdamOptimizer(args.learning_rate)
        train_op = optimizer.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

        # todo: valid
        it_train = BatchIterator(len(data_reader.train_data[0]),
                                 args.batch_size,
                                 data_reader.train_data,
                                 testing=False)
        it_valid = BatchIterator(len(data_reader.valid_data[0]),
                                 args.batch_size,
                                 data_reader.valid_data,
                                 testing=False)
        it_test = BatchIterator(len(data_reader.test_data[0]),
                                args.batch_size,
                                data_reader.test_data,
                                testing=False)

        num_batches_train = len(data_reader.train_data[0]) / args.batch_size
        num_batches_valid = len(data_reader.valid_data[0]) / args.batch_size
        num_batches_test = len(data_reader.test_data[0]) / args.batch_size
        with tf.Session(config=tf.ConfigProto(
                allow_soft_placement=True, log_device_placement=True)) as sess:
            out_dir = args.save_dir
            # Train Summaries
            print "Train Summaries"

            loss_summary = tf.scalar_summary("loss", m.loss)

            train_summary_op = tf.merge_summary([loss_summary])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.train.SummaryWriter(
                train_summary_dir, sess.graph_def)

            tf.initialize_all_variables().run()
            saver = tf.train.Saver(tf.all_variables())
            if args.init_from is not None:
                saver.restore(sess, ckpt.model_checkpoint_path)
            for e in range(args.num_epochs):
                # todo: learningrate
                #state = m.initial_state.eval()
                # train
                outs = []
                cnt_words = 0
                for b in range(num_batches_train):
                    x, y = it_train.next()
                    x, mx, lx = prepare_data(x, args.seq_len)
                    y, my, ly = prepare_data(y, args.seq_len)
                    feed = dict(zip(inputs, [x, lx, y, my]))
                    out, _, step, summaries = sess.run(
                        [m.cost, train_op, global_step, train_summary_op],
                        feed)
                    outs.append(out)
                    cnt_words += np.sum(ly)
                    train_summary_writer.add_summary(summaries, step)

                    # save the model
                    if (e*num_batches_train + b + 1) % args.save_every == 0 \
                            or (e == args.num_epochs - 1 and b == num_batches_train - 1):
                        print('Save at step {}: {:.3f}'.format(
                            e * num_batches_train + b,
                            np.exp(out * args.batch_size / np.sum(ly))))
                        checkpoint_path = os.path.join(args.save_dir,
                                                       'model_ckpt')
                        saver.save(sess,
                                   checkpoint_path,
                                   global_step=e * num_batches_train + b)
                    print('Epoch {}: train loss {}'.format(
                        e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

                    current_step = (e * num_batches_train + b + 1)
                    if (current_step % args.save_every
                            == 0) or (e == args.num_epochs - 1
                                      and b == num_batches_train - 1):
                        for v in tf.trainable_variables():
                            save_vars[v.value().name] = sess.run(v)

                        g2 = tf.Graph()
                        with g2.as_default():
                            consts = {}
                            for k in save_vars.keys():
                                consts[k] = tf.constant(save_vars[k])
                            tf.import_graph_def(g1.as_graph_def(),
                                                input_map={
                                                    name: consts[name]
                                                    for name in consts.keys()
                                                })
                            tf.train.write_graph(g2.as_graph_def(), model_dir,
                                                 'cnn.pb%s' % (current_step),
                                                 False)
                            tf.train.write_graph(g2.as_graph_def(), model_dir,
                                                 'cnn.txt%s' % (current_step))
                        sys.exit(1)
                # valid
                outs = []
                cnt_words = 0
                for b in range(num_batches_valid):
                    x, y = it_valid.next()
                    x, mx, lx = prepare_data(x, args.seq_len)
                    y, my, ly = prepare_data(y, args.seq_len)
                    feed = dict(zip(inputs, [x, lx, y, my]))
                    out, = sess.run([
                        m.cost,
                    ], feed)
                    #print(np.exp(out*args.batch_size/np.sum(ly)))
                    outs.append(out)
                    cnt_words += np.sum(ly)
                print('Epoch {}: valid loss {}'.format(
                    e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

                # test
                outs = []
                cnt_words = 0
                for b in range(num_batches_test):
                    x, y = it_test.next()
                    x, mx, lx = prepare_data(x, args.seq_len)
                    y, my, ly = prepare_data(y, args.seq_len)
                    feed = dict(zip(inputs, [x, lx, y, my]))
                    out, = sess.run([
                        m.cost,
                    ], feed)
                    outs.append(out)
                    cnt_words += np.sum(ly)
                print('Epoch {}: test loss {}'.format(
                    e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))
Exemple #3
0
def sample(args):
    print(args)
    black_reply_dict = {}
    for line in open(args.black):
        black_reply_dict[line.strip().split("\t")[0].strip()] = 1
    print("black_reply_dict size:", len(black_reply_dict))

    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = pkl.load(f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'rb') as f:
        vocab = pkl.load(f)
        idx2word = {v: k for (k, v) in vocab.items()}
        idx2word[0] = '\n'
        idx2word[1] = '<UNK>'

    with tf.device(args.device):
        m = RnnAutoencoder(rnn_type=saved_args.rnn_type,
                           batch_size=saved_args.batch_size,
                           dim_emb=saved_args.dim_emb,
                           num_units=saved_args.num_units,
                           vocab_size=saved_args.vocab_size,
                           seq_len=saved_args.seq_len,
                           grad_clip=saved_args.grad_clip,
                           learning_rate=saved_args.learning_rate,
                           infer=True)

        inputs = m.get_inputs(infer=True)
        m.build_model(inputs, infer=True)

    sym_x, sym_lx, sym_y = inputs[0], inputs[1], inputs[2]

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model load done")
            while 1:
                query = sys.stdin.readline()
                if not query:
                    break
                start = time.time()
                query = query.strip()

                str_in = query
                x = str_in.split()
                for w in x:
                    print(w),
                    print(w.decode('utf-8') in vocab)
                x = [[vocab.get(w.decode('utf-8'), 1) for w in x]]
                x, mx, lx = prepare_data(x, saved_args.seq_len)
                print(x)

                #lx_in = np.asarray([len(wx_in)], dtype='int32')
                sents, scores, common_sents, common_scores = m.generate_original(sess, vocab, sym_x, sym_lx, sym_y, x, lx, \
                        args.beam_size, args.max_seq_len, idx2word, black_reply_dict)
                print("query:", query)
                print("--" * 10, "generated reply", "--" * 10)
                ss = print_predict(sents, scores, idx2word, 1)
                #for s in ss:
                #    s = s.strip().encode("utf8", "ignore")
                #    if s in black_reply_dict:
                #        continue
                #    else:
                #        print("\t%s" % (s))
                print("--" * 10, "common replys", "--" * 10)
                ss = print_predict(common_sents, common_scores, idx2word, 1)

                finish = time.time()
                generateTime = (finish - start)
                print("generating cost time :", generateTime, " ms")
def sample(args):
    print(args)
    with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f:
        saved_args = pkl.load(f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'rb') as f:
        vocab = pkl.load(f)
        idx2word = {v: k for (k, v) in vocab.items()}
        idx2word[0] = '\n'
        idx2word[1] = '<UNK>'

    with tf.device(args.device):
        m = RnnAutoencoder(rnn_type=saved_args.rnn_type,
                           batch_size=saved_args.batch_size,
                           dim_emb=saved_args.dim_emb,
                           num_units=saved_args.num_units,
                           vocab_size=saved_args.vocab_size,
                           seq_len=saved_args.seq_len,
                           grad_clip=saved_args.grad_clip,
                           learning_rate=saved_args.learning_rate,
                           infer=True)

        inputs = m.get_inputs(infer=True)
        m.build_model(inputs, infer=True)

    sym_x, sym_lx, sym_y = inputs[0], inputs[1], inputs[2]

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False)) as sess:
        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        ckpt = tf.train.get_checkpoint_state(args.save_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
            print("Model load done")
            while 1:
                query = sys.stdin.readline()
                if not query:
                    break
                query = query.strip()

                str_in = query
                x = str_in.split()
                for w in x:
                    print(w),
                    print(w.decode('utf-8') in vocab)
                x = [[vocab.get(w.decode('utf-8'), 0) for w in x]]
                x, mx, lx = prepare_data(x, saved_args.seq_len)
                print(x)

                #lx_in = np.asarray([len(wx_in)], dtype='int32')
                sents, scores = m.generate(sess, vocab, sym_x, sym_lx, sym_y, x, lx, \
                        args.beam_size, args.max_seq_len)
                sents = [[idx2word[i] for i in s] for s in sents]

                ns = []
                for s in sents:
                    tmp = ''
                    for w in s:
                        tmp = tmp + w
                    tmp = tmp.strip()
                    #ns.append(tmp.strip())
                    if tmp not in ns:
                        ns.append(tmp.strip())

                for s in ns:
                    print(s.encode('utf-8'))
def train(args):
    data_reader = DoubanReader(args.data_dir)
    args.vocab_size = data_reader.vocab_size
    print(args.vocab_size)

    if args.init_from is not None:
        assert os.path.isdir(
            args.init_from), "%s must be a path" % args.init_from
        assert os.path.isfile(os.path.join(args.init_from, 'config.pkl')), \
                'config.pkl file does not exist in path %s' % args.init_from
        assert os.path.isfile(os.path.join(args.init_from, 'vocab.pkl')), \
                'vocab.pkl file does not exist in path %s' % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = pkl.load(f)
        need_to_be_same = [
            'rnn_type',
            'num_units',
            'dim_emb',
        ]
        for checkme in need_to_be_same:
            assert vars(saved_model_args)[checkme]==vars(args)[checkme], \
                    'Command line argument and saved model disagree on %s' % checkme

    it_train = BatchIterator(len(data_reader.train_data[0]),
                             args.batch_size,
                             data_reader.train_data,
                             testing=False)
    it_valid = BatchIterator(len(data_reader.valid_data[0]),
                             args.batch_size,
                             data_reader.valid_data,
                             testing=False)
    it_test = BatchIterator(len(data_reader.test_data[0]),
                            args.batch_size,
                            data_reader.test_data,
                            testing=False)

    num_batches_train = len(data_reader.train_data[0]) / args.batch_size
    num_batches_valid = len(data_reader.valid_data[0]) / args.batch_size
    num_batches_test = len(data_reader.test_data[0]) / args.batch_size

    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pkl.dump(args, f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'wb') as f:
        pkl.dump(data_reader.vocab, f)

    with tf.Graph().as_default():
        # todo: sgd?
        opt = tf.train.AdamOptimizer(args.learning_rate)
        m = RnnAutoencoder(
            rnn_type=args.rnn_type,
            batch_size=args.batch_size / len(gpus),
            dim_emb=args.dim_emb,
            num_units=args.num_units,
            vocab_size=args.vocab_size,
            seq_len=args.seq_len,
            grad_clip=args.grad_clip,
            learning_rate=args.learning_rate,
            infer=False,
        )

        # get the input format, the batchsize is default set to None
        sx, slx, sy, smy = m.get_inputs(infer=False)

        # init the variable on cpu0, although redundant operations are also introduced.
        with tf.device('/cpu:0'):
            with tf.name_scope('cpu_aux'):
                m.build_model([sx, slx, sy, smy], infer=False)
                tf.get_variable_scope().reuse_variables()

        tower_grads = []
        total_cost = []
        for i in range(len(gpus)):
            with tf.device('/gpu:%d' % gpus[i]):
                print("USE GPU ID:", gpus[i])
                with tf.name_scope('%s_%d' % (m.TOWER_NAME, i)) as scope:
                    # split the input for each device
                    x_slice = tf.gather(
                        sx,
                        range(i * args.batch_size / len(gpus),
                              (i + 1) * args.batch_size / len(gpus)))
                    lx_slice = tf.gather(
                        slx,
                        range(i * args.batch_size / len(gpus),
                              (i + 1) * args.batch_size / len(gpus)))
                    y_slice = tf.gather(
                        sy,
                        range(i * args.batch_size / len(gpus),
                              (i + 1) * args.batch_size / len(gpus)))
                    my_slice = tf.gather(
                        smy,
                        range(i * args.batch_size / len(gpus),
                              (i + 1) * args.batch_size / len(gpus)))
                    input_slice = [x_slice, lx_slice, y_slice, my_slice]

                    cost = m.build_model(input_slice, infer=False)
                    tf.get_variable_scope().reuse_variables()

                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)
                    grads = opt.compute_gradients(cost)
                    tower_grads.append(grads)
                    total_cost.append(cost)

        grads = average_gradients(tower_grads)
        apply_gradient_op = opt.apply_gradients(grads)
        total_cost = tf.add_n(total_cost) / len(gpus)

        for var in tf.trainable_variables():
            summaries.append(tf.histogram_summary(var.op.name, var))

        saver = tf.train.Saver(tf.all_variables())
        summary_op = tf.merge_summary(summaries)
        init = tf.initialize_all_variables()
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                                log_device_placement=False))
        sess.run(init)
        #tf.train.start_queue_runners(sess=sess)
        summary_writer = tf.train.SummaryWriter(args.save_dir, sess.graph)

        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)

        for e in range(args.num_epochs):
            # train
            outs = []
            cnt_words = 0
            for b in range(num_batches_train):
                x, y = it_train.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                out, _ = sess.run([total_cost, apply_gradient_op], {
                    sx: x,
                    slx: lx,
                    sy: y,
                    smy: my,
                })
                outs.append(out)
                cnt_words += np.sum(ly)

                # save the model
                if (e*num_batches_train + b + 1) % args.save_every == 0 \
                        or (e == args.num_epochs - 1 and b == num_batches_train - 1):
                    print('Save at step {}: {:.3f}'.format(
                        e * num_batches_train + b,
                        np.exp(np.sum(outs) * args.batch_size / cnt_words)))
                    checkpoint_path = os.path.join(args.save_dir, 'model_ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * num_batches_train + b)
            print('Epoch {}: train loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

            # valid
            outs = []
            cnt_words = 0
            for b in range(num_batches_valid):
                x, y = it_valid.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                out, = sess.run([
                    cost,
                ], {
                    sx: x,
                    slx: lx,
                    sy: y,
                    smy: my,
                })
                #print(np.exp(out*args.batch_size/np.sum(ly)))
                outs.append(out)
                cnt_words += np.sum(ly)
            print('Epoch {}: valid loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

            # test
            outs = []
            cnt_words = 0
            for b in range(num_batches_test):
                x, y = it_test.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                out, = sess.run([
                    cost,
                ], {
                    sx: x,
                    slx: lx,
                    sy: y,
                    smy: my,
                })
                outs.append(out)
                cnt_words += np.sum(ly)
            print('Epoch {}: test loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))
def train(args):
    data_reader = DoubanReader(args.data_dir)
    args.vocab_size = data_reader.vocab_size
    print(args.vocab_size)
    maxi = 0
    for i in data_reader.train_data:
        for j in i:
            for k in j:
                maxi = np.max([maxi, k])
    print(maxi)

    if args.init_from is not None:
        assert os.path.isdir(
            args.init_from), "%s must be a path" % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, 'config.pkl')
        ), 'config.pkl file does not exist in path %s' % args.init_from
        assert os.path.isfile(
            os.path.join(args.init_from, 'vocab.pkl')
        ), 'vocab.pkl file does not exist in path %s' % args.init_from
        ckpt = tf.train.get_checkpoint_state(args.init_from)
        assert ckpt, 'No checkpoint found'
        assert ckpt.model_checkpoint_path, 'No model path found in checkpoint'

        with open(os.path.join(args.init_from, 'config.pkl')) as f:
            saved_model_args = pkl.load(f)
        need_to_be_same = [
            'rnn_type',
            'num_units',
            'dim_emb',
        ]
        for checkme in need_to_be_same:
            assert vars(saved_model_args)[checkme] == vars(
                args
            )[checkme], 'Command line argument and saved model disagree on %s' % checkme
    with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
        pkl.dump(args, f)
    with open(os.path.join(args.save_dir, 'vocab.pkl'), 'wb') as f:
        pkl.dump(data_reader.vocab, f)
    with tf.device(args.device):
        m = RnnAutoencoder(
            rnn_type=args.rnn_type,
            batch_size=args.batch_size,
            dim_emb=args.dim_emb,
            num_units=args.num_units,
            vocab_size=args.vocab_size,
            seq_len=args.seq_len,
            grad_clip=args.grad_clip,
            learning_rate=args.learning_rate,
            infer=False,
        )
        inputs = m.get_inputs(infer=False)
        cost = m.build_model(inputs, infer=False)

        global_step = tf.Variable(0, name="global_step", trainable=False)

        tvars = tf.trainable_variables()
        grads = tf.gradients(cost, tvars)
        if args.grad_clip:
            grads, _ = tf.clip_by_global_norm(grads, args.grad_clip)
        optimizer = tf.train.AdamOptimizer(args.learning_rate)
        train_op = optimizer.apply_gradients(zip(grads, tvars),
                                             global_step=global_step)

    # todo: valid
    it_train = BatchIterator(len(data_reader.train_data[0]),
                             args.batch_size,
                             data_reader.train_data,
                             testing=False)
    it_valid = BatchIterator(len(data_reader.valid_data[0]),
                             args.batch_size,
                             data_reader.valid_data,
                             testing=False)
    it_test = BatchIterator(len(data_reader.test_data[0]),
                            args.batch_size,
                            data_reader.test_data,
                            testing=False)

    num_batches_train = len(data_reader.train_data[0]) / args.batch_size
    num_batches_valid = len(data_reader.valid_data[0]) / args.batch_size
    num_batches_test = len(data_reader.test_data[0]) / args.batch_size
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=True)) as sess:
        '''
        # Keep track of gradient values and sparsity (optional)
        grad_summaries = []
        for g, v in grads_and_vars:
            if g is not None:
                #grad_hist_summary = tf.histogram_summary("{}/grad/hist".format(v.name), g)
                sparsity_summary = tf.scalar_summary("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                #grad_summaries.append(grad_hist_summary)
                grad_summaries.append(sparsity_summary)
        grad_summaries_merged = tf.merge_summary(grad_summaries)
        
        # Output directory for models and summaries
        out_dir = os.path.abspath(os.path.join(os.path.curdir, FLAGS.model_dir))
        print("Writing to {}\n".format(out_dir))
        
        # Summaries for loss and accuracy
        loss_summary = tf.scalar_summary("loss", cnn.loss)
        acc_summary = tf.scalar_summary("accuracy", cnn.accuracy)
        '''
        out_dir = args.save_dir
        # Train Summaries
        print "Train Summaries"

        loss_summary = tf.scalar_summary("loss", m.loss)

        train_summary_op = tf.merge_summary([loss_summary])
        train_summary_dir = os.path.join(out_dir, "summaries", "train")
        train_summary_writer = tf.train.SummaryWriter(train_summary_dir,
                                                      sess.graph_def)

        tf.initialize_all_variables().run()
        saver = tf.train.Saver(tf.all_variables())
        if args.init_from is not None:
            saver.restore(sess, ckpt.model_checkpoint_path)
        for e in range(args.num_epochs):
            # todo: learningrate
            #state = m.initial_state.eval()
            # train
            outs = []
            cnt_words = 0
            for b in range(num_batches_train):
                x, y = it_train.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                feed = dict(zip(inputs, [x, lx, y, my]))
                out, _, step, summaries = sess.run(
                    [m.cost, train_op, global_step, train_summary_op], feed)
                outs.append(out)
                cnt_words += np.sum(ly)
                train_summary_writer.add_summary(summaries, step)

                # save the model
                if (e*num_batches_train + b + 1) % args.save_every == 0 \
                        or (e == args.num_epochs - 1 and b == num_batches_train - 1):
                    print('Save at step {}: {:.3f}'.format(
                        e * num_batches_train + b,
                        np.exp(out * args.batch_size / np.sum(ly))))
                    checkpoint_path = os.path.join(args.save_dir, 'model_ckpt')
                    saver.save(sess,
                               checkpoint_path,
                               global_step=e * num_batches_train + b)
            print('Epoch {}: train loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

            # valid
            outs = []
            cnt_words = 0
            for b in range(num_batches_valid):
                x, y = it_valid.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                feed = dict(zip(inputs, [x, lx, y, my]))
                out, = sess.run([
                    m.cost,
                ], feed)
                #print(np.exp(out*args.batch_size/np.sum(ly)))
                outs.append(out)
                cnt_words += np.sum(ly)
            print('Epoch {}: valid loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))

            # test
            outs = []
            cnt_words = 0
            for b in range(num_batches_test):
                x, y = it_test.next()
                x, mx, lx = prepare_data(x, args.seq_len)
                y, my, ly = prepare_data(y, args.seq_len)
                feed = dict(zip(inputs, [x, lx, y, my]))
                out, = sess.run([
                    m.cost,
                ], feed)
                outs.append(out)
                cnt_words += np.sum(ly)
            print('Epoch {}: test loss {}'.format(
                e, np.exp(np.sum(outs) * args.batch_size / cnt_words)))