예제 #1
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]
        model_lst = []
        for index, name in enumerate(input_name):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                base_model(model_config,
                           features,
                           labels,
                           mode,
                           name,
                           reuse=reuse))

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            seq_output_lst = [model.get_pooled_output() for model in model_lst]

            [loss, per_example_loss, logits
             ] = classifier.interaction_classifier(model_config,
                                                   seq_output_lst, num_labels,
                                                   label_ids, dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')

        ebm_noise_fce = EBM_NOISE_NCE(
            model_config_dict,
            num_labels_dict,
            init_checkpoint_dict,
            load_pretrained_dict,
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            **kargs)

        model_io_fn = model_io.ModelIO(model_io_config)
        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            train_op = get_train_op(ebm_noise_fce,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['ebm_dist'],
                                    model_config_dict['noise_dist'],
                                    model_config_dict['generator'],
                                    features,
                                    labels,
                                    mode,
                                    params,
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    alternate_order=['ebm', 'generator'])

            ebm_noise_fce.load_pretrained_model(**kargs)
            var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list
            loss = ebm_noise_fce.loss
            tvars = ebm_noise_fce.tvars

            if len(var_checkpoint_dict_list) >= 1:
                scaffold_fn = model_io_fn.load_multi_pretrained(
                    var_checkpoint_dict_list, use_tpu=use_tpu)
            else:
                scaffold_fn = None

            metric_dict = ebm_train_metric(
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits'])

            if not kargs.get('use_tpu', False):
                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("ebm_loss",
                                  ebm_noise_fce.ebm_opt_dict['ebm_loss'])
                tf.summary.scalar("mlm_loss",
                                  ebm_noise_fce.ebm_opt_dict['mlm_loss'])
                tf.summary.scalar("all_loss",
                                  ebm_noise_fce.ebm_opt_dict['all_loss'])

            model_io_fn.print_params(tvars, string=", trainable params")

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            ebm_noise_fce.get_loss(features, labels, mode, params, **kargs)
            ebm_noise_fce.load_pretrained_model(**kargs)
            var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list
            loss = ebm_noise_fce.loss

            if len(var_checkpoint_dict_list) >= 1:
                scaffold_fn = model_io_fn.load_multi_pretrained(
                    var_checkpoint_dict_list, use_tpu=use_tpu)
            else:
                scaffold_fn = None

            tpu_eval_metrics = (ebm_eval_metric, [
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits']
            ])
            gpu_eval_metrics = ebm_eval_metric(
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits'])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
            for ext in usersong.UserSong.EXTENSIONS:
                if fnmatch.fnmatch(file_path, "*%s" % ext):
                    file_paths.append(file_path)
                    break
    return file_paths

song_folder_path = os.path.normpath(os.path.join(project_path, "djskinnyg_songs"))
song_paths = get_file_paths(song_folder_path)
cache_path = os.path.normpath(os.path.join(song_folder_path, "cache"))
song_objects = usersong.batch_create_user_songs(song_paths)
usersong.batch_analyze_user_songs(song_objects, cache_path)

for song in song_objects:
    song.write_analysis_to_folder(cache_path)
    print("{} : {} : {} : {} : {} : {}".format(song.get_analysis_feature(analysis.Feature.NAME), song.get_analysis_feature(analysis.Feature.TEMPO), song.get_analysis_feature(analysis.Feature.KEY), song.get_analysis_feature(analysis.Feature.DANCEABILITY), song.get_analysis_feature(analysis.Feature.ENERGY), song.get_analysis_feature(analysis.Feature.VALENCE)))

# create a list with one goal (start)
first_goal = mix_goal.MixGoal(0.0, 0.0, 0.0, 0.0, 1)
goals = list([first_goal])
# initialize optimizer
dj = optimizer.Optimizer(song_objects, goals, style.Style_Lib.tempo_based.value)
mix_script = dj.generate_mixtape()
print("***MIX SCRIPT RESULT***")

for mix in mix_script:
    print(mix)

c = composer.composer_parser(mix_script)

c.compose()
예제 #4
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        mask_method = kargs.get('mask_method', 'only_mask')
        use_tpu = 1 if kargs.get('use_tpu', False) else 0
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if mask_method == 'only_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)
        elif mask_method == 'all_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with all token *******",
                mask_method)
        else:
            mask_method = 'only_mask'
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)

        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        else:
            if_flip_grad = True
            train_op_type = 'joint'
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method='only_mask',
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator_generator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            loss='cross_entropy',
            **kargs)

        discriminator_features = {}
        # minmax_mode in ['masked', 'corrupted']
        minmax_mode = kargs.get('minmax_mode', 'corrupted')
        tf.logging.info("****** minmax mode for discriminator: %s *******",
                        minmax_mode)
        if minmax_mode == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif minmax_mode == 'masked':
            discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditional sampled_ids *******")
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_features['ori_input_ids'] = generator_dict['sampled_ids']

        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        [disc_loss, disc_logits, disc_per_example_loss
         ] = optimal_discriminator(model_config_dict['discriminator'],
                                   generator_dict,
                                   features,
                                   discriminator_dict,
                                   discriminator_features,
                                   use_tpu=use_tpu)

        [
            equal_per_example_loss, equal_loss_all, equal_loss_self,
            not_equal_per_example_loss, not_equal_loss_all, not_equal_loss_self
        ] = modified_loss(disc_per_example_loss,
                          disc_logits,
                          discriminator_features['input_ori_ids'],
                          discriminator_features['ori_input_ids'],
                          discriminator_features['input_mask'],
                          sampled_binary_mask=discriminator_features.get(
                              'sampled_binary_mask', None),
                          **kargs)
        output_dict = {}
        output_dict['logits'] = disc_logits
        output_dict['per_example_loss'] = disc_per_example_loss
        output_dict['loss'] = disc_loss + 0.0 * discriminator_dict["loss"]
        output_dict["equal_per_example_loss"] = equal_per_example_loss,
        output_dict["equal_loss_all"] = equal_loss_all,
        output_dict["equal_loss_self"] = equal_loss_self,
        output_dict["not_equal_per_example_loss"] = not_equal_per_example_loss,
        output_dict["not_equal_loss_all"] = not_equal_loss_all,
        output_dict["not_equal_loss_self"] = not_equal_loss_self
        output_dict['tvars'] = discriminator_dict['tvars']

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * output_dict['loss']

        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(
                    output_dict['per_example_loss'], output_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_loss",
                                  discriminator_dict['loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    output_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # with tf.control_dependencies(update_ops):
            # 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
            # 					opt_config.init_lr,
            # 					opt_config.num_train_steps,
            # 					use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        seq_features['input_ids'] = features["input_ori_ids"]

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        sequence_mask = tf.to_float(
            tf.not_equal(features['input_ori_ids'][:, 1:],
                         kargs.get('[PAD]', 0)))

        # batch x seq_length
        print(model.get_sequence_output_logits().get_shape(),
              "===logits shape===")
        seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=features['input_ori_ids'][:, 1:],
            logits=model.get_sequence_output_logits()[:, :-1])

        per_example_loss = tf.reduce_sum(seq_loss * sequence_mask, axis=-1) / (
            tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
        loss = tf.reduce_mean(per_example_loss)

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
            tf.logging.info("***** using tpu *****")
        else:
            scaffold_fn = None
            tf.logging.info("***** not using tpu *****")

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                train_metric_dict = train_metric(
                    features['input_ori_ids'],
                    model.get_sequence_output_logits(), **kargs)

                if not kargs.get('use_tpu', False):
                    for key in train_metric_dict:
                        tf.summary.scalar(key, train_metric_dict[key])
                    tf.summary.scalar('learning_rate',
                                      optimizer_fn.learning_rate)
                    tf.logging.info("***** logging metric *****")
                    tf.summary.scalar("causal_attenion_mask_length",
                                      tf.reduce_sum(model.attention_mask))
                    tf.summary.scalar("bi_attenion_mask_length",
                                      tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits())
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits()
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #6
0
    def _build_model(self):
        """Build model (initially based on CPLEX 12.8.1)"""
        self._diet = optimizer.Optimizer()
        self._var_names_x = ["x" + str(f_id) for f_id in self.ingredient_ids]

        diet = self._diet
        diet.set_sense(sense="max")

        self._remove_inf(self.cost_obj_vector)

        x_vars = list(
            diet.add_variables(
                obj=self.cost_obj_vector,
                lb=self.ds.sorted_column(self.data_feed_scenario,
                                         self.headers_feed_scenario.s_min,
                                         self.ingredient_ids,
                                         self.headers_feed_scenario.s_ID),
                ub=self.ds.sorted_column(self.data_feed_scenario,
                                         self.headers_feed_scenario.s_max,
                                         self.ingredient_ids,
                                         self.headers_feed_scenario.s_ID),
                names=self._var_names_x))
        diet.set_obj_offset(self.cst_obj)

        "Constraint: sum(x a) == CNEm"
        diet.add_constraint(names=["CNEm GE"],
                            lin_expr=[[
                                x_vars,
                                self.ds.sorted_column(
                                    self.data_feed_lib,
                                    self.headers_feed_lib.s_NEma,
                                    self.ingredient_ids,
                                    self.headers_feed_lib.s_ID)
                            ]],
                            rhs=[self._p_cnem * 0.999],
                            senses=["G"])
        diet.add_constraint(names=["CNEm LE"],
                            lin_expr=[[
                                x_vars,
                                self.ds.sorted_column(
                                    self.data_feed_lib,
                                    self.headers_feed_lib.s_NEma,
                                    self.ingredient_ids,
                                    self.headers_feed_lib.s_ID)
                            ]],
                            rhs=[self._p_cnem * 1.001],
                            senses=["L"])
        "Constraint: sum(x) == 1"
        diet.add_constraint(names=["SUM 1"],
                            lin_expr=[[x_vars, [1] * len(x_vars)]],
                            rhs=[1],
                            senses=["E"])
        "Constraint: sum(x a)>= MPm"
        mp_properties = self.ds.sorted_column(self.data_feed_lib, [
            self.headers_feed_lib.s_DM, self.headers_feed_lib.s_TDN,
            self.headers_feed_lib.s_CP, self.headers_feed_lib.s_RUP,
            self.headers_feed_lib.s_Forage, self.headers_feed_lib.s_Fat
        ], self.ingredient_ids, self.headers_feed_lib.s_ID)
        mpm_list = [nrc.mp(*row) for row in mp_properties]

        # for i, v in enumerate(mpm_list):
        #     mpm_list[i] = v - (self._p_swg * 268 - self._p_neg * 29.4) * 0.001 / self._p_dmi

        diet.add_constraint(
            names=["MPm"],
            lin_expr=[[x_vars, mpm_list]],
            rhs=[(self._p_mpm + 268 * self._p_swg - 29.4 * self._p_neg) *
                 0.001 / self._p_dmi],
            senses=["G"])

        rdp_data = [
            (1 - self.ds.sorted_column(
                self.data_feed_lib, self.headers_feed_lib.s_RUP,
                self.ingredient_ids, self.headers_feed_lib.s_ID)[x_index]) *
            self.ds.sorted_column(
                self.data_feed_lib, self.headers_feed_lib.s_CP,
                self.ingredient_ids, self.headers_feed_lib.s_ID)[x_index]
            for x_index in range(len(x_vars))
        ]

        "Constraint: RUP: sum(x a) >= 0.125 CNEm"
        diet.add_constraint(names=["RDP"],
                            lin_expr=[[x_vars, rdp_data]],
                            rhs=[0.125 * self._p_cnem],
                            senses=["G"])

        "Constraint: Fat: sum(x a) <= 0.06 DMI"
        diet.add_constraint(names=["Fat"],
                            lin_expr=[[
                                x_vars,
                                self.ds.sorted_column(
                                    self.data_feed_lib,
                                    self.headers_feed_lib.s_Fat,
                                    self.ingredient_ids,
                                    self.headers_feed_lib.s_ID)
                            ]],
                            rhs=[0.06],
                            senses=["L"])

        "Constraint: peNDF: sum(x a) <= peNDF DMI"
        pendf_data = [
            self.ds.sorted_column(
                self.data_feed_lib, self.headers_feed_lib.s_NDF,
                self.ingredient_ids, self.headers_feed_lib.s_ID)[x_index] *
            self.ds.sorted_column(
                self.data_feed_lib, self.headers_feed_lib.s_pef,
                self.ingredient_ids, self.headers_feed_lib.s_ID)[x_index]
            for x_index in range(len(x_vars))
        ]
        diet.add_constraint(names=["peNDF"],
                            lin_expr=[[x_vars, pendf_data]],
                            rhs=[self._p_pe_ndf],
                            senses=["G"])

        self.constraints_names = diet.get_constraints_names()
        # diet.write_lp(name="file.lp")
        pass
예제 #7
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method="all_mask",
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        output_dict = get_losses(true_distriminator_dict["logits"],
                                 fake_discriminator_dict["logits"],
                                 use_tpu=use_tpu,
                                 gan_type=kargs.get('gan_type', "JS"))

        discriminator_dict = {}
        discriminator_dict['gen_loss'] = output_dict['gen_loss']
        discriminator_dict['disc_loss'] = output_dict['disc_loss']
        discriminator_dict['tvars'] = fake_discriminator_dict['tvars']
        discriminator_dict['fake_logits'] = fake_discriminator_dict['logits']
        discriminator_dict['true_logits'] = true_distriminator_dict['logits']

        model_io_fn = model_io.ModelIO(model_io_config)

        loss = discriminator_dict['disc_loss']
        tvars = []
        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(discriminator_dict)

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_true_loss",
                                  discriminator_dict['disc_loss'])
                tf.summary.scalar("discriminator_fake_loss",
                                  discriminator_dict['gen_loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    discriminator_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 discriminator_dict):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        discriminator_dict)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict)
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict)
                tpu_eval_metrics = (discriminator_metric_eval,
                                    [discriminator_dict])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #8
0
	def model_fn(features, labels, mode, params):

		ebm_noise_fce = EBM_NOISE_FCE(model_config_dict,
									num_labels_dict,
									init_checkpoint_dict,
									load_pretrained_dict,
									model_io_config=model_io_config,
									opt_config=opt_config,
									exclude_scope_dict=exclude_scope_dict,
									not_storage_params_dict=not_storage_params_dict,
									target_dict=target_dict,
									**kargs)

		model_io_fn = model_io.ModelIO(model_io_config)

		if mode == tf.estimator.ModeKeys.TRAIN:

			if kargs.get('use_tpu', False):
				optimizer_fn = optimizer.Optimizer(opt_config)
				use_tpu = 1
			else:
				optimizer_fn = distributed_optimizer.Optimizer(opt_config)
				use_tpu = 0

			train_op, loss, var_checkpoint_dict_list = get_train_op(
								optimizer_fn, opt_config,
								model_config_dict['ebm_dist'], 
								model_config_dict['noise_dist'],
								features, labels, mode, params,
								ebm_noise_fce,
								use_tpu=use_tpu)

			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			
			if len(var_checkpoint_dict_list) >= 1:
				scaffold_fn = model_io_fn.load_multi_pretrained(
												var_checkpoint_dict_list,
												use_tpu=use_tpu)
			else:
				scaffold_fn = None

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
								mode=mode,
								loss=loss,
								train_op=train_op,
								scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(
								mode=mode, 
								loss=loss, 
								train_op=train_op)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			ebm_noise_fce.get_loss(features, labels, mode, params, **kargs)

			tpu_eval_metrics = (ebm_noise_eval_metric, 
								[
								ebm_noise_fce.true_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['true_logits'], 
								ebm_noise_fce.fake_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								ebm_noise_fce.noise_dist_dict["true_seq_logits"]
								])
			gpu_eval_metrics = ebm_noise_eval_metric(
								ebm_noise_fce.true_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['true_logits'], 
								ebm_noise_fce.fake_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								ebm_noise_fce.noise_dist_dict["true_seq_logits"]
								)

			loss = ebm_noise_fce.ebm_loss + ebm_noise_fce.noise_loss
			var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list

			if len(var_checkpoint_dict_list) >= 1:
				scaffold_fn = model_io_fn.load_multi_pretrained(
												var_checkpoint_dict_list,
												use_tpu=use_tpu)
			else:
				scaffold_fn = None

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							  mode=mode,
							  loss=loss,
							  eval_metrics=tpu_eval_metrics,
							  scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=gpu_eval_metrics)

			return estimator_spec
		else:
			raise NotImplementedError()
예제 #9
0
    def model_fn(features, labels, mode, params):

        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            **kargs)
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        discriminator_features = {}
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []
        loss = discriminator_dict['loss']
        print(loss)
        tvars.extend(discriminator_dict['tvars'])
        if kargs.get('joint_train', '0') == '1':
            tvars.extend(generator_fn['tvars'])
            loss += generator_dict['loss']

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars": generator_dict['tvars'],
                        "init_checkpoint": init_checkpoint_dict['generator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars": discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            metric_dict = discriminator_metric_train(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            for key in metric_dict:
                tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':
                generator_metric = generator_metric_fn_eval(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None))
            else:
                generator_metric = {}

            discriminator_metric = discriminator_metric_eval(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            metric_dict = discriminator_metric
            if len(generator_metric):
                metric_dict.update(discriminator_metric)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=metric_dict,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=metric_dict)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #10
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if 'input_mask' not in features:
            input_mask = tf.cast(
                tf.not_equal(features['input_ids_{}'.format(target)],
                             kargs.get('[PAD]', 0)), tf.int32)

            if target:
                features['input_mask_{}'.format(target)] = input_mask
            else:
                features['input_mask'] = input_mask
        if 'segment_ids' not in features:
            segment_ids = tf.zeros_like(input_mask)
            if target:
                features['segment_ids_{}'.format(target)] = segment_ids
            else:
                features['segment_ids'] = segment_ids

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_mask'] = features['input_mask_{}'.format(target)]
            features['segment_ids'] = features['segment_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]

        input_ori_ids = features.get('input_ori_ids', None)
        if mode == tf.estimator.ModeKeys.TRAIN:
            if input_ori_ids is not None:
                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(
                # 							model_config,
                # 							input_ori_ids,
                # 							features['input_mask'],
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.1,
                    original_probability=0.1,
                    mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info(
                    "***** Running random sample input generation *****")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        #(nsp_loss,
        # nsp_per_example_loss,
        # nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
        #								model.get_pooled_output(),
        #								features['next_sentence_labels'],
        #								reuse=tf.AUTO_REUSE)

        # masked_lm_positions = features["masked_lm_positions"]
        # masked_lm_ids = features["masked_lm_ids"]
        # masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:

            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]

            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        if load_pretrained == "yes":
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=1)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=opt_config.use_tpu)

                #	train_metric_dict = train_metric_fn(
                #			masked_lm_example_loss, masked_lm_log_probs,
                #			masked_lm_ids,
                #			masked_lm_mask,
                #			nsp_per_example_loss,
                #			nsp_log_prob,
                #			features['next_sentence_labels'],
                #			masked_lm_mask=masked_lm_mask
                #		)

                # for key in train_metric_dict:
                # 	tf.summary.scalar(key, train_metric_dict[key])
                # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_mask, nsp_per_example_loss, nsp_log_prob,
                features['next_sentence_labels']
            ])

            estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #11
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        if kargs.get('optimization_type', 'grl') == 'grl':
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            generator_fn = generator_normal(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        else:
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        nce_loss = nce_loss_fn(true_distriminator_dict,
                               true_distriminator_features,
                               fake_discriminator_dict,
                               fake_discriminator_features)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * nce_loss

        tvars.extend(fake_discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('summary_debug', False):
                metric_dict = discriminator_metric_train(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
예제 #12
0
	def model_fn(features, labels, mode, params):

		train_op_type = kargs.get('train_op_type', 'joint')
		print("==input shape==", features["input_ids"].get_shape())

		ebm_dist_fn = ebm_dist(model_config_dict['ebm_dist'],
					num_labels_dict['ebm_dist'],
					init_checkpoint_dict['ebm_dist'],
					model_reuse=None,
					load_pretrained=load_pretrained_dict['ebm_dist'],
					model_io_config=model_io_config,
					opt_config=opt_config,
					exclude_scope=exclude_scope_dict.get('ebm_dist', ""),
					not_storage_params=not_storage_params_dict.get('ebm_dist', []),
					target=target_dict['ebm_dist'],
					prob_ln=False,
					transform=False,
					transformer_activation="linear",
					logz_mode='standard',
					normalized_constant="length_linear",
					energy_pooling="mi",
					softplus_features=False,
					**kargs)

		noise_prob_ln = False
		noise_sample = kargs.get("noise_sample", 'mlm')

		if kargs.get("noise_sample", 'mlm') == 'gpt':
			tf.logging.info("****** using gpt for noise dist sample *******")
			sample_noise_dist = True
		elif kargs.get("noise_sample", 'mlm') == 'mlm':
			tf.logging.info("****** using bert mlm for noise dist sample *******")
			sample_noise_dist = False
		else:
			tf.logging.info("****** using gpt for noise dist sample *******")
			sample_noise_dist = True

		noise_dist_fn = noise_dist(model_config_dict['noise_dist'],
					num_labels_dict['noise_dist'],
					init_checkpoint_dict['noise_dist'],
					model_reuse=None,
					load_pretrained=load_pretrained_dict['noise_dist'],
					model_io_config=model_io_config,
					opt_config=opt_config,
					exclude_scope=exclude_scope_dict.get('noise_dist', ""),
					not_storage_params=not_storage_params_dict.get('noise_dist', []),
					target=target_dict['noise_dist'],
					noise_true_distribution=True,
					sample_noise_dist=sample_noise_dist,
					noise_estimator_type=kargs.get("noise_estimator_type", "stop_gradient"),
					prob_ln=noise_prob_ln,
					if_bp=True,
					**kargs)

		if not sample_noise_dist:
			tf.logging.info("****** using bert mlm for noise dist sample *******")

			global_step = tf.train.get_or_create_global_step()
			noise_sample_ratio = tf.train.polynomial_decay(
													0.20,
													global_step,
													opt_config.num_train_steps,
													end_learning_rate=0.1,
													power=1.0,
													cycle=False)

			mlm_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'],
						num_labels_dict['generator'],
						init_checkpoint_dict['generator'],
						model_reuse=None,
						load_pretrained=load_pretrained_dict['generator'],
						model_io_config=model_io_config,
						opt_config=opt_config,
						exclude_scope=exclude_scope_dict.get('generator', ""),
						not_storage_params=not_storage_params_dict.get('generator', []),
						target=target_dict['generator'],
						mask_probability=noise_sample_ratio,
						replace_probability=0.2,
						original_probability=0.0,
						**kargs)
		else:
			mlm_noise_dist_fn = None

		true_features = {}

		for key in features:
			if key == 'input_ori_ids':
				true_features["input_ids"] = tf.cast(features['input_ori_ids'], tf.int32)
			if key in ['input_mask', 'segment_ids']:
				true_features[key] = tf.cast(features[key], tf.int32)

		if kargs.get("dnce", False):

			if kargs.get("anneal_dnce", False):
				global_step = tf.train.get_or_create_global_step()
				noise_sample_ratio = tf.train.polynomial_decay(
														0.10,
														global_step,
														opt_config.num_train_steps,
														end_learning_rate=0.05,
														power=1.0,
														cycle=False)
				tf.logging.info("****** anneal dnce mix ratio *******")
			else:
				noise_sample_ratio = 0.10
				tf.logging.info("****** not anneal dnce mix ratio *******")

			mlm_noise_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'],
						num_labels_dict['generator'],
						init_checkpoint_dict['generator'],
						model_reuse=None,
						load_pretrained=load_pretrained_dict['generator'],
						model_io_config=model_io_config,
						opt_config=opt_config,
						exclude_scope=exclude_scope_dict.get('generator', ""),
						not_storage_params=not_storage_params_dict.get('generator', []),
						target=target_dict['generator'],
						mask_probability=noise_sample_ratio,
						replace_probability=0.0,
						original_probability=0.0,
						**kargs)

			mlm_noise_dist_dict_noise = mlm_noise_noise_dist_fn(features, labels, mode, params)

			mixed_mask = mixed_sample(features, mix_ratio=noise_sample_ratio)
			tf.logging.info("****** apply dnce *******")
			mixed_mask = tf.expand_dims(mixed_mask, axis=-1) # batch_size x 1
			mixed_mask = tf.cast(mixed_mask, tf.int32)
			true_features["input_ids"] = (1-mixed_mask)*true_features["input_ids"] + mixed_mask * mlm_noise_dist_dict_noise['sampled_ids']

		if not sample_noise_dist:
			mlm_noise_dist_dict = mlm_noise_dist_fn(features, labels, mode, params)
		else:
			mlm_noise_dist_dict = {}

		# first get noise dict
		noise_dist_dict = noise_dist_fn(true_features, labels, mode, params)

		# third, get fake ebm dict
		fake_features = {}

		if noise_sample == 'gpt':
			if kargs.get("training_mode", "stop_gradient") == 'stop_gradient':
				fake_features["input_ids"] = noise_dist_dict['fake_samples']
				tf.logging.info("****** using samples stop gradient *******")
			elif kargs.get("training_mode", "stop_gradient") == 'adv_gumbel':
				fake_features["input_ids"] = noise_dist_dict['gumbel_probs']
				tf.logging.info("****** using samples with gradient *******")
			fake_features['input_mask'] = tf.cast(noise_dist_dict['fake_mask'], tf.int32)
			fake_features['segment_ids'] = tf.zeros_like(fake_features['input_mask'])
		elif noise_sample == 'mlm':
			fake_features["input_ids"] = mlm_noise_dist_dict['sampled_ids']
			fake_features['input_mask'] = tf.cast(features['input_mask'], tf.int32)
			fake_features['segment_ids'] = tf.zeros_like(features['input_mask'])
			tf.logging.info("****** using bert mlm stop gradient *******")

		# second, get true ebm dict
		true_ebm_dist_dict = ebm_dist_fn(true_features, labels, mode, params)
		fake_ebm_dist_dict = ebm_dist_fn(fake_features, labels, mode, params)
		if not sample_noise_dist:
			fake_noise_dist_dict = noise_dist_fn(fake_features, labels, mode, params)
			noise_dist_dict['fake_logits'] = fake_noise_dist_dict['true_logits']

		[ebm_loss, 
		ebm_all_true_loss,
		ebm_all_fake_loss] = get_ebm_loss(true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'], 
								use_tpu=kargs.get('use_tpu', False),
								valid_mask=mlm_noise_dist_dict.get("valid_mask", None))

		logz_length_true_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'],
															true_features,
															ebm_all_true_loss,
															valid_mask=mlm_noise_dist_dict.get("valid_mask", None))

		logz_length_fake_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'],
															fake_features,
															ebm_all_fake_loss,
															valid_mask=mlm_noise_dist_dict.get("valid_mask", None))
		true_ebm_dist_dict['logz_loss'] = logz_length_true_loss + logz_length_fake_loss

		noise_loss = get_noise_loss(true_ebm_dist_dict['logits'], 
									noise_dist_dict['true_logits'], 
									fake_ebm_dist_dict['logits'], 
									noise_dist_dict['fake_logits'], 
									noise_loss_type=kargs.get('noise_loss_type', 'jsd_noise'),
									num_train_steps=opt_config.num_train_steps,
									num_warmup_steps=opt_config.num_warmup_steps,
									use_tpu=kargs.get('use_tpu', False),
									loss_mask=features['input_mask'],
									prob_ln=noise_prob_ln)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = []
		loss = ebm_loss
		tvars.extend(true_ebm_dist_dict['tvars'])

		if kargs.get('joint_train', '1') == '1':
			tf.logging.info("****** joint generator and discriminator training *******")
			tvars.extend(noise_dist_dict['tvars'])
			loss += noise_loss
		tvars = list(set(tvars))

		ebm_opt_dict = {
			"loss":ebm_loss,
			"tvars":true_ebm_dist_dict['tvars'],
			"logz_tvars":true_ebm_dist_dict['logz_tvars'],
			"logz_loss":true_ebm_dist_dict['logz_loss']
		}

		noise_opt_dict = {
			"loss":noise_loss,
			"tvars":noise_dist_dict['tvars']
		}

		var_checkpoint_dict_list = []
		for key in init_checkpoint_dict:
			if load_pretrained_dict[key] == "yes":
				if key == 'ebm_dist':
					tmp = {
							"tvars":ebm_opt_dict['tvars']+ebm_opt_dict['logz_tvars'],
							"init_checkpoint":init_checkpoint_dict['ebm_dist'],
							"exclude_scope":exclude_scope_dict[key],
							"restore_var_name":model_config_dict['ebm_dist'].get('restore_var_name', [])
					}
					if kargs.get("sharing_mode", "none") != "none":
						tmp['exclude_scope'] = ''
					var_checkpoint_dict_list.append(tmp)
				elif key == 'noise_dist':
					tmp = {
							"tvars":noise_opt_dict['tvars'],
							"init_checkpoint":init_checkpoint_dict['noise_dist'],
							"exclude_scope":exclude_scope_dict[key],
							"restore_var_name":model_config_dict['noise_dist'].get('restore_var_name', [])
					}
					var_checkpoint_dict_list.append(tmp)
				elif key == 'generator':
					if not sample_noise_dist:
						tmp = {
								"tvars":mlm_noise_dist_dict['tvars'],
								"init_checkpoint":init_checkpoint_dict['generator'],
								"exclude_scope":exclude_scope_dict[key],
								"restore_var_name":model_config_dict['generator'].get('restore_var_name', [])
						}
						if kargs.get("sharing_mode", "none") != "none":
							tmp['exclude_scope'] = ''
						var_checkpoint_dict_list.append(tmp)

		use_tpu = 1 if kargs.get('use_tpu', False) else 0
			
		if len(var_checkpoint_dict_list) >= 1:
			scaffold_fn = model_io_fn.load_multi_pretrained(
											var_checkpoint_dict_list,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None

		if mode == tf.estimator.ModeKeys.TRAIN:

			metric_dict = ebm_noise_train_metric(
										true_ebm_dist_dict['logits'], 
										noise_dist_dict['true_logits'], 
										fake_ebm_dist_dict['logits'], 
										noise_dist_dict['fake_logits'],
										features['input_ori_ids'],
										tf.cast(features['input_mask'], tf.float32),
										noise_dist_dict["true_seq_logits"],
										prob_ln=noise_prob_ln,
										)

			if not kargs.get('use_tpu', False):
				for key in metric_dict:
					tf.summary.scalar(key, metric_dict[key])
				tf.summary.scalar("ebm_loss", ebm_opt_dict['loss'])
				tf.summary.scalar("noise_loss", noise_opt_dict['loss'])
	
			if kargs.get('use_tpu', False):
				optimizer_fn = optimizer.Optimizer(opt_config)
				use_tpu = 1
			else:
				optimizer_fn = distributed_optimizer.Optimizer(opt_config)
				use_tpu = 0

			model_io_fn.print_params(tvars, string=", trainable params")

			train_op = get_train_op(ebm_opt_dict, noise_opt_dict, 
								optimizer_fn, opt_config,
								model_config_dict['ebm_dist'], 
								model_config_dict['noise_dist'],
								use_tpu=use_tpu, 
								train_op_type=train_op_type,
								fce_acc=metric_dict['all_accuracy'])
			
			# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			# with tf.control_dependencies(update_ops):
			# 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
			# 					opt_config.init_lr, 
			# 					opt_config.num_train_steps,
			# 					use_tpu=use_tpu)

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
								mode=mode,
								loss=loss,
								train_op=train_op,
								scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(
								mode=mode, 
								loss=loss, 
								train_op=train_op)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			tpu_eval_metrics = (ebm_noise_eval_metric, 
								[
								true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								noise_dist_dict["true_seq_logits"]
								])
			gpu_eval_metrics = ebm_noise_eval_metric(
								true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								noise_dist_dict["true_seq_logits"]
								)

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							  mode=mode,
							  loss=loss,
							  eval_metrics=tpu_eval_metrics,
							  scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=gpu_eval_metrics)

			return estimator_spec
		else:
			raise NotImplementedError()
