def train():
    # Load parallel data to train
    print 'Loading training data..'
    train_set = SegTextIterator(source=FLAGS.source_train_data,
                               time_step=FLAGS.time_step,
                               batch_size=FLAGS.batch_size, set_type=0, test_set_partition=FLAGS.test_set_partition, balance=FLAGS.data_balance, noise=FLAGS.noise)

    if FLAGS.source_valid_data:
        print 'Loading validation data..'
        valid_set = SegTextIterator(source=FLAGS.source_valid_data,
                                   time_step=FLAGS.time_step,
                                   batch_size=FLAGS.batch_size, set_type=1, test_set_partition=FLAGS.test_set_partition, balance=FLAGS.data_balance, noise=FLAGS.noise)
    else:
        valid_set = None

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, 
        log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        # Create a new model or reload existing checkpoint
        model = create_model(sess, FLAGS)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        # Training loop
        print 'Training..'
        for epoch_idx in xrange(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print 'Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs)
                break

            for source_data, source_word, source_char, source_word_mask, source_char_mask, target_seq, true_batch_size in train_set:
                source_data, source_len, target, target_len = prepare_train_batch(source_data, target_seq)
                source = source_data, source_word, source_word_mask, source_char, source_char_mask
                
                #print source_char
                #print "===="
                #print source_char_mask
                #print source_len
                #print "===="
                #print target
                #print target_len
                #print type(source_char), type(source_char_mask), type(source_len), type(target_len)
                #exit()
                if source is None or target is None:
                    print 'No samples under max_seq_length ', FLAGS.max_seq_length
                    continue

                # Execute a single training step
                step_loss, summary = model.train(sess, encoder_inputs=source, encoder_inputs_length=source_len, 
                                                 decoder_inputs=target, decoder_inputs_length=target_len)

                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len+target_len))
                sents_seen += float(FLAGS.batch_size) # batch_size

                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.5f}'.format(avg_perplexity), '    Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                # Execute a validation step
                if valid_set and model.global_step.eval() % FLAGS.valid_freq == 0:
                    print 'Validation step'
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for source_seq, target_seq in valid_set:
                        # Get a batch from validation parallel data
                        source, source_len, target, target_len = prepare_train_batch(source_seq, target_seq)

                        # Compute validation loss: average per word cross entropy loss
                        step_loss, summary = model.eval(sess, encoder_inputs=source, encoder_inputs_length=source_len,
                                                        decoder_inputs=target, decoder_inputs_length=target_len)
                        batch_size = source.shape[0]

                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                        print '  {} samples seen'.format(valid_sents_seen)

                    valid_loss = valid_loss / valid_sents_seen
                    print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print 'Saving the model..'
                    checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
                    model.save(sess, checkpoint_path, global_step=model.global_step)
                    json.dump(model.config,
                              open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'),
                              indent=2)

            # Increase the epoch index of the model
            model.global_epoch_step_op.eval()
            print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval())
        
        print 'Saving the last model..'
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, checkpoint_path, global_step=model.global_step)
        json.dump(model.config,
                  open('%s-%d.json' % (checkpoint_path, model.global_step.eval()), 'wb'),
                  indent=2)
        
    print 'Training Terminated'
