Beispiel #1
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode=mode,
                                                           type=level)
        model = model_fn(args, maxTimeSteps)

        # count the num of params
        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            for epoch in range(num_epochs):
                ## training
                start = time.time()
                if mode == 'train':
                    print('Epoch', epoch + 1, '...')
                batchErrors = np.zeros(len(batchedData))
                batchRandIxs = np.random.permutation(len(batchedData))

                for batch, batchOrigI in enumerate(batchRandIxs):
                    batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                        batchOrigI]
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }

                    if level == 'cha':
                        if mode == 'train':
                            _, l, pre, y, er = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY, model.errorRate
                            ],
                                                        feed_dict=feedDict)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er / batch_size))

                        elif mode == 'test':
                            l, pre, y, er = sess.run([
                                model.loss, model.predictions, model.targetY,
                                model.errorRate
                            ],
                                                     feed_dict=feedDict)
                            batchErrors[batch] = er
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er / batch_size))

                    elif level == 'phn':
                        if mode == 'train':
                            _, l, pre, y = sess.run([
                                model.optimizer, model.loss, model.predictions,
                                model.targetY
                            ],
                                                    feed_dict=feedDict)

                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), epoch + 1,
                                        num_epochs, l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)
                        elif mode == 'test':
                            l, pre, y = sess.run(
                                [model.loss, model.predictions, model.targetY],
                                feed_dict=feedDict)
                            er = get_edit_distance([pre.values], [y.values],
                                                   True, level)
                            print(
                                '\n{} mode, total:{},batch:{}/{},test loss={:.3f},mean test PER={:.3f}\n'
                                .format(level, totalN, batch + 1,
                                        len(batchRandIxs), l, er))
                            batchErrors[batch] = er * len(batchSeqLengths)

                    # NOTE:
                    if er / batch_size == 1.0:
                        break

                    if batch % 30 == 0:
                        print('Truth:\n' + output_to_sequence(y, type=level))
                        print('Output:\n' +
                              output_to_sequence(pre, type=level))

                    if mode == 'train' and (
                        (epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or
                        (epoch == num_epochs - 1
                         and batch == len(batchRandIxs) - 1)):
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                end = time.time()
                delta_time = end - start
                print('Epoch ' + str(epoch + 1) + ' needs time:' +
                      str(delta_time) + ' s')

                if mode == 'train':
                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

                if mode == 'test':
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type=level) + '\n')
                        result.write(
                            output_to_sequence(pre, type=level) + '\n')
                        result.write('\n')
                    epochER = batchErrors.sum() / totalN
                    print(' test error rate:', epochER)
                    logging(model, logfile, epochER, mode=mode)
Beispiel #2
0
    def run(self):
        # load data
        args_dict = self._default_configs()

        args = dotdict(args_dict)

        feature_dirs, label_dirs = get_data(datadir, level, train_dataset,
                                            mode)

        FL_pair = list(zip(feature_dirs, label_dirs))

        feature_dirs, label_dirs = zip(*FL_pair)

        for feature_dir, label_dir in zip(feature_dirs, label_dirs):

            batchedData, maxTimeSteps, totalN = self.load_data(
                feature_dir, label_dir, mode, level)

            model = model_fn(args, maxTimeSteps)

            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params

            with tf.Session(graph=model.graph) as sess:

                print('Initializing')
                sess.run(model.initial_op)

                for epoch in range(num_epochs):
                    ## training
                    start = time.time()
                    if mode == 'train':
                        print('Epoch {} ...'.format(epoch + 1))

                    avgloss = 0

                    batchErrors = np.zeros(len(batchedData))

                    batchRandIxs = np.arange(len(batchedData))

                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                            batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse

                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }

                        if level == 'cha':
                            if mode == 'train':

                                _, l, pre, y, er = sess.run([
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate
                                ],
                                                            feed_dict=feedDict)

                                batchErrors[batch] = er

                                avgloss = avgloss + l

                                print(
                                    '\n{} mode, total:{},batch:{}/{},epoch:{}/{},train loss={:.4f},mean train CER={:.4f}\n'
                                    .format(level, totalN, batch + 1,
                                            len(batchRandIxs), epoch + 1,
                                            num_epochs, l, er / batch_size))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if mode == 'train':
                        ctcloss = avgloss / (8440 / batch_size)
                        print('avarage CTC loss = ', ctcloss)

                        if (epoch + 1) % 1 == 0:
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch + 1)
                            print('Model has been saved in {}'.format(savedir))
                        epochER = batchErrors.sum() / totalN
                        print('Epoch', epoch + 1, 'mean train error rate:',
                              epochER)
                        if epoch < 1:
                            logging(model,
                                    logfile,
                                    epochER,
                                    ctcloss,
                                    epoch,
                                    delta_time,
                                    mode='config')
                        logging(model,
                                logfile,
                                epochER,
                                ctcloss,
                                epoch,
                                delta_time,
                                mode=mode)
                        if epochER <= 0.003:
                            break
