Ejemplo n.º 1
0
def test_add_measurement():
  """Tests that measurements can be added to a MetricCallback
  object properly """

  static_state = {}

  test_obj = m.MetricCallback(static_state)

  assert len(test_obj.measurements) == 0

  INTERVAL = 10
  MEASURED_VALUE = 0
  MEASUREMENT_NAME = 'L2'

  test_obj.add_measurement({
      'name': MEASUREMENT_NAME,
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda state: MEASURED_VALUE
  })

  assert len(test_obj.measurements) == 1
  assert MEASUREMENT_NAME in test_obj.measurements.keys()
  for i in range(3):
    assert test_obj.measurements[MEASUREMENT_NAME]['trigger'](i * INTERVAL)
  assert test_obj.measurements[MEASUREMENT_NAME]['function'](
      static_state) == MEASURED_VALUE
Ejemplo n.º 2
0
def test_setup():
  """Tests whether the MetricCallback() object is setup
  without error"""

  static_state = {}

  test_obj = m.MetricCallback(static_state)
Ejemplo n.º 3
0
def test_force_measure():
  """ Tests that the measurement manager handles forced measurements, i.e.
  measurements provided as a measurement_list argument, properly """

  data_store = {}
  reporter = r.MemoryReporter(data_store).stepped()

  STATIC1 = 3.14
  MEASURE_STEP = 4

  static_state = {'STATIC1': STATIC1}
  test_obj = m.MetricCallback(static_state)

  test_obj.add_measurement({
      'name': "STATIC",
      'trigger': lambda x: False,
      'function': lambda x: x['STATIC1']
  })
  test_obj.add_measurement({
      'name': "NOTMEASURED",
      'trigger': lambda x: False,
      'function': lambda x: 10.
  })

  step_measurements = test_obj.measure(MEASURE_STEP, {}, ['STATIC'])
  if step_measurements is not None:
    reporter.report_all(MEASURE_STEP, step_measurements)

  assert data_store == {'STATIC': [{'step': MEASURE_STEP, 'value': STATIC1}]}
Ejemplo n.º 4
0
def test_measurement_invervals():
  """Tests that measurements occur at the specified intervals"""

  static_state = {}

  data_store = {}
  reporter = r.MemoryReporter(data_store).stepped()
  test_obj = m.MetricCallback(static_state)

  # log the value 0 at these intervals:
  test_intervals = [10, 13, 17, 20]
  MEASURED_VALUE = 0

  def build_trigger_function(n):

    def trigger_function(step):
      return step % n == 0

    return trigger_function

  for interval in test_intervals:
    test_obj.add_measurement({
        'name': f'Int_{interval}',
        'trigger': build_trigger_function(interval),
        'function': lambda state: MEASURED_VALUE
    })

  TEST_STEPS = 100
  for step in range(TEST_STEPS):
    step_measurements = test_obj.measure(step, {})
    if step_measurements is not None:
      reporter.report_all(step, step_measurements)

  # check data_store
  assert data_store == {
      f'Int_{k}': [{
          'step': s,
          'value': MEASURED_VALUE
      } for s in range(0, TEST_STEPS, k)] for k in test_intervals
  }
Ejemplo n.º 5
0
def test_state_usage():
  """ Tests that the measurements are made on the proper state,
  i.e. the static and dynamic states are used appropriately """
  # parameters of the test
  VALUE1 = 2.
  VALUE2 = 4.
  INTERVAL = 3
  TEST_STEPS = 5

  static_state = {'Value1': VALUE1, 'Value2': VALUE2}

  data_store = {}
  reporter = r.MemoryReporter(data_store).stepped()
  test_obj = m.MetricCallback(static_state)

  test_obj.add_measurement({
      'name': 'Static1',
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda x: x['Value1']
  })
  test_obj.add_measurement({
      'name': 'Static2',
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda x: x['Value2']
  })
  test_obj.add_measurement({
      'name': 'StaticSum',
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda x: x['Value2'] + x['Value1']
  })
  test_obj.add_measurement({
      'name': 'DynamicSum',
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda x: x['Value1'] + x['Value3']
  })
  test_obj.add_measurement({
      'name': 'Dynamic3',
      'trigger': lambda step: step % INTERVAL == 0,
      'function': lambda x: x['Value3']
  })

  for step in range(TEST_STEPS):
    dynamic_state = {'Value3': step * step}
    step_measurements = test_obj.measure(step, dynamic_state)
    if step_measurements is not None:
      reporter.report_all(step, step_measurements)

  # build the desired result to check against
  steps_measured = range(0, TEST_STEPS, INTERVAL)

  assert data_store == {
      'Static1': [{
          'step': step,
          'value': VALUE1
      } for step in steps_measured],
      'Static2': [{
          'step': step,
          'value': VALUE2
      } for step in steps_measured],
      'StaticSum': [{
          'step': step,
          'value': VALUE1 + VALUE2
      } for step in steps_measured],
      'DynamicSum': [{
          'step': step,
          'value': VALUE1 + step * step
      } for step in steps_measured],
      'Dynamic3': [{
          'step': step,
          'value': step * step
      } for step in steps_measured]
  }
