def compile_model(self, optimizer_name, optimizer_args, rdrop_alpha=None): logger.info("compiling model...") with self.get_scope(): token_output = Input(shape=(None, ), name='token_output', dtype=tf.int32) self.train_model = Model(self.nn_model.inputs + [token_output], self.nn_model.output, name="train_model") output = self.train_model.output loss_mask = Lambda( function=lambda x: tf.cast(tf.not_equal(x, 0), tf.float32), name="pred_mask")(token_output) loss_layer = build_classify_loss_layer(multi_label=False, with_mask=True) loss = loss_layer([token_output, output, loss_mask]) self.train_model.add_loss(loss) accuracy_func = masked_sparse_categorical_accuracy metric_layer = MetricLayer(accuracy_func, name="metric_layer") accuracy = metric_layer([token_output, output, loss_mask]) self.train_model.add_metric(accuracy, aggregation="mean", name="accuracy") optimizer = OptimizerFactory.create(optimizer_name, optimizer_args) self.train_model.compile(optimizer=optimizer) logger.info("training model's summary:") self.train_model.summary(print_fn=logger.info) self._update_model_dict("train", self.train_model)
def compile_model(self, optimizer_name: str, optimizer_args: dict): logger.info( f"compile model with optimizer_name:{optimizer_name}, optimizer_args:{optimizer_args}" ) with self.get_scope(): classify_output = Input(shape=(self.label_num, None, None), dtype=tf.float32, name='classify_output') token_ids, segment_ids = self.nn_model.inputs output = self.nn_model([token_ids, segment_ids]) self.train_model = Model( inputs=[token_ids, segment_ids, classify_output], outputs=[output]) loss_layer = LossLayer(loss_func=global_pointer_crossentropy, name="loss_layer") loss = loss_layer([classify_output, output]) self.train_model.add_loss(loss) accuracy_func = global_pointer_f1_score metric_layer = MetricLayer(accuracy_func, name="metric_layer") metric = metric_layer([classify_output, output]) self.train_model.add_metric(metric, aggregation="mean", name="global_pointer_f1_score") optimizer = OptimizerFactory.create(optimizer_name, optimizer_args) self.train_model.compile(optimizer=optimizer) logger.info("training model's summary:") self.train_model.summary(print_fn=logger.info) self._update_model_dict("train", self.train_model)
def compile_model(self, optimizer_name, optimizer_args, rdrop_alpha=None): logger.info("compiling model...") with self.get_scope(): classify_output = Input(shape=(self.label_num,) if self.multi_label else (), name='classify_output', dtype=tf.float32) inputs = self.nn_model.inputs output = self.nn_model.output loss_input = [classify_output, output] if rdrop_alpha: output1 = self.nn_model(inputs) loss_input.append(output1) output = Lambda(function=lambda x: sum(x) / len(x), name="avg_pool_layer")([output, output1]) self.train_model = Model(inputs + [classify_output], output, name="train_model") loss_layer = build_classify_loss_layer(multi_label=self.multi_label, rdrop_alpha=rdrop_alpha) loss = loss_layer(loss_input) self.train_model.add_loss(loss) accuracy_func = binary_accuracy if self.multi_label else sparse_categorical_accuracy metric_layer = MetricLayer(accuracy_func, name="metric_layer") accuracy = metric_layer([classify_output, output]) self.train_model.add_metric(accuracy, aggregation="mean", name="accuracy") optimizer = OptimizerFactory.create(optimizer_name, optimizer_args) self.train_model.compile(optimizer=optimizer) logger.info("training model's summary:") self.train_model.summary(print_fn=logger.info) self._update_model_dict("train", self.train_model)
def compile_model(self, optimizer_name: str, optimizer_args: dict, **kwargs): logger.info( f"compile model with optimizer_name:{optimizer_name}, optimizer_args:{optimizer_args}" ) with self.get_scope(): classify_labels = Input( shape=(None, self.label_num) if self.multi_label else (None, ), name='classify_labels', dtype=tf.int32) token_ids, segment_ids = self.nn_model.inputs output = self.nn_model([token_ids, segment_ids]) self.train_model = Model( inputs=[token_ids, segment_ids, classify_labels], outputs=[output]) loss_mask = Lambda( function=lambda x: tf.cast(tf.not_equal(x, 0), tf.float32), name="pred_mask")(token_ids) # 计算loss的时候,过滤掉pad token的loss loss_layer = build_classify_loss_layer(multi_label=self.multi_label, with_mask=True) loss = loss_layer([classify_labels, output, loss_mask]) self.train_model.add_loss(loss) # 计算accuracy的时候,过滤掉pad token 的accuracy masked_accuracy_func = masked_binary_accuracy if self.multi_label else masked_sparse_categorical_accuracy metric_layer = MetricLayer(masked_accuracy_func) masked_accuracy = metric_layer([classify_labels, output, loss_mask]) self.train_model.add_metric(masked_accuracy, aggregation="mean", name="accuracy") optimizer = OptimizerFactory.create(optimizer_name, optimizer_args) self.train_model.compile(optimizer=optimizer) logger.info("training model's summary:") self.train_model.summary(print_fn=logger.info) self._update_model_dict("train", self.train_model)
def compile_model(self, optimizer_name: str, optimizer_args: str, **kwargs): logger.info( f"compile model with optimizer_name:{optimizer_name}, optimizer_args:{optimizer_args}" ) with self.get_scope(): input_ids, segment_ids = self.nn_model.inputs[:2] prob_vec = self.nn_model(self.nn_model.inputs) self.train_model = Model(inputs=self.nn_model.inputs, outputs=prob_vec) target_token_ids = Lambda(lambda x: x[:, 1:], name="target_tokens")(input_ids) prob_vec = Lambda(lambda x: x[:, :-1], name="prob_vec")(prob_vec) loss_mask = Lambda(lambda x: x[:, 1:], name="loss_mask")(segment_ids) loss_layer = build_classify_loss_layer(multi_label=False, with_mask=True) loss = loss_layer([target_token_ids, prob_vec, loss_mask]) self.train_model.add_loss(loss) accuracy_func = masked_sparse_categorical_accuracy metric_layer = MetricLayer(accuracy_func, name="metric_layer") accuracy = metric_layer([target_token_ids, prob_vec, loss_mask]) self.train_model.add_metric(accuracy, aggregation="mean", name="accuracy") optimizer = OptimizerFactory.create(optimizer_name, optimizer_args) self.train_model.compile(optimizer=optimizer) logger.info("training model's summary:") self.train_model.summary(print_fn=logger.info) self._update_model_dict("train", self.train_model) self._build_gen_model()