for subfield in field:
      set_default_reference(
          child_config=child_config,
          field=subfield,
          parent_config=parent_config,
          parent_field=parent_field)
  else:
    if parent_field is None:
      parent_field = field
    child_config[field] = make_reference(parent_config, parent_field)


# Functions returning placeholders are marked with _ph suffix are a device
# to increase code reability in this file. Their intent is to reduce large
# amount of repetition and getting the type closer to the colon.
float_ph = lambda: config_dict.placeholder(float)
int_ph = lambda: config_dict.placeholder(int)
str_ph = lambda: config_dict.placeholder(str)
bool_ph = lambda: config_dict.placeholder(bool)


def get_dense_config(
    parent_config):
  """Creates a ConfigDict corresponding to aqt.flax_layers.DenseAqt.HParams."""
  config = ml_collections.ConfigDict()
  set_default_reference(
      config, parent_config,
      ["weight_prec", "weight_quant_granularity", "quant_type", "quant_act"])
  config.lock()
  return config
示例#2
0
class FieldReferenceTest(parameterized.TestCase):

  def _test_binary_operator(self,
                            initial_value,
                            other_value,
                            op,
                            true_value,
                            new_initial_value,
                            new_true_value,
                            assert_fn=None):
    """Helper for testing binary operators.

    Generally speaking this checks that:
      1. `op(initial_value, other_value) COMP true_value`
      2. `op(new_initial_value, other_value) COMP new_true_value
    where `COMP` is the comparison function defined by `assert_fn`.

    Args:
      initial_value: Initial value for the `FieldReference`, this is the first
        argument for the binary operator.
      other_value: The second argument for the binary operator.
      op: The binary operator.
      true_value: The expected output of the binary operator.
      new_initial_value: The value that the `FieldReference` is changed to.
      new_true_value: The expected output of the binary operator after the
        `FieldReference` has changed.
      assert_fn: Function used to check the output values.
    """
    if assert_fn is None:
      assert_fn = self.assertEqual

    ref = config_dict.FieldReference(initial_value)
    new_ref = op(ref, other_value)
    assert_fn(new_ref.get(), true_value)

    config = config_dict.ConfigDict()
    config.a = initial_value
    config.b = other_value
    config.result = op(config.get_ref('a'), config.b)
    assert_fn(config.result, true_value)

    config.a = new_initial_value
    assert_fn(config.result, new_true_value)

  def _test_unary_operator(self,
                           initial_value,
                           op,
                           true_value,
                           new_initial_value,
                           new_true_value,
                           assert_fn=None):
    """Helper for testing unary operators.

    Generally speaking this checks that:
      1. `op(initial_value) COMP true_value`
      2. `op(new_initial_value) COMP new_true_value
    where `COMP` is the comparison function defined by `assert_fn`.

    Args:
      initial_value: Initial value for the `FieldReference`, this is the first
        argument for the unary operator.
      op: The unary operator.
      true_value: The expected output of the unary operator.
      new_initial_value: The value that the `FieldReference` is changed to.
      new_true_value: The expected output of the unary operator after the
        `FieldReference` has changed.
      assert_fn: Function used to check the output values.
    """
    if assert_fn is None:
      assert_fn = self.assertEqual

    ref = config_dict.FieldReference(initial_value)
    new_ref = op(ref)
    assert_fn(new_ref.get(), true_value)

    config = config_dict.ConfigDict()
    config.a = initial_value
    config.result = op(config.get_ref('a'))
    assert_fn(config.result, true_value)

    config.a = new_initial_value
    assert_fn(config.result, new_true_value)

  def testBasic(self):
    ref = config_dict.FieldReference(1)
    self.assertEqual(ref.get(), 1)

  def testGetRef(self):
    config = config_dict.ConfigDict()
    config.a = 1.
    config.b = config.get_ref('a') + 10
    config.c = config.get_ref('b') + 10
    self.assertEqual(config.c, 21.0)

  def testFunction(self):

    def fn(x):
      return x + 5

    config = config_dict.ConfigDict()
    config.a = 1
    config.b = fn(config.get_ref('a'))
    config.c = fn(config.get_ref('b'))

    self.assertEqual(config.b, 6)
    self.assertEqual(config.c, 11)
    config.a = 2
    self.assertEqual(config.b, 7)
    self.assertEqual(config.c, 12)

  def testCycles(self):
    config = config_dict.ConfigDict()
    config.a = 1.
    config.b = config.get_ref('a') + 10
    config.c = config.get_ref('b') + 10

    self.assertEqual(config.b, 11.0)
    self.assertEqual(config.c, 21.0)

    # Introduce a cycle
    with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
      config.a = config.get_ref('c') - 1.0

    # Introduce a cycle on second operand
    with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'):
      config.a = config_dict.FieldReference(5.0) + config.get_ref('c')

    # We can create multiple FieldReferences that all point to the same object
    l = [0]
    config = config_dict.ConfigDict()
    config.a = l
    config.b = l
    config.c = config.get_ref('a') + ['c']
    config.d = config.get_ref('b') + ['d']

    self.assertEqual(config.c, [0, 'c'])
    self.assertEqual(config.d, [0, 'd'])

    # Make sure nothing was mutated
    self.assertEqual(l, [0])
    self.assertEqual(config.c, [0, 'c'])

    config.a = [1]
    config.b = [2]
    self.assertEqual(l, [0])
    self.assertEqual(config.c, [1, 'c'])
    self.assertEqual(config.d, [2, 'd'])

  @parameterized.parameters(
      {
          'initial_value': 1,
          'other_value': 2,
          'true_value': 3,
          'new_initial_value': 10,
          'new_true_value': 12
      }, {
          'initial_value': 2.0,
          'other_value': 2.5,
          'true_value': 4.5,
          'new_initial_value': 3.7,
          'new_true_value': 6.2
      }, {
          'initial_value': 'hello, ',
          'other_value': 'world!',
          'true_value': 'hello, world!',
          'new_initial_value': 'foo, ',
          'new_true_value': 'foo, world!'
      }, {
          'initial_value': ['hello'],
          'other_value': ['world'],
          'true_value': ['hello', 'world'],
          'new_initial_value': ['foo'],
          'new_true_value': ['foo', 'world']
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5.0),
          'true_value': 15.0,
          'new_initial_value': 12,
          'new_true_value': 17.0
      }, {
          'initial_value': config_dict.placeholder(float),
          'other_value': 7.0,
          'true_value': None,
          'new_initial_value': 12,
          'new_true_value': 19.0
      }, {
          'initial_value': 5.0,
          'other_value': config_dict.placeholder(float),
          'true_value': None,
          'new_initial_value': 8.0,
          'new_true_value': None
      }, {
          'initial_value': config_dict.placeholder(str),
          'other_value': 'tail',
          'true_value': None,
          'new_initial_value': 'head',
          'new_true_value': 'headtail'
      })
  def testAdd(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.add,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 5,
          'other_value': 3,
          'true_value': 2,
          'new_initial_value': -1,
          'new_true_value': -4
      }, {
          'initial_value': 2.0,
          'other_value': 2.5,
          'true_value': -0.5,
          'new_initial_value': 12.3,
          'new_true_value': 9.8
      }, {
          'initial_value': set(['hello', 123, 4.5]),
          'other_value': set([123]),
          'true_value': set(['hello', 4.5]),
          'new_initial_value': set([123]),
          'new_true_value': set([])
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5.0),
          'true_value': 5.0,
          'new_initial_value': 12,
          'new_true_value': 7.0
      }, {
          'initial_value': config_dict.placeholder(float),
          'other_value': 7.0,
          'true_value': None,
          'new_initial_value': 12,
          'new_true_value': 5.0
      })
  def testSub(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.sub,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 1,
          'other_value': 2,
          'true_value': 2,
          'new_initial_value': 3,
          'new_true_value': 6
      }, {
          'initial_value': 2.0,
          'other_value': 2.5,
          'true_value': 5.0,
          'new_initial_value': 3.5,
          'new_true_value': 8.75
      }, {
          'initial_value': ['hello'],
          'other_value': 3,
          'true_value': ['hello', 'hello', 'hello'],
          'new_initial_value': ['foo'],
          'new_true_value': ['foo', 'foo', 'foo']
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5.0),
          'true_value': 50.0,
          'new_initial_value': 1,
          'new_true_value': 5.0
      }, {
          'initial_value': config_dict.placeholder(float),
          'other_value': 7.0,
          'true_value': None,
          'new_initial_value': 12,
          'new_true_value': 84.0
      })
  def testMul(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.mul,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'other_value': 2,
          'true_value': 1.5,
          'new_initial_value': 10,
          'new_true_value': 5.0
      }, {
          'initial_value': 2.0,
          'other_value': 2.5,
          'true_value': 0.8,
          'new_initial_value': 6.3,
          'new_true_value': 2.52
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5.0),
          'true_value': 2.0,
          'new_initial_value': 13,
          'new_true_value': 2.6
      }, {
          'initial_value': config_dict.placeholder(float),
          'other_value': 7.0,
          'true_value': None,
          'new_initial_value': 17.5,
          'new_true_value': 2.5
      })
  def testTrueDiv(self, initial_value, other_value, true_value,
                  new_initial_value, new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.truediv,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'other_value': 2,
          'true_value': 1,
          'new_initial_value': 7,
          'new_true_value': 3
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5),
          'true_value': 2,
          'new_initial_value': 28,
          'new_true_value': 5
      }, {
          'initial_value': config_dict.placeholder(int),
          'other_value': 7,
          'true_value': None,
          'new_initial_value': 25,
          'new_true_value': 3
      })
  def testFloorDiv(self, initial_value, other_value, true_value,
                   new_initial_value, new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.floordiv,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'other_value': 2,
          'true_value': 9,
          'new_initial_value': 10,
          'new_true_value': 100
      }, {
          'initial_value': 2.7,
          'other_value': 3.2,
          'true_value': 24.0084457245,
          'new_initial_value': 6.5,
          'new_true_value': 399.321543621
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5),
          'true_value': 1e5,
          'new_initial_value': 2,
          'new_true_value': 32
      }, {
          'initial_value': config_dict.placeholder(float),
          'other_value': 3.0,
          'true_value': None,
          'new_initial_value': 7.0,
          'new_true_value': 343.0
      })
  def testPow(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(
        initial_value,
        other_value,
        operator.pow,
        true_value,
        new_initial_value,
        new_true_value,
        assert_fn=self.assertAlmostEqual)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'other_value': 2,
          'true_value': 1,
          'new_initial_value': 10,
          'new_true_value': 0
      }, {
          'initial_value': 5.3,
          'other_value': 3.2,
          'true_value': 2.0999999999999996,
          'new_initial_value': 77,
          'new_true_value': 0.2
      }, {
          'initial_value': config_dict.FieldReference(10),
          'other_value': config_dict.FieldReference(5),
          'true_value': 0,
          'new_initial_value': 32,
          'new_true_value': 2
      }, {
          'initial_value': config_dict.placeholder(int),
          'other_value': 7,
          'true_value': None,
          'new_initial_value': 25,
          'new_true_value': 4
      })
  def testMod(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(
        initial_value,
        other_value,
        operator.mod,
        true_value,
        new_initial_value,
        new_true_value,
        assert_fn=self.assertAlmostEqual)

  @parameterized.parameters(
      {
          'initial_value': True,
          'other_value': True,
          'true_value': True,
          'new_initial_value': False,
          'new_true_value': False
      }, {
          'initial_value': config_dict.FieldReference(False),
          'other_value': config_dict.FieldReference(False),
          'true_value': False,
          'new_initial_value': True,
          'new_true_value': False
      }, {
          'initial_value': config_dict.placeholder(bool),
          'other_value': True,
          'true_value': None,
          'new_initial_value': False,
          'new_true_value': False
      })
  def testAnd(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.and_,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': False,
          'other_value': False,
          'true_value': False,
          'new_initial_value': True,
          'new_true_value': True
      }, {
          'initial_value': config_dict.FieldReference(True),
          'other_value': config_dict.FieldReference(True),
          'true_value': True,
          'new_initial_value': False,
          'new_true_value': True
      }, {
          'initial_value': config_dict.placeholder(bool),
          'other_value': False,
          'true_value': None,
          'new_initial_value': True,
          'new_true_value': True
      })
  def testOr(self, initial_value, other_value, true_value, new_initial_value,
             new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.or_,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': False,
          'other_value': True,
          'true_value': True,
          'new_initial_value': True,
          'new_true_value': False
      }, {
          'initial_value': config_dict.FieldReference(True),
          'other_value': config_dict.FieldReference(True),
          'true_value': False,
          'new_initial_value': False,
          'new_true_value': True
      }, {
          'initial_value': config_dict.placeholder(bool),
          'other_value': True,
          'true_value': None,
          'new_initial_value': True,
          'new_true_value': False
      })
  def testXor(self, initial_value, other_value, true_value, new_initial_value,
              new_true_value):
    self._test_binary_operator(initial_value, other_value, operator.xor,
                               true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'true_value': -3,
          'new_initial_value': -22,
          'new_true_value': 22
      }, {
          'initial_value': 15.3,
          'true_value': -15.3,
          'new_initial_value': -0.2,
          'new_true_value': 0.2
      }, {
          'initial_value': config_dict.FieldReference(7),
          'true_value': config_dict.FieldReference(-7),
          'new_initial_value': 123,
          'new_true_value': -123
      }, {
          'initial_value': config_dict.placeholder(int),
          'true_value': None,
          'new_initial_value': -6,
          'new_true_value': 6
      })
  def testNeg(self, initial_value, true_value, new_initial_value,
              new_true_value):
    self._test_unary_operator(initial_value, operator.neg, true_value,
                              new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': config_dict.create(attribute=2),
          'true_value': 2,
          'new_initial_value': config_dict.create(attribute=3),
          'new_true_value': 3,
      },
      {
          'initial_value': config_dict.create(attribute={'a': 1}),
          'true_value': config_dict.create(a=1),
          'new_initial_value': config_dict.create(attribute={'b': 1}),
          'new_true_value': config_dict.create(b=1),
      },
      {
          'initial_value':
              config_dict.FieldReference(config_dict.create(attribute=2)),
          'true_value':
              config_dict.FieldReference(2),
          'new_initial_value':
              config_dict.create(attribute=3),
          'new_true_value':
              3,
      },
      {
          'initial_value': config_dict.placeholder(config_dict.ConfigDict),
          'true_value': None,
          'new_initial_value': config_dict.create(attribute=3),
          'new_true_value': 3,
      },
  )
  def testAttr(self, initial_value, true_value, new_initial_value,
               new_true_value):
    self._test_unary_operator(initial_value, lambda x: x.attr('attribute'),
                              true_value, new_initial_value, new_true_value)

  @parameterized.parameters(
      {
          'initial_value': 3,
          'true_value': 3,
          'new_initial_value': -101,
          'new_true_value': 101
      }, {
          'initial_value': -15.3,
          'true_value': 15.3,
          'new_initial_value': 7.3,
          'new_true_value': 7.3
      }, {
          'initial_value': config_dict.FieldReference(-7),
          'true_value': config_dict.FieldReference(7),
          'new_initial_value': 3,
          'new_true_value': 3
      }, {
          'initial_value': config_dict.placeholder(float),
          'true_value': None,
          'new_initial_value': -6.25,
          'new_true_value': 6.25
      })
  def testAbs(self, initial_value, true_value, new_initial_value,
              new_true_value):
    self._test_unary_operator(initial_value, operator.abs, true_value,
                              new_initial_value, new_true_value)

  def testToInt(self):
    self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27)
    ref = config_dict.FieldReference(64.7)
    ref = ref.to_int()
    self.assertEqual(ref.get(), 64)
    self.assertEqual(ref._field_type, int)

  def testToFloat(self):
    self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0)

    ref = config_dict.FieldReference(647)
    ref = ref.to_float()
    self.assertEqual(ref.get(), 647.0)
    self.assertEqual(ref._field_type, float)

  def testToString(self):
    self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0')

    ref = config_dict.FieldReference(647)
    ref = ref.to_str()
    self.assertEqual(ref.get(), '647')
    self.assertEqual(ref._field_type, str)

  def testSetValue(self):
    ref = config_dict.FieldReference(1.0)
    other = config_dict.FieldReference(3)
    ref_plus_other = ref + other

    self.assertEqual(ref_plus_other.get(), 4.0)

    ref.set(2.5)
    self.assertEqual(ref_plus_other.get(), 5.5)

    other.set(110)
    self.assertEqual(ref_plus_other.get(), 112.5)

    # Type checking
    with self.assertRaises(TypeError):
      other.set('this is a string')

    with self.assertRaises(TypeError):
      other.set(config_dict.FieldReference('this is a string'))

    with self.assertRaises(TypeError):
      other.set(config_dict.FieldReference(None, field_type=str))

  def testSetResult(self):
    ref = config_dict.FieldReference(1.0)
    result = ref + 1.0
    second_result = result + 1.0

    self.assertEqual(ref.get(), 1.0)
    self.assertEqual(result.get(), 2.0)
    self.assertEqual(second_result.get(), 3.0)

    ref.set(2.0)
    self.assertEqual(ref.get(), 2.0)
    self.assertEqual(result.get(), 3.0)
    self.assertEqual(second_result.get(), 4.0)

    result.set(4.0)
    self.assertEqual(ref.get(), 2.0)
    self.assertEqual(result.get(), 4.0)
    self.assertEqual(second_result.get(), 5.0)

    # All references are broken at this point.
    ref.set(1.0)
    self.assertEqual(ref.get(), 1.0)
    self.assertEqual(result.get(), 4.0)
    self.assertEqual(second_result.get(), 5.0)

  def testTypeChecking(self):
    ref = config_dict.FieldReference(1)
    string_ref = config_dict.FieldReference('a')

    x = ref + string_ref
    with self.assertRaises(TypeError):
      x.get()

  def testNoType(self):
    self.assertRaisesRegex(TypeError, 'field_type should be a type.*',
                           config_dict.FieldReference, None, 0)

  def testEqual(self):
    # Simple case
    ref1 = config_dict.FieldReference(1)
    ref2 = config_dict.FieldReference(1)
    ref3 = config_dict.FieldReference(2)
    self.assertEqual(ref1, 1)
    self.assertEqual(ref1, ref1)
    self.assertEqual(ref1, ref2)
    self.assertNotEqual(ref1, 2)
    self.assertNotEqual(ref1, ref3)

    # ConfigDict inside FieldReference
    ref1 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1}))
    ref2 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1}))
    ref3 = config_dict.FieldReference(config_dict.ConfigDict({'a': 2}))
    self.assertEqual(ref1, config_dict.ConfigDict({'a': 1}))
    self.assertEqual(ref1, ref1)
    self.assertEqual(ref1, ref2)
    self.assertNotEqual(ref1, config_dict.ConfigDict({'a': 2}))
    self.assertNotEqual(ref1, ref3)

  def testLessEqual(self):
    # Simple case
    ref1 = config_dict.FieldReference(1)
    ref2 = config_dict.FieldReference(1)
    ref3 = config_dict.FieldReference(2)
    self.assertLessEqual(ref1, 1)
    self.assertLessEqual(ref1, 2)
    self.assertLessEqual(0, ref1)
    self.assertLessEqual(1, ref1)
    self.assertGreater(ref1, 0)

    self.assertLessEqual(ref1, ref1)
    self.assertLessEqual(ref1, ref2)
    self.assertLessEqual(ref1, ref3)
    self.assertGreater(ref3, ref1)

  def testControlFlowError(self):
    ref1 = config_dict.FieldReference(True)
    ref2 = config_dict.FieldReference(False)

    with self.assertRaises(NotImplementedError):
      if ref1:
        pass
    with self.assertRaises(NotImplementedError):
      _ = ref1 and ref2
    with self.assertRaises(NotImplementedError):
      _ = ref1 or ref2
    with self.assertRaises(NotImplementedError):
      _ = not ref1
