예제 #1
0
    def test_sweeps(self):
        def ind_feed(row_indices, col_indices):
            return {
                self._input_row_indices_ph: row_indices,
                self._input_col_indices_ph: col_indices
            }

        with self.test_session() as sess:
            is_row_sweep_var = variables.Variable(True)
            completed_sweeps_var = variables.Variable(0)
            sweep_hook = wals_lib._SweepHook(
                is_row_sweep_var, self._train_op, self._num_rows,
                self._num_cols, self._input_row_indices_ph,
                self._input_col_indices_ph, self._row_prep_ops,
                self._col_prep_ops, self._init_ops, completed_sweeps_var)
            mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
            sess.run([variables.global_variables_initializer()])

            # Init ops should run before the first run. Row sweep not completed.
            mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
            self.assertTrue(sess.run(self._init_done),
                            msg='init ops not run by the sweep_hook')
            self.assertTrue(sess.run(self._row_prep_done),
                            msg='row_prep not run by the sweep_hook')
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Row sweep is not complete but is_row_sweep is '
                'False.')
            # Row sweep completed.
            mon_sess.run(self._train_op, ind_feed([3, 4],
                                                  [0, 1, 2, 3, 4, 5, 6]))
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Row sweep is complete but is_row_sweep is True.')
            self.assertTrue(sess.run(completed_sweeps_var) == 1,
                            msg='Completed sweeps should be equal to 1.')
            self.assertTrue(
                sweep_hook._is_sweep_done,
                msg='Sweep is complete but is_sweep_done is False.')
            # Col init ops should run. Col sweep not completed.
            mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
            self.assertTrue(sess.run(self._col_prep_done),
                            msg='col_prep not run by the sweep_hook')
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Col sweep is not complete but is_row_sweep is '
                'True.')
            self.assertFalse(
                sweep_hook._is_sweep_done,
                msg='Sweep is not complete but is_sweep_done is True.')
            # Col sweep completed.
            mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Col sweep is complete but is_row_sweep is False')
            self.assertTrue(
                sweep_hook._is_sweep_done,
                msg='Sweep is complete but is_sweep_done is False.')
            self.assertTrue(sess.run(completed_sweeps_var) == 2,
                            msg='Completed sweeps should be equal to 2.')
예제 #2
0
    def test_col_sweep(self):
        with self.test_session() as sess:
            is_row_sweep_var = variables.Variable(False)
            sweep_hook = wals_lib._SweepHook(
                is_row_sweep_var, self._train_op, self._num_rows,
                self._num_cols, self._input_row_indices_ph,
                self._input_col_indices_ph, self._row_prep_ops,
                self._col_prep_ops, self._init_ops)

            # Initialize variables
            sess.run([variables.global_variables_initializer()])
            # Col sweep
            self.run_hook_with_indices(sweep_hook, [], [])
            self.assertTrue(sess.run(self._col_prep_done),
                            msg='col_prep not run by the sweep_hook')
            self.run_hook_with_indices(sweep_hook, [], [0, 1, 2, 3, 4])
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Col sweep is not complete but is_row_sweep is '
                'True.')
            self.assertFalse(
                sweep_hook._is_sweep_done,
                msg='Sweep is not complete but is_sweep_done is True.')
            self.run_hook_with_indices(sweep_hook, [], [4, 5, 6])
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Col sweep is complete but is_row_sweep is False')
            self.assertTrue(
                sweep_hook._is_sweep_done,
                msg='Sweep is complete but is_sweep_done is False.')