예제 #2
0
def train(config):
    # Load parallel data to train
    print 'Loading training data..'
    train_set = BiTextIterator(source=config.src_train, target=config.tgt_train,
                               source_dict=config.src_vocab, target_dict=config.tgt_vocab,
                               batch_size=config.batch_size, maxlen=config.max_seq_len,
                               n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols,
                               shuffle_each_epoch=config.shuffle, sort_by_length=config.sort_by_len,
                               maxibatch_size=config.maxi_batches)
    valid_set = None
    if config.src_valid and config.tgt_valid:
        print 'Loading validation data..'
        valid_set = BiTextIterator(source=config.src_valid, target=config.tgt_valid,
                                   source_dict=config.src_vocab, target_dict=config.tgt_vocab,
                                   batch_size=config.batch_size, maxlen=None,
                                   n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols,
                                   shuffle_each_epoch=False, sort_by_length=config.sort_by_len,
                                   maxibatch_size=config.maxi_batches)
    # Create a Quasi-RNN model
    model, model_state = create_model(config)

    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    loss = 0.0
    words_seen, sents_seen = 0, 0
    start_time = time.time()
    # Training loop
    print 'Training..'
    for epoch_idx in xrange(config.max_epochs):
        if model_state['epoch'] >= config.max_epochs:
            print 'Training is already complete.', \
                  'current epoch:{}, max epoch:{}'.format(model_state['epoch'], config.max_epochs)
            break
        for source_seq, target_seq in train_set:    
            # Get a batch from training parallel data
            enc_input, enc_len, dec_input, dec_target, dec_len = \
                prepare_train_batch(source_seq, target_seq, config.max_seq_len)
 
            if enc_input is None or dec_input is None or dec_target is None:
                print 'No samples under max_seq_length ', config.max_seq_len
                continue
           
            if use_cuda:
                enc_input = Variable(enc_input.cuda())
                enc_len = Variable(enc_len.cuda())
                dec_input = Variable(dec_input.cuda())
                dec_target = Variable(dec_target.cuda())
                dec_len = Variable(dec_len.cuda())
            else:
                enc_input = Variable(enc_input)
                enc_len = Variable(enc_len)
                dec_input = Variable(dec_input)
                dec_target = Variable(dec_target)
                dec_len = Variable(dec_len)

            # Execute a single training step
            optimizer.zero_grad()
            dec_logits = model(enc_input, enc_len, dec_input)
            step_loss = criterion(dec_logits, dec_target.view(-1))
            step_loss.backward()
            nn.utils.clip_grad_norm(model.parameters(), config.max_grad_norm)
            optimizer.step()

            loss += float(step_loss.data[0]) / config.display_freq
            words_seen += torch.sum(enc_len + dec_len).data[0]
            sents_seen += enc_input.size(0)  # batch_size

            model_state['train_steps'] += 1

            # Display training status
            if model_state['train_steps'] % config.display_freq == 0:
                avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
                time_elapsed = time.time() - start_time
                step_time = time_elapsed / config.display_freq

                words_per_sec = words_seen / time_elapsed
                sents_per_sec = sents_seen / time_elapsed

                print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \
                      'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \
                      '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

                loss = 0.0
                words_seen, sents_seen = 0, 0
                start_time = time.time()

            # Execute a validation process
            if valid_set and model_state['train_steps'] % config.valid_freq == 0:
                model.eval()
                print 'Validation step'
                valid_steps = 0
                valid_loss = 0.0
                valid_sents_seen = 0
                for source_seq, target_seq in valid_set:
                    # Get a batch from validation parallel data
                    enc_input, enc_len, dec_input, dec_target, _ = \
                        prepare_train_batch(source_seq, target_seq)

                    if use_cuda:
                        enc_input = Variable(enc_input.cuda())
                        enc_len = Variable(enc_len.cuda())
                        dec_input = Variable(dec_input.cuda())
                        dec_target = Variable(dec_target.cuda())
                    else:
                        enc_input = Variable(enc_input)
                        enc_len = Variable(enc_len)
                        dec_input = Variable(dec_input)
                        dec_target = Variable(dec_target)

                    dec_logits = model(enc_input, enc_len, dec_input)
                    step_loss = criterion(dec_logits, dec_target.view(-1))
                    valid_steps += 1 
                    valid_loss += float(step_loss.data[0])
                    valid_sents_seen += enc_input.size(0)
                    print '  {} samples seen'.format(valid_sents_seen)

                model.train()
                print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps))

            # Save the model checkpoint
            if model_state['train_steps'] % config.save_freq == 0:
                print 'Saving the model..'

                model_state['state_dict'] = model.state_dict()
#                state = dict(list(model_state.items()))
                model_path = os.path.join(config.model_dir, config.model_name)
                torch.save(model_state, model_path)

        # Increase the epoch index of the model
        model_state['epoch'] += 1
        print 'Epoch {0:} DONE'.format(model_state['epoch'])