示例#3
0
def get_config(debug: bool = False) -> config_dict.ConfigDict:
  """Get Jaxline experiment config."""
  config = base_config.get_base_config()
  # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
  config.restore_path = config_dict.placeholder(str)

  training_batch_size = 64
  eval_batch_size = 64

  ## Experiment config.
  loss_config_name = 'RegressionLossConfig'
  loss_kwargs = dict(
      exponent=1.,  # 2 for l2 loss, 1 for l1 loss, etc...
  )

  dataset_config = dict(
      data_root=config_dict.placeholder(str),
      augment_with_random_mirror_symmetry=True,
      k_fold_split_id=config_dict.placeholder(int),
      num_k_fold_splits=config_dict.placeholder(int),
      # Options: "in" or "out".
      # Filter=in would keep the samples with nans in the conformer features.
      # Filter=out would keep the samples with no NaNs anywhere in the conformer
      # features.
      filter_in_or_out_samples_with_nans_in_conformers=(
          config_dict.placeholder(str)),
      cached_conformers_file=config_dict.placeholder(str))

  model_config = dict(
      mlp_hidden_size=512,
      mlp_layers=2,
      latent_size=512,
      use_layer_norm=False,
      num_message_passing_steps=32,
      shared_message_passing_weights=False,
      mask_padding_graph_at_every_step=True,
      loss_config_name=loss_config_name,
      loss_kwargs=loss_kwargs,
      processor_mode='resnet',
      global_reducer='sum',
      node_reducer='sum',
      dropedge_rate=0.1,
      dropnode_rate=0.1,
      aux_multiplier=0.1,
      add_relative_distance=True,
      add_relative_displacement=True,
      add_absolute_positions=False,
      position_normalization=2.,
      relative_displacement_normalization=1.,
      ignore_globals=False,
      ignore_globals_from_final_layer_for_predictions=True,
  )

  if debug:
    # Make network smaller.
    model_config.update(dict(
        mlp_hidden_size=32,
        mlp_layers=1,
        latent_size=32,
        num_message_passing_steps=1))

  config.experiment_kwargs = config_dict.ConfigDict(
      dict(
          config=dict(
              debug=debug,
              predictions_dir=config_dict.placeholder(str),
              ema=True,
              ema_decay=0.9999,
              sample_random=0.05,
              optimizer=dict(
                  name='adam',
                  optimizer_kwargs=dict(b1=.9, b2=.95),
                  lr_schedule=dict(
                      warmup_steps=int(5e4),
                      decay_steps=int(5e5),
                      init_value=1e-5,
                      peak_value=1e-4,
                      end_value=0.,
                  ),
              ),
              model=model_config,
              dataset_config=dataset_config,
              # As a rule of thumb, use the following statistics:
              # Avg. # nodes in graph: 16.
              # Avg. # edges in graph: 40.
              training=dict(
                  dynamic_batch_size={
                      'n_node': 256 if debug else 16 * training_batch_size,
                      'n_edge': 512 if debug else 40 * training_batch_size,
                      'n_graph': 2 if debug else training_batch_size,
                  },),
              evaluation=dict(
                  split='valid',
                  dynamic_batch_size=dict(
                      n_node=256 if debug else 16 * eval_batch_size,
                      n_edge=512 if debug else 40 * eval_batch_size,
                      n_graph=2 if debug else eval_batch_size,
                  )))))

  ## Training loop config.
  config.training_steps = int(5e6)
  config.checkpoint_dir = '/tmp/checkpoint/pcq/'
  config.train_checkpoint_all_hosts = False
  config.save_checkpoint_interval = 300
  config.log_train_data_interval = 60
  config.log_tensors_interval = 60
  config.best_model_eval_metric = 'mae'
  config.best_model_eval_metric_higher_is_better = False

  return config
