Exemplo n.º 1
0
    def search(self,
               search_method='random',
               epochs=2,
               max_trials=1,
               max_instances_at_once=1,
               augmentation_search=False):
        """
        max_trials: maximum number of trials before hard stop, is not used in hyperband algorithm
        """
        ongoing_trials = OngoingTrials()
        tuner = Tuner(ongoing_trials,
                      search_method=search_method,
                      epochs=epochs,
                      max_trials=max_trials,
                      max_instances_at_once=max_instances_at_once,
                      hp_space=self.hp_space)
        gun = Launcher(ongoing_trials,
                       model_fn=self.name,
                       training_configs=self.training_configs,
                       home_path=self.home_path,
                       annotation_type=self.annotation_type)

        logger.info('commencing hyper-parameter search . . . ')
        tuner.search_hp()
        gun.launch_trials()
        tuner.end_trial()
        # starting second set of trials
        tuner.search_hp()
        while ongoing_trials.status != 'STOPPED':
            gun.launch_trials()
            tuner.end_trial()
            # starting next set of trials
            tuner.search_hp()

        trials = tuner.trials
        if augmentation_search:
            self._searchaugs_retrain_push(trials, tuner, gun)

        sorted_trial_ids = tuner.get_sorted_trial_ids()
        save_best_checkpoint_location = 'best_checkpoint.pt'
        logger.info('the best trial, trial ' + sorted_trial_ids[0] +
                    '\tval: ' + str(trials[sorted_trial_ids[0]]['metrics']))
        temp_checkpoint = torch.load(
            trials[sorted_trial_ids[0]]['meta_checkpoint']['checkpoint_path'])
        checkpoint = trials[sorted_trial_ids[0]]['meta_checkpoint']
        checkpoint.update(temp_checkpoint)
        if os.path.exists(save_best_checkpoint_location):
            logger.info('overwriting checkpoint . . .')
            os.remove(save_best_checkpoint_location)
        torch.save(trials[sorted_trial_ids[0]]['meta_checkpoint'],
                   save_best_checkpoint_location)

        logger.info('best trial: ' +
                    str(trials[sorted_trial_ids[0]]['hp_values']) +
                    '\nbest value: ' +
                    str(trials[sorted_trial_ids[0]]['metrics']))

        best_trial = trials[sorted_trial_ids[0]]['hp_values']
        if os.path.exists(self.path_to_best_trial):
            logger.info('overwriting best_trial.json . . .')
            os.remove(self.path_to_best_trial)
        with open(self.path_to_best_trial, 'w') as fp:
            json.dump(best_trial, fp)
            logger.info('results saved to best_trial.json')
Exemplo n.º 2
0
    def hp_search(self):
        if not self.remote:
            if self.opt_model.max_instances_at_once > torch.cuda.device_count(
            ):
                print(torch.cuda.is_available())
                raise Exception(
                    ''' 'max_instances_at_once' must be smaller or equal to the number of available gpus'''
                )
        if not hasattr(self.opt_model, 'name'):
            logger.info(
                "no 'update_optimal_model' method, checking for model.txt file . . . "
            )
            self.update_optimal_model()
        # initialize hyperparameter_tuner and gun i.e.
        ongoing_trials = OngoingTrials()
        tuner = Tuner(self.opt_model, ongoing_trials)
        gun = Launcher(self.opt_model, ongoing_trials, remote=self.remote)
        logger.info('commencing hyper-parameter search . . . ')
        tuner.search_hp()
        gun.launch_trials()
        tuner.end_trial()
        # starting second set of trials
        tuner.search_hp()
        while ongoing_trials.status is not 'STOPPED':
            gun.launch_trials()
            tuner.end_trial()
            # starting next set of trials
            tuner.search_hp()

        trials = tuner.get_trials()
        if self.opt_model.augmentation_search_method == 'fastautoaugment':
            sorted_trial_ids = tuner.get_sorted_trial_ids()

            string1 = self.path_to_best_checkpoint.split('.')[0]
            paths_ls = []
            for i in range(len(sorted_trial_ids[:5])):
                save_checkpoint_location = string1 + str(i) + '.pt'
                logger.info('trial ' + sorted_trial_ids[i] + '\tval: ' +
                            str(trials[sorted_trial_ids[i]]['metrics']))
                save_checkpoint_location = os.path.join(
                    'augmentations_tuner', 'fastautoaugment',
                    'FastAutoAugment', 'models', save_checkpoint_location)
                if os.path.exists(save_checkpoint_location):
                    logger.info('overwriting checkpoint . . .')
                    os.remove(save_checkpoint_location)
                torch.save(trials[sorted_trial_ids[i]]['checkpoint'],
                           save_checkpoint_location)
                paths_ls.append(save_checkpoint_location)
            augsearch = AugSearch(
                paths_ls=paths_ls
            )  #TODO: calibrate between the model dictionaries
            checkpointwithaugspath = 'final' + string1 + '.pt'
            augsearch.retrain(save_path=checkpointwithaugspath)
            tuner.add_to_oracle_trials(checkpointwithaugspath)

        sorted_trial_ids = tuner.get_sorted_trial_ids()
        save_best_checkpoint_location = 'best_checkpoint.pt'
        logger.info('the best trial, trial ' + sorted_trial_ids[0] +
                    '\tval: ' + str(trials[sorted_trial_ids[0]]['metrics']))
        if os.path.exists(save_best_checkpoint_location):
            logger.info('overwriting checkpoint . . .')
            os.remove(save_best_checkpoint_location)
        torch.save(trials[sorted_trial_ids[0]]['checkpoint'],
                   save_best_checkpoint_location)

        logger.info('best trial: ' +
                    str(trials[sorted_trial_ids[0]]['hp_values']) +
                    '\nbest value: ' +
                    str(trials[sorted_trial_ids[0]]['metrics']))

        best_trial = trials[sorted_trial_ids[0]]['hp_values']
        if os.path.exists(self.path_to_best_trial):
            logger.info('overwriting best_trial.json . . .')
            os.remove(self.path_to_best_trial)
        with open(self.path_to_best_trial, 'w') as fp:
            json.dump(best_trial, fp)
            logger.info('results saved to best_trial.json')