예제 #3
0
def train():
    # Load parallel data to train
    print 'Loading training data..'
    train_set = BiTextIterator(source=FLAGS.source_train_data,
                               target=FLAGS.target_train_data,
                               source_dict=FLAGS.source_vocabulary,
                               target_dict=FLAGS.target_vocabulary,
                               batch_size=FLAGS.batch_size,
                               maxlen=FLAGS.max_seq_length,
                               n_words_source=FLAGS.num_encoder_symbols,
                               n_words_target=FLAGS.num_decoder_symbols,
                               shuffle_each_epoch=FLAGS.shuffle_each_epoch,
                               sort_by_length=FLAGS.sort_by_length,
                               maxibatch_size=FLAGS.max_load_batches)

    if FLAGS.source_valid_data and FLAGS.target_valid_data:
        print 'Loading validation data..'
        valid_set = BiTextIterator(source=FLAGS.source_valid_data,
                                   target=FLAGS.target_valid_data,
                                   source_dict=FLAGS.source_vocabulary,
                                   target_dict=FLAGS.target_vocabulary,
                                   batch_size=FLAGS.batch_size,
                                   maxlen=None,
                                   n_words_source=FLAGS.num_encoder_symbols,
                                   n_words_target=FLAGS.num_decoder_symbols)
    else:
        valid_set = None

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Create a new model or reload existing checkpoint
        model = create_model(sess, FLAGS)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        # Training loop
        print 'Training..'
        for epoch_idx in xrange(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print 'Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs)
                break

            for source_seq, target_seq in train_set:
                # Get a batch from training parallel data
                source, source_len, target, target_len = prepare_train_batch(
                    source_seq, target_seq, FLAGS.max_seq_length)
                if source is None or target is None:
                    print 'No samples under max_seq_length ', FLAGS.max_seq_length
                    continue

                # Execute a single training step
                step_loss = model.train(sess,
                                        encoder_inputs=source,
                                        encoder_inputs_length=source_len,
                                        decoder_inputs=target,
                                        decoder_inputs_length=target_len)

                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len + target_len))
                sents_seen += float(source.shape[0])  # batch_size

                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                # Execute a validation step
                if valid_set and model.global_step.eval(
                ) % FLAGS.valid_freq == 0:
                    print 'Validation step'
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for source_seq, target_seq in valid_set:
                        # Get a batch from validation parallel data
                        source, source_len, target, target_len = prepare_train_batch(
                            source_seq, target_seq)

                        # Compute validation loss: average per word cross entropy loss
                        step_loss = model.eval(
                            sess,
                            encoder_inputs=source,
                            encoder_inputs_length=source_len,
                            decoder_inputs=target,
                            decoder_inputs_length=target_len)
                        batch_size = source.shape[0]

                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                        print '  {} samples seen'.format(valid_sents_seen)

                    valid_loss = valid_loss / valid_sents_seen
                    print 'Valid perplexity: {0:.2f}'.format(
                        math.exp(valid_loss))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print 'Saving the model..'
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'wb'),
                              indent=2)

            # Increase the epoch index of the model
            model.global_epoch_step_op.eval()
            print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval())

        print 'Saving the last model..'
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, checkpoint_path, global_step=model.global_step)
        json.dump(
            model.config,
            open('%s-%d.json' % (checkpoint_path, model.global_step.eval()),
                 'wb'),
            indent=2)

    print 'Training Terminated'
예제 #4
0
def decode():
    # Load model config
    _data_word = None
    _alignment = None
    _predict = np.array([])
    _ground_truth = np.array([])
    config = load_config(FLAGS)

    # Load source data to decode
    #test_set = TextIterator(source=config['decode_input'],
    #batch_size=config['decode_batch_size'],
    #source_dict=config['source_vocabulary'],
    #maxlen=None,
    #n_words_source=config['num_encoder_symbols'])

    test_set = SegTextIterator(source=config['decode_input'],
                               time_step=config['time_step'],
                               batch_size=config['batch_size'],
                               set_type=1,
                               test_set_partition=config["test_set_partition"],
                               balance=config["data_balance"],
                               noise=config["noise"])

    # Load inverse dictionary used in decoding
    #target_inverse_dict = data_utils.load_inverse_dict(config['target_vocabulary'])

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            print 'Decoding {}..'.format(FLAGS.decode_input)
            #if FLAGS.write_n_best:
            #    fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
            #            for k in range(FLAGS.beam_width)]
            #else:
            #    fout = [data_utils.fopen(FLAGS.decode_output, 'w')]
            step = 0
            set_length = len(test_set)
            for source_data, source_word, source_char, source_word_mask, source_char_mask, target_seq, true_batch_size in test_set:
                source_data, source_len, target, target_len = prepare_train_batch(
                    source_data, target_seq)
                source = source_data, source_word, source_word_mask, source_char, source_char_mask

                #source, source_len = prepare_batch(source_seq)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids, alignment, _ = model.predict(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len)
                a = _[10:1010, :]
                for i in a:
                    t = []
                    for j in i:
                        t.append(str(float(j)))
                    print '\t'.join(t)
                exit()
                if config['use_attention']:
                    try:
                        _data_word = np.concatenate((_data_word, source_word))
                    except:
                        _data_word = source_word
                    try:
                        _alignment = np.concatenate((_alignment, alignment))
                    except:
                        _alignment = alignment
                _predict = np.concatenate((_predict, predicted_ids[:, 0, 0]))
                _ground_truth = np.concatenate((_ground_truth, target[:, 0]))

                progress = 100 * float(step) / set_length
                print "progress %.1f%%" % progress
                step += 1

                # Write decoding results
                #for k, f in reversed(list(enumerate(fout))):
                #    for seq in predicted_ids:
                #        f.write(str(data_utils.seq2words(seq[:,k], target_inverse_dict)) + '\n')
                #    if not FLAGS.write_n_best:
                #        break
                #print '  {}th line decoded'.format(idx * FLAGS.decode_batch_size)

            if config['use_attention']:
                idxs = [j for j in range(len(_alignment))]
                for idx in idxs:
                    print ">>>>"
                    print_group_idx(
                        [_alignment[idx],
                         word_seq_list(_data_word[idx])])
                    #print_idx(word_seq_list(_data_word[idx]))
                    #print_idx(_alignment[idx])
                    print _predict[idx], _ground_truth[idx]

            evaluate(_predict, _ground_truth)
            print 'Decoding terminated'
        except IOError:
            pass