示例#1
0
def decode():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])

    # Load inverse dictionary used in decoding
    target_inverse_dict = data_utils.load_inverse_dict(
        config['target_vocabulary'])

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            print('Decoding {}..'.format(FLAGS.decode_input))
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]

            for idx, source_seq in enumerate(test_set.next()):
                print('Source', source_seq)
                source, source_len = prepare_batch(source_seq)
                print('Source', source, 'Source Len', source_len)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess,
                                              encoder_inputs=source,
                                              encoder_inputs_length=source_len)
                print(predicted_ids)
                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    for seq in predicted_ids:
                        f.write(
                            str(
                                data_utils.seq2words(
                                    seq[:, k], target_inverse_dict)) + '\n')
                    if not FLAGS.write_n_best:
                        break
                print('{}th line decoded'.format(idx *
                                                 FLAGS.decode_batch_size))

            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
示例#2
0
def decode(config):

    model, config = load_model(config)
    
    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            source_dict=config['src_vocab'],
                            batch_size=config['batch_size'],
                            n_words_source=config['num_enc_symbols'],
                            maxlen=None)
    target_inv_dict = load_inv_dict(config['tgt_vocab'])

    if use_cuda:
        print 'Using gpu..'
        model = model.cuda()

    try:
        print 'Decoding starts..'
        fout = fopen(config['decode_output'], 'w')
        for idx, source_seq in enumerate(test_set):
            source, source_len = prepare_batch(source_seq)

            preds_prev = torch.zeros(config['batch_size'], config['max_decode_step']).long()
            preds_prev[:,0] += data_utils.start_token
            preds = torch.zeros(config['batch_size'], config['max_decode_step']).long()

            if use_cuda:
                source = Variable(source.cuda())
                source_len = Variable(source_len.cuda())
                preds_prev = Variable(preds_prev.cuda())
                preds = preds.cuda()
            else:
                source = Variable(source)
                source_len = Variable(source_len)
                preds_prev = Variable(preds_prev)

            states, memories = model.encode(source, source_len)
            
            for t in xrange(config['max_decode_step']):
                # logits: [batch_size x max_decode_step, tgt_vocab_size]
                _, logits = model.decode(preds_prev, None, memories, keep_len=True)
                # outputs: [batch_size, max_decode_step]
                outputs = torch.max(logits, dim=1)[1].view(config['batch_size'], -1)
                preds[:,t] = outputs[:,t].data
                if t < config['max_decode_step'] - 1:
                    preds_prev[:,t+1] = outputs[:,t]

            for i in xrange(len(preds)):
                fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
                fout.flush()

            print '  {}th line decoded'.format(idx * config['batch_size'])
        print 'Decoding terminated'

    except IOError:
        pass
    finally:
        fout.close()
示例#3
0
def decode(config):
    model, config = load_model(config)
    # Load source data to decode
    test_set = TextIterator(
        source=config['decode_input'],
        source_dict=config['src_vocab'],
        batch_size=config['batch_size'],
        maxlen=None,
        n_words_source=config['num_enc_symbols'],
        shuffle_each_epoch=False,
        sort_by_length=False,
    )
    target_inv_dict = load_inv_dict(config['tgt_vocab'])

    lines = 0
    max_decode_step = config['max_decode_step']
    print 'Decoding starts..'
    with fopen(config['decode_output'], 'w') as fout:
        for idx, source_seq in enumerate(test_set):
            source, source_len = prepare_batch(source_seq)

            preds_prev = torch.zeros(len(source), max_decode_step).long()
            preds_prev[:, 0] += data_utils.start_token
            preds = torch.zeros(len(source), max_decode_step).long()

            if use_cuda:
                source = Variable(source.cuda())
                source_len = Variable(source_len.cuda())
                preds_prev = Variable(preds_prev.cuda())
                preds = preds.cuda()
            else:
                source = Variable(source)
                source_len = Variable(source_len)
                preds_prev = Variable(preds_prev)

            states, memories = model.encode(source, source_len)

            for t in xrange(max_decode_step):
                # logits: [batch_size x max_decode_step, tgt_vocab_size]
                _, logits = model.decode(preds_prev[:, :t + 1], states,
                                         memories)
                # outputs: [batch_size, max_decode_step]
                outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
                preds[:, t] = outputs[:, t].data
                if t < max_decode_step - 1:
                    preds_prev[:, t + 1] = outputs[:, t]
            for i in xrange(len(preds)):
                fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
                fout.flush()

            lines += source.size(0)
            print '  {}th line decoded'.format(lines)
        print 'Decoding terminated'
