def testBasicEquality(self):
     """Tests basic equality with different types of initialization."""
     fcd = _test_frozenconfigdict()
     fcd_cd = config_dict.FrozenConfigDict(_test_configdict())
     fcd_fcd = config_dict.FrozenConfigDict(fcd)
     self.assertEqual(fcd, fcd_cd)
     self.assertEqual(fcd, fcd_fcd)
    def testInitInvalidAttributeName(self):
        """Ensure initialization fails on attributes with invalid names."""
        dot_name = {'dot.name': None}
        immutable_name = {'__hash__': None}

        with self.assertRaises(ValueError):
            config_dict.FrozenConfigDict(dot_name)

        with self.assertRaises(AttributeError):
            config_dict.FrozenConfigDict(immutable_name)
    def testHash(self):
        """Ensures __hash__() respects hidden mutability."""
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])

        self.assertEqual(
            hash(_test_frozenconfigdict()),
            hash(config_dict.FrozenConfigDict(_test_dict_deepcopy())))
        self.assertNotEqual(hash(_test_frozenconfigdict()),
                            hash(config_dict.FrozenConfigDict(list_to_tuple)))

        # Ensure Python realizes FrozenConfigDict is hashable
        self.assertIsInstance(_test_frozenconfigdict(),
                              collections_abc.Hashable)
    def testToDict(self):
        """Ensure to_dict() does not care about hidden mutability."""
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])

        self.assertEqual(_test_frozenconfigdict().to_dict(),
                         config_dict.FrozenConfigDict(list_to_tuple).to_dict())
 def testFieldReferenceResolved(self):
     """Tests that FieldReferences are resolved."""
     cfg = config_dict.ConfigDict({'fr': config_dict.FieldReference(1)})
     frozen_cfg = config_dict.FrozenConfigDict(cfg)
     self.assertNotIsInstance(frozen_cfg._fields['fr'],
                              config_dict.FieldReference)
     hash(
         frozen_cfg)  # with FieldReference resolved, frozen_cfg is hashable
    def testPickle(self):
        """Make sure FrozenConfigDict can be dumped and loaded with pickle."""
        fcd = _test_frozenconfigdict()
        locked_fcd = config_dict.FrozenConfigDict(_test_configdict().lock())

        unpickled_fcd = pickle.loads(pickle.dumps(fcd))
        unpickled_locked_fcd = pickle.loads(pickle.dumps(locked_fcd))

        self.assertEqual(fcd, unpickled_fcd)
        self.assertEqual(locked_fcd, unpickled_locked_fcd)
    def testFieldReferenceCycle(self):
        """Tests that FieldReferences may not contain reference cycles."""
        frozenset_fr = {'frozenset': frozenset({1, 2})}
        frozenset_fr['fr'] = config_dict.FieldReference(
            frozenset_fr['frozenset'])
        list_fr = {'list': [1, 2]}
        list_fr['fr'] = config_dict.FieldReference(list_fr['list'])

        cyclic_fr = {'a': 1}
        cyclic_fr['fr'] = config_dict.FieldReference(cyclic_fr)
        cyclic_fr_parent = {'dict': {}}
        cyclic_fr_parent['dict']['fr'] = config_dict.FieldReference(
            cyclic_fr_parent)

        # FieldReference is allowed to point to non-cyclic objects:
        _ = config_dict.FrozenConfigDict(frozenset_fr)
        _ = config_dict.FrozenConfigDict(list_fr)
        # But not cycles:
        self.assertFrozenRaisesValueError([cyclic_fr, cyclic_fr_parent])
    def testInitConfigDict(self):
        """Tests that ConfigDict initialization handles FrozenConfigDict.

    Initializing a ConfigDict on a dictionary with FrozenConfigDict values
    should unfreeze these values.
    """
        dict_without_fcd_node = _test_dict_deepcopy()
        dict_without_fcd_node.pop('ref')
        dict_with_fcd_node = copy.deepcopy(dict_without_fcd_node)
        dict_with_fcd_node['dict'] = config_dict.FrozenConfigDict(
            dict_with_fcd_node['dict'])
        cd_without_fcd_node = config_dict.ConfigDict(dict_without_fcd_node)
        cd_with_fcd_node = config_dict.ConfigDict(dict_with_fcd_node)
        fcd_without_fcd_node = config_dict.FrozenConfigDict(
            dict_without_fcd_node)
        fcd_with_fcd_node = config_dict.FrozenConfigDict(dict_with_fcd_node)

        self.assertEqual(cd_without_fcd_node, cd_with_fcd_node)
        self.assertEqual(fcd_without_fcd_node, fcd_with_fcd_node)