예제 #13
0
    def model_fn(features, labels, mode):

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = bert.Bert(model_config)
        model.build_embedder(input_ids,
                             segment_ids,
                             hidden_dropout_prob,
                             attention_probs_dropout_prob,
                             reuse=reuse)
        model.build_encoder(input_ids,
                            input_mask,
                            hidden_dropout_prob,
                            attention_probs_dropout_prob,
                            reuse=reuse)
        model.build_pooler(reuse=reuse)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

        print(logits.get_shape(), "===logits shape===")
        pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
        prob = tf.nn.softmax(logits)
        max_prob = tf.reduce_max(prob, axis=-1)

        output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                 predictions={
                                                     'pred_label': pred_label,
                                                     "label_ids": label_ids,
                                                     "max_prob": max_prob
                                                 })
        return output_spec
예제 #14
0
	def model_fn(features, labels, mode):
		print(features)
		input_ids = features["input_ids"]
		input_mask = features["input_mask"]
		segment_ids = features["segment_ids"]
		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			hidden_dropout_prob = model_config.hidden_dropout_prob
			attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
			dropout_prob = model_config.dropout_prob
		else:
			hidden_dropout_prob = 0.0
			attention_probs_dropout_prob = 0.0
			dropout_prob = 0.0

		model = bert.Bert(model_config)
		model.build_embedder(input_ids, segment_ids,
											hidden_dropout_prob,
											attention_probs_dropout_prob,
											reuse=reuse)
		model.build_encoder(input_ids,
											input_mask,
											hidden_dropout_prob, 
											attention_probs_dropout_prob,
											reuse=reuse)
		model.build_pooler(reuse=reuse)

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											model.get_pooled_output(),
											num_labels,
											label_ids,
											dropout_prob)

		# model_io_fn = model_io.ModelIO(model_io_config)
		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		if load_pretrained:
			model_io_fn.load_pretrained(pretrained_tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		tvars = pretrained_tvars
		model_io_fn.set_saver(var_lst=tvars)

		if mode == tf.estimator.ModeKeys.TRAIN:
			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				optimizer_fn = optimizer.Optimizer(opt_config)
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps)

				return [train_op, loss, per_example_loss, logits]
		else:
			model_io_fn.print_params(tvars, string=", trainable params")
			return [loss, loss, per_example_loss, logits]