예제 #3
0
  def test_col_sweep(self):
    with self.test_session() as sess:
      is_row_sweep_var = variables.Variable(False)
      sweep_hook = wals_lib._SweepHook(
          is_row_sweep_var,
          self._train_op,
          self._num_rows,
          self._num_cols,
          self._input_row_indices_ph,
          self._input_col_indices_ph,
          self._row_prep_ops,
          self._col_prep_ops,
          self._init_ops)

      # Initialize variables
      sess.run([variables.global_variables_initializer()])
      # Col sweep
      self.run_hook_with_indices(sweep_hook, [], [])
      self.assertTrue(sess.run(self._col_prep_done),
                      msg='col_prep not run by the sweep_hook')
      self.run_hook_with_indices(sweep_hook, [], [0, 1, 2, 3, 4])
      self.assertFalse(sess.run(is_row_sweep_var),
                       msg='Col sweep is not complete but is_row_sweep is '
                       'True.')
      self.assertFalse(sweep_hook._is_sweep_done,
                       msg='Sweep is not complete but is_sweep_done is True.')
      self.run_hook_with_indices(sweep_hook, [], [4, 5, 6])
      self.assertTrue(sess.run(is_row_sweep_var),
                      msg='Col sweep is complete but is_row_sweep is False')
      self.assertTrue(sweep_hook._is_sweep_done,
                      msg='Sweep is complete but is_sweep_done is False.')
예제 #4
0
    def test_sweeps(self):
        is_row_sweep_var = variables.Variable(True)
        is_sweep_done_var = variables.Variable(False)
        init_done = variables.Variable(False)
        row_prep_done = variables.Variable(False)
        col_prep_done = variables.Variable(False)
        row_train_done = variables.Variable(False)
        col_train_done = variables.Variable(False)

        init_op = state_ops.assign(init_done, True)
        row_prep_op = state_ops.assign(row_prep_done, True)
        col_prep_op = state_ops.assign(col_prep_done, True)
        row_train_op = state_ops.assign(row_train_done, True)
        col_train_op = state_ops.assign(col_train_done, True)
        train_op = control_flow_ops.no_op()
        switch_op = control_flow_ops.group(
            state_ops.assign(is_sweep_done_var, False),
            state_ops.assign(is_row_sweep_var,
                             math_ops.logical_not(is_row_sweep_var)))
        mark_sweep_done = state_ops.assign(is_sweep_done_var, True)

        with self.test_session() as sess:
            sweep_hook = wals_lib._SweepHook(is_row_sweep_var,
                                             is_sweep_done_var, init_op,
                                             [row_prep_op], [col_prep_op],
                                             row_train_op, col_train_op,
                                             switch_op)
            mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
            sess.run([variables.global_variables_initializer()])

            # Row sweep.
            mon_sess.run(train_op)
            self.assertTrue(sess.run(init_done),
                            msg='init op not run by the Sweephook')
            self.assertTrue(sess.run(row_prep_done),
                            msg='row_prep_op not run by the SweepHook')
            self.assertTrue(sess.run(row_train_done),
                            msg='row_train_op not run by the SweepHook')
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Row sweep is not complete but is_row_sweep_var is False.')
            # Col sweep.
            mon_sess.run(mark_sweep_done)
            mon_sess.run(train_op)
            self.assertTrue(sess.run(col_prep_done),
                            msg='col_prep_op not run by the SweepHook')
            self.assertTrue(sess.run(col_train_done),
                            msg='col_train_op not run by the SweepHook')
            self.assertFalse(
                sess.run(is_row_sweep_var),
                msg='Col sweep is not complete but is_row_sweep_var is True.')
            # Row sweep.
            mon_sess.run(mark_sweep_done)
            mon_sess.run(train_op)
            self.assertTrue(
                sess.run(is_row_sweep_var),
                msg='Col sweep is complete but is_row_sweep_var is False.')