Exemple #9
0
def main(_):
    print_section('Attribute Types.')
    cfg = config_dict.ConfigDict()
    cfg.int = 1
    cfg.list = [1, 2, 3]
    cfg.tuple = (1, 2, 3)
    cfg.set = {1, 2, 3}
    cfg.frozenset = frozenset({1, 2, 3})
    cfg.dict = {
        'nested_int': 4,
        'nested_list': [4, 5, 6],
        'nested_tuple': ([4], 5, 6),
    }

    print('Types of cfg fields:')
    print('list: ', type(cfg.list))  # List
    print('set: ', type(cfg.set))  # Set
    print('nested_list: ', type(cfg.dict.nested_list))  # List
    print('nested_tuple[0]: ', type(cfg.dict.nested_tuple[0]))  # List

    frozen_cfg = config_dict.FrozenConfigDict(cfg)
    print('\nTypes of FrozenConfigDict(cfg) fields:')
    print('list: ', type(frozen_cfg.list))  # Tuple
    print('set: ', type(frozen_cfg.set))  # Frozenset
    print('nested_list: ', type(frozen_cfg.dict.nested_list))  # Tuple
    print('nested_tuple[0]: ', type(frozen_cfg.dict.nested_tuple[0]))  # Tuple

    cfg_from_frozen = config_dict.ConfigDict(frozen_cfg)
    print('\nTypes of ConfigDict(FrozenConfigDict(cfg)) fields:')
    print('list: ', type(cfg_from_frozen.list))  # List
    print('set: ', type(cfg_from_frozen.set))  # Set
    print('nested_list: ', type(cfg_from_frozen.dict.nested_list))  # List
    print('nested_tuple[0]: ',
          type(cfg_from_frozen.dict.nested_tuple[0]))  # List

    print(
        '\nCan use FrozenConfigDict.as_configdict() to convert to ConfigDict:')
    print(cfg_from_frozen == frozen_cfg.as_configdict())  # True

    print_section('Immutability.')
    try:
        frozen_cfg.new_field = 1  # Raises AttributeError because of immutability.
    except AttributeError as e:
        print(e)

    print_section('"==" and eq_as_configdict().')
    # FrozenConfigDict.__eq__() is not type-invariant with respect to ConfigDict
    print(frozen_cfg == cfg)  # False
    # FrozenConfigDict.eq_as_configdict() is type-invariant with respect to
    # ConfigDict
    print(frozen_cfg.eq_as_configdict(cfg))  # True
    # .eq_as_congfigdict() is also a method of ConfigDict
    print(cfg.eq_as_configdict(frozen_cfg))  # True
    def testEqualsAsConfigDict(self):
        """Tests that eq_as_configdict respects hidden mutability but not type."""
        fcd = _test_frozenconfigdict()

        # First, ensure eq_as_configdict() returns True with an equal ConfigDict but
        # False for other types.
        self.assertFalse(fcd.eq_as_configdict([1, 2]))
        self.assertTrue(fcd.eq_as_configdict(fcd.as_configdict()))
        empty_fcd = config_dict.FrozenConfigDict()
        self.assertTrue(empty_fcd.eq_as_configdict(config_dict.ConfigDict()))

        # Now, ensure it has the same immutability detection as __eq__().
        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])
        fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple)

        set_to_frozenset = _test_dict_deepcopy()
        set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
        fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset)

        self.assertFalse(fcd.eq_as_configdict(fcd_list_to_tuple))
        # Because set == frozenset in Python:
        self.assertTrue(fcd.eq_as_configdict(fcd_set_to_frozenset))
    def testEquals(self):
        """Tests that __eq__() respects hidden mutability."""
        fcd = _test_frozenconfigdict()

        # First, ensure __eq__() returns False when comparing to other types
        self.assertNotEqual(fcd, (1, 2))
        self.assertNotEqual(fcd, fcd.as_configdict())

        list_to_tuple = _test_dict_deepcopy()
        list_to_tuple['list'] = tuple(list_to_tuple['list'])
        fcd_list_to_tuple = config_dict.FrozenConfigDict(list_to_tuple)

        set_to_frozenset = _test_dict_deepcopy()
        set_to_frozenset['set'] = frozenset(set_to_frozenset['set'])
        fcd_set_to_frozenset = config_dict.FrozenConfigDict(set_to_frozenset)

        self.assertNotEqual(fcd, fcd_list_to_tuple)

        # Because set == frozenset in Python:
        self.assertEqual(fcd, fcd_set_to_frozenset)

        # Items are not affected by hidden mutability
        self.assertCountEqual(fcd.items(), fcd_list_to_tuple.items())
        self.assertCountEqual(fcd.items(), fcd_set_to_frozenset.items())
    def testAsConfigDict(self):
        """Tests that converting FrozenConfigDict to ConfigDict works correctly.

    In particular, ensures that FrozenConfigDict does the inverse of ConfigDict
    regarding type_safe, lock, and attribute mutability.
    """
        # First ensure conversion to ConfigDict works on empty FrozenConfigDict
        self.assertEqual(
            config_dict.ConfigDict(config_dict.FrozenConfigDict()),
            config_dict.ConfigDict())

        cd = _test_configdict()
        cd_fcd_cd = config_dict.ConfigDict(config_dict.FrozenConfigDict(cd))
        self.assertEqual(cd, cd_fcd_cd)

        # Make sure locking is respected
        cd.lock()
        self.assertEqual(
            cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)))

        # Make sure type_safe is respected
        cd = config_dict.ConfigDict(_TEST_DICT, type_safe=False)
        self.assertEqual(
            cd, config_dict.ConfigDict(config_dict.FrozenConfigDict(cd)))