예제 #15
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]
        sequence_mask = tf.cast(
            tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)),
            tf.int32)
        features['input_mask'] = sequence_mask

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        if 'input_ori_ids' in features:
            seq_features['input_ids'] = features["input_ori_ids"]
        else:
            features['input_ori_ids'] = seq_features['input_ids']

        not_equal = tf.cast(
            tf.not_equal(features["input_ori_ids"],
                         tf.zeros_like(features["input_ori_ids"])), tf.int32)
        not_equal = tf.reduce_sum(not_equal, axis=-1)
        loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)),
                            tf.float32)

        if not kargs.get('use_tpu', False):
            tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask))

        casual_flag = model_config.get('is_casual', True)
        tf.logging.info("***** is casual flag *****", str(casual_flag))

        if not casual_flag:
            [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                model_config,
                features['input_ori_ids'],
                features['input_mask'], [
                    tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                    for hmm_tran_prob in hmm_tran_prob_list
                ],
                mask_probability=0.02,
                replace_probability=0.01,
                original_probability=0.01,
                mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                **kargs)
            tf.logging.info("***** apply random sampling *****")
            seq_features['input_ids'] = output_ids

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        # if mode == tf.estimator.ModeKeys.TRAIN:
        if kargs.get('mask_type', 'left2right') == 'left2right':
            tf.logging.info("***** using left2right mask and loss *****")
            sequence_mask = tf.to_float(
                tf.not_equal(features['input_ori_ids'][:, 1:],
                             kargs.get('[PAD]', 0)))
        elif kargs.get('mask_type', 'left2right') == 'seq2seq':
            tf.logging.info("***** using seq2seq mask and loss *****")
            sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
            if not kargs.get('use_tpu', False):
                tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask))

        # batch x seq_length
        if casual_flag:
            print(model.get_sequence_output_logits().get_shape(),
                  "===logits shape===")
            seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=features['input_ori_ids'][:, 1:],
                logits=model.get_sequence_output_logits()[:, :-1])

            per_example_loss = tf.reduce_sum(
                seq_loss * sequence_mask,
                axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
            loss = tf.reduce_mean(per_example_loss)

            if model_config.get("cnn_type",
                                "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']:
                seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ori_ids'][:, :-1],
                    logits=model.get_sequence_backward_output_logits()[:, 1:])

                per_backward_example_loss = tf.reduce_sum(
                    seq_backward_loss * sequence_mask,
                    axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
                backward_loss = tf.reduce_mean(per_backward_example_loss)
                loss += backward_loss
                tf.logging.info("***** using backward loss *****")
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = pretrain.seq_mask_masked_lm_output(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 seq_features['input_mask'],
                 seq_features['input_ori_ids'],
                 seq_features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            loss = masked_lm_loss
            tf.logging.info("***** using masked lm loss *****")
        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                # train_metric_dict = train_metric(features['input_ori_ids'],
                # 								model.get_sequence_output_logits(),
                # 								seq_features,
                # 								**kargs)

                # if not kargs.get('use_tpu', False):
                # 	for key in train_metric_dict:
                # 		tf.summary.scalar(key, train_metric_dict[key])
                # 	tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)
                # 	tf.logging.info("***** logging metric *****")
                # 	tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask))
                # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits(),
                                           seq_features, **kargs)
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits(), seq_features,
                kargs.get('mask_type', 'left2right')
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            if kargs.get('predict_type',
                         'sample_sequence') == 'sample_sequence':
                results = bert_seq_sample_utils.sample_sequence(
                    model_api,
                    model_config,
                    mode,
                    features,
                    target="",
                    start_token=kargs.get("start_token_id", 101),
                    batch_size=None,
                    context=features.get("context", None),
                    temperature=kargs.get("sample_temp", 1.0),
                    n_samples=kargs.get("n_samples", 1),
                    top_k=0,
                    end_token=kargs.get("end_token_id", 102),
                    greedy_or_sample="greedy",
                    gumbel_temp=0.01,
                    estimator="stop_gradient",
                    back_prop=True,
                    swap_memory=True,
                    seq_type=kargs.get("seq_type", "seq2seq"),
                    mask_type=kargs.get("mask_type", "seq2seq"),
                    attention_type=kargs.get('attention_type',
                                             'normal_attention'))
                # stop_gradient output:
                # samples, mask_sequence, presents, logits, final

                sampled_token = results['samples']
                sampled_token_logits = results['logits']
                mask_sequence = results['mask_sequence']

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': sampled_token,
                        "logits": sampled_token_logits,
                        "mask_sequence": mask_sequence
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            sampled_token,
                            "logits":
                            sampled_token_logits,
                            "mask_sequence":
                            mask_sequence
                        })
                    })

                return estimator_spec

            elif kargs.get('predict_type',
                           'sample_sequence') == 'infer_inputs':

                sequence_mask = tf.to_float(
                    tf.not_equal(features['input_ids'][:, 1:],
                                 kargs.get('[PAD]', 0)))

                if kargs.get('mask_type', 'left2right') == 'left2right':
                    tf.logging.info(
                        "***** using left2right mask and loss *****")
                    sequence_mask = tf.to_float(
                        tf.not_equal(features['input_ori_ids'][:, 1:],
                                     kargs.get('[PAD]', 0)))
                elif kargs.get('mask_type', 'left2right') == 'seq2seq':
                    tf.logging.info("***** using seq2seq mask and loss *****")
                    sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
                    if not kargs.get('use_tpu', False):
                        tf.summary.scalar("loss mask",
                                          tf.reduce_mean(sequence_mask))

                output_logits = model.get_sequence_output_logits()[:, :-1]
                # output_logits = tf.nn.log_softmax(output_logits, axis=-1)

                output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ids'][:, 1:], logits=output_logits)

                per_example_perplexity = tf.reduce_sum(output_id_logits *
                                                       sequence_mask,
                                                       axis=-1)  # batch
                per_example_perplexity /= tf.reduce_sum(sequence_mask,
                                                        axis=-1)  # batch

                perplexity = tf.exp(per_example_perplexity)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': features['input_ids'][:, 1:],
                        "logits": output_id_logits,
                        'perplexity': perplexity,
                        # "all_logits":output_logits
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            features['input_ids'][:, 1:],
                            "logits":
                            output_id_logits,
                            'perplexity':
                            perplexity,
                            # "all_logits":output_logits
                        })
                    })

                return estimator_spec
        else:
            raise NotImplementedError()