Beispiel #3
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)
        feature_dirs, label_dirs = get_data(datadir, level, train_dataset, dev_dataset, test_dataset, mode)
        batchedData, maxTimeSteps, totalN = self.load_data(feature_dirs[0], label_dirs[0], mode, level)
        model = model_fn(args, maxTimeSteps)

        ## shuffle feature_dir and label_dir by same order
        FL_pair = list(zip(feature_dirs, label_dirs))
        random.shuffle(FL_pair)
        feature_dirs, label_dirs = zip(*FL_pair)

        for feature_dir, label_dir in zip(feature_dirs, label_dirs):
            id_dir = feature_dirs.index(feature_dir)
            print('dir id:{}'.format(id_dir))
            batchedData, maxTimeSteps, totalN = self.load_data(feature_dir, label_dir, mode, level)
            model = model_fn(args, maxTimeSteps)
            num_params = count_params(model, mode='trainable')
            all_num_params = count_params(model, mode='all')
            model.config['trainable params'] = num_params
            model.config['all params'] = all_num_params
            print(model.config)
            with tf.Session(graph=model.graph) as sess:
                # restore from stored model
                if keep == True:
                    ckpt = tf.train.get_checkpoint_state(savedir)
                    if ckpt and ckpt.model_checkpoint_path:
                        model.saver.restore(sess, ckpt.model_checkpoint_path)
                        print('Model restored from:' + savedir)
                else:
                    print('Initializing')
                    sess.run(model.initial_op)

                for epoch in range(num_epochs):
                    ## training
                    start = time.time()
                    if mode == 'train':
                        print('Epoch {} ...'.format(epoch + 1))

                    batchErrors = np.zeros(len(batchedData))
                    batchRandIxs = np.random.permutation(len(batchedData))

                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = batchedData[batchOrigI]
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {model.inputX: batchInputs, model.targetIxs: batchTargetIxs,
                                    model.targetVals: batchTargetVals, model.targetShape: batchTargetShape,
                                    model.seqLengths: batchSeqLengths}

                        if level == 'cha':
                            if mode == 'train':
                                _, l, pre, y, er = sess.run([model.optimizer, model.loss,
                                    model.predictions, model.targetY, model.errorRate],
                                    feed_dict=feedDict)

                                batchErrors[batch] = er
                                print('\n{} mode, total:{},subdir:{}/{},batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'.format(
                                    level, totalN, id_dir+1, len(feature_dirs), batch+1, len(batchRandIxs), epoch+1, num_epochs, l, er/batch_size))

                            elif mode == 'dev':
                                l, pre, y, er = sess.run([model.loss, model.predictions, 
                                    model.targetY, model.errorRate], feed_dict=feedDict)
                                batchErrors[batch] = er
                                print('\n{} mode, total:{},subdir:{}/{},batch:{}/{},dev loss={:.3f},mean dev CER={:.3f}\n'.format(
                                    level, totalN, id_dir+1, len(feature_dirs), batch+1, len(batchRandIxs), l, er/batch_size))

                            elif mode == 'test':
                                l, pre, y, er = sess.run([model.loss, model.predictions, 
                                    model.targetY, model.errorRate], feed_dict=feedDict)
                                batchErrors[batch] = er
                                print('\n{} mode, total:{},subdir:{}/{},batch:{}/{},test loss={:.3f},mean test CER={:.3f}\n'.format(
                                    level, totalN, id_dir+1, len(feature_dirs), batch+1, len(batchRandIxs), l, er/batch_size))
                        elif level=='seq2seq':
                            raise ValueError('level %s is not supported now'%str(level))


                        # NOTE:
                        if er / batch_size == 1.0:
                            break

                        if batch % 20 == 0:
                            print('Truth:\n' + output_to_sequence(y, type=level))
                            print('Output:\n' + output_to_sequence(pre, type=level))

                    
                        if mode=='train' and ((epoch * len(batchRandIxs) + batch + 1) % 20 == 0 or (
                            epoch == num_epochs - 1 and batch == len(batchRandIxs) - 1)):
                            checkpoint_path = os.path.join(savedir, 'model.ckpt')
                            model.saver.save(sess, checkpoint_path, global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' + str(delta_time) + ' s')

                    if mode=='train':
                        if (epoch + 1) % 1 == 0:
                            checkpoint_path = os.path.join(savedir, 'model.ckpt')
                            model.saver.save(sess, checkpoint_path, global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))
                        epochER = batchErrors.sum() / totalN
                        print('Epoch', epoch + 1, 'mean train error rate:', epochER)
                        logging(model, logfile, epochER, epoch, delta_time, mode='config')
                        logging(model, logfile, epochER, epoch, delta_time, mode=mode)

                    if mode=='test' or mode=='dev':
                        with open(os.path.join(resultdir, level + '_result.txt'), 'a') as result:
                            result.write(output_to_sequence(y, type=level) + '\n')
                            result.write(output_to_sequence(pre, type=level) + '\n')
                            result.write('\n')
                        epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)
