def update_weights(self, train_op): """Updates the model weights. This function must be called on at least one worker after `minimize`. In distributed training this call can be omitted on non-chief workers to speed up training. Args: train_op: The operation returned by the `minimize` call. Returns: An Operation that updates the model weights. """ with ops.control_dependencies([train_op]): update_ops = [] # Copy over unshrinked weights to user provided variables. for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): update_ops.append(var.assign(slot_var)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: with ops.device(var.device): # pylint: disable=protected-access update_ops.append( gen_sdca_ops.sdca_shrink_l1( self._convert_n_to_tensor([var], as_ref=True), l1=self._symmetric_l1_regularization(), l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops)
def update_weights(self, train_op): """Updates the model weights. This function must be called on at least one worker after `minimize`. In distributed training this call can be omitted on non-chief workers to speed up training. Args: train_op: The operation returned by the `minimize` call. Returns: An Operation that updates the model weights. """ with ops.control_dependencies([train_op]): update_ops = [] # Copy over unshrinked weights to user provided variables. for name in ['sparse_features_weights', 'dense_features_weights']: for var, slot_var in zip(self._variables[name], self._slots['unshrinked_' + name]): update_ops.append(var.assign(slot_var)) # Apply proximal step. with ops.control_dependencies(update_ops): update_ops = [] for name in ['sparse_features_weights', 'dense_features_weights']: for var in self._variables[name]: with ops.device(var.device): # pylint: disable=protected-access update_ops.append( gen_sdca_ops.sdca_shrink_l1( self._convert_n_to_tensor( [var], as_ref=True), l1=self._symmetric_l1_regularization(), l2=self._symmetric_l2_regularization())) return control_flow_ops.group(*update_ops)