Beispiel #1
0
  def model_fn(features, labels, mode, params):
    # features name and shape
    _info('*** Features ****')
    for name in sorted(features.keys()):
      tf.logging.info(' name = {}, shape = {}'.format(name, features[name].shape))

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    # get data
    input_x = features['input_x']
    input_mask = features['input_mask']
    if is_training:
      input_y = features['input_y']
      seq_length = features['seq_length']
    else:
      input_y = None
      seq_length = None

    # build encoder
    model = BertEncoder(
      config=cg.BertEncoderConfig,
      is_training=is_training,
      input_ids=input_x,
      input_mask=input_mask)
    embedding_table = model.get_embedding_table()
    encoder_output = tf.reduce_sum(model.get_sequence_output(), axis=1)

    # build decoder
    decoder_model = Decoder(
      config=cg.DecoderConfig,
      is_training=is_training,
      encoder_state=encoder_output,
      embedding_table=embedding_table,
      decoder_intput_data=input_y,
      seq_length_decoder_input_data=seq_length)
    logits, sample_id, ppl_seq, ppl = decoder_model.get_decoder_output()

    if mode == tf.estimator.ModeKeys.PREDICT:
      predictions = {'sample_id': sample_id, 'ppls': ppl_seq}
      output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)
    else:
      if mode == tf.estimator.ModeKeys.TRAIN:
        max_time = ft.get_shape_list(labels, expected_rank=2)[1]
        target_weights = tf.sequence_mask(seq_length, max_time, dtype=logits.dtype)
        batch_size = tf.cast(ft.get_shape_list(labels, expected_rank=2)[0], tf.float32)

        loss = tf.reduce_sum(
          tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) * target_weights) / batch_size

        learning_rate = tf.train.polynomial_decay(cg.learning_rate,
                                          tf.train.get_or_create_global_step(),
                                          cg.train_steps / 100,
                                          end_learning_rate=1e-4,
                                          power=1.0,
                                          cycle=False)

        lr = tf.maximum(tf.constant(cg.lr_limit), learning_rate)
        optimizer = tf.train.AdamOptimizer(lr, name='optimizer')
        tvars = tf.trainable_variables()
        gradients = tf.gradients(loss, tvars, colocate_gradients_with_ops=cg.colocate_gradients_with_ops)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, tvars), global_step=tf.train.get_global_step())


        # this is excellent, because it could display the result each step, i.e., each step equals to batch_size.
        # the output_spec, display the result every save checkpoints step.
        logging_hook = tf.train.LoggingTensorHook({'loss' : loss, 'ppl': ppl, 'lr': lr}, every_n_iter=cg.print_info_interval)

        output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, training_hooks=[logging_hook])
      elif mode == tf.estimator.ModeKeys.EVAL:
        # TODO
        raise NotImplementedError
    
    return output_spec
Beispiel #2
0
  def model_fn(features, labels, mode, params):
    # features name and shape
    for name in sorted(features.keys()):
      tf.logging.info(' name = {}, shape = {}'.format(name, features[name].shape))

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)

    # get data
    input_data = features['input_data']
    input_mask = features['input_mask']
    if mode == tf.estimator.ModeKeys.TRAIN:
      sentiment_labels = features['sentiment_labels']
      sentiment_mask_indices = features['sentiment_mask_indices']
      true_length_from_data = features['true_length']

    # build model
    model = BertEncoder(
      config=cg.BertEncoderConfig,
      is_training=is_training,
      input_ids=input_data,
      input_mask=input_mask)
    
    tvars = tf.trainable_variables()
    initialized_variable_names = {}
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    # [cls] output -> [b, h]
    cls_output = model.get_cls_output()
    # sequence_output -> [b, s, h], do not contain [CLS], because the mask indices do not shift
    sequence_output = model.get_sequence_output()[:, 1:, :]

    # project the hidden size to the num_classes
    with tf.variable_scope('final_output'):
      # [b, num_classes]
      output_logits = tf.layers.dense(
        cls_output,
        cg.BertEncoderConfig.num_classes,
        name='final_output',
        kernel_initializer=ft.create_initializer(initializer_range=cg.BertEncoderConfig.initializer_range))

    if mode == tf.estimator.ModeKeys.PREDICT:
      output_softmax = tf.nn.softmax(output_logits, axis=-1)
      output_result = tf.argmax(output_softmax, axis=-1)
      predictions = {'predict': output_result}
      output_spec = tf.estimator.EstimatorSpec(mode, predictions=predictions)
    else:
      if mode == tf.estimator.ModeKeys.TRAIN:
        # masked_output -> [b * x, h]
        masked_output = gather_indexs(sequence_output, sentiment_mask_indices)
        
        # get output for word polarity prediction
        with tf.variable_scope('sentiment_project'):
          # [b * x, 2]
          output_sentiment = tf.layers.dense(
            masked_output,
            2,
            name='final_output',
            kernel_initializer=ft.create_initializer(initializer_range=cg.BertEncoderConfig.initializer_range))
        # output_sentiment_probs = tf.nn.softmax(output_sentiment, axis=-1)

        batch_size = tf.cast(ft.get_shape_list(labels, expected_rank=1)[0], dtype=tf.float32)
        # cross-entropy loss
        cls_loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=labels,
          logits=output_logits)) / batch_size

        # mse loss
        # # Regression Model
        true_sequence = get_true_sequence(true_length_from_data)
        # mse_loss = calculate_mse_loss(
        #   output_sentiment, sentiment_labels, true_sequence)

        # # Classification Model
        true_label_flatten = tf.reshape(sentiment_labels, [-1])
        mse_loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=true_label_flatten,
          logits=output_sentiment) * true_sequence) / tf.reduce_sum(true_sequence)

        loss = cls_loss + mse_loss
        # loss = cls_loss

        learning_rate = tf.train.polynomial_decay(cg.learning_rate,
                                  tf.train.get_or_create_global_step(),
                                  cg.train_steps,
                                  end_learning_rate=cg.lr_limit,
                                  power=1.0,
                                  cycle=False)

        lr = tf.maximum(tf.constant(cg.lr_limit), learning_rate)
        optimizer = tf.train.AdamOptimizer(lr, name='optimizer')
        tvars = tf.trainable_variables()
        gradients = tf.gradients(loss, tvars, colocate_gradients_with_ops=cg.colocate_gradients_with_ops)
        clipped_gradients, _ = tf.clip_by_global_norm(gradients, 5.0)
        train_op = optimizer.apply_gradients(zip(clipped_gradients, tvars), global_step=tf.train.get_global_step())

        current_steps = tf.train.get_or_create_global_step()
        logging_hook = tf.train.LoggingTensorHook(
          {'step' : current_steps, 'loss' : loss, 'cls_loss' : cls_loss, 'mse_loss': mse_loss, 'lr' : lr}, 
          every_n_iter=cg.print_info_interval)

        output_spec = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, training_hooks=[logging_hook])
      elif mode == tf.estimator.ModeKeys.EVAL:
        # TODO
        raise NotImplementedError
    
    return output_spec