Esempio n. 1
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method="all_mask",
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        output_dict = get_losses(true_distriminator_dict["logits"],
                                 fake_discriminator_dict["logits"],
                                 use_tpu=use_tpu,
                                 gan_type=kargs.get('gan_type', "JS"))

        discriminator_dict = {}
        discriminator_dict['gen_loss'] = output_dict['gen_loss']
        discriminator_dict['disc_loss'] = output_dict['disc_loss']
        discriminator_dict['tvars'] = fake_discriminator_dict['tvars']
        discriminator_dict['fake_logits'] = fake_discriminator_dict['logits']
        discriminator_dict['true_logits'] = true_distriminator_dict['logits']

        model_io_fn = model_io.ModelIO(model_io_config)

        loss = discriminator_dict['disc_loss']
        tvars = []
        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(discriminator_dict)

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_true_loss",
                                  discriminator_dict['disc_loss'])
                tf.summary.scalar("discriminator_fake_loss",
                                  discriminator_dict['gen_loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    discriminator_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 discriminator_dict):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        discriminator_dict)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict)
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict)
                tpu_eval_metrics = (discriminator_metric_eval,
                                    [discriminator_dict])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
Esempio n. 2
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        mask_method = kargs.get('mask_method', 'only_mask')
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if mask_method == 'only_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)
        elif mask_method == 'all_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with all token *******",
                mask_method)
        else:
            mask_method = 'only_mask'
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)

        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        else:
            if_flip_grad = True
            train_op_type = 'joint'
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            resample_discriminator=False,
            # mask_method='only_mask',
            **kargs)
        # train_op_type = 'joint'
        # elif kargs.get('optimization_type', 'grl') == 'minmax':
        # 	generator_fn = generator_normal(model_config_dict['generator'],
        # 				num_labels_dict['generator'],
        # 				init_checkpoint_dict['generator'],
        # 				model_reuse=None,
        # 				load_pretrained=load_pretrained_dict['generator'],
        # 				model_io_config=model_io_config,
        # 				opt_config=opt_config,
        # 				exclude_scope=exclude_scope_dict.get('generator', ""),
        # 				not_storage_params=not_storage_params_dict.get('generator', []),
        # 				target=target_dict['generator'],
        # 				**kargs)
        # else:
        # 	generator_fn = generator(model_config_dict['generator'],
        # 				num_labels_dict['generator'],
        # 				init_checkpoint_dict['generator'],
        # 				model_reuse=None,
        # 				load_pretrained=load_pretrained_dict['generator'],
        # 				model_io_config=model_io_config,
        # 				opt_config=opt_config,
        # 				exclude_scope=exclude_scope_dict.get('generator', ""),
        # 				not_storage_params=not_storage_params_dict.get('generator', []),
        # 				target=target_dict['generator'],
        # 				**kargs)
        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        # for key in generator_dict:
        # 	if isinstance(generator_dict[key], list):
        # 		for item in generator_dict[key]:
        # 			print(key, item.graph, '=====generator graph=====')
        # 	else:
        # 		try:
        # 			print(key, generator_dict[key].graph, '=====generator graph=====')
        # 		except:
        # 			print(key, type(generator_dict[key]), '=====generator graph=====')

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            loss='cross_entropy',  # cross_entropy
            **kargs)

        discriminator_features = {}
        # minmax_mode in ['masked', 'corrupted']
        minmax_mode = kargs.get('minmax_mode', 'corrupted')
        tf.logging.info("****** minmax mode for discriminator: %s *******",
                        minmax_mode)
        if minmax_mode == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif minmax_mode == 'masked':
            discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditional sampled_ids *******")
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_features['ori_input_ids'] = generator_dict['sampled_ids']

        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)
        discriminator_dict['input_ids'] = generator_dict['sampled_ids']

        # for key in discriminator_dict:
        # 	if isinstance(discriminator_dict[key], list):
        # 		for item in discriminator_dict[key]:
        # 			print(key, item.graph, '=====discriminator graph=====')
        # 	else:
        # 		try:
        # 			print(key, discriminator_dict[key].graph, '=====discriminator graph=====')
        # 		except:
        # 			print(key, type(discriminator_dict[key]), '=====discriminator graph=====')

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * discriminator_dict['loss']

        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        # print(loss.graph, '===total graph===')

        # logging_hook = tf.train.LoggingTensorHook({
        # 				"generator_loss" : tf.get_collection('generator_loss'),
        # 				"discriminator_loss":tf.get_collection('discriminator_loss')},
        # 				every_n_iter=1000)

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_loss",
                                  discriminator_dict['loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    discriminator_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # with tf.control_dependencies(update_ops):
            # 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
            # 					opt_config.init_lr,
            # 					opt_config.num_train_steps,
            # 					use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        if kargs.get('optimization_type', 'grl') == 'grl':
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            generator_fn = generator_normal(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        else:
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        nce_loss = nce_loss_fn(true_distriminator_dict,
                               true_distriminator_features,
                               fake_discriminator_dict,
                               fake_discriminator_features)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * nce_loss

        tvars.extend(fake_discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('summary_debug', False):
                metric_dict = discriminator_metric_train(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()