Beispiel #1
0
def train():
    # Load parallel data to train
    print 'Loading training data..'
    train_set = BiTextIterator(source=FLAGS.source_train_data,
                               target=FLAGS.target_train_data,
                               source_dict=FLAGS.source_vocabulary,
                               target_dict=FLAGS.target_vocabulary,
                               batch_size=FLAGS.batch_size,
                               maxlen=FLAGS.max_seq_length,
                               n_words_source=FLAGS.num_encoder_symbols,
                               n_words_target=FLAGS.num_decoder_symbols,
                               shuffle_each_epoch=FLAGS.shuffle_each_epoch,
                               sort_by_length=FLAGS.sort_by_length,
                               maxibatch_size=FLAGS.max_load_batches)

    if FLAGS.source_valid_data and FLAGS.target_valid_data:
        print 'Loading validation data..'
        valid_set = BiTextIterator(source=FLAGS.source_valid_data,
                                   target=FLAGS.target_valid_data,
                                   source_dict=FLAGS.source_vocabulary,
                                   target_dict=FLAGS.target_vocabulary,
                                   batch_size=FLAGS.batch_size,
                                   maxlen=None,
                                   n_words_source=FLAGS.num_encoder_symbols,
                                   n_words_target=FLAGS.num_decoder_symbols)
    else:
        valid_set = None

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Create a new model or reload existing checkpoint
        model = create_model(sess, FLAGS)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        # Training loop
        print 'Training..'
        for epoch_idx in xrange(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print 'Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs)
                break

            for source_seq, target_seq in train_set:
                # Get a batch from training parallel data
                source, source_len, target, target_len = prepare_train_batch(
                    source_seq, target_seq, FLAGS.max_seq_length)
                if source is None or target is None:
                    print 'No samples under max_seq_length ', FLAGS.max_seq_length
                    continue

                # Execute a single training step
                step_loss = model.train(sess,
                                        encoder_inputs=source,
                                        encoder_inputs_length=source_len,
                                        decoder_inputs=target,
                                        decoder_inputs_length=target_len)

                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len + target_len))
                sents_seen += float(source.shape[0])  # batch_size

                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                # Execute a validation step
                if valid_set and model.global_step.eval(
                ) % FLAGS.valid_freq == 0:
                    print 'Validation step'
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for source_seq, target_seq in valid_set:
                        # Get a batch from validation parallel data
                        source, source_len, target, target_len = prepare_train_batch(
                            source_seq, target_seq)

                        # Compute validation loss: average per word cross entropy loss
                        step_loss = model.eval(
                            sess,
                            encoder_inputs=source,
                            encoder_inputs_length=source_len,
                            decoder_inputs=target,
                            decoder_inputs_length=target_len)
                        batch_size = source.shape[0]

                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                        print '  {} samples seen'.format(valid_sents_seen)

                    valid_loss = valid_loss / valid_sents_seen
                    print 'Valid perplexity: {0:.2f}'.format(
                        math.exp(valid_loss))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print 'Saving the model..'
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'wb'),
                              indent=2)

            # Increase the epoch index of the model
            model.global_epoch_step_op.eval()
            print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval())

        print 'Saving the last model..'
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, checkpoint_path, global_step=model.global_step)
        json.dump(
            model.config,
            open('%s-%d.json' % (checkpoint_path, model.global_step.eval()),
                 'wb'),
            indent=2)

    print 'Training Terminated'