예제 #5
0
  def test_sweeps(self):
    def ind_feed(row_indices, col_indices):
      return {
          self._input_row_indices_ph: row_indices,
          self._input_col_indices_ph: col_indices
      }

    with self.test_session() as sess:
      is_row_sweep_var = variables.Variable(True)
      completed_sweeps_var = variables.Variable(0)
      sweep_hook = wals_lib._SweepHook(
          is_row_sweep_var,
          [self._train_op],
          self._num_rows,
          self._num_cols,
          self._input_row_indices_ph,
          self._input_col_indices_ph,
          self._row_prep_ops,
          self._col_prep_ops,
          self._init_ops,
          completed_sweeps_var)
      mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
      sess.run([variables.global_variables_initializer()])

      # Init ops should run before the first run. Row sweep not completed.
      mon_sess.run(self._train_op, ind_feed([0, 1, 2], []))
      self.assertTrue(sess.run(self._init_done),
                      msg='init ops not run by the sweep_hook')
      self.assertTrue(sess.run(self._row_prep_done),
                      msg='row_prep not run by the sweep_hook')
      self.assertTrue(sess.run(is_row_sweep_var),
                      msg='Row sweep is not complete but is_row_sweep is '
                      'False.')
      # Row sweep completed.
      mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6]))
      self.assertTrue(sess.run(completed_sweeps_var) == 1,
                      msg='Completed sweeps should be equal to 1.')
      self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
                      msg='Sweep is complete but is_sweep_done is False.')
      # Col init ops should run. Col sweep not completed.
      mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4]))
      self.assertTrue(sess.run(self._col_prep_done),
                      msg='col_prep not run by the sweep_hook')
      self.assertFalse(sess.run(is_row_sweep_var),
                       msg='Col sweep is not complete but is_row_sweep is '
                       'True.')
      self.assertFalse(sess.run(sweep_hook._is_sweep_done_var),
                       msg='Sweep is not complete but is_sweep_done is True.')
      # Col sweep completed.
      mon_sess.run(self._train_op, ind_feed([], [4, 5, 6]))
      self.assertTrue(sess.run(sweep_hook._is_sweep_done_var),
                      msg='Sweep is complete but is_sweep_done is False.')
      self.assertTrue(sess.run(completed_sweeps_var) == 2,
                      msg='Completed sweeps should be equal to 2.')
예제 #6
0
  def test_sweeps(self):
    is_row_sweep_var = variables.Variable(True)
    is_sweep_done_var = variables.Variable(False)
    init_done = variables.Variable(False)
    row_prep_done = variables.Variable(False)
    col_prep_done = variables.Variable(False)
    row_train_done = variables.Variable(False)
    col_train_done = variables.Variable(False)

    init_op = state_ops.assign(init_done, True)
    row_prep_op = state_ops.assign(row_prep_done, True)
    col_prep_op = state_ops.assign(col_prep_done, True)
    row_train_op = state_ops.assign(row_train_done, True)
    col_train_op = state_ops.assign(col_train_done, True)
    train_op = control_flow_ops.no_op()
    switch_op = control_flow_ops.group(
        state_ops.assign(is_sweep_done_var, False),
        state_ops.assign(is_row_sweep_var,
                         math_ops.logical_not(is_row_sweep_var)))
    mark_sweep_done = state_ops.assign(is_sweep_done_var, True)

    with self.test_session() as sess:
      sweep_hook = wals_lib._SweepHook(
          is_row_sweep_var,
          is_sweep_done_var,
          init_op,
          [row_prep_op],
          [col_prep_op],
          row_train_op,
          col_train_op,
          switch_op)
      mon_sess = monitored_session._HookedSession(sess, [sweep_hook])
      sess.run([variables.global_variables_initializer()])

      # Row sweep.
      mon_sess.run(train_op)
      self.assertTrue(sess.run(init_done),
                      msg='init op not run by the Sweephook')
      self.assertTrue(sess.run(row_prep_done),
                      msg='row_prep_op not run by the SweepHook')
      self.assertTrue(sess.run(row_train_done),
                      msg='row_train_op not run by the SweepHook')
      self.assertTrue(
          sess.run(is_row_sweep_var),
          msg='Row sweep is not complete but is_row_sweep_var is False.')
      # Col sweep.
      mon_sess.run(mark_sweep_done)
      mon_sess.run(train_op)
      self.assertTrue(sess.run(col_prep_done),
                      msg='col_prep_op not run by the SweepHook')
      self.assertTrue(sess.run(col_train_done),
                      msg='col_train_op not run by the SweepHook')
      self.assertFalse(
          sess.run(is_row_sweep_var),
          msg='Col sweep is not complete but is_row_sweep_var is True.')
      # Row sweep.
      mon_sess.run(mark_sweep_done)
      mon_sess.run(train_op)
      self.assertTrue(
          sess.run(is_row_sweep_var),
          msg='Col sweep is complete but is_row_sweep_var is False.')