예제 #16
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		masked_lm_positions = features["masked_lm_positions"]
		masked_lm_ids = features["masked_lm_ids"]
		masked_lm_weights = features["masked_lm_weights"]

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")

		(masked_lm_loss,
		masked_lm_example_loss, 
		masked_lm_log_probs,
		masked_lm_mask) = masked_lm_fn(
										model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										masked_lm_positions, 
										masked_lm_ids, 
										masked_lm_weights,
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table())
		print(model_config.lm_ratio, '==mlm lm_ratio==')
		loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss
		
		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)

		if load_pretrained == "yes":
			scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=1)
		else:
			scaffold_fn = None
                print("******* scaffold fn *******", scaffold_fn)
		if mode == tf.estimator.ModeKeys.TRAIN:
						
			optimizer_fn = optimizer.Optimizer(opt_config)
						
			tvars = pretrained_tvars
			model_io_fn.print_params(tvars, string=", trainable params")
			
			# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			# with tf.control_dependencies(update_ops):
			print('==gpu count==', opt_config.get('gpu_count', 1))

			train_op = optimizer_fn.get_train_op(loss, tvars,
							opt_config.init_lr, 
							opt_config.num_train_steps,
							use_tpu=opt_config.use_tpu)

			train_metric_dict = train_metric_fn(
					masked_lm_example_loss, masked_lm_log_probs, 
					masked_lm_ids,
					masked_lm_weights, 
					nsp_per_example_loss,
					nsp_log_prob, 
					features['next_sentence_labels'],
					masked_lm_mask=masked_lm_mask
				)

			# for key in train_metric_dict:
			# 	tf.summary.scalar(key, train_metric_dict[key])
			# tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							mode=mode,
							loss=loss,
							train_op=train_op,
							scaffold_fn=scaffold_fn)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
					masked_lm_weights, next_sentence_example_loss,
					next_sentence_log_probs, next_sentence_labels):
				"""Computes the loss and accuracy of the model."""
				masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
												 [-1, masked_lm_log_probs.shape[-1]])
				masked_lm_predictions = tf.argmax(
					masked_lm_log_probs, axis=-1, output_type=tf.int32)
				masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
				masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
				masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
				masked_lm_accuracy = tf.metrics.accuracy(
					labels=masked_lm_ids,
					predictions=masked_lm_predictions,
					weights=masked_lm_weights)
				masked_lm_mean_loss = tf.metrics.mean(
					values=masked_lm_example_loss, weights=masked_lm_weights)

				next_sentence_log_probs = tf.reshape(
					next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
				next_sentence_predictions = tf.argmax(
					next_sentence_log_probs, axis=-1, output_type=tf.int32)
				next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
				next_sentence_accuracy = tf.metrics.accuracy(
					labels=next_sentence_labels, predictions=next_sentence_predictions)
				next_sentence_mean_loss = tf.metrics.mean(
					values=next_sentence_example_loss)

				return {
					"masked_lm_accuracy": masked_lm_accuracy,
					"masked_lm_loss": masked_lm_mean_loss,
					"next_sentence_accuracy": next_sentence_accuracy,
					"next_sentence_loss": next_sentence_mean_loss
					}

			eval_metrics = (metric_fn, [
			  masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
			  masked_lm_weights, nsp_per_example_loss,
			  nsp_log_prob, features['next_sentence_labels']
			])

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
						  mode=mode,
						  loss=loss,
						  eval_metrics=eval_metrics,
						  scaffold_fn=scaffold_fn)

			return estimator_spec
		else:
			raise NotImplementedError()
