Exemple #1
0
  def _run_test_train_full_low_rank_wals(self, use_factors_weights_cache):
    rows = 15
    cols = 11
    dims = 3

    with ops.Graph().as_default(), self.cached_session():
      data = np.dot(np.random.rand(rows, 3), np.random.rand(
          3, cols)).astype(np.float32) / 3.0
      indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
      values = data.reshape(-1)
      inp = sparse_tensor.SparseTensor(indices, values, [rows, cols])
      model = factorization_ops.WALSModel(
          rows,
          cols,
          dims,
          regularization=1e-5,
          row_weights=0,
          col_weights=[0] * cols,
          use_factors_weights_cache=use_factors_weights_cache)
      self.simple_train(model, inp, 25)
      row_factor = model.row_factors[0].eval()
      col_factor = model.col_factors[0].eval()
      self.assertAllClose(
          data,
          np.dot(row_factor, np.transpose(col_factor)),
          rtol=0.01,
          atol=0.01)
Exemple #2
0
  def _run_test_sum_weights(self, test_rows):
    # test_rows: True to test row weights, False to test column weights.

    num_rows = 5
    num_cols = 5
    unobserved_weight = 0.1
    row_weights = [[8., 18., 28., 38., 48.]]
    col_weights = [[90., 91., 92., 93., 94.]]
    sparse_indices = [[0, 1], [2, 3], [4, 1]]
    sparse_values = [666., 777., 888.]

    unobserved = unobserved_weight * num_rows * num_cols
    observed = 8. * 91. + 28. * 93. + 48. * 91.
    # sparse_indices has three unique rows and two unique columns
    observed *= num_rows / 3. if test_rows else num_cols / 2.
    want_weight_sum = unobserved + observed

    with ops.Graph().as_default(), self.cached_session() as sess:
      wals_model = factorization_ops.WALSModel(
          input_rows=num_rows,
          input_cols=num_cols,
          n_components=5,
          unobserved_weight=unobserved_weight,
          row_weights=row_weights,
          col_weights=col_weights,
          use_factors_weights_cache=False)

      wals_model.initialize_op.run()
      wals_model.worker_init.run()

      update_factors = (wals_model.update_row_factors
                        if test_rows else wals_model.update_col_factors)

      (_, _, _, _, sum_weights) = update_factors(
          sp_input=sparse_tensor.SparseTensor(
              indices=sparse_indices,
              values=sparse_values,
              dense_shape=[num_rows, num_cols]),
          transpose_input=False)

      got_weight_sum = sess.run(sum_weights)

      self.assertNear(
          got_weight_sum,
          want_weight_sum,
          err=.001,
          msg="got weight sum [{}], want weight sum [{}]".format(
              got_weight_sum, want_weight_sum))
Exemple #3
0
  def _run_test_train_matrix_completion_wals(self, use_factors_weights_cache):
    rows = 11
    cols = 9
    dims = 4

    def keep_index(x):
      return not (x[0] + x[1]) % 4

    with ops.Graph().as_default(), self.cached_session():
      row_wts = 0.1 + np.random.rand(rows)
      col_wts = 0.1 + np.random.rand(cols)
      data = np.dot(np.random.rand(rows, 3), np.random.rand(
          3, cols)).astype(np.float32) / 3.0
      indices = np.array(
          list(
              filter(keep_index,
                     [[i, j] for i in xrange(rows) for j in xrange(cols)])))
      values = data[indices[:, 0], indices[:, 1]]
      inp = sparse_tensor.SparseTensor(indices, values, [rows, cols])
      model = factorization_ops.WALSModel(
          rows,
          cols,
          dims,
          unobserved_weight=0.01,
          regularization=0.001,
          row_weights=row_wts,
          col_weights=col_wts,
          use_factors_weights_cache=use_factors_weights_cache)
      self.simple_train(model, inp, 25)
      row_factor = model.row_factors[0].eval()
      col_factor = model.col_factors[0].eval()
      out = np.dot(row_factor, np.transpose(col_factor))
      for i in xrange(rows):
        for j in xrange(cols):
          if keep_index([i, j]):
            self.assertNear(
                data[i][j], out[i][j], err=0.4, msg="%d, %d" % (i, j))
          else:
            self.assertNear(0, out[i][j], err=0.5, msg="%d, %d" % (i, j))