示例#4
0
def predict(config):
    tf.reset_default_graph()
    from data.data_iterator import TextIterator, Butian_TextIterator
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'])

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'test')
        model.restore(sess, config['save_path'])
        #model.restore_specific(sess, config['save_path'])
        _acc = 0
        _loss = 0
        _num = 0
        prediction = []
        all_labels = []
        all_pred = []
        for idx, sources in enumerate(valid_set):
            source_seq = sources[0]
            label = sources[1]
            sources, labels = prepare_batch(source_seq,
                                            label,
                                            max_batch=config['max_batch'],
                                            maxlen=config['maxlen'],
                                            stride=config['stride'],
                                            batch_size=config['batch_size'])
            for source, label in zip(sources, labels):
                pred, logit, acc, loss = model.predict(sess, source, label)
                prediction.extend(logit)
                #all_labels.extend(label_binarize(label,classes=["0","1","2"]))
                all_labels.extend(list(map(int, label)))
                all_pred.extend(list(map(int, pred)))
                #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                _acc += acc * pred.shape[0]
                _loss += loss * pred.shape[0]
                _num += pred.shape[0]
        print(config['save_path'])
        print("step {}, acc {:g}, softmax_loss {:g}".format(
            model.global_step.eval(), _acc / _num, _loss))
        prediction = np.stack(prediction)
        all_labels = np.stack(all_labels)
        all_pred = np.stack(all_pred)
        return prediction, label_binarize(all_labels, classes=[0, 1, 2]), [
            accuracy_score(all_labels, all_pred),
            precision_score(all_labels, all_pred),
            recall_score(all_labels, all_pred)
        ]
示例#5
0
文件: decode.py 项目: Moirai7/Malware
def decode():
    # Load model config
    config = load_config(FLAGS)
    print(config['source_vocabulary'])
    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],)
    # Load inverse dictionary used in decoding
    
    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, 
        log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]
            for source_seq, label in test_set:
                # label = test_labels[idx]
                source, source_len = prepare_batch(source_seq, batch_size=config['decode_batch_size'], stride = config['max_seq_length'],maxlen=config['max_seq_length'])
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess, encoder_inputs=source, 
                                              encoder_inputs_length=source_len)
                   
                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    f.write(str(source_seq)+'\t\t\t')
                    res = []
                    for seq in predicted_ids:
                        res.append(list(seq[:,k]))
                    f.write(str(res)+'\n')
                    if not FLAGS.write_n_best:
                        break
                
            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
示例#6
0
def predict_without_window(config):
    config['max_epochs'] = 15
    config['batch_size'] = 6
    tf.reset_default_graph()
    from data.data_iterator import TextIterator
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'],
                             shuffle_each_epoch=False)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'test')
        model.restore(sess, config['save_path'])
        _acc = 0
        _loss = 0
        _num = 0
        prediction = []
        all_labels = []
        all_pred = []
        for idx, test_sources in enumerate(valid_set):
            sub_source_seq = test_sources[0]
            sub_label = test_sources[1]
            sub_sources, _ = prepare_batch_without_window(
                sub_source_seq, batch_size=config['batch_size'])
            pred, logit, acc, loss = model.predict(sess, sub_sources,
                                                   sub_label)
            prediction.extend(logit)
            all_labels.extend(list(map(int, sub_label)))
            all_pred.extend(list(map(int, pred)))
            _acc += acc * pred.shape[0]
            _loss += loss * pred.shape[0]
            _num += pred.shape[0]
        print("step {}, acc {:g}, softmax_loss {:g}".format(
            model.global_step.eval(), _acc / _num, _loss))
        prediction = np.stack(prediction)
        all_labels = np.stack(all_labels)
        all_pred = np.stack(all_pred)
        return prediction, label_binarize(all_labels,
                                          classes=[0, 1, 2]), accuracy_score(
                                              all_labels, all_pred)
