Exemplo n.º 1
0
  def setUp(self):
    self._num_rows = 5
    self._num_cols = 7
    self._embedding_dimension = 3
    self._unobserved_weight = 0.1
    self._num_row_shards = 2
    self._num_col_shards = 3
    self._regularization_coeff = 0.01
    self._col_init = [
        # Shard 0.
        [[-0.36444709, -0.39077035, -0.32528427],
         [1.19056475, 0.07231052, 2.11834812],
         [0.93468881, -0.71099287, 1.91826844]],
        # Shard 1.
        [[1.18160152, 1.52490723, -0.50015002],
         [1.82574749, -0.57515913, -1.32810032]],
        # Shard 2.
        [[-0.15515432, -0.84675711, 0.13097958],
         [-0.9246484, 0.69117504, 1.2036494]],
    ]
    self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]]
    self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]]

    # Values of row and column factors after running one iteration or factor
    # updates.
    self._row_factors_0 = [[0.097689, -0.219293, -0.020780],
                           [0.50842, 0.64626, 0.22364],
                           [0.401159, -0.046558, -0.192854]]
    self._row_factors_1 = [[1.20597, -0.48025, 0.35582],
                           [1.5564, 1.2528, 1.0528]]
    self._col_factors_0 = [[2.4725, -1.2950, -1.9980],
                           [0.44625, 1.50771, 1.27118],
                           [1.39801, -2.10134, 0.73572]]
    self._col_factors_1 = [[3.36509, -0.66595, -3.51208],
                           [0.57191, 1.59407, 1.33020]]
    self._col_factors_2 = [[3.3459, -1.3341, -3.3008],
                           [0.57366, 1.83729, 1.26798]]
    self._model = wals_lib.WALSMatrixFactorization(
        self._num_rows,
        self._num_cols,
        self._embedding_dimension,
        self._unobserved_weight,
        col_init=self._col_init,
        regularization_coeff=self._regularization_coeff,
        num_row_shards=self._num_row_shards,
        num_col_shards=self._num_col_shards,
        row_weights=self._row_weights,
        col_weights=self._col_weights,
        max_sweeps=self.max_sweeps,
        use_factors_weights_cache_for_training=self.use_cache,
        use_gramian_cache_for_training=self.use_cache)
Exemplo n.º 2
0
 def testDistributedWALSUnsupported(self):
   tf_config = {
       'cluster': {
           run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],
           run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']
       },
       'task': {
           'type': run_config_lib.TaskType.WORKER,
           'index': 1
       }
   }
   with test.mock.patch.dict('os.environ',
                             {'TF_CONFIG': json.dumps(tf_config)}):
     config = run_config.RunConfig()
   self.assertEqual(config.num_worker_replicas, 2)
   with self.assertRaises(ValueError):
     self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config)