Beispiel #4
0
    def test(self):
        # load data
        args = self.args
        batchedData, maxTimeSteps, totalN = self.load_data(args,
                                                           mode='test',
                                                           type=args.level)
        if args.model == 'ResNet':
            model = ResNet(args, maxTimeSteps)
        elif args.model == 'BiRNN':
            model = BiRNN(args, maxTimeSteps)
        elif args.model == 'DBiRNN':
            model = DBiRNN(args, maxTimeSteps)

        num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        model.config['trainable params'] = num_params
        model.config['all params'] = all_num_params
        with tf.Session(graph=model.graph) as sess:
            ckpt = tf.train.get_checkpoint_state(args.save_dir)
            if ckpt and ckpt.model_checkpoint_path:
                model.saver.restore(sess, ckpt.model_checkpoint_path)
                print('Model restored from:' + args.save_dir)

            batchErrors = np.zeros(len(batchedData))
            batchRandIxs = np.random.permutation(len(batchedData))
            for batch, batchOrigI in enumerate(batchRandIxs):
                batchInputs, batchTargetSparse, batchSeqLengths = batchedData[
                    batchOrigI]
                batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                feedDict = {
                    model.inputX: batchInputs,
                    model.targetIxs: batchTargetIxs,
                    model.targetVals: batchTargetVals,
                    model.targetShape: batchTargetShape,
                    model.seqLengths: batchSeqLengths
                }

                if args.level == 'cha':
                    l, pre, y, er = sess.run([
                        model.loss, model.predictions, model.targetY,
                        model.errorRate
                    ],
                                             feed_dict=feedDict)
                    batchErrors[batch] = er
                    print(
                        '\ntotal:{},batch:{}/{},loss={:.3f},mean CER={:.3f}\n'.
                        format(totalN, batch + 1, len(batchRandIxs), l,
                               er / args.batch_size))

                elif args.level == 'phn':
                    l, pre, y = sess.run(
                        [model.loss, model.predictions, model.targetY],
                        feed_dict=feedDict)
                    er = get_edit_distance([pre.values], [y.values], True,
                                           'test', args.level)
                    print(
                        '\ntotal:{},batch:{}/{},loss={:.3f},mean PER={:.3f}\n'.
                        format(totalN, batch + 1, len(batchRandIxs), l,
                               er / args.batch_size))
                    batchErrors[batch] = er * len(batchSeqLengths)

                print('Truth:\n' + output_to_sequence(y, type=args.level))
                print('Output:\n' + output_to_sequence(pre, type=args.level))
                '''
                l, pre, y = sess.run([ model.loss,
					    model.predictions,
					    model.targetY],
				            feed_dict=feedDict)


		er = get_edit_distance([pre.values], [y.values], True, 'test', args.level)
	    	print(output_to_sequence(y,type=args.level))
	    	print(output_to_sequence(pre,type=args.level))
		'''
                with open(args.task + '_result.txt', 'a') as result:
                    result.write(output_to_sequence(y, type=args.level) + '\n')
                    result.write(
                        output_to_sequence(pre, type=args.level) + '\n')
                    result.write('\n')
            epochER = batchErrors.sum() / totalN
            print(args.task + ' test error rate:', epochER)
            logging(model, self.logfile, epochER, mode='test')