def _wals_factorization_model_function(features, labels, mode, params):
    """Model function for the WALSFactorization estimator.

  Args:
    features: Dictionary of features. See WALSMatrixFactorization.
    labels: Must be None.
    mode: A model_fn.ModeKeys object.
    params: Dictionary of parameters containing arguments passed to the
      WALSMatrixFactorization constructor.

  Returns:
    A ModelFnOps object.

  Raises:
    ValueError: If `mode` is not recognized.
  """
    assert labels is None
    use_factors_weights_cache = (
        params["use_factors_weights_cache_for_training"]
        and mode == model_fn.ModeKeys.TRAIN)
    use_gramian_cache = (params["use_gramian_cache_for_training"]
                         and mode == model_fn.ModeKeys.TRAIN)
    max_sweeps = params["max_sweeps"]
    model = factorization_ops.WALSModel(
        params["num_rows"],
        params["num_cols"],
        params["embedding_dimension"],
        unobserved_weight=params["unobserved_weight"],
        regularization=params["regularization_coeff"],
        row_init=params["row_init"],
        col_init=params["col_init"],
        num_row_shards=params["num_row_shards"],
        num_col_shards=params["num_col_shards"],
        row_weights=params["row_weights"],
        col_weights=params["col_weights"],
        use_factors_weights_cache=use_factors_weights_cache,
        use_gramian_cache=use_gramian_cache)

    # Get input rows and cols. We either update rows or columns depending on
    # the value of row_sweep, which is maintained using a session hook.
    input_rows = features[WALSMatrixFactorization.INPUT_ROWS]
    input_cols = features[WALSMatrixFactorization.INPUT_COLS]

    # TRAIN mode:
    if mode == model_fn.ModeKeys.TRAIN:
        # Training consists of the following ops (controlled using a SweepHook).
        # Before a row sweep:
        #   row_update_prep_gramian_op
        #   initialize_row_update_op
        # During a row sweep:
        #   update_row_factors_op
        # Before a col sweep:
        #   col_update_prep_gramian_op
        #   initialize_col_update_op
        # During a col sweep:
        #   update_col_factors_op

        is_row_sweep_var = variable_scope.variable(
            True,
            trainable=False,
            name="is_row_sweep",
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        is_sweep_done_var = variable_scope.variable(
            False,
            trainable=False,
            name="is_sweep_done",
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        completed_sweeps_var = variable_scope.variable(
            0,
            trainable=False,
            name=WALSMatrixFactorization.COMPLETED_SWEEPS,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        loss_var = variable_scope.variable(
            0.,
            trainable=False,
            name=WALSMatrixFactorization.LOSS,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])
        # The root weighted squared error =
        #   \\(\sqrt( \sum_{i,j} w_ij * (a_ij - r_ij)^2 / \sum_{i,j} w_ij )\\)
        rwse_var = variable_scope.variable(
            0.,
            trainable=False,
            name=WALSMatrixFactorization.RWSE,
            collections=[ops.GraphKeys.GLOBAL_VARIABLES])

        summary.scalar("loss", loss_var)
        summary.scalar("root_weighted_squared_error", rwse_var)
        summary.scalar("completed_sweeps", completed_sweeps_var)

        def create_axis_ops(sp_input, num_items, update_fn, axis_name):
            """Creates book-keeping and training ops for a given axis.

      Args:
        sp_input: A SparseTensor corresponding to the row or column batch.
        num_items: An integer, the total number of items of this axis.
        update_fn: A function that takes one argument (`sp_input`), and that
        returns a tuple of
          * new_factors: A float Tensor of the factor values after update.
          * update_op: a TensorFlow op which updates the factors.
          * loss: A float Tensor, the unregularized loss.
          * reg_loss: A float Tensor, the regularization loss.
          * sum_weights: A float Tensor, the sum of factor weights.
        axis_name: A string that specifies the name of the axis.

      Returns:
        A tuple consisting of:
          * reset_processed_items_op: A TensorFlow op, to be run before the
            beginning of any sweep. It marks all items as not-processed.
          * axis_train_op: A Tensorflow op, to be run during this axis' sweeps.
      """
            processed_items_init = array_ops.fill(dims=[num_items],
                                                  value=False)
            with ops.colocate_with(processed_items_init):
                processed_items = variable_scope.variable(
                    processed_items_init,
                    collections=[ops.GraphKeys.GLOBAL_VARIABLES],
                    trainable=False,
                    name="processed_" + axis_name)
            _, update_op, loss, reg, sum_weights = update_fn(sp_input)
            input_indices = sp_input.indices[:, 0]
            with ops.control_dependencies([
                    update_op,
                    state_ops.assign(loss_var, loss + reg),
                    state_ops.assign(rwse_var,
                                     math_ops.sqrt(loss / sum_weights))
            ]):
                with ops.colocate_with(processed_items):
                    update_processed_items = state_ops.scatter_update(
                        processed_items,
                        input_indices,
                        array_ops.ones_like(input_indices, dtype=dtypes.bool),
                        name="update_processed_{}_indices".format(axis_name))
                with ops.control_dependencies([update_processed_items]):
                    is_sweep_done = math_ops.reduce_all(processed_items)
                    axis_train_op = control_flow_ops.group(
                        state_ops.assign(is_sweep_done_var, is_sweep_done),
                        state_ops.assign_add(
                            completed_sweeps_var,
                            math_ops.cast(is_sweep_done, dtypes.int32)),
                        name="{}_sweep_train_op".format(axis_name))
            return processed_items.initializer, axis_train_op

        reset_processed_rows_op, row_train_op = create_axis_ops(
            input_rows, params["num_rows"],
            lambda x: model.update_row_factors(sp_input=x,
                                               transpose_input=False), "rows")
        reset_processed_cols_op, col_train_op = create_axis_ops(
            input_cols, params["num_cols"],
            lambda x: model.update_col_factors(sp_input=x,
                                               transpose_input=True), "cols")
        switch_op = control_flow_ops.group(state_ops.assign(
            is_row_sweep_var, math_ops.logical_not(is_row_sweep_var)),
                                           reset_processed_rows_op,
                                           reset_processed_cols_op,
                                           name="sweep_switch_op")
        row_prep_ops = [
            model.row_update_prep_gramian_op, model.initialize_row_update_op
        ]
        col_prep_ops = [
            model.col_update_prep_gramian_op, model.initialize_col_update_op
        ]
        init_op = model.worker_init
        sweep_hook = _SweepHook(is_row_sweep_var, is_sweep_done_var, init_op,
                                row_prep_ops, col_prep_ops, row_train_op,
                                col_train_op, switch_op)
        global_step_hook = _IncrementGlobalStepHook()
        training_hooks = [sweep_hook, global_step_hook]
        if max_sweeps is not None:
            training_hooks.append(_StopAtSweepHook(max_sweeps))

        return model_fn.ModelFnOps(mode=model_fn.ModeKeys.TRAIN,
                                   predictions={},
                                   loss=loss_var,
                                   eval_metric_ops={},
                                   train_op=control_flow_ops.no_op(),
                                   training_hooks=training_hooks)

    # INFER mode
    elif mode == model_fn.ModeKeys.INFER:
        projection_weights = features.get(
            WALSMatrixFactorization.PROJECTION_WEIGHTS)

        def get_row_projection():
            return model.project_row_factors(
                sp_input=input_rows,
                projection_weights=projection_weights,
                transpose_input=False)

        def get_col_projection():
            return model.project_col_factors(
                sp_input=input_cols,
                projection_weights=projection_weights,
                transpose_input=True)

        predictions = {
            WALSMatrixFactorization.PROJECTION_RESULT:
            control_flow_ops.cond(
                features[WALSMatrixFactorization.PROJECT_ROW],
                get_row_projection, get_col_projection)
        }

        return model_fn.ModelFnOps(mode=model_fn.ModeKeys.INFER,
                                   predictions=predictions,
                                   loss=None,
                                   eval_metric_ops={},
                                   train_op=control_flow_ops.no_op(),
                                   training_hooks=[])

    # EVAL mode
    elif mode == model_fn.ModeKeys.EVAL:

        def get_row_loss():
            _, _, loss, reg, _ = model.update_row_factors(
                sp_input=input_rows, transpose_input=False)
            return loss + reg

        def get_col_loss():
            _, _, loss, reg, _ = model.update_col_factors(sp_input=input_cols,
                                                          transpose_input=True)
            return loss + reg

        loss = control_flow_ops.cond(
            features[WALSMatrixFactorization.PROJECT_ROW], get_row_loss,
            get_col_loss)
        return model_fn.ModelFnOps(mode=model_fn.ModeKeys.EVAL,
                                   predictions={},
                                   loss=loss,
                                   eval_metric_ops={},
                                   train_op=control_flow_ops.no_op(),
                                   training_hooks=[])

    else:
        raise ValueError("mode=%s is not recognized." % str(mode))
Exemple #5
0
  def _run_test_als_transposed(self, use_factors_weights_cache):
    with ops.Graph().as_default(), self.cached_session():
      self._wals_inputs = self.sparse_input()
      col_init = np.random.rand(7, 3)
      als_model = factorization_ops.WALSModel(
          5,
          7,
          3,
          col_init=col_init,
          row_weights=None,
          col_weights=None,
          use_factors_weights_cache=use_factors_weights_cache)

      als_model.initialize_op.run()
      als_model.worker_init.run()

      wals_model = factorization_ops.WALSModel(
          5,
          7,
          3,
          col_init=col_init,
          row_weights=[0] * 5,
          col_weights=[0] * 7,
          use_factors_weights_cache=use_factors_weights_cache)
      wals_model.initialize_op.run()
      wals_model.worker_init.run()
      sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
      # Here test partial row update with identical inputs but with transposed
      # input for als.
      sp_r_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, [3, 1], transpose=True).eval()
      sp_r = np_matrix_to_tf_sparse(INPUT_MATRIX, [3, 1]).eval()

      feed_dict = {sp_feeder: sp_r_t}
      als_model.row_update_prep_gramian_op.run()
      als_model.initialize_row_update_op.run()
      process_input_op = als_model.update_row_factors(
          sp_input=sp_feeder, transpose_input=True)[1]
      process_input_op.run(feed_dict=feed_dict)
      # Only updated row 1 and row 3, so only compare these rows since others
      # have randomly initialized values.
      row_factors1 = [
          als_model.row_factors[0].eval()[1], als_model.row_factors[0].eval()[3]
      ]
      # Testing row projection. Projection weight doesn't matter in this case
      # since the model is ALS special case. Note that the ordering of the
      # returned results will be preserved as the input feature vectors
      # ordering.
      als_projected_row_factors1 = als_model.project_row_factors(
          sp_input=sp_feeder, transpose_input=True).eval(feed_dict=feed_dict)

      feed_dict = {sp_feeder: sp_r}
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      process_input_op = wals_model.update_row_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      # Only updated row 1 and row 3, so only compare these rows since others
      # have randomly initialized values.
      row_factors2 = [
          wals_model.row_factors[0].eval()[1],
          wals_model.row_factors[0].eval()[3]
      ]
      for r1, r2 in zip(row_factors1, row_factors2):
        self.assertAllClose(r1, r2, atol=1e-3)
      # Note that the ordering of the returned projection results is preserved
      # as the input feature vectors ordering.
      self.assertAllClose(
          als_projected_row_factors1, [row_factors2[1], row_factors2[0]],
          atol=1e-3)
Exemple #6
0
  def _run_test_als(self, use_factors_weights_cache):
    with ops.Graph().as_default(), self.cached_session():
      self._wals_inputs = self.sparse_input()
      col_init = np.random.rand(7, 3)
      als_model = factorization_ops.WALSModel(
          5,
          7,
          3,
          col_init=col_init,
          row_weights=None,
          col_weights=None,
          use_factors_weights_cache=use_factors_weights_cache)

      als_model.initialize_op.run()
      als_model.worker_init.run()
      als_model.row_update_prep_gramian_op.run()
      als_model.initialize_row_update_op.run()
      process_input_op = als_model.update_row_factors(self._wals_inputs)[1]
      process_input_op.run()
      row_factors1 = [x.eval() for x in als_model.row_factors]
      # Testing row projection. Projection weight doesn't matter in this case
      # since the model is ALS special case.
      als_projected_row_factors1 = als_model.project_row_factors(
          self._wals_inputs).eval()

      wals_model = factorization_ops.WALSModel(
          5,
          7,
          3,
          col_init=col_init,
          row_weights=0,
          col_weights=0,
          use_factors_weights_cache=use_factors_weights_cache)
      wals_model.initialize_op.run()
      wals_model.worker_init.run()
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      process_input_op = wals_model.update_row_factors(self._wals_inputs)[1]
      process_input_op.run()
      row_factors2 = [x.eval() for x in wals_model.row_factors]

      for r1, r2 in zip(row_factors1, row_factors2):
        self.assertAllClose(r1, r2, atol=1e-3)
      self.assertAllClose(
          als_projected_row_factors1,
          [row for shard in row_factors2 for row in shard],
          atol=1e-3)

      # Here we test partial column updates.
      sp_c = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[2, 0], shuffle=True).eval()

      sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
      feed_dict = {sp_feeder: sp_c}
      als_model.col_update_prep_gramian_op.run()
      als_model.initialize_col_update_op.run()
      process_input_op = als_model.update_col_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      col_factors1 = [x.eval() for x in als_model.col_factors]
      # Testing column projection. Projection weight doesn't matter in this case
      # since the model is ALS special case.
      als_projected_col_factors1 = als_model.project_col_factors(
          np_matrix_to_tf_sparse(
              INPUT_MATRIX, col_slices=[2, 0], shuffle=False)).eval()

      feed_dict = {sp_feeder: sp_c}
      wals_model.col_update_prep_gramian_op.run()
      wals_model.initialize_col_update_op.run()
      process_input_op = wals_model.update_col_factors(sp_input=sp_feeder)[1]
      process_input_op.run(feed_dict=feed_dict)
      col_factors2 = [x.eval() for x in wals_model.col_factors]

      for c1, c2 in zip(col_factors1, col_factors2):
        self.assertAllClose(c1, c2, rtol=5e-3, atol=1e-2)
      self.assertAllClose(
          als_projected_col_factors1, [col_factors2[0][2], col_factors2[0][0]],
          atol=1e-2)
