def _tree_train_op_fn(loss): """Returns the op to optimize the loss.""" if dnn_to_tree_distillation_param: loss_weight, loss_fn = dnn_to_tree_distillation_param weight_tensor = head_lib._weight_tensor( # pylint: disable=protected-access features, head.weight_column_name) dnn_logits_fixed = array_ops.stop_gradient(dnn_logits) if loss_fn is None: # we create the loss_fn similar to the head loss_fn for # multi_class_head used previously as the default one. n_classes = 2 if head.logits_dimension == 1 else head.logits_dimension loss_fn = distillation_loss.create_dnn_to_tree_cross_entropy_loss_fn( n_classes) dnn_to_tree_distillation_loss = loss_weight * loss_fn( dnn_logits_fixed, tree_logits, weight_tensor) summary.scalar("dnn_to_tree_distillation_loss", dnn_to_tree_distillation_loss) loss += dnn_to_tree_distillation_loss update_op = gbdt_model.train(loss, predictions_dict, labels) with ops.control_dependencies( [update_op]), (ops.colocate_with(global_step)): update_op = state_ops.assign_add(global_step, 1).op return update_op