Exemple #1
0
def train():
    config_proto = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement,
        gpu_options=tf.GPUOptions(allow_growth=True))

    with tf.Session(config=config_proto) as sess:
        # Build the model
        config = OrderedDict(sorted(FLAGS.__flags.items()))
        model = Seq2SeqModel(config, 'train')

        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        # Create a saver
        # Using var_list = None returns the list of all saveable variables
        saver = tf.train.Saver(var_list=None)

        # Initiaize global variables or reload existing checkpoint
        load_or_create_model(sess, model, saver, FLAGS)

        # Load word2vec embedding
        embedding = get_word_embedding(FLAGS.hidden_units,
                                       alignment=FLAGS.align_word2vec)
        print(embedding.shape)
        model.init_vars(sess, embedding=embedding)

        step_time, loss = 0.0, 0.0
        sents_seen = 0

        start_time = time.time()

        print 'Training...'
        for epoch_idx in xrange(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print 'Training is already complete.', \
                      'Current epoch: {}, Max epoch: {}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs)
                break

            # Prepare batch training data
            # TODO(sdsuo): Make corresponding changes in data_utils
            for source, source_len, target, target_len in gen_batch_train_data(
                    FLAGS.batch_size,
                    prev=FLAGS.prev_data,
                    rev=FLAGS.rev_data,
                    align=FLAGS.align_data,
                    cangtou=FLAGS.cangtou_data):
                step_loss, summary = model.train(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len,
                    decoder_inputs=target,
                    decoder_inputs_length=target_len)

                loss += float(step_loss) / FLAGS.display_freq
                sents_seen += float(source.shape[0])  # batch_size

                # Display information
                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    sents_per_sec = sents_seen / time_elapsed

                    print 'Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec)

                    loss = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print 'Saving the model..'
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               saver,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'wb'),
                              indent=2)

            # Increase the epoch index of the model
            model.increment_global_epoch_step_op.eval()
            print 'Epoch {0:} DONE'.format(model.global_epoch_step.eval())

        print 'Saving the last model'
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, saver, checkpoint_path, global_step=model.global_step)
        json.dump(
            model.config,
            open('%s-%d.json' % (checkpoint_path, model.global_step.eval()),
                 'wb'),
            indent=2)

    print 'Training terminated'
def train():
    config_proto = tf.ConfigProto(
        allow_soft_placement=FLAGS.allow_soft_placement,
        log_device_placement=FLAGS.log_device_placement,
        gpu_options=tf.GPUOptions(allow_growth=True)
        # device_count = {'GPU': 0}
    )

    with tf.Session(config=config_proto) as sess:
        # Build the model

        config = {
            'cangtou_data': False,
            'rev_data': True,
            'align_data': True,
            'prev_data': True,
            'align_word2vec': True,
            'cell_type': 'lstm',
            'attention_type': 'bahdanau',
            'hidden_units': 128,
            'depth': 4,
            'embedding_size': 128,
            'num_encoder_symbols': 30000,
            'num_decoder_symbols': 30000,
            'vocab_size': 6000,
            'use_residual': True,
            'attn_input_feeding': False,
            'use_dropout': True,
            'dropout_rate': 0.3,
            'learning_rate': 0.0002,
            'max_gradient_norm': 1.0,
            'batch_size': 64,
            'max_epochs': 10000,
            'max_load_batches': 20,
            'max_seq_length': 50,
            'display_freq': 100,
            'save_freq': 100,
            'valid_freq': 1150000,
            'optimizer': 'adam',
            'model_dir': 'model',
            'summary_dir': 'model/summary',
            'model_name': 'translate.ckpt',
            'shuffle_each_epoch': True,
            'sort_by_length': True,
            'use_fp16': False,
            'bidirectional': True,
            'train_mode': 'ground_truth',
            'sampling_probability': 0.1,
            'start_token': 0,
            'end_token': 5999,
            'allow_soft_placement': True,
            'log_device_placement': False
        }
        model = Seq2SeqModel(config, 'train')

        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_dir, graph=sess.graph)

        # Create a saver
        # Using var_list = None returns the list of all saveable variables
        saver = tf.train.Saver(var_list=None)

        # Initiaize global variables or reload existing checkpoint
        load_or_create_model(sess, model, saver, FLAGS)

        # Load word2vec embedding
        embedding = get_word_embedding(FLAGS.hidden_units,
                                       alignment=FLAGS.align_word2vec)
        model.init_vars(sess, embedding=embedding)

        step_time, loss = 0.0, 0.0
        sents_seen = 0

        start_time = time.time()

        print('Training...')
        cc = 0
        for epoch_idx in range(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print(
                    'Training is already complete.',
                    'Current epoch: {}, Max epoch: {}'.format(
                        model.global_epoch_step.eval(), FLAGS.max_epochs))
                break

            # Prepare batch training data
            # TODO(sdsuo): Make corresponding changes in data_utils
            for source, source_len, target, target_len in gen_batch_train_data(
                    FLAGS.batch_size,
                    prev=FLAGS.prev_data,
                    rev=FLAGS.rev_data,
                    align=FLAGS.align_data,
                    cangtou=FLAGS.cangtou_data):
                step_loss, summary = model.train(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len,
                    decoder_inputs=target,
                    decoder_inputs_length=target_len)
                cc += 1
                #print("dvdvshdskjs", cc)

                loss += float(step_loss) / FLAGS.display_freq
                sents_seen += float(source.shape[0])  # batch_size

                # Display information
                if model.global_step.eval() % 100 == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    sents_per_sec = sents_seen / time_elapsed

                    print('Epoch ', model.global_epoch_step.eval(), 'Step ',
                          model.global_step.eval(),
                          'Perplexity {0:.2f}'.format(avg_perplexity),
                          'Step-time ', step_time,
                          '{0:.2f} sents/s'.format(sents_per_sec))

                    loss = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                # Save the model checkpoint
                if model.global_step.eval() % 100 == 0:
                    print('Saving the model..')
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               saver,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'w'),
                              indent=2)

            # Increase the epoch index of the model
            model.increment_global_epoch_step_op.eval()
            print('Epoch {0:} DONE'.format(model.global_epoch_step.eval()))

        print('Saving the last model')
        checkpoint_path = os.path.join(FLAGS.model_dir, FLAGS.model_name)
        model.save(sess, saver, checkpoint_path, global_step=model.global_step)
        json.dump(
            model.config,
            open('%s-%d.json' % (checkpoint_path, model.global_step.eval()),
                 'w'),
            indent=2)

    print('Training terminated')