def setUp(self): self._num_rows = 5 self._num_cols = 7 self._embedding_dimension = 3 self._unobserved_weight = 0.1 self._num_row_shards = 2 self._num_col_shards = 3 self._regularization_coeff = 0.01 self._col_init = [ # Shard 0. [[-0.36444709, -0.39077035, -0.32528427], [1.19056475, 0.07231052, 2.11834812], [0.93468881, -0.71099287, 1.91826844]], # Shard 1. [[1.18160152, 1.52490723, -0.50015002], [1.82574749, -0.57515913, -1.32810032]], # Shard 2. [[-0.15515432, -0.84675711, 0.13097958], [-0.9246484, 0.69117504, 1.2036494]], ] self._row_weights = [[0.1, 0.2, 0.3], [0.4, 0.5]] self._col_weights = [[0.1, 0.2, 0.3], [0.4, 0.5], [0.6, 0.7]] # Values of row and column factors after running one iteration or factor # updates. self._row_factors_0 = [[0.097689, -0.219293, -0.020780], [0.50842, 0.64626, 0.22364], [0.401159, -0.046558, -0.192854]] self._row_factors_1 = [[1.20597, -0.48025, 0.35582], [1.5564, 1.2528, 1.0528]] self._col_factors_0 = [[2.4725, -1.2950, -1.9980], [0.44625, 1.50771, 1.27118], [1.39801, -2.10134, 0.73572]] self._col_factors_1 = [[3.36509, -0.66595, -3.51208], [0.57191, 1.59407, 1.33020]] self._col_factors_2 = [[3.3459, -1.3341, -3.3008], [0.57366, 1.83729, 1.26798]] self._model = wals_lib.WALSMatrixFactorization( self._num_rows, self._num_cols, self._embedding_dimension, self._unobserved_weight, col_init=self._col_init, regularization_coeff=self._regularization_coeff, num_row_shards=self._num_row_shards, num_col_shards=self._num_col_shards, row_weights=self._row_weights, col_weights=self._col_weights, max_sweeps=self.max_sweeps, use_factors_weights_cache_for_training=self.use_cache, use_gramian_cache_for_training=self.use_cache)
def testDistributedWALSUnsupported(self): tf_config = { 'cluster': { run_config_lib.TaskType.PS: ['host1:1', 'host2:2'], run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4'] }, 'task': { 'type': run_config_lib.TaskType.WORKER, 'index': 1 } } with test.mock.patch.dict('os.environ', {'TF_CONFIG': json.dumps(tf_config)}): config = run_config.RunConfig() self.assertEqual(config.num_worker_replicas, 2) with self.assertRaises(ValueError): self._model = wals_lib.WALSMatrixFactorization(1, 1, 1, config=config)