Beispiel #2
0
def train(config):
    # Load parallel data to train
    print 'Loading training data..'
    train_set = BiTextIterator(source=config.src_train, target=config.tgt_train,
                               source_dict=config.src_vocab, target_dict=config.tgt_vocab,
                               batch_size=config.batch_size, maxlen=config.max_seq_len,
                               n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols,
                               shuffle_each_epoch=config.shuffle, sort_by_length=config.sort_by_len,
                               maxibatch_size=config.maxi_batches)
    valid_set = None
    if config.src_valid and config.tgt_valid:
        print 'Loading validation data..'
        valid_set = BiTextIterator(source=config.src_valid, target=config.tgt_valid,
                                   source_dict=config.src_vocab, target_dict=config.tgt_vocab,
                                   batch_size=config.batch_size, maxlen=None,
                                   n_words_source=config.num_enc_symbols, n_words_target=config.num_dec_symbols,
                                   shuffle_each_epoch=False, sort_by_length=config.sort_by_len,
                                   maxibatch_size=config.maxi_batches)
    # Create a Quasi-RNN model
    model, model_state = create_model(config)

    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=data_utils.pad_token)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    loss = 0.0
    words_seen, sents_seen = 0, 0
    start_time = time.time()
    # Training loop
    print 'Training..'
    for epoch_idx in xrange(config.max_epochs):
        if model_state['epoch'] >= config.max_epochs:
            print 'Training is already complete.', \
                  'current epoch:{}, max epoch:{}'.format(model_state['epoch'], config.max_epochs)
            break
        for source_seq, target_seq in train_set:    
            # Get a batch from training parallel data
            enc_input, enc_len, dec_input, dec_target, dec_len = \
                prepare_train_batch(source_seq, target_seq, config.max_seq_len)
 
            if enc_input is None or dec_input is None or dec_target is None:
                print 'No samples under max_seq_length ', config.max_seq_len
                continue
           
            if use_cuda:
                enc_input = Variable(enc_input.cuda())
                enc_len = Variable(enc_len.cuda())
                dec_input = Variable(dec_input.cuda())
                dec_target = Variable(dec_target.cuda())
                dec_len = Variable(dec_len.cuda())
            else:
                enc_input = Variable(enc_input)
                enc_len = Variable(enc_len)
                dec_input = Variable(dec_input)
                dec_target = Variable(dec_target)
                dec_len = Variable(dec_len)

            # Execute a single training step
            optimizer.zero_grad()
            dec_logits = model(enc_input, enc_len, dec_input)
            step_loss = criterion(dec_logits, dec_target.view(-1))
            step_loss.backward()
            nn.utils.clip_grad_norm(model.parameters(), config.max_grad_norm)
            optimizer.step()

            loss += float(step_loss.data[0]) / config.display_freq
            words_seen += torch.sum(enc_len + dec_len).data[0]
            sents_seen += enc_input.size(0)  # batch_size

            model_state['train_steps'] += 1

            # Display training status
            if model_state['train_steps'] % config.display_freq == 0:
                avg_perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
                time_elapsed = time.time() - start_time
                step_time = time_elapsed / config.display_freq

                words_per_sec = words_seen / time_elapsed
                sents_per_sec = sents_seen / time_elapsed

                print 'Epoch ', model_state['epoch'], 'Step ', model_state['train_steps'], \
                      'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time {0:.2f}'.format(step_time), \
                      '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec)

                loss = 0.0
                words_seen, sents_seen = 0, 0
                start_time = time.time()

            # Execute a validation process
            if valid_set and model_state['train_steps'] % config.valid_freq == 0:
                model.eval()
                print 'Validation step'
                valid_steps = 0
                valid_loss = 0.0
                valid_sents_seen = 0
                for source_seq, target_seq in valid_set:
                    # Get a batch from validation parallel data
                    enc_input, enc_len, dec_input, dec_target, _ = \
                        prepare_train_batch(source_seq, target_seq)

                    if use_cuda:
                        enc_input = Variable(enc_input.cuda())
                        enc_len = Variable(enc_len.cuda())
                        dec_input = Variable(dec_input.cuda())
                        dec_target = Variable(dec_target.cuda())
                    else:
                        enc_input = Variable(enc_input)
                        enc_len = Variable(enc_len)
                        dec_input = Variable(dec_input)
                        dec_target = Variable(dec_target)

                    dec_logits = model(enc_input, enc_len, dec_input)
                    step_loss = criterion(dec_logits, dec_target.view(-1))
                    valid_steps += 1 
                    valid_loss += float(step_loss.data[0])
                    valid_sents_seen += enc_input.size(0)
                    print '  {} samples seen'.format(valid_sents_seen)

                model.train()
                print 'Valid perplexity: {0:.2f}'.format(math.exp(valid_loss / valid_steps))

            # Save the model checkpoint
            if model_state['train_steps'] % config.save_freq == 0:
                print 'Saving the model..'

                model_state['state_dict'] = model.state_dict()
#                state = dict(list(model_state.items()))
                model_path = os.path.join(config.model_dir, config.model_name)
                torch.save(model_state, model_path)

        # Increase the epoch index of the model
        model_state['epoch'] += 1
        print 'Epoch {0:} DONE'.format(model_state['epoch'])