for subfield in field: set_default_reference( child_config=child_config, field=subfield, parent_config=parent_config, parent_field=parent_field) else: if parent_field is None: parent_field = field child_config[field] = make_reference(parent_config, parent_field) # Functions returning placeholders are marked with _ph suffix are a device # to increase code reability in this file. Their intent is to reduce large # amount of repetition and getting the type closer to the colon. float_ph = lambda: config_dict.placeholder(float) int_ph = lambda: config_dict.placeholder(int) str_ph = lambda: config_dict.placeholder(str) bool_ph = lambda: config_dict.placeholder(bool) def get_dense_config( parent_config): """Creates a ConfigDict corresponding to aqt.flax_layers.DenseAqt.HParams.""" config = ml_collections.ConfigDict() set_default_reference( config, parent_config, ["weight_prec", "weight_quant_granularity", "quant_type", "quant_act"]) config.lock() return config
class FieldReferenceTest(parameterized.TestCase): def _test_binary_operator(self, initial_value, other_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing binary operators. Generally speaking this checks that: 1. `op(initial_value, other_value) COMP true_value` 2. `op(new_initial_value, other_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the binary operator. other_value: The second argument for the binary operator. op: The binary operator. true_value: The expected output of the binary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the binary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = config_dict.FieldReference(initial_value) new_ref = op(ref, other_value) assert_fn(new_ref.get(), true_value) config = config_dict.ConfigDict() config.a = initial_value config.b = other_value config.result = op(config.get_ref('a'), config.b) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def _test_unary_operator(self, initial_value, op, true_value, new_initial_value, new_true_value, assert_fn=None): """Helper for testing unary operators. Generally speaking this checks that: 1. `op(initial_value) COMP true_value` 2. `op(new_initial_value) COMP new_true_value where `COMP` is the comparison function defined by `assert_fn`. Args: initial_value: Initial value for the `FieldReference`, this is the first argument for the unary operator. op: The unary operator. true_value: The expected output of the unary operator. new_initial_value: The value that the `FieldReference` is changed to. new_true_value: The expected output of the unary operator after the `FieldReference` has changed. assert_fn: Function used to check the output values. """ if assert_fn is None: assert_fn = self.assertEqual ref = config_dict.FieldReference(initial_value) new_ref = op(ref) assert_fn(new_ref.get(), true_value) config = config_dict.ConfigDict() config.a = initial_value config.result = op(config.get_ref('a')) assert_fn(config.result, true_value) config.a = new_initial_value assert_fn(config.result, new_true_value) def testBasic(self): ref = config_dict.FieldReference(1) self.assertEqual(ref.get(), 1) def testGetRef(self): config = config_dict.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.c, 21.0) def testFunction(self): def fn(x): return x + 5 config = config_dict.ConfigDict() config.a = 1 config.b = fn(config.get_ref('a')) config.c = fn(config.get_ref('b')) self.assertEqual(config.b, 6) self.assertEqual(config.c, 11) config.a = 2 self.assertEqual(config.b, 7) self.assertEqual(config.c, 12) def testCycles(self): config = config_dict.ConfigDict() config.a = 1. config.b = config.get_ref('a') + 10 config.c = config.get_ref('b') + 10 self.assertEqual(config.b, 11.0) self.assertEqual(config.c, 21.0) # Introduce a cycle with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config.get_ref('c') - 1.0 # Introduce a cycle on second operand with self.assertRaisesRegex(config_dict.MutabilityError, 'cycle'): config.a = config_dict.FieldReference(5.0) + config.get_ref('c') # We can create multiple FieldReferences that all point to the same object l = [0] config = config_dict.ConfigDict() config.a = l config.b = l config.c = config.get_ref('a') + ['c'] config.d = config.get_ref('b') + ['d'] self.assertEqual(config.c, [0, 'c']) self.assertEqual(config.d, [0, 'd']) # Make sure nothing was mutated self.assertEqual(l, [0]) self.assertEqual(config.c, [0, 'c']) config.a = [1] config.b = [2] self.assertEqual(l, [0]) self.assertEqual(config.c, [1, 'c']) self.assertEqual(config.d, [2, 'd']) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 3, 'new_initial_value': 10, 'new_true_value': 12 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 4.5, 'new_initial_value': 3.7, 'new_true_value': 6.2 }, { 'initial_value': 'hello, ', 'other_value': 'world!', 'true_value': 'hello, world!', 'new_initial_value': 'foo, ', 'new_true_value': 'foo, world!' }, { 'initial_value': ['hello'], 'other_value': ['world'], 'true_value': ['hello', 'world'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'world'] }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5.0), 'true_value': 15.0, 'new_initial_value': 12, 'new_true_value': 17.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 19.0 }, { 'initial_value': 5.0, 'other_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': 8.0, 'new_true_value': None }, { 'initial_value': config_dict.placeholder(str), 'other_value': 'tail', 'true_value': None, 'new_initial_value': 'head', 'new_true_value': 'headtail' }) def testAdd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.add, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 5, 'other_value': 3, 'true_value': 2, 'new_initial_value': -1, 'new_true_value': -4 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': -0.5, 'new_initial_value': 12.3, 'new_true_value': 9.8 }, { 'initial_value': set(['hello', 123, 4.5]), 'other_value': set([123]), 'true_value': set(['hello', 4.5]), 'new_initial_value': set([123]), 'new_true_value': set([]) }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5.0), 'true_value': 5.0, 'new_initial_value': 12, 'new_true_value': 7.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 5.0 }) def testSub(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.sub, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 1, 'other_value': 2, 'true_value': 2, 'new_initial_value': 3, 'new_true_value': 6 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 5.0, 'new_initial_value': 3.5, 'new_true_value': 8.75 }, { 'initial_value': ['hello'], 'other_value': 3, 'true_value': ['hello', 'hello', 'hello'], 'new_initial_value': ['foo'], 'new_true_value': ['foo', 'foo', 'foo'] }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5.0), 'true_value': 50.0, 'new_initial_value': 1, 'new_true_value': 5.0 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 12, 'new_true_value': 84.0 }) def testMul(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.mul, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1.5, 'new_initial_value': 10, 'new_true_value': 5.0 }, { 'initial_value': 2.0, 'other_value': 2.5, 'true_value': 0.8, 'new_initial_value': 6.3, 'new_true_value': 2.52 }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5.0), 'true_value': 2.0, 'new_initial_value': 13, 'new_true_value': 2.6 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 7.0, 'true_value': None, 'new_initial_value': 17.5, 'new_true_value': 2.5 }) def testTrueDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.truediv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 7, 'new_true_value': 3 }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5), 'true_value': 2, 'new_initial_value': 28, 'new_true_value': 5 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 3 }) def testFloorDiv(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.floordiv, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 9, 'new_initial_value': 10, 'new_true_value': 100 }, { 'initial_value': 2.7, 'other_value': 3.2, 'true_value': 24.0084457245, 'new_initial_value': 6.5, 'new_true_value': 399.321543621 }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5), 'true_value': 1e5, 'new_initial_value': 2, 'new_true_value': 32 }, { 'initial_value': config_dict.placeholder(float), 'other_value': 3.0, 'true_value': None, 'new_initial_value': 7.0, 'new_true_value': 343.0 }) def testPow(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.pow, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': 3, 'other_value': 2, 'true_value': 1, 'new_initial_value': 10, 'new_true_value': 0 }, { 'initial_value': 5.3, 'other_value': 3.2, 'true_value': 2.0999999999999996, 'new_initial_value': 77, 'new_true_value': 0.2 }, { 'initial_value': config_dict.FieldReference(10), 'other_value': config_dict.FieldReference(5), 'true_value': 0, 'new_initial_value': 32, 'new_true_value': 2 }, { 'initial_value': config_dict.placeholder(int), 'other_value': 7, 'true_value': None, 'new_initial_value': 25, 'new_true_value': 4 }) def testMod(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator( initial_value, other_value, operator.mod, true_value, new_initial_value, new_true_value, assert_fn=self.assertAlmostEqual) @parameterized.parameters( { 'initial_value': True, 'other_value': True, 'true_value': True, 'new_initial_value': False, 'new_true_value': False }, { 'initial_value': config_dict.FieldReference(False), 'other_value': config_dict.FieldReference(False), 'true_value': False, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': False, 'new_true_value': False }) def testAnd(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.and_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': False, 'true_value': False, 'new_initial_value': True, 'new_true_value': True }, { 'initial_value': config_dict.FieldReference(True), 'other_value': config_dict.FieldReference(True), 'true_value': True, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': False, 'true_value': None, 'new_initial_value': True, 'new_true_value': True }) def testOr(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.or_, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': False, 'other_value': True, 'true_value': True, 'new_initial_value': True, 'new_true_value': False }, { 'initial_value': config_dict.FieldReference(True), 'other_value': config_dict.FieldReference(True), 'true_value': False, 'new_initial_value': False, 'new_true_value': True }, { 'initial_value': config_dict.placeholder(bool), 'other_value': True, 'true_value': None, 'new_initial_value': True, 'new_true_value': False }) def testXor(self, initial_value, other_value, true_value, new_initial_value, new_true_value): self._test_binary_operator(initial_value, other_value, operator.xor, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': -3, 'new_initial_value': -22, 'new_true_value': 22 }, { 'initial_value': 15.3, 'true_value': -15.3, 'new_initial_value': -0.2, 'new_true_value': 0.2 }, { 'initial_value': config_dict.FieldReference(7), 'true_value': config_dict.FieldReference(-7), 'new_initial_value': 123, 'new_true_value': -123 }, { 'initial_value': config_dict.placeholder(int), 'true_value': None, 'new_initial_value': -6, 'new_true_value': 6 }) def testNeg(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.neg, true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': config_dict.create(attribute=2), 'true_value': 2, 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, { 'initial_value': config_dict.create(attribute={'a': 1}), 'true_value': config_dict.create(a=1), 'new_initial_value': config_dict.create(attribute={'b': 1}), 'new_true_value': config_dict.create(b=1), }, { 'initial_value': config_dict.FieldReference(config_dict.create(attribute=2)), 'true_value': config_dict.FieldReference(2), 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, { 'initial_value': config_dict.placeholder(config_dict.ConfigDict), 'true_value': None, 'new_initial_value': config_dict.create(attribute=3), 'new_true_value': 3, }, ) def testAttr(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, lambda x: x.attr('attribute'), true_value, new_initial_value, new_true_value) @parameterized.parameters( { 'initial_value': 3, 'true_value': 3, 'new_initial_value': -101, 'new_true_value': 101 }, { 'initial_value': -15.3, 'true_value': 15.3, 'new_initial_value': 7.3, 'new_true_value': 7.3 }, { 'initial_value': config_dict.FieldReference(-7), 'true_value': config_dict.FieldReference(7), 'new_initial_value': 3, 'new_true_value': 3 }, { 'initial_value': config_dict.placeholder(float), 'true_value': None, 'new_initial_value': -6.25, 'new_true_value': 6.25 }) def testAbs(self, initial_value, true_value, new_initial_value, new_true_value): self._test_unary_operator(initial_value, operator.abs, true_value, new_initial_value, new_true_value) def testToInt(self): self._test_unary_operator(25.3, lambda ref: ref.to_int(), 25, 27.9, 27) ref = config_dict.FieldReference(64.7) ref = ref.to_int() self.assertEqual(ref.get(), 64) self.assertEqual(ref._field_type, int) def testToFloat(self): self._test_unary_operator(12, lambda ref: ref.to_float(), 12.0, 0, 0.0) ref = config_dict.FieldReference(647) ref = ref.to_float() self.assertEqual(ref.get(), 647.0) self.assertEqual(ref._field_type, float) def testToString(self): self._test_unary_operator(12, lambda ref: ref.to_str(), '12', 0, '0') ref = config_dict.FieldReference(647) ref = ref.to_str() self.assertEqual(ref.get(), '647') self.assertEqual(ref._field_type, str) def testSetValue(self): ref = config_dict.FieldReference(1.0) other = config_dict.FieldReference(3) ref_plus_other = ref + other self.assertEqual(ref_plus_other.get(), 4.0) ref.set(2.5) self.assertEqual(ref_plus_other.get(), 5.5) other.set(110) self.assertEqual(ref_plus_other.get(), 112.5) # Type checking with self.assertRaises(TypeError): other.set('this is a string') with self.assertRaises(TypeError): other.set(config_dict.FieldReference('this is a string')) with self.assertRaises(TypeError): other.set(config_dict.FieldReference(None, field_type=str)) def testSetResult(self): ref = config_dict.FieldReference(1.0) result = ref + 1.0 second_result = result + 1.0 self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 2.0) self.assertEqual(second_result.get(), 3.0) ref.set(2.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 3.0) self.assertEqual(second_result.get(), 4.0) result.set(4.0) self.assertEqual(ref.get(), 2.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) # All references are broken at this point. ref.set(1.0) self.assertEqual(ref.get(), 1.0) self.assertEqual(result.get(), 4.0) self.assertEqual(second_result.get(), 5.0) def testTypeChecking(self): ref = config_dict.FieldReference(1) string_ref = config_dict.FieldReference('a') x = ref + string_ref with self.assertRaises(TypeError): x.get() def testNoType(self): self.assertRaisesRegex(TypeError, 'field_type should be a type.*', config_dict.FieldReference, None, 0) def testEqual(self): # Simple case ref1 = config_dict.FieldReference(1) ref2 = config_dict.FieldReference(1) ref3 = config_dict.FieldReference(2) self.assertEqual(ref1, 1) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, 2) self.assertNotEqual(ref1, ref3) # ConfigDict inside FieldReference ref1 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1})) ref2 = config_dict.FieldReference(config_dict.ConfigDict({'a': 1})) ref3 = config_dict.FieldReference(config_dict.ConfigDict({'a': 2})) self.assertEqual(ref1, config_dict.ConfigDict({'a': 1})) self.assertEqual(ref1, ref1) self.assertEqual(ref1, ref2) self.assertNotEqual(ref1, config_dict.ConfigDict({'a': 2})) self.assertNotEqual(ref1, ref3) def testLessEqual(self): # Simple case ref1 = config_dict.FieldReference(1) ref2 = config_dict.FieldReference(1) ref3 = config_dict.FieldReference(2) self.assertLessEqual(ref1, 1) self.assertLessEqual(ref1, 2) self.assertLessEqual(0, ref1) self.assertLessEqual(1, ref1) self.assertGreater(ref1, 0) self.assertLessEqual(ref1, ref1) self.assertLessEqual(ref1, ref2) self.assertLessEqual(ref1, ref3) self.assertGreater(ref3, ref1) def testControlFlowError(self): ref1 = config_dict.FieldReference(True) ref2 = config_dict.FieldReference(False) with self.assertRaises(NotImplementedError): if ref1: pass with self.assertRaises(NotImplementedError): _ = ref1 and ref2 with self.assertRaises(NotImplementedError): _ = ref1 or ref2 with self.assertRaises(NotImplementedError): _ = not ref1
def get_config(debug: bool = False) -> config_dict.ConfigDict: """Get Jaxline experiment config.""" config = base_config.get_base_config() # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) config.restore_path = config_dict.placeholder(str) training_batch_size = 64 eval_batch_size = 64 ## Experiment config. loss_config_name = 'RegressionLossConfig' loss_kwargs = dict( exponent=1., # 2 for l2 loss, 1 for l1 loss, etc... ) dataset_config = dict( data_root=config_dict.placeholder(str), augment_with_random_mirror_symmetry=True, k_fold_split_id=config_dict.placeholder(int), num_k_fold_splits=config_dict.placeholder(int), # Options: "in" or "out". # Filter=in would keep the samples with nans in the conformer features. # Filter=out would keep the samples with no NaNs anywhere in the conformer # features. filter_in_or_out_samples_with_nans_in_conformers=( config_dict.placeholder(str)), cached_conformers_file=config_dict.placeholder(str)) model_config = dict( mlp_hidden_size=512, mlp_layers=2, latent_size=512, use_layer_norm=False, num_message_passing_steps=32, shared_message_passing_weights=False, mask_padding_graph_at_every_step=True, loss_config_name=loss_config_name, loss_kwargs=loss_kwargs, processor_mode='resnet', global_reducer='sum', node_reducer='sum', dropedge_rate=0.1, dropnode_rate=0.1, aux_multiplier=0.1, add_relative_distance=True, add_relative_displacement=True, add_absolute_positions=False, position_normalization=2., relative_displacement_normalization=1., ignore_globals=False, ignore_globals_from_final_layer_for_predictions=True, ) if debug: # Make network smaller. model_config.update(dict( mlp_hidden_size=32, mlp_layers=1, latent_size=32, num_message_passing_steps=1)) config.experiment_kwargs = config_dict.ConfigDict( dict( config=dict( debug=debug, predictions_dir=config_dict.placeholder(str), ema=True, ema_decay=0.9999, sample_random=0.05, optimizer=dict( name='adam', optimizer_kwargs=dict(b1=.9, b2=.95), lr_schedule=dict( warmup_steps=int(5e4), decay_steps=int(5e5), init_value=1e-5, peak_value=1e-4, end_value=0., ), ), model=model_config, dataset_config=dataset_config, # As a rule of thumb, use the following statistics: # Avg. # nodes in graph: 16. # Avg. # edges in graph: 40. training=dict( dynamic_batch_size={ 'n_node': 256 if debug else 16 * training_batch_size, 'n_edge': 512 if debug else 40 * training_batch_size, 'n_graph': 2 if debug else training_batch_size, },), evaluation=dict( split='valid', dynamic_batch_size=dict( n_node=256 if debug else 16 * eval_batch_size, n_edge=512 if debug else 40 * eval_batch_size, n_graph=2 if debug else eval_batch_size, ))))) ## Training loop config. config.training_steps = int(5e6) config.checkpoint_dir = '/tmp/checkpoint/pcq/' config.train_checkpoint_all_hosts = False config.save_checkpoint_interval = 300 config.log_train_data_interval = 60 config.log_tensors_interval = 60 config.best_model_eval_metric = 'mae' config.best_model_eval_metric_higher_is_better = False return config
from sklearn.metrics import mean_squared_error from data import load_train_test_splits from definitions import ARTIFACT_DIR from model_dispatcher import load_model from utils import set_seed # Configure experiment runner FLAGS = flags.FLAGS flags.DEFINE_bool('debug', False, "Show debugging information.") flags.DEFINE_bool('log', False, "Log this experiment to wandb.") # Configure experiment tracking config_wandb = ml_collections.ConfigDict() config_wandb.project = "hparam-src" config_wandb.job_type = placeholder(str) config_wandb.notes = placeholder(str) config_flags.DEFINE_config_dict( 'wandb', config_wandb, "Configuration for W&B experiment tracking.", ) def main(_): if FLAGS.log: wandb.init(config=FLAGS, **FLAGS.wandb) # Pipeline ## Setup set_seed()
def get_config(): cfg = config_dict.ConfigDict() cfg.ref = config_dict.FieldReference(123) cfg.ref_nodefault = config_dict.placeholder(int) return cfg
def get_config(debug: bool = False) -> config_dict.ConfigDict: """Get Jaxline experiment config.""" config = base_config.get_base_config() config.random_seed = 42 # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) config.restore_path = config_dict.placeholder(str) config.experiment_kwargs = config_dict.ConfigDict( dict(config=dict( debug=debug, predictions_dir=config_dict.placeholder(str), # 5 for model selection and early stopping, 50 for final eval. num_eval_iterations_to_ensemble=5, dataset_kwargs=dict( data_root='/data/', online_subsampling_kwargs=dict( max_nb_neighbours_per_type=[ [[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]], [[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]], ], remove_future_nodes=True, deduplicate_nodes=True, ), ratio_unlabeled_data_to_labeled_data=10.0, k_fold_split_id=config_dict.placeholder(int), use_all_labels_when_not_training=False, use_dummy_adjacencies=debug, ), optimizer=dict( name='adamw', kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999), learning_rate_schedule=dict( use_schedule=True, base_learning_rate=1e-2, warmup_steps=50000, total_steps=config.get_ref('training_steps'), ), ), model_config=dict( mlp_hidden_sizes=[32] if debug else [512], latent_size=32 if debug else 256, num_message_passing_steps=2 if debug else 4, activation='relu', dropout_rate=0.3, dropedge_rate=0.25, disable_edge_updates=True, use_sent_edges=True, normalization_type='layer_norm', aggregation_function='sum', ), training=dict( loss_config=dict(bgrl_loss_config=dict( stop_gradient_for_supervised_loss=False, bgrl_loss_scale=1.0, symmetrize=True, first_graph_corruption_config=dict( feature_drop_prob=0.4, edge_drop_prob=0.2, ), second_graph_corruption_config=dict( feature_drop_prob=0.4, edge_drop_prob=0.2, ), ), ), # GPU memory may require reducing the `256`s below to `48`. dynamic_batch_size_config=dict( n_node=256 if debug else 340 * 256, n_edge=512 if debug else 720 * 256, n_graph=4 if debug else 256, ), ), eval=dict( split='valid', ema_annealing_schedule=dict(use_schedule=True, base_rate=0.999, total_steps=config.get_ref( 'training_steps')), dynamic_batch_size_config=dict( n_node=256 if debug else 340 * 128, n_edge=512 if debug else 720 * 128, n_graph=4 if debug else 128, ), )))) ## Training loop config. config.training_steps = 500000 config.checkpoint_dir = '/tmp/checkpoint/mag/' config.train_checkpoint_all_hosts = False config.log_train_data_interval = 10 config.log_tensors_interval = 10 config.save_checkpoint_interval = 30 config.best_model_eval_metric = 'accuracy' config.best_model_eval_metric_higher_is_better = True return config