Beispiel #5
0
    def run(self):
        # load data
        args_dict = self._default_configs()
        args = dotdict(args_dict)

        X, labels = get_data(level, train_dataset, test_dataset, mode)
        totalN = len(X)
        print("X :", len(X))
        num_batches = len(X) / batch_size
        maxLength = 0
        for x in X:
            maxLength = max(maxLength, x.shape[1])
        if (maxLength % 5000 != 0):
            maxLength = maxLength + 5000 - maxLength % 5000

        #batchedData, maxTimeSteps, totalN = self.load_data(X,labels,batch_size,mode,level)
        maxTimeSteps = maxLength
        model = model_fn(args, maxTimeSteps)
        model.build_graph(args, maxTimeSteps)
        #print("hello")
        #num_params = count_params(model, mode='trainable')
        #all_num_params = count_params(model, mode='all')
        #model.config['trainable params'] = num_params
        #model.config['all params'] = all_num_params
        print(model.config)

        with tf.Session(graph=model.graph) as sess:
            # restore from stored model
            if keep == True:
                ckpt = tf.train.get_checkpoint_state(savedir)
                if ckpt and ckpt.model_checkpoint_path:
                    model.saver.restore(sess, ckpt.model_checkpoint_path)
                    print('Model restored from:' + savedir)
                else:
                    sess.run(model.initial_op)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            if (mode == 'train'):
                writer = tf.summary.FileWriter("loggingdir", graph=model.graph)
                for epoch in range(num_epochs):
                    # training
                    start = time.time()
                    print('Epoch {} ...'.format(epoch + 1))

                    batchErrors = np.zeros(num_batches)
                    batchRandIxs = np.random.permutation(num_batches)
                    for batch, batchOrigI in enumerate(batchRandIxs):
                        batchInputs, batchTargetSparse, batchSeqLengths = next(
                            load_batched_data(X, labels, batch_size, mode,
                                              level))
                        batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                        feedDict = {
                            model.inputX: batchInputs,
                            model.targetIxs: batchTargetIxs,
                            model.targetVals: batchTargetVals,
                            model.targetShape: batchTargetShape,
                            model.seqLengths: batchSeqLengths
                        }
                        if level == 'cha':
                            _, l, pre, y, er, summary = sess.run(
                                [
                                    model.optimizer, model.loss,
                                    model.predictions, model.targetY,
                                    model.errorRate, model.summary_op
                                ],
                                feed_dict=feedDict)
                            writer.add_summary(summary,
                                               epoch * num_batches + batch)

                            batchErrors[batch] = er
                            print(
                                '\n{} mode, batch:{}/{},epoch:{}/{},train loss={:.3f},mean train CER={:.3f}\n'
                                .format(level, batch + 1, len(batchRandIxs),
                                        epoch + 1, num_epochs, l,
                                        er / batch_size))

                        if (batch + 1) % 10 == 0:
                            print('Truth:\n' +
                                  output_to_sequence(y, type='phn'))
                            print('Output:\n' +
                                  output_to_sequence(pre, type='phn'))

                        if ((epoch * len(batchRandIxs) + batch + 1) % 20 == 0
                                or (epoch == num_epochs - 1
                                    and batch == len(batchRandIxs) - 1)):
                            checkpoint_path = os.path.join(
                                savedir, 'model.ckpt')
                            model.saver.save(sess,
                                             checkpoint_path,
                                             global_step=epoch)
                            print('Model has been saved in {}'.format(savedir))

                    end = time.time()
                    delta_time = end - start
                    print('Epoch ' + str(epoch + 1) + ' needs time:' +
                          str(delta_time) + ' s')

                    if (epoch + 1) % 1 == 0:
                        checkpoint_path = os.path.join(savedir, 'model.ckpt')
                        model.saver.save(sess,
                                         checkpoint_path,
                                         global_step=epoch)
                        print('Model has been saved in {}'.format(savedir))
                    epochER = batchErrors.sum() / totalN
                    print('Epoch', epoch + 1, 'mean train error rate:',
                          epochER)
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode='config')
                    logging(model,
                            logfile,
                            epochER,
                            epoch,
                            delta_time,
                            mode=mode)

            elif (mode == 'test'):
                for data in load_batched_data(X, labels, batch_size, mode,
                                              level):
                    batchInputs, batchTargetSparse, batchSeqLengths = data
                    batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
                    feedDict = {
                        model.inputX: batchInputs,
                        model.targetIxs: batchTargetIxs,
                        model.targetVals: batchTargetVals,
                        model.targetShape: batchTargetShape,
                        model.seqLengths: batchSeqLengths
                    }
                    _, l, pre, y, er = sess.run([
                        model.optimizer, model.loss, model.predictions,
                        model.targetY, model.errorRate
                    ],
                                                feed_dict=feedDict)
                    with open(os.path.join(resultdir, level + '_result.txt'),
                              'a') as result:
                        result.write(output_to_sequence(y, type='phn') + '\n')
                        result.write(
                            output_to_sequence(pre, type='phn') + '\n')
                        result.write('\n')
                        #epochER = batchErrors.sum() / totalN
                        print(' test error rate:', epochER)
                        logging(model, logfile, epochER, mode=mode)