示例#7
0
def pretrain():
    # Load parallel data to train
    mwlist = data_utils.readMalware(FLAGS.source_train_data, 0)

    print('Loading training data..')
    train_set = TextIterator(source=FLAGS.source_train_data,
                             source_dict=FLAGS.source_vocabulary,
                             batch_size=FLAGS.batch_size,
                             shuffle_each_epoch=FLAGS.shuffle_each_epoch,
                             maxibatch_size=FLAGS.max_load_batches)

    if FLAGS.source_valid_data:
        print('Loading validation data..')
        valid_set = TextIterator(source=FLAGS.source_valid_data,
                                 source_dict=FLAGS.source_vocabulary,
                                 batch_size=FLAGS.batch_size,
                                 shuffle_each_epoch=FLAGS.shuffle_each_epoch,
                                 maxibatch_size=FLAGS.max_load_batches)
    else:
        valid_set = None
    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Create a new model or reload existing checkpoint
        model = create_model(sess, FLAGS, mwlist)
        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        # Training loop
        print('Training..')
        for epoch_idx in range(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print('Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs))
                break

            for source_seq, label in train_set:
                # Get a batch from training parallel data
                source, source_len = prepare_batch_without_window(
                    source_seq, batch_size=FLAGS.batch_size)
                _, step_loss, _, decoder, target, _, acc, summary = model.pretrain(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len)
                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len + source_len))
                sents_seen += float(source.shape[0])  # batch_size
                print(decoder, target)
                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print('Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'acc {0:.2f}'.format(acc), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec))
                    print(' {} decoder'.format(decoder), decoder.shape)
                    print(' {} target'.format(target), target.shape)

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                # Execute a validation step
                if valid_set and model.global_step.eval(
                ) % FLAGS.valid_freq == 0:
                    print('Validation step')
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for val_source_seq, val_label in valid_set:
                        source, source_len = prepare_batch_without_window(
                            source_seq, batch_size=FLAGS.batch_size)
                        step_loss, _, decoder, target, _, acc, summary = model.eval(
                            sess,
                            encoder_inputs=source,
                            encoder_inputs_length=source_len)
                        batch_size = source.shape[0]

                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                        #print('  {} logit'.format(logit))
                        print(' {} decoder'.format(decoder), decoder.shape)
                        print(' {} target'.format(target), target.shape)

                    valid_loss = valid_loss / valid_sents_seen
                    print(
                        'Valid perplexity: {0:.2f}'.format(
                            math.exp(valid_loss)), 'acc {0:.2f}'.format(acc))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print('Saving the model..')
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'w'),
                              indent=2)

            # Increase the epoch index of the model
            model.global_epoch_step_op.eval()
            print('Epoch {0:} DONE'.format(model.global_epoch_step.eval()))

        print('Saving the last model..')
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, checkpoint_path, global_step=model.global_step)
        json.dump(
            model.config,
            open('%s-%d.json' % (checkpoint_path, model.global_step.eval()),
                 'w'),
            indent=2)

    print('Training Terminated')
示例#8
0
文件: train.py 项目: Moirai7/Malware
def train():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['source_train_data'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])
    #test_set, test_labels = data_utils.load_data('test')
    #valid_set, valid_labels = data_utils.load_data('valid')
    # Load inverse dictionary used in decoding

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)

        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_path, graph=sess.graph)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        for epoch_idx in range(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print('Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs))
                break
            for idx, source_seq in enumerate(test_set):
                source, source_len = prepare_batch(source_seq)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                step_loss, summary, predicted_ids = model.train(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len)
                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len + target_len))
                sents_seen += float(source.shape[0])  # batch_size
                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print('Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec))

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                    # Write decoding results
                    print(
                        str(source_seq),
                        '\t',
                    )
                    for i in range(len(source_seq)):
                        res = []
                        for seq in predicted_ids:
                            res.append(list(seq[:, i]))
                        print(str(res))
                    print('  {}th line decoded'.format(
                        idx * FLAGS.decode_batch_size))
                if valid_set and model.global_step.eval(
                ) % FLAGS.valid_freq == 0:
                    print('Validation step')
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for idx, source_seq in enumerate(valid_set):
                        label = test_labels[idx]
                        source, source_len = prepare_batch(source_seq)
                        step_loss, summary, predicted_ids = model.train(
                            sess,
                            encoder_inputs=source,
                            encoder_inputs_length=source_len)
                        batch_size = source.shape[0]
                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                    valid_loss = valid_loss / valid_sents_seen
                    print('Valid perplexity: {0:.2f}'.format(
                        math.exp(valid_loss)))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print('Saving the model..')
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'wb'),
                              indent=2)
