def testBasicEquality(self): """Tests basic equality with different types of initialization.""" fcd = _test_frozenconfigdict() fcd_cd = config_dict.FrozenConfigDict(_test_configdict()) fcd_fcd = config_dict.FrozenConfigDict(fcd) self.assertEqual(fcd, fcd_cd) self.assertEqual(fcd, fcd_fcd)
def testInitInvalidAttributeName(self): """Ensure initialization fails on attributes with invalid names.""" dot_name = {'dot.name': None} immutable_name = {'__hash__': None} with self.assertRaises(ValueError): config_dict.FrozenConfigDict(dot_name) with self.assertRaises(AttributeError): config_dict.FrozenConfigDict(immutable_name)
def testHash(self): """Ensures __hash__() respects hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual( hash(_test_frozenconfigdict()), hash(config_dict.FrozenConfigDict(_test_dict_deepcopy()))) self.assertNotEqual(hash(_test_frozenconfigdict()), hash(config_dict.FrozenConfigDict(list_to_tuple))) # Ensure Python realizes FrozenConfigDict is hashable self.assertIsInstance(_test_frozenconfigdict(), collections_abc.Hashable)
def testToDict(self): """Ensure to_dict() does not care about hidden mutability.""" list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) self.assertEqual(_test_frozenconfigdict().to_dict(), config_dict.FrozenConfigDict(list_to_tuple).to_dict())
def testFieldReferenceResolved(self): """Tests that FieldReferences are resolved.""" cfg = config_dict.ConfigDict({'fr': config_dict.FieldReference(1)}) frozen_cfg = config_dict.FrozenConfigDict(cfg) self.assertNotIsInstance(frozen_cfg._fields['fr'], config_dict.FieldReference) hash( frozen_cfg) # with FieldReference resolved, frozen_cfg is hashable
def testPickle(self): """Make sure FrozenConfigDict can be dumped and loaded with pickle.""" fcd = _test_frozenconfigdict() locked_fcd = config_dict.FrozenConfigDict(_test_configdict().lock()) unpickled_fcd = pickle.loads(pickle.dumps(fcd)) unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd)) self.assertEqual(fcd, unpickled_fcd) self.assertEqual(locked_fcd, unpickled_locked_fcd)
def testFieldReferenceCycle(self): """Tests that FieldReferences may not contain reference cycles.""" frozenset_fr = {'frozenset': frozenset({1, 2})} frozenset_fr['fr'] = config_dict.FieldReference( frozenset_fr['frozenset']) list_fr = {'list': [1, 2]} list_fr['fr'] = config_dict.FieldReference(list_fr['list']) cyclic_fr = {'a': 1} cyclic_fr['fr'] = config_dict.FieldReference(cyclic_fr) cyclic_fr_parent = {'dict': {}} cyclic_fr_parent['dict']['fr'] = config_dict.FieldReference( cyclic_fr_parent) # FieldReference is allowed to point to non-cyclic objects: _ = config_dict.FrozenConfigDict(frozenset_fr) _ = config_dict.FrozenConfigDict(list_fr) # But not cycles: self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
def testInitConfigDict(self): """Tests that ConfigDict initialization handles FrozenConfigDict. Initializing a ConfigDict on a dictionary with FrozenConfigDict values should unfreeze these values. """ dict_without_fcd_node = _test_dict_deepcopy() dict_without_fcd_node.pop('ref') dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node) dict_with_fcd_node['dict'] = config_dict.FrozenConfigDict( dict_with_fcd_node['dict']) cd_without_fcd_node = config_dict.ConfigDict(dict_without_fcd_node) cd_with_fcd_node = config_dict.ConfigDict(dict_with_fcd_node) fcd_without_fcd_node = config_dict.FrozenConfigDict( dict_without_fcd_node) fcd_with_fcd_node = config_dict.FrozenConfigDict(dict_with_fcd_node) self.assertEqual(cd_without_fcd_node, cd_with_fcd_node) self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)
def main(_): print_section('Attribute Types.') cfg = config_dict.ConfigDict() cfg.int = 1 cfg.list = [1, 2, 3] cfg.tuple = (1, 2, 3) cfg.set = {1, 2, 3} cfg.frozenset = frozenset({1, 2, 3}) cfg.dict = { 'nested_int': 4, 'nested_list': [4, 5, 6], 'nested_tuple': ([4], 5, 6), } print('Types of cfg fields:') print('list: ', type(cfg.list)) # List print('set: ', type(cfg.set)) # Set print('nested_list: ', type(cfg.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0])) # List frozen_cfg = config_dict.FrozenConfigDict(cfg) print('\nTypes of FrozenConfigDict(cfg) fields:') print('list: ', type(frozen_cfg.list)) # Tuple print('set: ', type(frozen_cfg.set)) # Frozenset print('nested_list: ', type(frozen_cfg.dict.nested_list)) # Tuple print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0])) # Tuple cfg_from_frozen = config_dict.ConfigDict(frozen_cfg) print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:') print('list: ', type(cfg_from_frozen.list)) # List print('set: ', type(cfg_from_frozen.set)) # Set print('nested_list: ', type(cfg_from_frozen.dict.nested_list)) # List print('nested_tuple[0]: ', type(cfg_from_frozen.dict.nested_tuple[0])) # List print( '\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:') print(cfg_from_frozen == frozen_cfg.as_configdict()) # True print_section('Immutability.') try: frozen_cfg.new_field = 1 # Raises AttributeError because of immutability. except AttributeError as e: print(e) print_section('"==" and eq_as_configdict().') # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict print(frozen_cfg == cfg) # False # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to # ConfigDict print(frozen_cfg.eq_as_configdict(cfg)) # True # .eq_as_congfigdict() is also a method of ConfigDict print(cfg.eq_as_configdict(frozen_cfg)) # True
def testEqualsAsConfigDict(self): """Tests that eq_as_configdict respects hidden mutability but not type.""" fcd = _test_frozenconfigdict() # First, ensure eq_as_configdict() returns True with an equal ConfigDict but # False for other types. self.assertFalse(fcd.eq_as_configdict([1, 2])) self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict())) empty_fcd = config_dict.FrozenConfigDict() self.assertTrue(empty_fcd.eq_as_configdict(config_dict.ConfigDict())) # Now, ensure it has the same immutability detection as __eq__(). list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset) self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple)) # Because set == frozenset in Python: self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset))
def testEquals(self): """Tests that __eq__() respects hidden mutability.""" fcd = _test_frozenconfigdict() # First, ensure __eq__() returns False when comparing to other types self.assertNotEqual(fcd, (1, 2)) self.assertNotEqual(fcd, fcd.as_configdict()) list_to_tuple = _test_dict_deepcopy() list_to_tuple['list'] = tuple(list_to_tuple['list']) fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple) set_to_frozenset = _test_dict_deepcopy() set_to_frozenset['set'] = frozenset(set_to_frozenset['set']) fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset) self.assertNotEqual(fcd, fcd_list_to_tuple) # Because set == frozenset in Python: self.assertEqual(fcd, fcd_set_to_frozenset) # Items are not affected by hidden mutability self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items()) self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items())
def testAsConfigDict(self): """Tests that converting FrozenConfigDict to ConfigDict works correctly. In particular, ensures that FrozenConfigDict does the inverse of ConfigDict regarding type_safe, lock, and attribute mutability. """ # First ensure conversion to ConfigDict works on empty FrozenConfigDict self.assertEqual( config_dict.ConfigDict(config_dict.FrozenConfigDict()), config_dict.ConfigDict()) cd = _test_configdict() cd_fcd_cd = config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)) self.assertEqual(cd, cd_fcd_cd) # Make sure locking is respected cd.lock() self.assertEqual( cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd))) # Make sure type_safe is respected cd = config_dict.ConfigDict(_TEST_DICT, type_safe=False) self.assertEqual( cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)))
def main(_): # Make sure we have a valid config that inherits all the keys defined in the # base config. validate_config(FLAGS.config, mode="rl") config = FLAGS.config exp_dir = osp.join( config.save_dir, FLAGS.experiment_name, str(FLAGS.seed), ) utils.setup_experiment(exp_dir, config, FLAGS.resume) # Setup compute device. if torch.cuda.is_available(): device = torch.device(FLAGS.device) else: logging.info("No GPU device found. Falling back to CPU.") device = torch.device("cpu") logging.info("Using device: %s", device) # Set RNG seeds. if FLAGS.seed is not None: logging.info("RL experiment seed: %d", FLAGS.seed) experiment.seed_rngs(FLAGS.seed) experiment.set_cudnn(config.cudnn_deterministic, config.cudnn_benchmark) else: logging.info("No RNG seed has been set for this RL experiment.") # Load env. env = utils.make_env( FLAGS.env_name, FLAGS.seed, action_repeat=config.action_repeat, frame_stack=config.frame_stack, ) eval_env = utils.make_env( FLAGS.env_name, FLAGS.seed + 42, action_repeat=config.action_repeat, frame_stack=config.frame_stack, save_dir=osp.join(exp_dir, "video", "eval"), ) # Dynamically set observation and action space values. config.sac.obs_dim = env.observation_space.shape[0] config.sac.action_dim = env.action_space.shape[0] config.sac.action_range = [ float(env.action_space.low.min()), float(env.action_space.high.max()), ] # Resave the config since the dynamic values have been updated at this point # and make it immutable for safety :) utils.dump_config(exp_dir, config) config = config_dict.FrozenConfigDict(config) policy = agent.SAC(device, config.sac) buffer = utils.make_buffer(env, device, config) # Create checkpoint manager. checkpoint_dir = osp.join(exp_dir, "checkpoints") checkpoint_manager = CheckpointManager( checkpoint_dir, policy=policy, **policy.optim_dict(), ) logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume) try: start = checkpoint_manager.restore_or_initialize() observation, done = env.reset(), False for i in tqdm(range(start, config.num_train_steps), initial=start): if i < config.num_seed_steps: action = env.action_space.sample() else: policy.eval() action = policy.act(observation, sample=True) next_observation, reward, done, info = env.step(action) if not done or "TimeLimit.truncated" in info: mask = 1.0 else: mask = 0.0 if not config.reward_wrapper.pretrained_path: buffer.insert(observation, action, reward, next_observation, mask) else: buffer.insert( observation, action, reward, next_observation, mask, env.render(mode="rgb_array"), ) observation = next_observation if done: observation, done = env.reset(), False for k, v in info["episode"].items(): logger.log_scalar(v, info["total"]["timesteps"], k, "training") if i >= config.num_seed_steps: policy.train() train_info = policy.update(buffer, i) if (i + 1) % config.log_frequency == 0: for k, v in train_info.items(): logger.log_scalar(v, info["total"]["timesteps"], k, "training") logger.flush() if (i + 1) % config.eval_frequency == 0: eval_stats = evaluate(policy, eval_env, config.num_eval_episodes) for k, v in eval_stats.items(): logger.log_scalar( v, info["total"]["timesteps"], f"average_{k}s", "evaluation", ) logger.flush() if (i + 1) % config.checkpoint_frequency == 0: checkpoint_manager.save(i) except KeyboardInterrupt: print("Caught keyboard interrupt. Saving before quitting.") finally: checkpoint_manager.save(i) # pylint: disable=undefined-loop-variable logger.close()
def assertFrozenRaisesValueError(self, input_list): """Assert initialization on all elements of input_list raise ValueError.""" for initial_dictionary in input_list: with self.assertRaises(ValueError): _ = config_dict.FrozenConfigDict(initial_dictionary)
def _test_frozenconfigdict(): return config_dict.FrozenConfigDict(_TEST_DICT)
def testUnhashableType(self): """Ensures __hash__() fails if FrozenConfigDict has unhashable value.""" unhashable_fcd = config_dict.FrozenConfigDict( {'unhashable': bytearray()}) with self.assertRaises(TypeError): hash(unhashable_fcd)