示例#4
0
from sklearn.metrics import mean_squared_error

from data import load_train_test_splits
from definitions import ARTIFACT_DIR
from model_dispatcher import load_model
from utils import set_seed

# Configure experiment runner
FLAGS = flags.FLAGS
flags.DEFINE_bool('debug', False, "Show debugging information.")
flags.DEFINE_bool('log', False, "Log this experiment to wandb.")

# Configure experiment tracking
config_wandb = ml_collections.ConfigDict()
config_wandb.project = "hparam-src"
config_wandb.job_type = placeholder(str)
config_wandb.notes = placeholder(str)
config_flags.DEFINE_config_dict(
    'wandb',
    config_wandb,
    "Configuration for W&B experiment tracking.",
)


def main(_):
    if FLAGS.log:
        wandb.init(config=FLAGS, **FLAGS.wandb)

    # Pipeline
    ## Setup
    set_seed()
示例#5
0
def get_config():
    cfg = config_dict.ConfigDict()
    cfg.ref = config_dict.FieldReference(123)
    cfg.ref_nodefault = config_dict.placeholder(int)
    return cfg
示例#6
0
def get_config(debug: bool = False) -> config_dict.ConfigDict:
    """Get Jaxline experiment config."""
    config = base_config.get_base_config()
    config.random_seed = 42
    # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
    config.restore_path = config_dict.placeholder(str)
    config.experiment_kwargs = config_dict.ConfigDict(
        dict(config=dict(
            debug=debug,
            predictions_dir=config_dict.placeholder(str),
            # 5 for model selection and early stopping, 50 for final eval.
            num_eval_iterations_to_ensemble=5,
            dataset_kwargs=dict(
                data_root='/data/',
                online_subsampling_kwargs=dict(
                    max_nb_neighbours_per_type=[
                        [[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]],
                        [[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]],
                    ],
                    remove_future_nodes=True,
                    deduplicate_nodes=True,
                ),
                ratio_unlabeled_data_to_labeled_data=10.0,
                k_fold_split_id=config_dict.placeholder(int),
                use_all_labels_when_not_training=False,
                use_dummy_adjacencies=debug,
            ),
            optimizer=dict(
                name='adamw',
                kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999),
                learning_rate_schedule=dict(
                    use_schedule=True,
                    base_learning_rate=1e-2,
                    warmup_steps=50000,
                    total_steps=config.get_ref('training_steps'),
                ),
            ),
            model_config=dict(
                mlp_hidden_sizes=[32] if debug else [512],
                latent_size=32 if debug else 256,
                num_message_passing_steps=2 if debug else 4,
                activation='relu',
                dropout_rate=0.3,
                dropedge_rate=0.25,
                disable_edge_updates=True,
                use_sent_edges=True,
                normalization_type='layer_norm',
                aggregation_function='sum',
            ),
            training=dict(
                loss_config=dict(bgrl_loss_config=dict(
                    stop_gradient_for_supervised_loss=False,
                    bgrl_loss_scale=1.0,
                    symmetrize=True,
                    first_graph_corruption_config=dict(
                        feature_drop_prob=0.4,
                        edge_drop_prob=0.2,
                    ),
                    second_graph_corruption_config=dict(
                        feature_drop_prob=0.4,
                        edge_drop_prob=0.2,
                    ),
                ), ),
                # GPU memory may require reducing the `256`s below to `48`.
                dynamic_batch_size_config=dict(
                    n_node=256 if debug else 340 * 256,
                    n_edge=512 if debug else 720 * 256,
                    n_graph=4 if debug else 256,
                ),
            ),
            eval=dict(
                split='valid',
                ema_annealing_schedule=dict(use_schedule=True,
                                            base_rate=0.999,
                                            total_steps=config.get_ref(
                                                'training_steps')),
                dynamic_batch_size_config=dict(
                    n_node=256 if debug else 340 * 128,
                    n_edge=512 if debug else 720 * 128,
                    n_graph=4 if debug else 128,
                ),
            ))))

    ## Training loop config.
    config.training_steps = 500000
    config.checkpoint_dir = '/tmp/checkpoint/mag/'
    config.train_checkpoint_all_hosts = False
    config.log_train_data_interval = 10
    config.log_tensors_interval = 10
    config.save_checkpoint_interval = 30
    config.best_model_eval_metric = 'accuracy'
    config.best_model_eval_metric_higher_is_better = True

    return config