Ejemplo n.º 6
0
def main(_):
    """Builds and trains a sentiment classification RNN."""

    # prevent tf from accessing GPU
    tf.config.experimental.set_visible_devices([], "GPU")

    # Get and save config
    config = argparser.parse_args('main')
    logging.info(json.dumps(config, indent=2))

    with uv.start_run(
            experiment_name=config['save']['mlflow_expname'],
            run_name=config['save']['mlflow_runname']), uv.active_reporter(
                MLFlowReporter()):

        reporters.save_config(config)

        uv.report_params(reporters.flatten(config))

        prng_key = random.PRNGKey(config['run']['seed'])

        # Load data.
        vocab_size, train_dset, test_dset = data.get_dataset(config['data'])

        # Build network.
        cell = model_utils.get_cell(config['model']['cell_type'],
                                    num_units=config['model']['num_units'])

        init_fun, apply_fun, _, _ = network.build_rnn(
            vocab_size, config['model']['emb_size'], cell,
            config['model']['num_outputs'])

        loss_fun, acc_fun = optim_utils.loss_and_accuracy(
            apply_fun, config['model'], config['optim'])

        _, initial_params = init_fun(
            prng_key,
            (config['data']['batch_size'], config['data']['max_pad']))

        initial_params = model_utils.initialize(initial_params,
                                                config['model'])

        # get optimizer
        opt, get_params, opt_state, step_fun = optim_utils.optimization_suite(
            initial_params, loss_fun, config['optim'])

        ## Scope setup
        # Reporter setup
        data_store = {}
        reporter = reporters.build_reporters(config['save'], data_store)
        # Static state for scope
        static_state = {
            'acc_fun': acc_fun,
            'loss_fun': loss_fun,
            'param_extractor': get_params,
            'test_set': test_dset
        }

        oscilloscope = m.MetricCallback(static_state)

        def interval_trigger(interval):
            def function_to_return(x):
                return x % interval == 0

            return function_to_return

        oscilloscope.add_measurement({
            'name':
            'test_acc',
            'trigger':
            interval_trigger(config['save']['measure_test']),
            'function':
            measurements.measure_test_acc
        })
        oscilloscope.add_measurement({
            'name':
            'shuffled_test_acc',
            'trigger':
            interval_trigger(config['save']['measure_test']),
            'function':
            measurements.measure_shuffled_acc
        })
        oscilloscope.add_measurement({
            'name':
            'train_acc',
            'trigger':
            interval_trigger(config['save']['measure_train']),
            'function':
            measurements.measure_batch_acc
        })
        oscilloscope.add_measurement({
            'name':
            'train_loss',
            'trigger':
            interval_trigger(config['save']['measure_train']),
            'function':
            measurements.measure_batch_loss
        })
        oscilloscope.add_measurement({
            'name':
            'l2_norm',
            'trigger':
            interval_trigger(config['save']['measure_test']),
            'function':
            measurements.measure_l2_norm
        })
        # Train
        global_step = 0
        loss = np.nan
        for epoch in range(config['optim']['num_epochs']):

            for batch_num, batch in enumerate(tfds.as_numpy(train_dset)):
                dynamic_state = {
                    'opt_state': opt_state,
                    'batch_train_loss': loss,
                    'batch': batch
                }

                step_measurements = oscilloscope.measure(
                    int(global_step), dynamic_state)
                if step_measurements is not None:
                    reporter.report_all(int(global_step), step_measurements)

                global_step, opt_state, loss = step_fun(
                    global_step, opt_state, batch)

                if global_step % config['save']['checkpoint_interval'] == 0:
                    params = get_params(opt_state)
                    np_params = np.asarray(params, dtype=object)
                    reporters.save_dict(config, np_params,
                                        f'checkpoint_{global_step}')

        final_measurements = oscilloscope.measure(
            int(global_step),
            dynamic_state,
            measurement_list=['test_acc', 'shuffled_test_acc'])
        reporter.report_all(int(global_step), final_measurements)

        final_params = {
            'params': np.asarray(get_params(opt_state), dtype=object)
        }
        reporters.save_dict(config, final_params, 'final_params')