def parse_policy(self, hparams): """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list. If list is not nested, then uses the same policy for all epochs. Args: hparams: tf.hparams object. """ # Parse policy if isinstance(hparams.hp_policy, str) and hparams.hp_policy.endswith('.txt'): if hparams.num_epochs % hparams.hp_policy_epochs != 0: tf.logging.warning( "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.", hparams.num_epochs, hparams.hp_policy_epochs) tf.logging.info( 'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}' .format(hparams.hp_policy_epochs, hparams.hp_policy, float(hparams.num_epochs) / hparams.hp_policy_epochs)) raw_policy = parse_log_schedule( hparams.hp_policy, epochs=hparams.hp_policy_epochs, multiplier=float(hparams.num_epochs) / hparams.hp_policy_epochs) elif isinstance(hparams.hp_policy, list): # support list of hp_policy for search stage raw_policy = hparams.hp_policy else: raise ValueError('hp_policy must be txt or None during training!') if isinstance(raw_policy[0], list): self.policy = [] split = len(raw_policy[0]) // 2 for pol in raw_policy: cur_pol = self._parse_policy(pol[:split]) cur_pol.extend(self._parse_policy(pol[split:])) self.policy.append(cur_pol) tf.logging.info('using HP policy schedule, last: {}'.format( self.policy[-1])) elif isinstance(raw_policy, list): split = len(raw_policy) // 2 self.policy = self._parse_policy(raw_policy[:split]) self.policy.extend(self._parse_policy(raw_policy[split:])) tf.logging.info('using HP Policy, policy: {}'.format(self.policy))
def parse_policy(self, hparams): """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list. If list is not nested, then uses the same policy for all epochs. Args: hparams: tf.hparams object. """ # Parse policy if hparams.use_hp_policy: self.augmentation_transforms = augmentation_transforms_pba if isinstance(hparams.hp_policy, str) and hparams.hp_policy.endswith('.txt'): if hparams.num_epochs % hparams.hp_policy_epochs != 0: tf.logging.warning( "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.", hparams.num_epochs, hparams.hp_policy_epochs) tf.logging.info( 'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}' .format( hparams.hp_policy_epochs, hparams.hp_policy, float(hparams.num_epochs) / hparams.hp_policy_epochs)) raw_policy = parse_log_schedule( hparams.hp_policy, epochs=hparams.hp_policy_epochs, multiplier=float(hparams.num_epochs) / hparams.hp_policy_epochs) elif isinstance(hparams.hp_policy, str) and hparams.hp_policy.endswith('.p'): assert hparams.num_epochs % hparams.hp_policy_epochs == 0 tf.logging.info('custom .p file, policy number: {}'.format( hparams.schedule_num)) with open(hparams.hp_policy, 'rb') as f: policy = pickle.load(f)[hparams.schedule_num] raw_policy = [] for num_iters, pol in policy: for _ in range(num_iters * hparams.num_epochs // hparams.hp_policy_epochs): raw_policy.append(pol) else: raw_policy = hparams.hp_policy if isinstance(raw_policy[0], list): self.policy = [] split = len(raw_policy[0]) // 2 for pol in raw_policy: cur_pol = parse_policy(pol[:split], self.augmentation_transforms) cur_pol.extend( parse_policy(pol[split:], self.augmentation_transforms)) self.policy.append(cur_pol) tf.logging.info('using HP policy schedule, last: {}'.format( self.policy[-1])) elif isinstance(raw_policy, list): split = len(raw_policy) // 2 self.policy = parse_policy(raw_policy[:split], self.augmentation_transforms) self.policy.extend( parse_policy(raw_policy[split:], self.augmentation_transforms)) tf.logging.info('using HP Policy, policy: {}'.format( self.policy)) else: self.augmentation_transforms = augmentation_transforms_autoaug tf.logging.info('using ENAS Policy or no augmentaton policy') if 'svhn' in hparams.dataset: self.good_policies = found_policies.good_policies_svhn() else: assert 'cifar' in hparams.dataset self.good_policies = found_policies.good_policies()
def parse_policy(self, hparams): """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list. If list is not nested, then uses the same policy for all epochs. Args: hparams: tf.hparams object. """ if hparams.no_aug_policy: tf.logging.info("no augmentation policy will be used") if hparams.use_kitti_aug: tf.logging.info("using augmentations from SIGNet") return # Parse policy if hparams.use_hp_policy: self.augmentation_transforms = augmentation_transforms_pba tf.logging.info('hp policy is selected') if isinstance(hparams.hp_policy, str) and hparams.hp_policy.endswith('.txt'): if hparams.num_epochs % hparams.hp_policy_epochs != 0: tf.logging.warning( "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.", hparams.num_epochs, hparams.hp_policy_epochs ) tf.logging.info( 'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}'.format( hparams.hp_policy_epochs, hparams.hp_policy, float(hparams.num_epochs) / hparams.hp_policy_epochs ) ) raw_policy = parse_log_schedule( hparams.hp_policy, epochs=hparams.hp_policy_epochs, multiplier=float(hparams.num_epochs) / hparams.hp_policy_epochs ) else: raw_policy = hparams.hp_policy if isinstance(raw_policy[0], list): self.policy = [] split = len(raw_policy[0]) // 2 for pol in raw_policy: cur_pol = parse_policy(pol[:split], self.augmentation_transforms) cur_pol.extend(parse_policy(pol[split:], self.augmentation_transforms)) self.policy.append(cur_pol) tf.logging.info('using HP policy schedule, last: {}'.format(self.policy[-1])) if self.comet_exp is not None: self.comet_exp.log_parameter('hp_policy_schedule_last', self.policy[-1]) elif isinstance(raw_policy, list): split = len(raw_policy) // 2 self.policy = parse_policy(raw_policy[:split], self.augmentation_transforms) self.policy.extend(parse_policy(raw_policy[split:], self.augmentation_transforms)) tf.logging.info('using HP Policy, policy: {}'.format(self.policy)) if self.comet_exp is not None: self.comet_exp.log_parameter('hp_policy', self.policy) else: # use autoaugment policies modified for KITTI self.augmentation_transforms = augmentation_transforms_autoaug tf.logging.info('using autoaument policy: {}'.format(hparams.policy_dataset)) if hparams.policy_dataset == 'svhn': self.good_policies = found_policies.good_policies_svhn() else: # use cifar10 good policies self.good_policies = found_policies.good_policies_cifar()