Exemple #7
0
  def _run_test_process_input_transposed(self,
                                         use_factors_weights_cache,
                                         compute_loss=False):
    with ops.Graph().as_default(), self.cached_session() as sess:
      self._wals_inputs = self.sparse_input()
      sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
      num_rows = 5
      num_cols = 7
      factor_dim = 3
      wals_model = factorization_ops.WALSModel(
          num_rows,
          num_cols,
          factor_dim,
          num_row_shards=2,
          num_col_shards=3,
          regularization=0.01,
          unobserved_weight=0.1,
          col_init=self.col_init,
          row_weights=self.row_wts,
          col_weights=self.col_wts,
          use_factors_weights_cache=use_factors_weights_cache)

      wals_model.initialize_op.run()
      wals_model.worker_init.run()

      # Split input into multiple SparseTensors with scattered rows.
      # Here the inputs are transposed. But the same constraints as described in
      # the previous non-transposed test case apply to these inputs (before they
      # are transposed).
      sp_r0_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, [0, 3], transpose=True).eval()
      sp_r1_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, [4, 1], shuffle=True, transpose=True).eval()
      sp_r2_t = np_matrix_to_tf_sparse(INPUT_MATRIX, [2], transpose=True).eval()
      sp_r3_t = sp_r1_t
      input_scattered_rows = [sp_r0_t, sp_r1_t, sp_r2_t, sp_r3_t]
      input_scattered_rows_non_duplicate = [sp_r0_t, sp_r1_t, sp_r2_t]
      # Test updating row factors.
      # Here we feed in scattered rows of the input.
      # Note that the needed suffix of placeholder are in the order of test
      # case name lexicographical order and then in the line order of where
      # they appear.
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      (_, process_input_op, unregularized_loss, regularization,
       _) = wals_model.update_row_factors(
           sp_input=sp_feeder, transpose_input=True)
      factor_loss = unregularized_loss + regularization
      for inp in input_scattered_rows:
        feed_dict = {sp_feeder: inp}
        process_input_op.run(feed_dict=feed_dict)
      row_factors = [x.eval() for x in wals_model.row_factors]

      self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
      self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)

      # Test row projection.
      # Using the specified projection weights for the 2 row feature vectors.
      # This is expected to reproduce the same row factors in the model as the
      # weights and feature vectors are identical to that used in model
      # training.
      projected_rows = wals_model.project_row_factors(
          sp_input=sp_feeder,
          transpose_input=True,
          projection_weights=[0.5, 0.2])
      # Don't specify the projection weight, so 1.0 will be used. The feature
      # weights will be those specified in model.
      projected_rows_no_weights = wals_model.project_row_factors(
          sp_input=sp_feeder, transpose_input=True)
      feed_dict = {
          sp_feeder:
              np_matrix_to_tf_sparse(
                  INPUT_MATRIX, [4, 1], shuffle=False, transpose=True).eval()
      }
      self.assertAllClose(
          projected_rows.eval(feed_dict=feed_dict),
          [self._row_factors_1[1], self._row_factors_0[1]],
          atol=1e-3)
      self.assertAllClose(
          projected_rows_no_weights.eval(feed_dict=feed_dict),
          [[1.915879, 1.992677, 1.109057], [0.569082, 0.715088, 0.31777]],
          atol=1e-3)

      if compute_loss:
        # Test loss computation after the row update
        loss = sum(
            sess.run(
                factor_loss * self.count_cols(inp) / num_rows,
                feed_dict={sp_feeder: inp})
            for inp in input_scattered_rows_non_duplicate)
        true_loss = self.calculate_loss_from_wals_model(wals_model,
                                                        self._wals_inputs)
        self.assertNear(
            loss,
            true_loss,
            err=.001,
            msg="After row update, computed loss [{}] does not match"
            " true loss [{}]".format(loss, true_loss))

      # Split input into multiple SparseTensors with scattered columns.
      # Here the inputs are transposed. But the same constraints as described in
      # the previous non-transposed test case apply to these inputs (before they
      # are transposed).
      sp_c0_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[0, 1], transpose=True).eval()
      sp_c1_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[4, 2], transpose=True).eval()
      sp_c2_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[5], transpose=True, shuffle=True).eval()
      sp_c3_t = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[3, 6], transpose=True).eval()

      sp_c4_t = sp_c2_t
      input_scattered_cols = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t, sp_c4_t]
      input_scattered_cols_non_duplicate = [sp_c0_t, sp_c1_t, sp_c2_t, sp_c3_t]

      # Test updating column factors.
      # Here we feed in scattered columns of the input.
      wals_model.col_update_prep_gramian_op.run()
      wals_model.initialize_col_update_op.run()
      (_, process_input_op, unregularized_loss, regularization,
       _) = wals_model.update_col_factors(
           sp_input=sp_feeder, transpose_input=True)
      factor_loss = unregularized_loss + regularization
      for inp in input_scattered_cols:
        feed_dict = {sp_feeder: inp}
        process_input_op.run(feed_dict=feed_dict)
      col_factors = [x.eval() for x in wals_model.col_factors]

      self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
      self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
      self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)

      # Test column projection.
      # Using the specified projection weights for the 2 column feature vectors.
      # This is expected to reproduce the same column factors in the model as
      # the weights and feature vectors are identical to that used in model
      # training.
      projected_cols = wals_model.project_col_factors(
          sp_input=sp_feeder,
          transpose_input=True,
          projection_weights=[0.4, 0.7])
      # Don't specify the projection weight, so 1.0 will be used. The feature
      # weights will be those specified in model.
      projected_cols_no_weights = wals_model.project_col_factors(
          sp_input=sp_feeder, transpose_input=True)
      feed_dict = {sp_feeder: sp_c3_t}
      self.assertAllClose(
          projected_cols.eval(feed_dict=feed_dict),
          [self._col_factors_1[0], self._col_factors_2[1]],
          atol=1e-3)
      self.assertAllClose(
          projected_cols_no_weights.eval(feed_dict=feed_dict),
          [[3.585139, -0.487476, -3.852232], [0.557937, 1.813907, 1.331171]],
          atol=1e-3)
      if compute_loss:
        # Test loss computation after the col update
        loss = sum(
            sess.run(
                factor_loss * self.count_rows(inp) / num_cols,
                feed_dict={sp_feeder: inp})
            for inp in input_scattered_cols_non_duplicate)
        true_loss = self.calculate_loss_from_wals_model(wals_model,
                                                        self._wals_inputs)
        self.assertNear(
            loss,
            true_loss,
            err=.001,
            msg="After col update, computed loss [{}] does not match"
            " true loss [{}]".format(loss, true_loss))