예제 #17
0
파일: distillation.py 프로젝트: P79N6A/BERT
    def model_fn(features, labels, mode):
        labeled_student_model = student.model_builder_fn(
            model_config_dict["student"],
            num_labels,
            init_checkpoint_dict["student"],
            model_reuse=model_reuse,
            load_pretrained=load_pretrained,
            model_io_fn=model_io_fn,
            model_io_config=model_io_config,
            opt_config=opt_config,
            input_name=student_input_name,
            temperature=temperature,
            exclude_scope=exclude_scope_dict["student"],
            not_storage_params=not_storage_params)

        [loss, per_example_loss, logits,
         temperature_log_prob] = labeled_student_model(features, labels, mode)
        train_loss = loss * distillation_weight.get("true_label", 1.0)
        tf.logging.info(" build student model ")
        if mode == tf.estimator.ModeKeys.TRAIN:
            tf.logging.info(" build teacher model for training ")
            labeled_teacher_model = teacher.model_builder_fn(
                model_config_dict["teacher"],
                num_labels,
                init_checkpoint_dict["teacher"],
                model_reuse=None,
                load_pretrained=load_pretrained,
                model_io_fn=model_io_fn,
                model_io_config=model_io_config,
                opt_config=opt_config,
                input_name=teacher_input_name,
                temperature=temperature,
                exclude_scope=exclude_scope_dict["teacher"],
                not_storage_params=not_storage_params)

            [tloss, tper_example_loss, tlogits, ttemperature_log_prob
             ] = labeled_teacher_model(features, labels, mode)

            cross_entropy = temperature_log_prob * tf.stop_gradient(
                tf.exp(ttemperature_log_prob))
            print("===size of cross entropy===", cross_entropy.get_shape())
            distillation_loss = -tf.reduce_sum(cross_entropy, axis=-1)

            distillation_loss = tf.reduce_mean(distillation_loss)
            train_loss += distillation_weight["label"] * distillation_loss

            if if_distill_unlabeled:
                tf.logging.info(
                    " build unlabeled student and teacher model for training ")
                unlabeled_student_model = student.model_builder_fn(
                    model_config_dict["student"],
                    num_labels,
                    init_checkpoint_dict["student"],
                    model_reuse=True,
                    load_pretrained=load_pretrained,
                    model_io_fn=model_io_fn,
                    model_io_config=model_io_config,
                    opt_config=opt_config,
                    input_name=unlabel_input_name,
                    temperature=temperature,
                    exclude_scope=exclude_scope_dict["student"],
                    not_storage_params=not_storage_params)

                [uloss, uper_example_loss, ulogits, utemperature_log_prob
                 ] = unlabeled_student_model(features, labels, mode)

                unlabeled_teacher_model = teacher.model_builder_fn(
                    model_config_dict["teacher"],
                    num_labels,
                    init_checkpoint_dict["teacher"],
                    model_reuse=True,
                    load_pretrained=load_pretrained,
                    model_io_fn=model_io_fn,
                    model_io_config=model_io_config,
                    opt_config=opt_config,
                    input_name=unlabel_input_name,
                    temperature=temperature,
                    exclude_scope=exclude_scope_dict["teacher"],
                    not_storage_params=not_storage_params)

                [utloss, utper_example_loss, utlogits, uttemperature_log_prob
                 ] = unlabeled_teacher_model(features, labels, mode)

                cross_entropy = utemperature_log_prob * tf.stop_gradient(
                    tf.exp(uttemperature_log_prob))
                unlabeled_distillation_loss = -tf.reduce_sum(cross_entropy,
                                                             axis=-1)
                unlabeled_distillation_loss = tf.reduce_mean(
                    unlabeled_distillation_loss)

                train_loss += distillation_weight[
                    "unlabel"] * unlabeled_distillation_loss

            teacher_pretrained_tvars = model_io_fn.get_params(
                model_config_dict["teacher"].scope,
                not_storage_params=not_storage_params)
            if load_pretrained.get("teacher", True):
                tf.logging.info(" load pre-trained teacher model ")
                model_io_fn.load_pretrained(
                    teacher_pretrained_tvars,
                    init_checkpoint_dict["teacher"],
                    exclude_scope=exclude_scope_dict["teacher"])

        student_pretrained_tvars = model_io_fn.get_params(
            model_config_dict["student"].scope,
            not_storage_params=not_storage_params)
        if load_pretrained.get("student", True):
            tf.logging.info(" load pre-trained student model ")
            model_io_fn.load_pretrained(
                student_pretrained_tvars,
                init_checkpoint_dict["student"],
                exclude_scope=exclude_scope_dict["student"])

        tvars = student_pretrained_tvars
        model_io_fn.set_saver(var_lst=tvars)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    train_loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]