示例#9
0
def train_without_window(config, maxs):
    config['max_epochs'] = 15
    config['batch_size'] = 6
    tf.reset_default_graph()
    from data.data_iterator import TextIterator
    test_set = TextIterator(source=config['input'],
                            batch_size=config['batch_size'],
                            source_dict=config['source_vocabulary'],
                            shuffle_each_epoch=True)
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'],
                             shuffle_each_epoch=False)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'pretrain')
        model.restore(sess, config['save_path'])
        for epoch_idx in range(config['max_epochs']):
            print(epoch_idx)
            oldtime = datetime.datetime.now()
            for idx, train_sources in enumerate(test_set):
                source_seq = train_sources[0]
                label = train_sources[1]
                sources, _ = prepare_batch_without_window(
                    source_seq, batch_size=config['batch_size'])
                loss, acc, pred, logit, summary, _, _labels, _1 = model.pretrain(
                    sess, sources, label)
                print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(
                    model.global_step.eval(), pred.shape, acc, loss))
            newtime = datetime.datetime.now()
            print('%s microseconds' % (newtime - oldtime).seconds)
            _acc = 0
            _loss = 0
            _num = 0
            for idx, test_sources in enumerate(valid_set):
                sub_source_seq = test_sources[0]
                sub_label = test_sources[1]
                sub_sources, _ = prepare_batch_without_window(
                    sub_source_seq, batch_size=config['batch_size'])
                pred, logit, acc, loss = model.predict(sess, sub_sources,
                                                       sub_label)
                print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(
                    model.global_step.eval(), pred.shape, acc, loss))
                _acc += acc * pred.shape[0]
                _loss += loss * pred.shape[0]
                _num += pred.shape[0]
            print("step {}, acc {:g}, softmax_loss {:g}".format(
                model.global_step.eval(), _acc / _num, _loss))
            if _acc / _num > maxs:
                checkpoint_path = os.path.join(config['save_path'], 'detector')
                model.save(sess,
                           checkpoint_path,
                           global_step=model.global_step)
                json.dump(model.config,
                          open(
                              '%s-%d.json' %
                              (checkpoint_path, model.global_step.eval()),
                              'w'),
                          indent=2)
                maxs = _acc / _num
示例#10
0
def train(config, maxs):
    check = False
    tf.reset_default_graph()
    from data.data_iterator import TextIterator, Butian_TextIterator
    test_set = TextIterator(source=config['input'],
                            batch_size=config['batch_size'],
                            source_dict=config['source_vocabulary'],
                            shuffle_each_epoch=True)
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'],
                             shuffle_each_epoch=False)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'pretrain')
        #model.restore(sess, config['save_path'])
        model.restore_specific(
            sess,
            '/home/dbtest/lan/Malware/model/detector/gru_False_2att/detector-9650'
        )
        for epoch_idx in range(config['max_epochs']):
            #print(epoch_idx)
            for idx, train_sources in enumerate(test_set):
                source_seq = train_sources[0]
                label = train_sources[1]
                sources, labels = prepare_batch(
                    source_seq,
                    label,
                    max_batch=config['max_batch'],
                    maxlen=config['maxlen'],
                    stride=config['stride'],
                    batch_size=config['batch_size'])
                for source, label in zip(sources, labels):
                    loss, acc, pred, logit, summary, _, _labels, _1 = model.pretrain(
                        sess, source, label)
                    #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                    if (model.global_step.eval() % 50 == 0):
                        _acc = 0
                        _loss = 0
                        _num = 0
                        for idx, test_sources in enumerate(valid_set):
                            sub_source_seq = test_sources[0]
                            sub_label = test_sources[1]
                            sub_sources, sub_labels = prepare_batch(
                                sub_source_seq,
                                sub_label,
                                max_batch=config['max_batch'],
                                maxlen=config['maxlen'],
                                stride=config['stride'],
                                batch_size=config['batch_size'])
                            for sub_source, sub_label in zip(
                                    sub_sources, sub_labels):
                                pred, logit, acc, loss = model.predict(
                                    sess, sub_source, sub_label)
                                #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                                _acc += acc * pred.shape[0]
                                _loss += loss * pred.shape[0]
                                _num += pred.shape[0]
                        print("acc {:g}, softmax_loss {:g}".format(
                            _acc / _num, _loss / _num))
                        #print("step {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), _acc/_num, _loss/_num))
                        if _acc / _num > maxs:
                            save(sess, config, model)
                            maxs = _acc / _num