Beispiel #6
0
    def train(self, args):
        ''' import data, train model, save model
	'''
        args.data_dir = args.data_dir + args.style + '/'
        args.save_dir = args.save_dir + args.style + '/'
        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)
        print(args)
        if args.attention is True:
            print('attention mode')
        text_parser = TextParser(args)
        args.vocab_size = text_parser.vocab_size
        self.word_embedding_file = os.path.join(args.data_dir,
                                                "word_embedding.pkl")

        if args.pretrained is True:
            raise ValueError(
                'pretrained has bug now, so don"t set it to be True now!!!')
            if args.keep is False:
                raise ValueError(
                    'when pre-trained is True, keep must be true!')
            print("pretrained and keep mode...")
            print("restoring pretrained model file")
            ckpt = tf.train.get_checkpoint_state(
                "/home/pony/github/jaylyrics_generation_tensorflow/data/pre-trained/"
            )
            if os.path.exists(os.path.join("./data/pre-trained/",'config.pkl')) and \
         os.path.exists(os.path.join("./data/pre-trained/",'words_vocab.pkl')) and \
         ckpt and ckpt.model_checkpoint_path:
                with open(os.path.join("./data/pre-trained/", 'config.pkl'),
                          'rb') as f:
                    saved_model_args = cPickle.load(f)
                with open(
                        os.path.join("./data/pre-trained/", 'words_vocab.pkl'),
                        'rb') as f:
                    saved_words, saved_vocab = cPickle.load(f)
            else:
                raise ValueError('configuration doesn"t exist!')
        else:
            ckpt = tf.train.get_checkpoint_state(args.save_dir)

        if args.keep is True and args.pretrained is False:
            # check if all necessary files exist
            if os.path.exists(os.path.join(args.save_dir,'config.pkl')) and \
         os.path.exists(os.path.join(args.save_dir,'words_vocab.pkl')) and \
         ckpt and ckpt.model_checkpoint_path:
                with open(os.path.join(args.save_dir, 'config.pkl'),
                          'rb') as f:
                    saved_model_args = cPickle.load(f)
                with open(os.path.join(args.save_dir, 'words_vocab.pkl'),
                          'rb') as f:
                    saved_words, saved_vocab = cPickle.load(f)
            else:
                raise ValueError('configuration doesn"t exist!')

        if args.model == 'seq2seq_rnn':
            model = Model_rnn(args)
        else:
            pass

        trainable_num_params = count_params(model, mode='trainable')
        all_num_params = count_params(model, mode='all')
        args.num_trainable_params = trainable_num_params
        args.num_all_params = all_num_params
        with open(os.path.join(args.save_dir, 'config.pkl'), 'wb') as f:
            cPickle.dump(args, f)
        with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'wb') as f:
            cPickle.dump((text_parser.vocab_dict, text_parser.vocab_list), f)

        with tf.Session() as sess:
            if args.keep is True:
                print('Restoring')
                model.saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print('Initializing')
                sess.run(model.initial_op)

            sess.run(tf.assign(model.lr, args.learning_rate))
            for e in range(args.num_epochs):
                start = time.time()
                model.initial_state = tf.convert_to_tensor(model.initial_state)
                state = model.initial_state.eval()
                total_loss = []
                for b in range(text_parser.num_batches):
                    x, y = text_parser.next_batch()
                    if args.attention is True:
                        attention_states = sess.run(
                            tf.truncated_normal([
                                args.batch_size, model.attn_length,
                                model.attn_size
                            ],
                                                stddev=0.1,
                                                dtype=tf.float32))
                        feed = {
                            model.input_data: x,
                            model.targets: y,
                            model.initial_state: state,
                            model.attention_states: attention_states
                        }

                    else:
                        feed = {
                            model.input_data: x,
                            model.targets: y,
                            model.initial_state: state
                        }

                    train_loss, state, _, word_embedding = sess.run([
                        model.cost, model.final_state, model.train_op,
                        model.word_embedding
                    ], feed)
                    total_loss.append(train_loss)

                    print("{}/{} (epoch {}), train_loss = {:.3f}" \
                                .format(e * text_parser.num_batches + b, \
                                args.num_epochs * text_parser.num_batches, \
                                e, train_loss))

                    if (e * text_parser.num_batches +
                            b) % args.save_every == 0:
                        checkpoint_path = os.path.join(args.save_dir,
                                                       'model.ckpt')
                        model.saver.save(sess, checkpoint_path, global_step=e)
                        print("model has been saved in:" +
                              str(checkpoint_path))
                        np.save(self.word_embedding_file, word_embedding)
                        print("word embedding matrix has been saved in:" +
                              str(self.word_embedding_file))

                end = time.time()
                delta_time = end - start
                ave_loss = np.array(total_loss).mean()
                logging(model, ave_loss, e, delta_time, mode='train')
                if ave_loss < 0.1:
                    checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                    model.saver.save(sess, checkpoint_path, global_step=e)
                    print("model has been saved in:" + str(checkpoint_path))
                    np.save(self.word_embedding_file, word_embedding)
                    print("word embedding matrix has been saved in:" +
                          str(self.word_embedding_file))
                    break