コード例 #1
0
 def calculate_loss_from_wals_model(self, wals_model, sp_inputs):
   current_rows = embedding_ops.embedding_lookup(
       wals_model.row_factors, math_ops.range(wals_model._input_rows),
       partition_strategy="div")
   current_cols = embedding_ops.embedding_lookup(
       wals_model.col_factors, math_ops.range(wals_model._input_cols),
       partition_strategy="div")
   row_wts = embedding_ops.embedding_lookup(
       wals_model._row_weights, math_ops.range(wals_model._input_rows),
       partition_strategy="div")
   col_wts = embedding_ops.embedding_lookup(
       wals_model._col_weights, math_ops.range(wals_model._input_cols),
       partition_strategy="div")
   return factorization_ops_test_utils.calculate_loss(
       sp_inputs, current_rows, current_cols, wals_model._regularization,
       wals_model._unobserved_weight, row_wts, col_wts)
コード例 #2
0
ファイル: wals_test.py プロジェクト: 568xiaoma/WeChatProgram
 def calculate_loss(self):
   """Calculates the loss of the current (trained) model."""
   current_rows = embedding_ops.embedding_lookup(
       self._model.get_row_factors(), math_ops.range(self._num_rows),
       partition_strategy='div')
   current_cols = embedding_ops.embedding_lookup(
       self._model.get_col_factors(), math_ops.range(self._num_cols),
       partition_strategy='div')
   row_wts = embedding_ops.embedding_lookup(
       self._row_weights, math_ops.range(self._num_rows),
       partition_strategy='div')
   col_wts = embedding_ops.embedding_lookup(
       self._col_weights, math_ops.range(self._num_cols),
       partition_strategy='div')
   sp_inputs = self.np_array_to_sparse(self.INPUT_MATRIX)
   return factorization_ops_test_utils.calculate_loss(
       sp_inputs, current_rows, current_cols, self._regularization_coeff,
       self._unobserved_weight, row_wts, col_wts)
コード例 #3
0
 def calculate_loss_from_wals_model(self, wals_model, sp_inputs):
     current_rows = embedding_ops.embedding_lookup(
         wals_model.row_factors,
         math_ops.range(wals_model._input_rows),
         partition_strategy="div")
     current_cols = embedding_ops.embedding_lookup(
         wals_model.col_factors,
         math_ops.range(wals_model._input_cols),
         partition_strategy="div")
     row_wts = embedding_ops.embedding_lookup(wals_model._row_weights,
                                              math_ops.range(
                                                  wals_model._input_rows),
                                              partition_strategy="div")
     col_wts = embedding_ops.embedding_lookup(wals_model._col_weights,
                                              math_ops.range(
                                                  wals_model._input_cols),
                                              partition_strategy="div")
     return factorization_ops_test_utils.calculate_loss(
         sp_inputs, current_rows, current_cols, wals_model._regularization,
         wals_model._unobserved_weight, row_wts, col_wts)