Exemplo n.º 1
0
def train():
    """Train a en->hn transliteration model using REV_brandnames data."""
    # Prepare news_2012 data.
    print("Preparing News_2012 data in %s" % FLAGS.data_dir)
    en_train, hn_train, en_dev, hn_dev, _, _ = data_utils.prepare_rev_data(
        FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.hn_vocab_size)

    with tf.Session() as sess:
        # Create model.
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, False)  # forward_only = False

        # Read data into buckets and compute their sizes.
        print("Reading development and training data (limit: %d)." %
              FLAGS.max_train_data_size)
        dev_set = read_data(en_dev, hn_dev)
        train_set = read_data(en_train, hn_train, FLAGS.max_train_data_size)
        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
        train_total_size = float(sum(train_bucket_sizes))  #size of dataset

        # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
        # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
        # the size if i-th training bucket, as used later.
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        validation_loss = []
        validation_acc = []
        file_logs = open(FLAGS.train_dir + "/log_loss.txt", "w")
        acc_logs = open(FLAGS.train_dir + "/accuracy.txt", "w")
        while current_step <= FLAGS.max_steps:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights, _ = model.get_batch(
                train_set, bucket_id)
            _, step_loss, _, _ = model.step(sess, encoder_inputs,
                                            decoder_inputs, target_weights,
                                            bucket_id, False)
            step_time += (time.time() -
                          start_time) / FLAGS.steps_per_checkpoint
            loss += step_loss / FLAGS.steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % FLAGS.steps_per_checkpoint == 0:
                # Print statistics for the previous epoch.
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                print(
                    "global step %d learning rate %.4f step-time %.2f perplexity "
                    "%.2f" %
                    (model.global_step.eval(), model.learning_rate.eval(),
                     step_time, perplexity))

                # Decrease learning rate if no improvement was seen over last 3 times.
                '''if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
          sess.run(model.learning_rate_decay_op)'''
                previous_losses.append(loss)

                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               "transliterate.ckpt")
                model.saver.save(sess,
                                 checkpoint_path,
                                 global_step=model.global_step)

                # Run evals on development set and print their perplexity.
                valid_loss = []
                correct_prediction = 0
                itera = 0
                for bucket_id in xrange(len(_buckets)):
                    if len(dev_set[bucket_id]) == 0:
                        print("  eval: empty bucket %d" % (bucket_id))
                        continue
                    encoder_inputs, decoder_inputs, target_weights, dec_input = model.get_batch(
                        dev_set, bucket_id)

                    _, eval_loss, output_logits, _ = model.step(
                        sess, encoder_inputs, decoder_inputs, target_weights,
                        bucket_id, True)

                    eval_ppx = math.exp(
                        eval_loss) if eval_loss < 300 else float('inf')
                    print("  eval: bucket %d perplexity %.2f" %
                          (bucket_id, eval_ppx))
                    valid_loss.append(eval_loss)

                    output_logits = np.asarray(output_logits)
                    #print (np.shape(output_logits))
                    j = 0
                    for x in range(np.shape(output_logits)[1]):
                        outputs = []
                        for y in range(np.shape(output_logits)[0]):
                            logit = output_logits[y][x][:]
                            outputs.append([int(np.argmax(logit))])
                        #print (np.shape(np.asarray(outputs)))
                        #outputs = [int(np.argmax(logit,axis=1))for logit in output_logits[:][x][:]]
                        # If there is an EOS symbol in outputs, cut them at that point.

                        # Print out Hindi word corresponding to outputs.
                        #print (outputs)
                        output_ = []
                        for li in outputs:
                            output_.extend(li)
                        #print (output_)
                        if data_utils.EOS_ID in output_:
                            output_ = output_[:output_.index(data_utils.EOS_ID
                                                             ) + 1]

                        #print (output_)
                        #print (dec_input[j])
                        itera = itera + 1
                        if output_ == dec_input[j]:
                            correct_prediction = correct_prediction + 1
                        j += 1

                acc = correct_prediction / itera
                print(acc)
                print(itera)
                '''if len(validation_loss) > (FLAGS.pat - 1) and valid_loss_current_epoch > max(validation_loss[-1*FLAGS.pat:]):
            current_step = FLAGS.max_steps '''
                '''if len(validation_acc) > (FLAGS.pat - 1) and acc < min(validation_acc[-1*FLAGS.pat:]):
            current_step = FLAGS.max_steps'''
                valid_loss_current_epoch = np.mean(np.asarray(valid_loss))
                validation_loss.append(valid_loss_current_epoch)
                validation_acc.append(acc)
                file_logs.write(
                    "Step %i, train_Loss: % .5f, valid_Loss: % .5f \n" %
                    (current_step, loss, valid_loss_current_epoch))
                acc_logs.write("Step %i, accuracy: % .5f\n" %
                               (current_step, acc))
                print("validation loss % .3f" % (valid_loss_current_epoch))
                print("Train loss % .3f" % (step_loss)
                      )  #changed it to step_loss instead of running avg loss

                #reset params
                step_time, loss = 0.0, 0.0
                sys.stdout.flush()
        file_logs.close()
        np.savetxt(FLAGS.train_dir + 'log_loss_train.csv',
                   np.asarray(previous_losses),
                   delimiter=',')
        np.savetxt(FLAGS.train_dir + 'log_loss_valid.csv',
                   np.asarray(validation_loss),
                   delimiter=',')