Exemple #8
0
  def _run_test_process_input(self,
                              use_factors_weights_cache,
                              compute_loss=False):
    with ops.Graph().as_default(), self.cached_session() as sess:
      self._wals_inputs = self.sparse_input()
      sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
      num_rows = 5
      num_cols = 7
      factor_dim = 3
      wals_model = factorization_ops.WALSModel(
          num_rows,
          num_cols,
          factor_dim,
          num_row_shards=2,
          num_col_shards=3,
          regularization=0.01,
          unobserved_weight=0.1,
          col_init=self.col_init,
          row_weights=self.row_wts,
          col_weights=self.col_wts,
          use_factors_weights_cache=use_factors_weights_cache)

      wals_model.initialize_op.run()
      wals_model.worker_init.run()

      # Split input into multiple sparse tensors with scattered rows. Note that
      # this split can be different than the factor sharding and the inputs can
      # consist of non-consecutive rows. Each row needs to include all non-zero
      # elements in that row.
      sp_r0 = np_matrix_to_tf_sparse(INPUT_MATRIX, [0, 2]).eval()
      sp_r1 = np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=True).eval()
      sp_r2 = np_matrix_to_tf_sparse(INPUT_MATRIX, [3], shuffle=True).eval()
      input_scattered_rows = [sp_r0, sp_r1, sp_r2]

      # Test updating row factors.
      # Here we feed in scattered rows of the input.
      wals_model.row_update_prep_gramian_op.run()
      wals_model.initialize_row_update_op.run()
      (_, process_input_op, unregularized_loss, regularization,
       _) = wals_model.update_row_factors(
           sp_input=sp_feeder, transpose_input=False)
      factor_loss = unregularized_loss + regularization
      for inp in input_scattered_rows:
        feed_dict = {sp_feeder: inp}
        process_input_op.run(feed_dict=feed_dict)
      row_factors = [x.eval() for x in wals_model.row_factors]

      self.assertAllClose(row_factors[0], self._row_factors_0, atol=1e-3)
      self.assertAllClose(row_factors[1], self._row_factors_1, atol=1e-3)

      # Test row projection.
      # Using the specified projection weights for the 2 row feature vectors.
      # This is expected to reproduce the same row factors in the model as the
      # weights and feature vectors are identical to that used in model
      # training.
      projected_rows = wals_model.project_row_factors(
          sp_input=sp_feeder,
          transpose_input=False,
          projection_weights=[0.2, 0.5])
      # Don't specify the projection weight, so 1.0 will be used. The feature
      # weights will be those specified in model.
      projected_rows_no_weights = wals_model.project_row_factors(
          sp_input=sp_feeder, transpose_input=False)
      feed_dict = {
          sp_feeder:
              np_matrix_to_tf_sparse(INPUT_MATRIX, [1, 4], shuffle=False)
              .eval()
      }
      self.assertAllClose(
          projected_rows.eval(feed_dict=feed_dict),
          [self._row_factors_0[1], self._row_factors_1[1]],
          atol=1e-3)
      self.assertAllClose(
          projected_rows_no_weights.eval(feed_dict=feed_dict),
          [[0.569082, 0.715088, 0.31777], [1.915879, 1.992677, 1.109057]],
          atol=1e-3)

      if compute_loss:
        # Test loss computation after the row update
        loss = sum(
            sess.run(
                factor_loss * self.count_rows(inp) / num_rows,
                feed_dict={sp_feeder: inp}) for inp in input_scattered_rows)
        true_loss = self.calculate_loss_from_wals_model(wals_model,
                                                        self._wals_inputs)
        self.assertNear(
            loss,
            true_loss,
            err=.001,
            msg="After row update, computed loss [{}] does not match"
            " true loss [{}]".format(loss, true_loss))

      # Split input into multiple sparse tensors with scattered columns. Note
      # that here the elements in the sparse tensors are not ordered and also
      # do not need to consist of consecutive columns. However, each column
      # needs to include all non-zero elements in that column.
      sp_c0 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[2, 0]).eval()
      sp_c1 = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[5, 3, 1], shuffle=True).eval()
      sp_c2 = np_matrix_to_tf_sparse(INPUT_MATRIX, col_slices=[4, 6]).eval()
      sp_c3 = np_matrix_to_tf_sparse(
          INPUT_MATRIX, col_slices=[3, 6], shuffle=True).eval()

      input_scattered_cols = [sp_c0, sp_c1, sp_c2, sp_c3]
      input_scattered_cols_non_duplicate = [sp_c0, sp_c1, sp_c2]

      # Test updating column factors.
      # Here we feed in scattered columns of the input.
      wals_model.col_update_prep_gramian_op.run()
      wals_model.initialize_col_update_op.run()
      (_, process_input_op, unregularized_loss, regularization,
       _) = wals_model.update_col_factors(
           sp_input=sp_feeder, transpose_input=False)
      factor_loss = unregularized_loss + regularization
      for inp in input_scattered_cols:
        feed_dict = {sp_feeder: inp}
        process_input_op.run(feed_dict=feed_dict)
      col_factors = [x.eval() for x in wals_model.col_factors]

      self.assertAllClose(col_factors[0], self._col_factors_0, atol=1e-3)
      self.assertAllClose(col_factors[1], self._col_factors_1, atol=1e-3)
      self.assertAllClose(col_factors[2], self._col_factors_2, atol=1e-3)

      # Test column projection.
      # Using the specified projection weights for the 3 column feature vectors.
      # This is expected to reproduce the same column factors in the model as
      # the weights and feature vectors are identical to that used in model
      # training.
      projected_cols = wals_model.project_col_factors(
          sp_input=sp_feeder,
          transpose_input=False,
          projection_weights=[0.6, 0.4, 0.2])
      # Don't specify the projection weight, so 1.0 will be used. The feature
      # weights will be those specified in model.
      projected_cols_no_weights = wals_model.project_col_factors(
          sp_input=sp_feeder, transpose_input=False)
      feed_dict = {
          sp_feeder:
              np_matrix_to_tf_sparse(
                  INPUT_MATRIX, col_slices=[5, 3, 1], shuffle=False).eval()
      }
      self.assertAllClose(
          projected_cols.eval(feed_dict=feed_dict), [
              self._col_factors_2[0], self._col_factors_1[0],
              self._col_factors_0[1]
          ],
          atol=1e-3)
      self.assertAllClose(
          projected_cols_no_weights.eval(feed_dict=feed_dict),
          [[3.471045, -1.250835, -3.598917], [3.585139, -0.487476, -3.852232],
           [0.346433, 1.360644, 1.677121]],
          atol=1e-3)

      if compute_loss:
        # Test loss computation after the column update.
        loss = sum(
            sess.run(
                factor_loss * self.count_cols(inp) / num_cols,
                feed_dict={sp_feeder: inp})
            for inp in input_scattered_cols_non_duplicate)
        true_loss = self.calculate_loss_from_wals_model(wals_model,
                                                        self._wals_inputs)
        self.assertNear(
            loss,
            true_loss,
            err=.001,
            msg="After col update, computed loss [{}] does not match"
            " true loss [{}]".format(loss, true_loss))