Exemple #13
0
def main(_):
    # Make sure we have a valid config that inherits all the keys defined in the
    # base config.
    validate_config(FLAGS.config, mode="rl")

    config = FLAGS.config
    exp_dir = osp.join(
        config.save_dir,
        FLAGS.experiment_name,
        str(FLAGS.seed),
    )
    utils.setup_experiment(exp_dir, config, FLAGS.resume)

    # Setup compute device.
    if torch.cuda.is_available():
        device = torch.device(FLAGS.device)
    else:
        logging.info("No GPU device found. Falling back to CPU.")
        device = torch.device("cpu")
    logging.info("Using device: %s", device)

    # Set RNG seeds.
    if FLAGS.seed is not None:
        logging.info("RL experiment seed: %d", FLAGS.seed)
        experiment.seed_rngs(FLAGS.seed)
        experiment.set_cudnn(config.cudnn_deterministic,
                             config.cudnn_benchmark)
    else:
        logging.info("No RNG seed has been set for this RL experiment.")

    # Load env.
    env = utils.make_env(
        FLAGS.env_name,
        FLAGS.seed,
        action_repeat=config.action_repeat,
        frame_stack=config.frame_stack,
    )
    eval_env = utils.make_env(
        FLAGS.env_name,
        FLAGS.seed + 42,
        action_repeat=config.action_repeat,
        frame_stack=config.frame_stack,
        save_dir=osp.join(exp_dir, "video", "eval"),
    )

    # Dynamically set observation and action space values.
    config.sac.obs_dim = env.observation_space.shape[0]
    config.sac.action_dim = env.action_space.shape[0]
    config.sac.action_range = [
        float(env.action_space.low.min()),
        float(env.action_space.high.max()),
    ]

    # Resave the config since the dynamic values have been updated at this point
    # and make it immutable for safety :)
    utils.dump_config(exp_dir, config)
    config = config_dict.FrozenConfigDict(config)

    policy = agent.SAC(device, config.sac)

    buffer = utils.make_buffer(env, device, config)

    # Create checkpoint manager.
    checkpoint_dir = osp.join(exp_dir, "checkpoints")
    checkpoint_manager = CheckpointManager(
        checkpoint_dir,
        policy=policy,
        **policy.optim_dict(),
    )

    logger = Logger(osp.join(exp_dir, "tb"), FLAGS.resume)

    try:
        start = checkpoint_manager.restore_or_initialize()
        observation, done = env.reset(), False
        for i in tqdm(range(start, config.num_train_steps), initial=start):
            if i < config.num_seed_steps:
                action = env.action_space.sample()
            else:
                policy.eval()
                action = policy.act(observation, sample=True)
            next_observation, reward, done, info = env.step(action)

            if not done or "TimeLimit.truncated" in info:
                mask = 1.0
            else:
                mask = 0.0

            if not config.reward_wrapper.pretrained_path:
                buffer.insert(observation, action, reward, next_observation,
                              mask)
            else:
                buffer.insert(
                    observation,
                    action,
                    reward,
                    next_observation,
                    mask,
                    env.render(mode="rgb_array"),
                )
            observation = next_observation

            if done:
                observation, done = env.reset(), False
                for k, v in info["episode"].items():
                    logger.log_scalar(v, info["total"]["timesteps"], k,
                                      "training")

            if i >= config.num_seed_steps:
                policy.train()
                train_info = policy.update(buffer, i)

                if (i + 1) % config.log_frequency == 0:
                    for k, v in train_info.items():
                        logger.log_scalar(v, info["total"]["timesteps"], k,
                                          "training")
                    logger.flush()

            if (i + 1) % config.eval_frequency == 0:
                eval_stats = evaluate(policy, eval_env,
                                      config.num_eval_episodes)
                for k, v in eval_stats.items():
                    logger.log_scalar(
                        v,
                        info["total"]["timesteps"],
                        f"average_{k}s",
                        "evaluation",
                    )
                logger.flush()

            if (i + 1) % config.checkpoint_frequency == 0:
                checkpoint_manager.save(i)

    except KeyboardInterrupt:
        print("Caught keyboard interrupt. Saving before quitting.")

    finally:
        checkpoint_manager.save(i)  # pylint: disable=undefined-loop-variable
        logger.close()
 def assertFrozenRaisesValueError(self, input_list):
     """Assert initialization on all elements of input_list raise ValueError."""
     for initial_dictionary in input_list:
         with self.assertRaises(ValueError):
             _ = config_dict.FrozenConfigDict(initial_dictionary)
def _test_frozenconfigdict():
    return config_dict.FrozenConfigDict(_TEST_DICT)
 def testUnhashableType(self):
     """Ensures __hash__() fails if FrozenConfigDict has unhashable value."""
     unhashable_fcd = config_dict.FrozenConfigDict(
         {'unhashable': bytearray()})
     with self.assertRaises(TypeError):
         hash(unhashable_fcd)