Exemplo n.º 2
0
def train():
    """Train a en->hn transliteration model using REV_brandnames data."""
    # Prepare REV_brandnames data.
    print("Preparing REV data in %s" % FLAGS.data_dir)
    en_train, hn_train, en_dev, hn_dev, _, _ = data_utils.prepare_rev_data(
        FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.hn_vocab_size)

    with tf.Session() as sess:
        # Create model.
        print("Creating %d layers of %d units." %
              (FLAGS.num_layers, FLAGS.size))
        model = create_model(sess, False)

        # Read data into buckets and compute their sizes.
        print("Reading development and training data (limit: %d)." %
              FLAGS.max_train_data_size)
        dev_set = read_data(en_dev, hn_dev)
        train_set = read_data(en_train, hn_train, FLAGS.max_train_data_size)
        train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
        train_total_size = float(sum(train_bucket_sizes))

        # A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
        # to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
        # the size if i-th training bucket, as used later.
        train_buckets_scale = [
            sum(train_bucket_sizes[:i + 1]) / train_total_size
            for i in xrange(len(train_bucket_sizes))
        ]

        # This is the training loop.
        step_time, loss = 0.0, 0.0
        current_step = 0
        previous_losses = []
        while True:
            # Choose a bucket according to data distribution. We pick a random number
            # in [0, 1] and use the corresponding interval in train_buckets_scale.
            random_number_01 = np.random.random_sample()
            bucket_id = min([
                i for i in xrange(len(train_buckets_scale))
                if train_buckets_scale[i] > random_number_01
            ])

            # Get a batch and make a step.
            start_time = time.time()
            encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                train_set, bucket_id)
            _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
                                         target_weights, bucket_id, False)
            step_time += (time.time() -
                          start_time) / FLAGS.steps_per_checkpoint
            loss += step_loss / FLAGS.steps_per_checkpoint
            current_step += 1

            # Once in a while, we save checkpoint, print statistics, and run evals.
            if current_step % FLAGS.steps_per_checkpoint == 0:
                # Print statistics for the previous epoch.
                perplexity = math.exp(loss) if loss < 300 else float('inf')
                print(
                    "global step %d learning rate %.4f step-time %.2f perplexity "
                    "%.2f" %
                    (model.global_step.eval(), model.learning_rate.eval(),
                     step_time, perplexity))
                # Decrease learning rate if no improvement was seen over last 3 times.
                if len(previous_losses) > 2 and loss > max(
                        previous_losses[-3:]):
                    sess.run(model.learning_rate_decay_op)
                previous_losses.append(loss)
                # Save checkpoint and zero timer and loss.
                checkpoint_path = os.path.join(FLAGS.train_dir,
                                               "transliterate.ckpt")
                model.saver.save(sess,
                                 checkpoint_path,
                                 global_step=model.global_step)
                step_time, loss = 0.0, 0.0
                # Run evals on development set and print their perplexity.
                for bucket_id in xrange(len(_buckets)):
                    if len(dev_set[bucket_id]) == 0:
                        print("  eval: empty bucket %d" % (bucket_id))
                        continue
                    encoder_inputs, decoder_inputs, target_weights = model.get_batch(
                        dev_set, bucket_id)
                    _, eval_loss, _ = model.step(sess, encoder_inputs,
                                                 decoder_inputs,
                                                 target_weights, bucket_id,
                                                 True)
                    eval_ppx = math.exp(
                        eval_loss) if eval_loss < 300 else float('inf')
                    print("  eval: bucket %d perplexity %.2f" %
                          (bucket_id, eval_ppx))
                sys.stdout.flush()