Beispiel #1
0
 def parse_policy(self, hparams):
     """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list.
     If list is not nested, then uses the same policy for all epochs.
     Args:
     hparams: tf.hparams object.
     """
     # Parse policy
     if isinstance(hparams.hp_policy,
                   str) and hparams.hp_policy.endswith('.txt'):
         if hparams.num_epochs % hparams.hp_policy_epochs != 0:
             tf.logging.warning(
                 "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.",
                 hparams.num_epochs, hparams.hp_policy_epochs)
         tf.logging.info(
             'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}'
             .format(hparams.hp_policy_epochs, hparams.hp_policy,
                     float(hparams.num_epochs) / hparams.hp_policy_epochs))
         raw_policy = parse_log_schedule(
             hparams.hp_policy,
             epochs=hparams.hp_policy_epochs,
             multiplier=float(hparams.num_epochs) /
             hparams.hp_policy_epochs)
     elif isinstance(hparams.hp_policy, list):
         # support list of hp_policy for search stage
         raw_policy = hparams.hp_policy
     else:
         raise ValueError('hp_policy must be txt or None during training!')
     if isinstance(raw_policy[0], list):
         self.policy = []
         split = len(raw_policy[0]) // 2
         for pol in raw_policy:
             cur_pol = self._parse_policy(pol[:split])
             cur_pol.extend(self._parse_policy(pol[split:]))
             self.policy.append(cur_pol)
         tf.logging.info('using HP policy schedule, last: {}'.format(
             self.policy[-1]))
     elif isinstance(raw_policy, list):
         split = len(raw_policy) // 2
         self.policy = self._parse_policy(raw_policy[:split])
         self.policy.extend(self._parse_policy(raw_policy[split:]))
         tf.logging.info('using HP Policy, policy: {}'.format(self.policy))
Beispiel #2
0
    def parse_policy(self, hparams):
        """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list.

        If list is not nested, then uses the same policy for all epochs.

        Args:
        hparams: tf.hparams object.
        """
        # Parse policy
        if hparams.use_hp_policy:
            self.augmentation_transforms = augmentation_transforms_pba

            if isinstance(hparams.hp_policy,
                          str) and hparams.hp_policy.endswith('.txt'):
                if hparams.num_epochs % hparams.hp_policy_epochs != 0:
                    tf.logging.warning(
                        "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.",
                        hparams.num_epochs, hparams.hp_policy_epochs)
                tf.logging.info(
                    'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}'
                    .format(
                        hparams.hp_policy_epochs, hparams.hp_policy,
                        float(hparams.num_epochs) / hparams.hp_policy_epochs))
                raw_policy = parse_log_schedule(
                    hparams.hp_policy,
                    epochs=hparams.hp_policy_epochs,
                    multiplier=float(hparams.num_epochs) /
                    hparams.hp_policy_epochs)
            elif isinstance(hparams.hp_policy,
                            str) and hparams.hp_policy.endswith('.p'):
                assert hparams.num_epochs % hparams.hp_policy_epochs == 0
                tf.logging.info('custom .p file, policy number: {}'.format(
                    hparams.schedule_num))
                with open(hparams.hp_policy, 'rb') as f:
                    policy = pickle.load(f)[hparams.schedule_num]
                raw_policy = []
                for num_iters, pol in policy:
                    for _ in range(num_iters * hparams.num_epochs //
                                   hparams.hp_policy_epochs):
                        raw_policy.append(pol)
            else:
                raw_policy = hparams.hp_policy

            if isinstance(raw_policy[0], list):
                self.policy = []
                split = len(raw_policy[0]) // 2
                for pol in raw_policy:
                    cur_pol = parse_policy(pol[:split],
                                           self.augmentation_transforms)
                    cur_pol.extend(
                        parse_policy(pol[split:],
                                     self.augmentation_transforms))
                    self.policy.append(cur_pol)
                tf.logging.info('using HP policy schedule, last: {}'.format(
                    self.policy[-1]))
            elif isinstance(raw_policy, list):
                split = len(raw_policy) // 2
                self.policy = parse_policy(raw_policy[:split],
                                           self.augmentation_transforms)
                self.policy.extend(
                    parse_policy(raw_policy[split:],
                                 self.augmentation_transforms))
                tf.logging.info('using HP Policy, policy: {}'.format(
                    self.policy))

        else:
            self.augmentation_transforms = augmentation_transforms_autoaug
            tf.logging.info('using ENAS Policy or no augmentaton policy')
            if 'svhn' in hparams.dataset:
                self.good_policies = found_policies.good_policies_svhn()
            else:
                assert 'cifar' in hparams.dataset
                self.good_policies = found_policies.good_policies()
Beispiel #3
0
    def parse_policy(self, hparams):
        """Parses policy schedule from input, which can be a list, list of lists, text file, or pickled list.

        If list is not nested, then uses the same policy for all epochs.

        Args:
        hparams: tf.hparams object.
        """
        if hparams.no_aug_policy:
            tf.logging.info("no augmentation policy will be used")
            if hparams.use_kitti_aug:
                tf.logging.info("using augmentations from SIGNet")
            return

        # Parse policy
        if hparams.use_hp_policy:
            self.augmentation_transforms = augmentation_transforms_pba
            tf.logging.info('hp policy is selected')
            if isinstance(hparams.hp_policy, str) and hparams.hp_policy.endswith('.txt'):
                if hparams.num_epochs % hparams.hp_policy_epochs != 0:
                    tf.logging.warning(
                        "Schedule length (%s) doesn't divide evenly into epochs (%s), interpolating.",
                        hparams.num_epochs, hparams.hp_policy_epochs
                    )
                tf.logging.info(
                    'schedule policy trained on {} epochs, parsing from: {}, multiplier: {}'.format(
                        hparams.hp_policy_epochs, hparams.hp_policy,
                        float(hparams.num_epochs) / hparams.hp_policy_epochs
                    )
                )
                raw_policy = parse_log_schedule(
                    hparams.hp_policy,
                    epochs=hparams.hp_policy_epochs,
                    multiplier=float(hparams.num_epochs) / hparams.hp_policy_epochs
                )
            else:
                raw_policy = hparams.hp_policy

            if isinstance(raw_policy[0], list):
                self.policy = []
                split = len(raw_policy[0]) // 2
                for pol in raw_policy:
                    cur_pol = parse_policy(pol[:split], self.augmentation_transforms)
                    cur_pol.extend(parse_policy(pol[split:], self.augmentation_transforms))
                    self.policy.append(cur_pol)
                tf.logging.info('using HP policy schedule, last: {}'.format(self.policy[-1]))
                if self.comet_exp is not None:
                    self.comet_exp.log_parameter('hp_policy_schedule_last', self.policy[-1])
            elif isinstance(raw_policy, list):
                split = len(raw_policy) // 2
                self.policy = parse_policy(raw_policy[:split], self.augmentation_transforms)
                self.policy.extend(parse_policy(raw_policy[split:], self.augmentation_transforms))
                tf.logging.info('using HP Policy, policy: {}'.format(self.policy))
                if self.comet_exp is not None:
                    self.comet_exp.log_parameter('hp_policy', self.policy)
        else:
            # use autoaugment policies modified for KITTI
            self.augmentation_transforms = augmentation_transforms_autoaug
            tf.logging.info('using autoaument policy: {}'.format(hparams.policy_dataset))
            if hparams.policy_dataset == 'svhn':
                self.good_policies = found_policies.good_policies_svhn()
            else:
                # use cifar10 good policies
                self.good_policies = found_policies.good_policies_cifar()