Exemplo n.º 3
0
    def hp_search(self):
        if not self.remote:
            if self.opt_model.max_instances_at_once > torch.cuda.device_count():
                print(torch.cuda.is_available())
                raise Exception(''' 'max_instances_at_once' must be smaller or equal to the number of available gpus''')
        # initialize hyperparameter_tuner and gun i.e.
        ongoing_trials = OngoingTrials()
        tuner = Tuner(self.opt_model, ongoing_trials)
        gun = Launcher(self.opt_model, ongoing_trials, remote=self.remote)
        logger.info('commencing hyper-parameter search . . . ')
        tuner.search_hp()
        gun.launch_trials()
        tuner.end_trial()
        # starting second set of trials
        tuner.search_hp()
        while ongoing_trials.status is not 'STOPPED':
            gun.launch_trials()
            tuner.end_trial()
            # starting next set of trials
            tuner.search_hp()

        trials = tuner.trials
        if self.opt_model.augmentation_search_method == 'fastautoaugment':
            sorted_trial_ids = tuner.get_sorted_trial_ids()

            string1 = self.path_to_best_checkpoint.split('.')[0]
            paths_ls = []
            for i in range(len(sorted_trial_ids[:5])):
                save_checkpoint_location = string1 + str(i) + '.pt'
                logger.info('trial ' + sorted_trial_ids[i] + '\tval: ' + str(trials[sorted_trial_ids[i]]['metrics']))
                save_checkpoint_location = os.path.join(os.getcwd(), 'augmentations_tuner', 'fastautoaugment',
                                                        'FastAutoAugment', 'models', save_checkpoint_location)
                if os.path.exists(save_checkpoint_location):
                    logger.info('overwriting checkpoint . . .')
                    os.remove(save_checkpoint_location)
                torch.save(trials[sorted_trial_ids[i]]['checkpoint'], save_checkpoint_location)
                paths_ls.append(save_checkpoint_location)
            aug_policy = augsearch(paths_ls=paths_ls)  # TODO: calibrate between the model dictionaries
            best_trial = trials[sorted_trial_ids[0]]['hp_values']
            best_trial.update({"augment_policy": aug_policy})
            metrics_and_checkpoint_dict = gun.launch_trial(hp_values=best_trial)
            # no oracle to create trial with, must generate on our own
            trial_id = generate_trial_id()
            tuner.add_trial(trial_id=trial_id,
                            hp_values=best_trial,
                            metrics=metrics_and_checkpoint_dict['metrics'],
                            meta_checkpoint=metrics_and_checkpoint_dict['meta_checkpoint'])

        sorted_trial_ids = tuner.get_sorted_trial_ids()
        save_best_checkpoint_location = 'best_checkpoint.pt'
        logger.info(
            'the best trial, trial ' + sorted_trial_ids[0] + '\tval: ' + str(trials[sorted_trial_ids[0]]['metrics']))
        temp_checkpoint = torch.load(trials[sorted_trial_ids[0]]['meta_checkpoint']['checkpoint_path'])
        checkpoint = trials[sorted_trial_ids[0]]['meta_checkpoint']
        checkpoint.update(temp_checkpoint)
        if os.path.exists(save_best_checkpoint_location):
            logger.info('overwriting checkpoint . . .')
            os.remove(save_best_checkpoint_location)
        torch.save(trials[sorted_trial_ids[0]]['meta_checkpoint'], save_best_checkpoint_location)

        logger.info('best trial: ' + str(trials[sorted_trial_ids[0]]['hp_values']) + '\nbest value: ' + str(
            trials[sorted_trial_ids[0]]['metrics']))

        best_trial = trials[sorted_trial_ids[0]]['hp_values']
        if os.path.exists(self.path_to_best_trial):
            logger.info('overwriting best_trial.json . . .')
            os.remove(self.path_to_best_trial)
        with open(self.path_to_best_trial, 'w') as fp:
            json.dump(best_trial, fp)
            logger.info('results saved to best_trial.json')