def _run_test_als_transposed(self, use_factors_weights_cache): with self.test_session(): col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=None, col_weights=None, use_factors_weights_cache=use_factors_weights_cache) als_model.initialize_op.run() als_model.worker_init.run() wals_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=[0] * 5, col_weights=[0] * 7, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() sp_feeder = tf.sparse_placeholder(tf.float32) # Here test partial row update with identical inputs but with transposed # input for als. sp_r_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1], transpose=True).eval() sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval() feed_dict = {sp_feeder: sp_r_t} als_model.row_update_prep_gramian_op.run() als_model.initialize_row_update_op.run() process_input_op = als_model.update_row_factors( sp_input=sp_feeder, transpose_input=True)[1] process_input_op.run(feed_dict=feed_dict) # Only updated row 1 and row 3, so only compare these rows since others # have randomly initialized values. row_factors1 = [ als_model.row_factors[0].eval()[1], als_model.row_factors[0].eval()[3] ] feed_dict = {sp_feeder: sp_r} wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) # Only updated row 1 and row 3, so only compare these rows since others # have randomly initialized values. row_factors2 = [ wals_model.row_factors[0].eval()[1], wals_model.row_factors[0].eval()[3] ] for r1, r2 in zip(row_factors1, row_factors2): self.assertAllClose(r1, r2, atol=1e-3)
def test_als(self): with self.test_session(): col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel(5, 7, 3, col_init=col_init, row_weights=None, col_weights=None) als_model.initialize_op.run() als_model.worker_init.run() als_model.initialize_row_update_op.run() process_input_op = als_model.update_row_factors( self._wals_inputs)[1] process_input_op.run() row_factors1 = [x.eval() for x in als_model.row_factors] wals_model = factorization_ops.WALSModel(5, 7, 3, col_init=col_init, row_weights=[0] * 5, col_weights=[0] * 7) wals_model.initialize_op.run() wals_model.worker_init.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( self._wals_inputs)[1] process_input_op.run() row_factors2 = [x.eval() for x in wals_model.row_factors] for r1, r2 in zip(row_factors1, row_factors2): self.assertAllClose(r1, r2, atol=1e-3) # Here we test partial column updates. sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0], shuffle=True).eval() sp_feeder = tf.sparse_placeholder(tf.float32) feed_dict = {sp_feeder: sp_c} als_model.initialize_col_update_op.run() process_input_op = als_model.update_col_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) col_factors1 = [x.eval() for x in als_model.col_factors] feed_dict = {sp_feeder: sp_c} wals_model.initialize_col_update_op.run() process_input_op = wals_model.update_col_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) col_factors2 = [x.eval() for x in wals_model.col_factors] for c1, c2 in zip(col_factors1, col_factors2): self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
def factorize(self, indices, values): import tensorflow as tf from tensorflow.contrib.factorization.python.ops import factorization_ops from tensorflow.python.framework import sparse_tensor rows = self.nb_users cols = self.nb_works dims = self.nb_components row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) inp = sparse_tensor.SparseTensor(indices, values, [rows, cols]) use_factors_weights_cache = True model = factorization_ops.WALSModel( rows, cols, dims, unobserved_weight=1, # .1, regularization=0.001, # 001, row_weights=None, # row_wts, col_weights=None, # col_wts, use_factors_weights_cache=use_factors_weights_cache) tf.InteractiveSession() simple_train(model, inp, 25) row_factor = model.row_factors[0].eval() self.U = row_factor col_factor = model.col_factors[0].eval() self.V = col_factor
def factorize(self, indices, values): from tensorflow.contrib.factorization.python.ops import factorization_ops from tensorflow.python.framework import sparse_tensor rows = self.nb_users cols = self.nb_works dims = self.NB_COMPONENTS row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) inp = sparse_tensor.SparseTensor(indices, values, [rows, cols]) use_factors_weights_cache = True model = factorization_ops.WALSModel( rows, cols, dims, unobserved_weight=1, # .1, regularization=0.001, # 001, row_weights=None, # row_wts, col_weights=None, # col_wts, use_factors_weights_cache=use_factors_weights_cache) simple_train(model, inp, 25) row_factor = model.row_factors[0].eval() print('Shape', row_factor.shape) col_factor = model.col_factors[0].eval() print('Shape', col_factor.shape) out = np.dot(row_factor, np.transpose(col_factor)) return out
def test_train_matrix_completion_wals(self): rows = 11 cols = 9 dims = 4 def keep_index(x): return not (x[0] + x[1]) % 4 with self.test_session(): row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) data = np.dot(np.random.rand(rows, 3), np.random.rand(3, cols)).astype(np.float32) / 3.0 indices = np.array( list(filter(keep_index, [[i, j] for i in xrange(rows) for j in xrange(cols)]))) values = data[indices[:, 0], indices[:, 1]] inp = tf.SparseTensor(indices, values, [rows, cols]) model = factorization_ops.WALSModel(rows, cols, dims, unobserved_weight=0.01, regularization=0.001, row_weights=row_wts, col_weights=col_wts) self.simple_train(model, inp, 10) row_factor = model.row_factors[0].eval() col_factor = model.col_factors[0].eval() out = np.dot(row_factor, np.transpose(col_factor)) for i in xrange(rows): for j in xrange(cols): if keep_index([i, j]): self.assertNear(data[i][j], out[i][j], err=0.2, msg="%d, %d" % (i, j)) else: self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
def _run_test_train_full_low_rank_wals(self, use_factors_weights_cache): rows = 15 cols = 11 dims = 3 with ops.Graph().as_default(), self.test_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] values = data.reshape(-1) inp = sparse_tensor.SparseTensor(indices, values, [rows, cols]) model = factorization_ops.WALSModel( rows, cols, dims, regularization=1e-5, row_weights=0, col_weights=[0] * cols, use_factors_weights_cache=use_factors_weights_cache) self.simple_train(model, inp, 25) row_factor = model.row_factors[0].eval() col_factor = model.col_factors[0].eval() self.assertAllClose(data, np.dot(row_factor, np.transpose(col_factor)), rtol=0.01, atol=0.01)
def wals_model(data, dim, reg, unobs, weights=False, wt_type=LINEAR_RATINGS, feature_wt_exp=None, obs_wt=LINEAR_OBS_W): """Create the WALSModel and input, row and col factor tensors. Args: data: scipy coo_matrix of item ratings dim: number of latent factors reg: regularization constant unobs: unobserved item weight weights: True: set obs weights, False: obs weights = unobs weights wt_type: feature weight type: linear (0) or log (1) feature_wt_exp: feature weight exponent constant obs_wt: feature weight linear factor constant Returns: input_tensor: tensor holding the input ratings matrix row_factor: tensor for row_factor col_factor: tensor for col_factor model: WALSModel instance """ row_wts = 1 col_wts = 1 num_rows = data.shape[0] num_cols = data.shape[1] if weights: assert feature_wt_exp is not None row_wts = np.ones(num_rows) col_wts = make_wts(data, wt_type, obs_wt, feature_wt_exp, 0) row_factor = None col_factor = None with tf.Graph().as_default(): input_tensor = tf.SparseTensor(indices=np.array([data.row, data.col]).T, values=(data.data).astype(np.float32), dense_shape=data.shape) model = factorization_ops.WALSModel( num_rows, num_cols, dim, unobserved_weight=unobs, regularization=reg, num_row_shards=1, # number of shards to use for row factors num_col_shards=1, # number of shards to use for column factors row_weights=row_wts, col_weights=col_wts) # retrieve the row and column factors row_factor = model.row_factors[0] col_factor = model.col_factors[0] return input_tensor, row_factor, col_factor, model
def define_graph(data,PARAMS): graph = tf.Graph() with graph.as_default(): input_tensor = tf.SparseTensor(indices=np.array([data.row, data.col]).T, values=(data.data).astype(np.float32), dense_shape=data.shape) row_wts = None col_wts = None num_rows = data.shape[0] num_cols = data.shape[1] # initialize the weights if PARAMS["wt_type"] in ["LOG_RATINGS","LINEAR_RATINGS"]: row_wts = np.ones(num_rows) col_wts = make_weights(data, PARAMS["wt_type"], PARAMS['feature_wt_factor'], PARAMS['feature_wt_exp'],axis=0) model = factorization_ops.WALSModel(num_rows, num_cols, PARAMS["latent_factors"], unobserved_weight=PARAMS["unobs_weight"], regularization=PARAMS["regularization"], row_weights=row_wts, col_weights=col_wts) return(graph,model,input_tensor)
def train_model(data): row_factor, col_factor = None, None wt_type = 0 num_rows, num_cols = data.shape # row_wts = np.ones(num_rows) # col_wts = None # a = (data > 0.0).sum(0) # print(a.shape) # # times_rated = np.array((data > 0).sum(0)) # print(times_rated) # if wt_type == implicit: # frac = [] # for i in times_rated: # if i != 0: # frac.append(1.0 / i) # else: # frac.append(0.0) # col_wts = np.array(np.power(frac, 0.08)).flatten() # else: # col_wts = np.array(100 * times_rated).flatten() with tf.Graph().as_default(): input_tensor = tf.SparseTensor(indices=list(zip(data.row, data.col)), values=(data.data).astype(np.float32), dense_shape=data.shape) model = factorization_ops.WALSModel(num_rows, num_cols, n_components=10, unobserved_weight=0.001, regularization=0.08, row_weights=None, col_weights=None) row_factor = model.row_factors[0] col_factor = model.col_factors[0] sess = tf.Session(graph=input_tensor.graph) with input_tensor.graph.as_default(): row_update_op = model.update_row_factors(sp_input=input_tensor)[1] col_update_op = model.update_col_factors(sp_input=input_tensor)[1] sess.run(model.initialize_op) sess.run(model.worker_init) for _ in range(20): sess.run(model.row_update_prep_gramian_op) sess.run(model.initialize_row_update_op) sess.run(row_update_op) sess.run(model.col_update_prep_gramian_op) sess.run(model.initialize_col_update_op) sess.run(col_update_op) output_row = row_factor.eval(session=sess) output_col = col_factor.eval(session=sess) sess.close() return output_row, output_col
def _run_test_sum_weights(self, test_rows): # test_rows: True to test row weights, False to test column weights. num_rows = 5 num_cols = 5 unobserved_weight = 0.1 row_weights = [[8., 18., 28., 38., 48.]] col_weights = [[90., 91., 92., 93., 94.]] sparse_indices = [[0, 1], [2, 3], [4, 1]] sparse_values = [666., 777., 888.] unobserved = unobserved_weight * num_rows * num_cols observed = 8. * 91. + 28. * 93. + 48. * 91. # sparse_indices has three unique rows and two unique columns observed *= num_rows / 3. if test_rows else num_cols / 2. want_weight_sum = unobserved + observed with ops.Graph().as_default(), self.test_session() as sess: wals_model = factorization_ops.WALSModel( input_rows=num_rows, input_cols=num_cols, n_components=5, unobserved_weight=unobserved_weight, row_weights=row_weights, col_weights=col_weights, use_factors_weights_cache=False) wals_model.initialize_op.run() wals_model.worker_init.run() update_factors = (wals_model.update_row_factors if test_rows else wals_model.update_col_factors) (_, _, _, _, sum_weights) = update_factors(sp_input=sparse_tensor.SparseTensor( indices=sparse_indices, values=sparse_values, dense_shape=[num_rows, num_cols]), transpose_input=False) got_weight_sum = sess.run(sum_weights) self.assertNear( got_weight_sum, want_weight_sum, err=.001, msg="got weight sum [{}], want weight sum [{}]".format( got_weight_sum, want_weight_sum))
def test_train_full_low_rank_als(self): rows = 15 cols = 11 dims = 3 with self.test_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand(3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] values = data.reshape(-1) inp = tf.SparseTensor(indices, values, [rows, cols]) model = factorization_ops.WALSModel(rows, cols, dims, regularization=1e-5, row_weights=None, col_weights=None) self.simple_train(model, inp, 15) row_factor = model.row_factors[0].eval() col_factor = model.col_factors[0].eval() self.assertAllClose(data, np.dot(row_factor, np.transpose(col_factor)), rtol=0.01, atol=0.01)
def wals_model(data, latent_factor, unobserved_weight, reg, use_weight, weight_type, feature_weight_exp, feature_weight_lin): row_weights = None col_weights = None num_rows = data.shape[0] num_cols = data.shape[1] if use_weight: assert feature_weight_exp is not None row_weights = np.ones(num_rows) col_weights = make_weights(data, weight_type, feature_weight_lin, feature_weight_exp, 0) row_factor = None col_factor = None with tf.Graph().as_default(): input_tensor = tf.SparseTensor(indices=list(zip(data.row, data.col)), values=(data.data).astype(np.float32), dense_shape=data.shape) model = factorization_ops.WALSModel( num_rows, num_cols, latent_factor, unobserved_weight=unobserved_weight, regularization=reg, row_weights=row_weights, col_weights=col_weights) row_factor = model.row_factors[0] col_factor = model.col_factors[0] return input_tensor, row_factor, col_factor, model
def _run_test_process_input(self, use_factors_weights_cache): with self.test_session(): sp_feeder = tf.sparse_placeholder(tf.float32) wals_model = factorization_ops.WALSModel( 5, 7, 3, num_row_shards=2, num_col_shards=3, regularization=0.01, unobserved_weight=0.1, col_init=self.col_init, row_weights=self.row_wts, col_weights=self.col_wts, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() # Split input into multiple sparse tensors with scattered rows. Note that # this split can be different than the factor sharding and the inputs can # consist of non-consecutive rows. Each row needs to include all non-zero # elements in that row. sp_r0 = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 2]).eval() sp_r1 = np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=True).eval() sp_r2 = np_matrix_to_tf_sparse(INPUT_MATRIX, [3], shuffle=True).eval() input_scattered_rows = [sp_r0, sp_r1, sp_r2] # Test updating row factors. # Here we feed in scattered rows of the input. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( sp_input=sp_feeder, transpose_input=False)[1] for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) row_factors = [x.eval() for x in wals_model.row_factors] self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) # Split input into multiple sparse tensors with scattered columns. Note # that here the elements in the sparse tensors are not ordered and also # do not need to consist of consecutive columns. However, each column # needs to include all non-zero elements in that column. sp_c0 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0]).eval() sp_c1 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5, 3, 1], shuffle=True).eval() sp_c2 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 6]).eval() sp_c3 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], shuffle=True).eval() input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() process_input_op = wals_model.update_col_factors( sp_input=sp_feeder, transpose_input=False)[1] for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) col_factors = [x.eval() for x in wals_model.col_factors] self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
def _run_test_als_transposed(self, use_factors_weights_cache): with ops.Graph().as_default(), self.test_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=None, col_weights=None, use_factors_weights_cache=use_factors_weights_cache) als_model.initialize_op.run() als_model.worker_init.run() wals_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=[0] * 5, col_weights=[0] * 7, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) # Here test partial row update with identical inputs but with transposed # input for als. sp_r_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1], transpose=True).eval() sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval() feed_dict = {sp_feeder: sp_r_t} als_model.row_update_prep_gramian_op.run() als_model.initialize_row_update_op.run() process_input_op = als_model.update_row_factors( sp_input=sp_feeder, transpose_input=True)[1] process_input_op.run(feed_dict=feed_dict) # Only updated row 1 and row 3, so only compare these rows since others # have randomly initialized values. row_factors1 = [ als_model.row_factors[0].eval()[1], als_model.row_factors[0].eval()[3] ] # Testing row projection. Projection weight doesn't matter in this case # since the model is ALS special case. Note that the ordering of the # returned results will be preserved as the input feature vectors # ordering. als_projected_row_factors1 = als_model.project_row_factors( sp_input=sp_feeder, transpose_input=True).eval(feed_dict=feed_dict) feed_dict = {sp_feeder: sp_r} wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) # Only updated row 1 and row 3, so only compare these rows since others # have randomly initialized values. row_factors2 = [ wals_model.row_factors[0].eval()[1], wals_model.row_factors[0].eval()[3] ] for r1, r2 in zip(row_factors1, row_factors2): self.assertAllClose(r1, r2, atol=1e-3) # Note that the ordering of the returned projection results is preserved # as the input feature vectors ordering. self.assertAllClose(als_projected_row_factors1, [row_factors2[1], row_factors2[0]], atol=1e-3)
def _run_test_als(self, use_factors_weights_cache): with ops.Graph().as_default(), self.test_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=None, col_weights=None, use_factors_weights_cache=use_factors_weights_cache) als_model.initialize_op.run() als_model.worker_init.run() als_model.row_update_prep_gramian_op.run() als_model.initialize_row_update_op.run() process_input_op = als_model.update_row_factors( self._wals_inputs)[1] process_input_op.run() row_factors1 = [x.eval() for x in als_model.row_factors] # Testing row projection. Projection weight doesn't matter in this case # since the model is ALS special case. als_projected_row_factors1 = als_model.project_row_factors( self._wals_inputs).eval() wals_model = factorization_ops.WALSModel( 5, 7, 3, col_init=col_init, row_weights=0, col_weights=0, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( self._wals_inputs)[1] process_input_op.run() row_factors2 = [x.eval() for x in wals_model.row_factors] for r1, r2 in zip(row_factors1, row_factors2): self.assertAllClose(r1, r2, atol=1e-3) self.assertAllClose( als_projected_row_factors1, [row for shard in row_factors2 for row in shard], atol=1e-3) # Here we test partial column updates. sp_c = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0], shuffle=True).eval() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) feed_dict = {sp_feeder: sp_c} als_model.col_update_prep_gramian_op.run() als_model.initialize_col_update_op.run() process_input_op = als_model.update_col_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) col_factors1 = [x.eval() for x in als_model.col_factors] # Testing column projection. Projection weight doesn't matter in this case # since the model is ALS special case. als_projected_col_factors1 = als_model.project_col_factors( np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0], shuffle=False)).eval() feed_dict = {sp_feeder: sp_c} wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() process_input_op = wals_model.update_col_factors( sp_input=sp_feeder)[1] process_input_op.run(feed_dict=feed_dict) col_factors2 = [x.eval() for x in wals_model.col_factors] for c1, c2 in zip(col_factors1, col_factors2): self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2) self.assertAllClose(als_projected_col_factors1, [col_factors2[0][2], col_factors2[0][0]], atol=1e-2)
def _run_test_process_input_transposed(self, use_factors_weights_cache, compute_loss=False): with ops.Graph().as_default(), self.test_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 num_cols = 7 factor_dim = 3 wals_model = factorization_ops.WALSModel( num_rows, num_cols, factor_dim, num_row_shards=2, num_col_shards=3, regularization=0.01, unobserved_weight=0.1, col_init=self.col_init, row_weights=self.row_wts, col_weights=self.col_wts, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() # Split input into multiple SparseTensors with scattered rows. # Here the inputs are transposed. But the same constraints as described in # the previous non-transposed test case apply to these inputs (before they # are transposed). sp_r0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 3], transpose=True).eval() sp_r1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [4, 1], shuffle=True, transpose=True).eval() sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval() sp_r3_t = sp_r1_t input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t] input_scattered_rows_non_duplicate = [sp_r0_t, sp_r1_t, sp_r2_t] # Test updating row factors. # Here we feed in scattered rows of the input. # Note that the needed suffix of placeholder are in the order of test # case name lexicographical order and then in the line order of where # they appear. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() (_, process_input_op, unregularized_loss, regularization, _) = wals_model.update_row_factors(sp_input=sp_feeder, transpose_input=True) factor_loss = unregularized_loss + regularization for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) row_factors = [x.eval() for x in wals_model.row_factors] self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) # Test row projection. # Using the specified projection weights for the 2 row feature vectors. # This is expected to reprodue the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( sp_input=sp_feeder, transpose_input=True, projection_weights=[0.5, 0.2]) # Don't specify the projection weight, so 1.0 will be used. The feature # weights will be those specified in model. projected_rows_no_weights = wals_model.project_row_factors( sp_input=sp_feeder, transpose_input=True) feed_dict = { sp_feeder: np_matrix_to_tf_sparse(INPUT_MATRIX, [4, 1], shuffle=False, transpose=True).eval() } self.assertAllClose( projected_rows.eval(feed_dict=feed_dict), [self._row_factors_1[1], self._row_factors_0[1]], atol=1e-3) self.assertAllClose( projected_rows_no_weights.eval(feed_dict=feed_dict), [[1.915879, 1.992677, 1.109057], [0.569082, 0.715088, 0.31777] ], atol=1e-3) if compute_loss: # Test loss computation after the row update loss = sum( sess.run(factor_loss * self.count_cols(inp) / num_rows, feed_dict={sp_feeder: inp}) for inp in input_scattered_rows_non_duplicate) true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, msg="After row update, computed loss [{}] does not match" " true loss [{}]".format(loss, true_loss)) # Split input into multiple SparseTensors with scattered columns. # Here the inputs are transposed. But the same constraints as described in # the previous non-transposed test case apply to these inputs (before they # are transposed). sp_c0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[0, 1], transpose=True).eval() sp_c1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 2], transpose=True).eval() sp_c2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5], transpose=True, shuffle=True).eval() sp_c3_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], transpose=True).eval() sp_c4_t = sp_c2_t input_scattered_cols = [ sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, sp_c4_t ] input_scattered_cols_non_duplicate = [ sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t ] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() (_, process_input_op, unregularized_loss, regularization, _) = wals_model.update_col_factors(sp_input=sp_feeder, transpose_input=True) factor_loss = unregularized_loss + regularization for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) col_factors = [x.eval() for x in wals_model.col_factors] self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) # Test column projection. # Using the specified projection weights for the 2 column feature vectors. # This is expected to reprodue the same column factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, transpose_input=True, projection_weights=[0.4, 0.7]) # Don't specify the projection weight, so 1.0 will be used. The feature # weights will be those specified in model. projected_cols_no_weights = wals_model.project_col_factors( sp_input=sp_feeder, transpose_input=True) feed_dict = {sp_feeder: sp_c3_t} self.assertAllClose( projected_cols.eval(feed_dict=feed_dict), [self._col_factors_1[0], self._col_factors_2[1]], atol=1e-3) self.assertAllClose( projected_cols_no_weights.eval(feed_dict=feed_dict), [[3.585139, -0.487476, -3.852232], [0.557937, 1.813907, 1.331171]], atol=1e-3) if compute_loss: # Test loss computation after the col update loss = sum( sess.run(factor_loss * self.count_rows(inp) / num_cols, feed_dict={sp_feeder: inp}) for inp in input_scattered_cols_non_duplicate) true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, msg="After col update, computed loss [{}] does not match" " true loss [{}]".format(loss, true_loss))
def _run_test_process_input(self, use_factors_weights_cache, compute_loss=False): with ops.Graph().as_default(), self.test_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 num_cols = 7 factor_dim = 3 wals_model = factorization_ops.WALSModel( num_rows, num_cols, factor_dim, num_row_shards=2, num_col_shards=3, regularization=0.01, unobserved_weight=0.1, col_init=self.col_init, row_weights=self.row_wts, col_weights=self.col_wts, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() # Split input into multiple sparse tensors with scattered rows. Note that # this split can be different than the factor sharding and the inputs can # consist of non-consecutive rows. Each row needs to include all non-zero # elements in that row. sp_r0 = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 2]).eval() sp_r1 = np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=True).eval() sp_r2 = np_matrix_to_tf_sparse(INPUT_MATRIX, [3], shuffle=True).eval() input_scattered_rows = [sp_r0, sp_r1, sp_r2] # Test updating row factors. # Here we feed in scattered rows of the input. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() (_, process_input_op, unregularized_loss, regularization, _) = wals_model.update_row_factors(sp_input=sp_feeder, transpose_input=False) factor_loss = unregularized_loss + regularization for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) row_factors = [x.eval() for x in wals_model.row_factors] self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) # Test row projection. # Using the specified projection weights for the 2 row feature vectors. # This is expected to reprodue the same row factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_rows = wals_model.project_row_factors( sp_input=sp_feeder, transpose_input=False, projection_weights=[0.2, 0.5]) # Don't specify the projection weight, so 1.0 will be used. The feature # weights will be those specified in model. projected_rows_no_weights = wals_model.project_row_factors( sp_input=sp_feeder, transpose_input=False) feed_dict = { sp_feeder: np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=False).eval() } self.assertAllClose( projected_rows.eval(feed_dict=feed_dict), [self._row_factors_0[1], self._row_factors_1[1]], atol=1e-3) self.assertAllClose( projected_rows_no_weights.eval(feed_dict=feed_dict), [[0.569082, 0.715088, 0.31777], [1.915879, 1.992677, 1.109057] ], atol=1e-3) if compute_loss: # Test loss computation after the row update loss = sum( sess.run(factor_loss * self.count_rows(inp) / num_rows, feed_dict={sp_feeder: inp}) for inp in input_scattered_rows) true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, msg="After row update, computed loss [{}] does not match" " true loss [{}]".format(loss, true_loss)) # Split input into multiple sparse tensors with scattered columns. Note # that here the elements in the sparse tensors are not ordered and also # do not need to consist of consecutive columns. However, each column # needs to include all non-zero elements in that column. sp_c0 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0]).eval() sp_c1 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5, 3, 1], shuffle=True).eval() sp_c2 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 6]).eval() sp_c3 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], shuffle=True).eval() input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3] input_scattered_cols_non_duplicate = [sp_c0, sp_c1, sp_c2] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() (_, process_input_op, unregularized_loss, regularization, _) = wals_model.update_col_factors(sp_input=sp_feeder, transpose_input=False) factor_loss = unregularized_loss + regularization for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) col_factors = [x.eval() for x in wals_model.col_factors] self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3) # Test column projection. # Using the specified projection weights for the 3 column feature vectors. # This is expected to reprodue the same column factors in the model as the # weights and feature vectors are identical to that used in model # training. projected_cols = wals_model.project_col_factors( sp_input=sp_feeder, transpose_input=False, projection_weights=[0.6, 0.4, 0.2]) # Don't specify the projection weight, so 1.0 will be used. The feature # weights will be those specified in model. projected_cols_no_weights = wals_model.project_col_factors( sp_input=sp_feeder, transpose_input=False) feed_dict = { sp_feeder: np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5, 3, 1], shuffle=False).eval() } self.assertAllClose(projected_cols.eval(feed_dict=feed_dict), [ self._col_factors_2[0], self._col_factors_1[0], self._col_factors_0[1] ], atol=1e-3) self.assertAllClose( projected_cols_no_weights.eval(feed_dict=feed_dict), [[3.471045, -1.250835, -3.598917], [3.585139, -0.487476, -3.852232], [0.346433, 1.360644, 1.677121]], atol=1e-3) if compute_loss: # Test loss computation after the column update. loss = sum( sess.run(factor_loss * self.count_cols(inp) / num_cols, feed_dict={sp_feeder: inp}) for inp in input_scattered_cols_non_duplicate) true_loss = self.calculate_loss_from_wals_model( wals_model, self._wals_inputs) self.assertNear( loss, true_loss, err=.001, msg="After col update, computed loss [{}] does not match" " true loss [{}]".format(loss, true_loss))
def _wals_factorization_model_function(features, labels, mode, params): """Model function for the WALSFactorization estimator. Args: features: Dictionary of features. See WALSMatrixFactorization. labels: Must be None. mode: A model_fn.ModeKeys object. params: Dictionary of parameters containing arguments passed to the WALSMatrixFactorization constructor. Returns: A ModelFnOps object. Raises: ValueError: If `mode` is not recognized. """ assert labels is None use_factors_weights_cache = ( params["use_factors_weights_cache_for_training"] and mode == model_fn.ModeKeys.TRAIN) use_gramian_cache = (params["use_gramian_cache_for_training"] and mode == model_fn.ModeKeys.TRAIN) max_sweeps = params["max_sweeps"] model = factorization_ops.WALSModel( params["num_rows"], params["num_cols"], params["embedding_dimension"], unobserved_weight=params["unobserved_weight"], regularization=params["regularization_coeff"], row_init=params["row_init"], col_init=params["col_init"], num_row_shards=params["num_row_shards"], num_col_shards=params["num_col_shards"], row_weights=params["row_weights"], col_weights=params["col_weights"], use_factors_weights_cache=use_factors_weights_cache, use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on # the value of row_sweep, which is maintained using a session hook. input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] # TRAIN mode: if mode == model_fn.ModeKeys.TRAIN: # Training consists of the following ops (controlled using a SweepHook). # Before a row sweep: # row_update_prep_gramian_op # initialize_row_update_op # During a row sweep: # update_row_factors_op # Before a col sweep: # col_update_prep_gramian_op # initialize_col_update_op # During a col sweep: # update_col_factors_op is_row_sweep_var = variable_scope.variable( True, trainable=False, name="is_row_sweep", collections=[ops.GraphKeys.GLOBAL_VARIABLES]) is_sweep_done_var = variable_scope.variable( False, trainable=False, name="is_sweep_done", collections=[ops.GraphKeys.GLOBAL_VARIABLES]) completed_sweeps_var = variable_scope.variable( 0, trainable=False, name=WALSMatrixFactorization.COMPLETED_SWEEPS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) loss_var = variable_scope.variable( 0., trainable=False, name=WALSMatrixFactorization.LOSS, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) # The root weighted squared error = # \sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij ) rwse_var = variable_scope.variable( 0., trainable=False, name=WALSMatrixFactorization.RWSE, collections=[ops.GraphKeys.GLOBAL_VARIABLES]) summary.scalar("loss", loss_var) summary.scalar("root_weighted_squared_error", rwse_var) summary.scalar("completed_sweeps", completed_sweeps_var) # Increments global step. global_step = training_util.get_global_step() if global_step: global_step_incr_op = state_ops.assign_add( global_step, 1, name="global_step_incr").op else: global_step_incr_op = control_flow_ops.no_op() def create_axis_ops(sp_input, num_items, update_fn, axis_name): """Creates book-keeping and training ops for a given axis. Args: sp_input: A SparseTensor corresponding to the row or column batch. num_items: An integer, the total number of items of this axis. update_fn: A function that takes one argument (`sp_input`), and that returns a tuple of * new_factors: A flot Tensor of the factor values after update. * update_op: a TensorFlow op which updates the factors. * loss: A float Tensor, the unregularized loss. * reg_loss: A float Tensor, the regularization loss. * sum_weights: A float Tensor, the sum of factor weights. axis_name: A string that specifies the name of the axis. Returns: A tuple consisting of: * reset_processed_items_op: A TensorFlow op, to be run before the beginning of any sweep. It marks all items as not-processed. * axis_train_op: A Tensorflow op, to be run during this axis' sweeps. """ processed_items_init = array_ops.fill(dims=[num_items], value=False) with ops.colocate_with(processed_items_init): processed_items = variable_scope.variable( processed_items_init, collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="processed_" + axis_name) reset_processed_items_op = state_ops.assign( processed_items, processed_items_init, name="reset_processed_" + axis_name) _, update_op, loss, reg, sum_weights = update_fn(sp_input) input_indices = sp_input.indices[:, 0] with ops.control_dependencies([ update_op, state_ops.assign(loss_var, loss + reg), state_ops.assign(rwse_var, math_ops.sqrt(loss / sum_weights)) ]): with ops.colocate_with(processed_items): update_processed_items = state_ops.scatter_update( processed_items, input_indices, array_ops.ones_like(input_indices, dtype=dtypes.bool), name="update_processed_{}_indices".format(axis_name)) with ops.control_dependencies([update_processed_items]): is_sweep_done = math_ops.reduce_all(processed_items) axis_train_op = control_flow_ops.group( global_step_incr_op, state_ops.assign(is_sweep_done_var, is_sweep_done), state_ops.assign_add( completed_sweeps_var, math_ops.cast(is_sweep_done, dtypes.int32)), name="{}_sweep_train_op".format(axis_name)) return reset_processed_items_op, axis_train_op reset_processed_rows_op, row_train_op = create_axis_ops( input_rows, params["num_rows"], lambda x: model.update_row_factors(sp_input=x, transpose_input=False), "rows") reset_processed_cols_op, col_train_op = create_axis_ops( input_cols, params["num_cols"], lambda x: model.update_col_factors(sp_input=x, transpose_input=True), "cols") switch_op = control_flow_ops.group(state_ops.assign( is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)), reset_processed_rows_op, reset_processed_cols_op, name="sweep_switch_op") row_prep_ops = [ model.row_update_prep_gramian_op, model.initialize_row_update_op ] col_prep_ops = [ model.col_update_prep_gramian_op, model.initialize_col_update_op ] init_op = model.worker_init sweep_hook = _SweepHook(is_row_sweep_var, is_sweep_done_var, init_op, row_prep_ops, col_prep_ops, row_train_op, col_train_op, switch_op) training_hooks = [sweep_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) return model_fn.ModelFnOps(mode=model_fn.ModeKeys.TRAIN, predictions={}, loss=loss_var, eval_metric_ops={}, train_op=control_flow_ops.no_op(), training_hooks=training_hooks) # INFER mode elif mode == model_fn.ModeKeys.INFER: projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) def get_row_projection(): return model.project_row_factors( sp_input=input_rows, projection_weights=projection_weights, transpose_input=False) def get_col_projection(): return model.project_col_factors( sp_input=input_cols, projection_weights=projection_weights, transpose_input=True) predictions = { WALSMatrixFactorization.PROJECTION_RESULT: control_flow_ops.cond( features[WALSMatrixFactorization.PROJECT_ROW], get_row_projection, get_col_projection) } return model_fn.ModelFnOps(mode=model_fn.ModeKeys.INFER, predictions=predictions, loss=None, eval_metric_ops={}, train_op=control_flow_ops.no_op(), training_hooks=[]) # EVAL mode elif mode == model_fn.ModeKeys.EVAL: def get_row_loss(): _, _, loss, reg, _ = model.update_row_factors( sp_input=input_rows, transpose_input=False) return loss + reg def get_col_loss(): _, _, loss, reg, _ = model.update_col_factors(sp_input=input_cols, transpose_input=True) return loss + reg loss = control_flow_ops.cond( features[WALSMatrixFactorization.PROJECT_ROW], get_row_loss, get_col_loss) return model_fn.ModelFnOps(mode=model_fn.ModeKeys.EVAL, predictions={}, loss=loss, eval_metric_ops={}, train_op=control_flow_ops.no_op(), training_hooks=[]) else: raise ValueError("mode=%s is not recognized." % str(mode))
def _run_test_process_input_transposed(self, use_factors_weights_cache): with self.test_session(): sp_feeder = tf.sparse_placeholder(tf.float32) wals_model = factorization_ops.WALSModel( 5, 7, 3, num_row_shards=2, num_col_shards=3, regularization=0.01, unobserved_weight=0.1, col_init=self.col_init, row_weights=self.row_wts, col_weights=self.col_wts, use_factors_weights_cache=use_factors_weights_cache) wals_model.initialize_op.run() wals_model.worker_init.run() # Split input into multiple SparseTensors with scattered rows. # Here the inputs are transposed. But the same constraints as described in # the previous non-transposed test case apply to these inputs (before they # are transposed). sp_r0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 3], transpose=True).eval() sp_r1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [4, 1], shuffle=True, transpose=True).eval() sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval() sp_r3_t = sp_r1_t input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t] # Test updating row factors. # Here we feed in scattered rows of the input. # Note that the needed suffix of placeholder are in the order of test # case name lexicographical order and then in the line order of where # they appear. wals_model.row_update_prep_gramian_op.run() wals_model.initialize_row_update_op.run() process_input_op = wals_model.update_row_factors( sp_input=sp_feeder, transpose_input=True)[1] for inp in input_scattered_rows: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) row_factors = [x.eval() for x in wals_model.row_factors] self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3) self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3) # Split input into multiple SparseTensors with scattered columns. # Here the inputs are transposed. But the same constraints as described in # the previous non-transposed test case apply to these inputs (before they # are transposed). sp_c0_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[0, 1], transpose=True).eval() sp_c1_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 2], transpose=True).eval() sp_c2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[5], transpose=True, shuffle=True).eval() sp_c3_t = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[3, 6], transpose=True).eval() sp_c4_t = sp_c2_t input_scattered_cols = [ sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, sp_c4_t ] # Test updating column factors. # Here we feed in scattered columns of the input. wals_model.col_update_prep_gramian_op.run() wals_model.initialize_col_update_op.run() process_input_op = wals_model.update_col_factors( sp_input=sp_feeder, transpose_input=True)[1] for inp in input_scattered_cols: feed_dict = {sp_feeder: inp} process_input_op.run(feed_dict=feed_dict) col_factors = [x.eval() for x in wals_model.col_factors] self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3) self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3) self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)
params = DEFAULT_PARAMS # Create WALS model row_wts = None col_wts = None num_rows = train_sparse.shape[0] num_cols = train_sparse.shape[1] sess = tf.Session() #graph=input_tensor.graph) model = factorization_ops.WALSModel(num_rows, num_cols, n_components=params['latent_factors'], # num_row_shards=2, # num_col_shards=3, unobserved_weight=params['unobs_weight'], regularization=params['regularization'], row_weights=row_wts, col_weights=col_wts) print("\nPreparation for Training....\n") with tf.Session() as sess: sess.run(model.initialize_op) sess.run(model.worker_init) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~# batch = 512 X_train = train_sparse.tocsr()
def _wals_factorization_model_function(features, labels, mode, params): """Model function for the WALSFactorization estimator. Args: features: Dictionary of features. See WALSMatrixFactorization. labels: Must be None. mode: A model_fn.ModeKeys object. params: Dictionary of parameters containing arguments passed to the WALSMatrixFactorization constructor. Returns: A ModelFnOps object. """ assert labels is None use_factors_weights_cache = ( params["use_factors_weights_cache_for_training"] and mode == model_fn.ModeKeys.TRAIN) use_gramian_cache = ( params["use_gramian_cache_for_training"] and mode == model_fn.ModeKeys.TRAIN) model = factorization_ops.WALSModel( params["num_rows"], params["num_cols"], params["embedding_dimension"], unobserved_weight=params["unobserved_weight"], regularization=params["regularization_coeff"], row_init=params["row_init"], col_init=params["col_init"], num_row_shards=params["num_row_shards"], num_col_shards=params["num_col_shards"], row_weights=params["row_weights"], col_weights=params["col_weights"], use_factors_weights_cache=use_factors_weights_cache, use_gramian_cache=use_gramian_cache) # Get input rows and cols. We either update rows or columns depending on # the value of row_sweep, which is maintained using a session hook input_rows = features[WALSMatrixFactorization.INPUT_ROWS] input_cols = features[WALSMatrixFactorization.INPUT_COLS] input_row_indices, _ = array_ops.unique(input_rows.indices[:, 0]) input_col_indices, _ = array_ops.unique(input_cols.indices[:, 0]) # Train ops, controlled using the SweepHook # We need to run the following ops: # Before a row sweep: # row_update_prep_gramian_op # initialize_row_update_op # During a row sweep: # update_row_factors_op # Before a col sweep: # col_update_prep_gramian_op # initialize_col_update_op # During a col sweep: # update_col_factors_op is_row_sweep_var = variables.Variable( True, "is_row_sweep", collections=[ops.GraphKeys.GLOBAL_VARIABLES]) # The row sweep is determined by is_row_sweep_var (controlled by the # sweep_hook) in TRAIN mode, and manually in EVAL mode. is_row_sweep = (features[WALSMatrixFactorization.PROJECT_ROW] if mode == model_fn.ModeKeys.EVAL else is_row_sweep_var) def update_row_factors(): return model.update_row_factors(sp_input=input_rows, transpose_input=False) def update_col_factors(): return model.update_col_factors(sp_input=input_cols, transpose_input=True) _, train_op, loss = control_flow_ops.cond( is_row_sweep, update_row_factors, update_col_factors) row_prep_ops = [model.row_update_prep_gramian_op, model.initialize_row_update_op] col_prep_ops = [model.col_update_prep_gramian_op, model.initialize_col_update_op] cache_init_ops = [model.worker_init] sweep_hook = _SweepHook( is_row_sweep_var, train_op, params["num_rows"], params["num_cols"], input_row_indices, input_col_indices, row_prep_ops, col_prep_ops, cache_init_ops, ) # Prediction ops (only return predictions in INFER mode) predictions = {} if mode == model_fn.ModeKeys.INFER: project_row = features[WALSMatrixFactorization.PROJECT_ROW] projection_weights = features.get( WALSMatrixFactorization.PROJECTION_WEIGHTS) def get_row_projection(): return model.project_row_factors( sp_input=input_rows, projection_weights=projection_weights, transpose_input=False) def get_col_projection(): return model.project_col_factors( sp_input=input_cols, projection_weights=projection_weights, transpose_input=True) predictions[WALSMatrixFactorization.PROJECTION_RESULT] = ( control_flow_ops.cond( project_row, get_row_projection, get_col_projection)) return model_fn.ModelFnOps( mode=mode, predictions=predictions, loss=loss, eval_metric_ops={}, train_op=train_op, training_hooks=[sweep_hook])