예제 #1
0
def decrease_lr_in_optim_config(conf: Config, num_tasks_learnt: int) -> Config:
    """
    Creates a new optim config with a decreased LR
    """
    if num_tasks_learnt <= 0 or not conf.has('decrease_lr_coef'):
        return conf.clone()

    decrease_coef = conf.decrease_lr_coef**num_tasks_learnt

    # Updating LR in the main kwargs
    if conf.kwargs.has('lr'):
        target_lr = conf.kwargs.lr * decrease_coef
        conf = conf.overwrite({'kwargs': {'lr': target_lr}})

    if conf.kwargs.has('groups'):
        groups_with_lr = [
            g for g in conf.groups[g].keys() if conf.groups[g].has('lr')
        ]
        conf = conf.overwrite({
            'groups': {
                g: conf.groups[g].overwrite({'lr': conf.groups[g].lr})
                for g in groups_with_lr
            }
        })

    return conf
예제 #2
0
def split_classes_for_tasks(config: Config,
                            random_seed: int) -> List[List[int]]:
    """
    Splits classes into `num_tasks` groups and returns these splits

    :param num_classes:
    :param num_tasks:
    :param num_classes_per_task:
    :return:
    """

    if config.has('task_sizes'):
        num_classes_to_use = sum(config.task_sizes)
    else:
        num_classes_to_use = config.num_tasks * config.num_classes_per_task

    if num_classes_to_use > config.num_classes:
        warnings.warn(
            f"We'll have duplicated classes: {num_classes_to_use} > {config.num_classes}"
        )

    classes = np.arange(config.num_classes)
    classes = np.tile(classes,
                      np.ceil(num_classes_to_use / len(classes)).astype(int))
    classes = np.random.RandomState(
        seed=random_seed).permutation(classes)[:num_classes_to_use]

    # classes = np.array([1,2,4,6,9,10,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,29,31,38,39,40,41,43,44,45,46,47,49,51,53,54,55,56,57,58,59,60,61,62,63,64,66,67,68,69,70,72,73,74,75,76,77,79,80,81,84,86,87,88,89,91,92,93,96,98,99,103,104,105,106,107,108,109,110,112,114,115,116,117,119,121,122,123,124,125,126,127,128,130,131,132,133,135,136,138,139,140,141,142,143,144,145,147,148,149,150,151,152,153,154,156,157,158,159,160,161,163,166,167,168,169,170,171,172,173,174,175,176,177,178,180,181,183,187,188,189,190,191,192,193,194,195,197,198,199,0,3,5,7,8,13,22,28,30,32,33,34,35,36,37,42,48,50,52,65,71,78,82,83,85,90,94,95,97,100,101,102,111,113,118,120,129,134,137,146,155,162,164,165,179,182,184,185,186,196])

    if config.has('task_sizes'):
        steps = flatten([[0], np.cumsum(config.task_sizes[:-1])])
        splits = [
            classes[c:c + size].tolist()
            for c, size in zip(steps, config.task_sizes)
        ]
    else:
        splits = classes.reshape(config.num_tasks, config.num_classes_per_task)
        splits = splits.tolist()

    return splits
예제 #3
0
def construct_optimizer(model: nn.Module, optim_config: Config):
    name_to_cls = {
        'sgd': torch.optim.SGD,
        'adam': torch.optim.Adam,
        'rms_prop': torch.optim.RMSprop
    }

    if False and  optim_config.has('groups'):
        groups = [{'params': getattr(model, g).parameters(), **optim_config.groups.get(g)} for g in sorted(optim_config.groups.keys())]
    else:
        groups = [{'params': model.parameters()}]

    return name_to_cls[optim_config.type](groups, **optim_config.kwargs)
예제 #4
0
class BaseTrainer:
    def __init__(self, config):
        # TODO: we should somehow say more loudly that we are reserving these properties
        # Besides, some properties are vital for user to define at he has not idea about it :|
        # TODO: even I do not know all the options available in config :|
        if config.has('base_config'):
            self.config = Config.load(config.base_config)
            self.config.overwrite(config)
        else:
            self.config = config

        self._init_paths()

        # Reload config if we continue training
        if os.path.exists(self.paths.config_path):
            print(
                f'Detected existing config: {self.paths.config_path}. Loading it...'
            )
            # A dirty hack that ensures that multiple trainers sync
            # This is needed for a synced file system
            # For some reason, portalocker does not work on a shared FS...
            time.sleep(1)
            self.config = Config.load(self.paths.config_path)
            self.config = self.config.overwrite(Config.read_from_cli())

        self._init_logger()
        self._init_devices()

        if self.is_main_process() and not os.path.exists(
                self.paths.config_path):
            self.config.save(self.paths.config_path)

        if not self.config.get('silent') and self.is_main_process():
            self.logger.info(
                f'Experiment directory: {self.paths.experiment_dir}')

        self._init_tb_writer()
        self._init_callbacks()
        self._init_checkpointing_strategy()
        self._init_validation_strategy()
        self._init_stopping_criteria()

        self.num_iters_done = 0
        self.num_epochs_done = 0
        self.is_explicitly_stopped = False
        self.train_dataloader = None
        self.val_dataloader = None

    ############################
    ### Overwritable methods ###
    ############################
    def init_dataloaders(self):
        pass

    def init_models(self):
        pass

    def init_criterions(self):
        pass

    def init_optimizers(self):
        pass

    def train_on_batch(self, batch):
        pass

    def on_epoch_done(self):
        "Callback which is called when epoch has beed done"
        pass

    def validate(self):
        pass

    def is_main_process(self) -> bool:
        return is_main_process()

    #############
    ### Hooks ###
    #############
    def before_init_hook(self):
        pass

    def after_init_hook(self):
        pass

    def before_training_hook(self):
        pass

    def after_training_hook(self):
        pass

    def get_training_results(self) -> Dict:
        """
        Function which returns training results which
        are passed to summary generation after training is done
        """
        return {}

    ######################
    ### Public methods ###
    ######################
    def start(self):
        if len(self.gpus) > 0:
            with torch.cuda.device(self.gpus[0]):
                self._start()
        else:
            self._start()

    def stop(self, stopping_reason: str = ''):
        self.is_explicitly_stopped = True
        self._explicit_stopping_reason = stopping_reason

    def write_losses(self, losses: dict, prefix=''):
        """
        Iterates over losses and logs them with self.writer
        Arguments:
            - losses: dict of losses; each loss should be a scalar
        """
        for k in losses:
            self.writer.add_scalar(prefix + k, losses[k], self.num_iters_done)

    #######################
    ### Private methods ###
    #######################
    def init(self):
        # Initialization
        self.before_init_hook()
        self.init_dataloaders()
        self.init_models()
        self.init_criterions()
        self.init_optimizers()
        self._try_to_load_checkpoint()
        self.after_init_hook()

    def _start(self):
        self.init()

        # Training
        self.before_training_hook()
        self._run_training()
        self.after_training_hook()
        self.writer.close()

    def _run_training(self):
        try:
            while not self._should_stop():
                if self.config.get('logging.training_progress',
                                   True) and self.is_main_process():
                    batches = tqdm(self.train_dataloader)

                    self.logger.info(
                        'Running epoch #{}'.format(self.num_epochs_done + 1))
                else:
                    batches = self.train_dataloader

                for batch in batches:
                    self._set_train_mode()

                    if self.config.get('should_ignore_oom_batches', False):
                        safe_oom_call(self.train_on_batch,
                                      self.logger,
                                      batch,
                                      debug=self.config.get('debug_gpu'))
                    else:
                        self.train_on_batch(batch)

                    self.num_iters_done += 1

                    # Checkpointing the model BEFORE validation, since validation can hault :|
                    self._try_to_checkpoint()

                    if self.config.get('should_ignore_oom_batches', False):
                        safe_oom_call(self._try_to_validate,
                                      self.logger,
                                      debug=self.config.get('debug_gpu'))
                    else:
                        self._try_to_validate()

                    if self._should_stop():
                        break

                self.num_epochs_done += 1
                self.on_epoch_done()
        except Exception as e:
            self._terminate_experiment(str(e))
            raise

    def _try_to_validate(self):
        should_validate = False

        if self.val_freq_iters:
            should_validate = self.num_iters_done % self.val_freq_iters == 0
        elif self.val_freq_epochs:
            epoch_size = len(self.train_dataloader)
            was_epoch_just_finished = self.num_iters_done % epoch_size == 0
            # TODO: just use different callbacks for val_freq_epochs and val_freq_iters
            num_epochs_done = (
                self.num_epochs_done +
                1) if was_epoch_just_finished else self.num_epochs_done
            is_epoch_appropriate = num_epochs_done % self.val_freq_epochs == 0
            should_validate = was_epoch_just_finished and is_epoch_appropriate

        if should_validate:
            self._set_eval_mode()

            # Validating without grad enabled (less memory consumption)
            with torch.no_grad():
                self.validate()

    def _try_to_checkpoint(self):
        # Checkpointing in non-main processes lead to subtle erros when loading the weights
        if not self.is_main_process() or self.config.get('no_saving'): return

        should_checkpoint = False

        if self.checkpoint_freq_iters:
            should_checkpoint = self.num_iters_done % self.checkpoint_freq_iters == 0
        elif self.checkpoint_freq_epochs:
            # TODO: looks like govnokod
            epoch_size = len(self.train_dataloader)
            freq = self.checkpoint_freq_epochs * epoch_size
            should_checkpoint = self.num_iters_done % freq == 0

        if not should_checkpoint:
            return

        self.checkpoint()
        self._checkpoint_freq_warning()

    def checkpoint(self):
        # TODO: add max_num_checkpoints_to_store argument
        # We want to checkpoint right now!
        if not self.paths.has('checkpoints_path'):
            raise RuntimeError(
                'Tried to checkpoint, but no checkpoint path was specified. Cannot checkpoint.'\
                'Provide either `paths.checkpoints_path` or `experiment_dir` in config.')

        overwrite = not self.config.checkpoint.get('separate_checkpoints')

        for module_name in self.config.get('checkpoint.modules', []):
            self._save_module_state(getattr(self, module_name),
                                    module_name,
                                    overwrite=overwrite)

        for pickle_attr in self.config.get('checkpoint.pickle', []):
            self._pickle(getattr(self, pickle_attr),
                         pickle_attr,
                         overwrite=overwrite)

        self._pickle(
            {
                'num_iters_done': self.num_iters_done,
                'num_epochs_done': self.num_epochs_done
            },
            'training_state',
            overwrite=overwrite)

    def _checkpoint_freq_warning(self):
        """
        Prints warning if we write checkpoints too often and they cost too much
        TODO: wip
        """
        pass

    def _try_to_load_checkpoint(self):
        """Loads model state from checkpoint if it is provided"""
        if not self.is_main_process():
            return  # We should read and broadcast the checkpoint
        if not os.path.isdir(self.paths.checkpoints_path): return

        checkpoints = [
            c for c in os.listdir(self.paths.checkpoints_path)
            if 'training_state' in c
        ]
        if len(checkpoints) == 0:
            return
        checkpoints_iters = [
            int(c[len('training_state-'):-len('-pt')]) for c in checkpoints
        ]
        latest_iter = sorted(checkpoints_iters)[-1]

        try:
            training_state = self._read_pickle_module('training_state',
                                                      latest_iter)
        except FileNotFoundError:
            print('Could not load training state')
            return

        self.num_iters_done = training_state['num_iters_done']
        self.num_epochs_done = training_state['num_epochs_done']

        print(
            f'Continuing from iteration: {self.num_iters_done} ({self.num_epochs_done} epochs)'
        )

        if self.config.checkpoint.get('separate_checkpoints'):
            continue_from_iter = self.num_iters_done
        else:
            continue_from_iter = None  # Since all of them are overwritten

        for module_name in self.config.checkpoint.modules:
            self._load_module_state(getattr(self, module_name), module_name,
                                    continue_from_iter)

        for module_name in self.config.get('checkpoint.pickle', []):
            self._unpickle(module_name, continue_from_iter)

    def _should_stop(self) -> bool:
        "Checks all stopping criteria"
        if (not self.max_num_iters is None) and (self.num_iters_done >=
                                                 self.max_num_iters):
            self._terminate_experiment('Max num iters exceeded')
            return True

        if (not self.max_num_epochs is None) and (self.num_epochs_done >=
                                                  self.max_num_epochs):
            self._terminate_experiment('Max num epochs exceeded')
            return True

        if self._should_early_stop():
            self._terminate_experiment('Early stopping')
            return True

        if self.is_explicitly_stopped:
            self._terminate_experiment(
                f'Stopped explicitly via .stop() method. Reason: {self._explicit_stopping_reason}'
            )
            return True

        return False

    def _should_early_stop(self):
        "Checks early stopping criterion"
        if self.config.get('early_stopping') is None: return False

        history = self.losses[self.config.early_stopping.loss_name]
        n_steps = self.config.early_stopping.history_length
        should_decrease = self.config.early_stopping.should_decrease

        return not is_history_improving(history, n_steps, should_decrease)

    # TODO: we can gather modules automaticall (via "isinstance")
    def _set_train_mode(self, flag: bool = True):
        """Switches all models into training mode"""

        for model_name in self.config.get('modules.models', []):
            getattr(self, model_name).train(flag)

    def _set_eval_mode(self):
        "Switches all models into evaluation mode"
        self._set_train_mode(False)

    def _save_module_state(self,
                           module: nn.Module,
                           name: str,
                           overwrite: bool = True):
        suffix = '' if overwrite else f'-{self.num_iters_done}'
        file_name = f'{name}{suffix}.pt'
        module_path = os.path.join(self.paths.checkpoints_path, file_name)
        torch.save(module.state_dict(), module_path)

    def _load_module_state(self, module, name, iteration: int = None):
        suffix = '' if iteration == None else f'-{iteration}'
        file_name = f'{name}{suffix}.pt'
        module_path = os.path.join(self.paths.checkpoints_path, file_name)
        module.load_state_dict(torch.load(module_path))
        print(f'Loaded checkpoint: {module_path}')

    def _pickle(self, module, name, overwrite: bool = True):
        suffix = '' if overwrite else f'-{self.num_iters_done}'
        file_name = f'{name}{suffix}.pt'
        path = os.path.join(self.paths.checkpoints_path, file_name)
        pickle.dump(module, open(path, 'wb'))

    def _unpickle(self, name, iteration):
        setattr(self, name, self._read_pickle_module(name, iteration))

    def _read_pickle_module(self, name, iteration: int = None):
        suffix = '' if iteration == None else f'-{iteration}'
        file_name = f'{name}{suffix}.pt'
        path = os.path.join(self.paths.checkpoints_path, file_name)
        module = pickle.load(open(path, 'rb'))

        print(f'Loaded pickle module: {path}')

        return module

    def _terminate_experiment(self, termination_reason):
        if not self.is_main_process(): return
        self.logger.info('Terminating experiment because [%s]' %
                         termination_reason)
        self._write_summary(termination_reason)

    def _write_summary(self, termination_reason: str):
        if not self.is_main_process() or self.config.get('no_saving'): return
        if not self.paths.has('summary_path'): return

        summary = {
            'name': self.config.get('exp_name', 'unnamed'),
            'termination_reason': termination_reason,
            'num_iters_done': self.num_iters_done,
            'num_epochs_done': self.num_epochs_done,
            'config': self.config.to_dict(),
            'results': self.get_training_results()
        }

        with open(self.paths.summary_path, 'w') as f:
            yaml.safe_dump(summary, f, default_flow_style=False)

    ##############################
    ### Initialization methods ###
    ##############################
    def _init_logger(self):
        if self.config.has('exp_name'):
            self.logger = logging.getLogger(self.config.exp_name)
        else:
            # TODO: is it okay to use class name?
            self.logger = logging.getLogger(self.__class__.__name__)
            self.logger.warn('You should provide experiment name (by setting "exp_name" attribute in config) ' \
                             'if you want trainer logger to have a specific name.')

        coloredlogs.install(level=self.config.get('logging.level', 'DEBUG'),
                            logger=self.logger)

    def _init_paths(self):
        experiment_dir = infer_new_experiment_path(
            self.config.get('experiment_dir'),
            self.config.get('exp_series_dir'), self.config.get('exp_name'))

        self.paths = Config({
            'experiment_dir':
            experiment_dir,
            'checkpoints_path':
            os.path.join(experiment_dir, 'checkpoints'),
            'summary_path':
            os.path.join(experiment_dir, 'summary.yml'),
            'config_path':
            os.path.join(experiment_dir, 'config.yml'),
            'logs_path':
            os.path.join(experiment_dir, 'logs'),
            'tb_images_path':
            os.path.join(experiment_dir, 'tb_images'),
            'custom_data_path':
            os.path.join(experiment_dir, 'custom_data'),
        })

        if self.config.get('no_saving'): return

        # Have to create all the paths by ourselves
        os.makedirs(self.paths.experiment_dir, exist_ok=True)
        os.makedirs(self.paths.checkpoints_path, exist_ok=True)
        os.makedirs(self.paths.logs_path, exist_ok=True)
        os.makedirs(self.paths.tb_images_path, exist_ok=True)
        os.makedirs(self.paths.custom_data_path, exist_ok=True)
        os.makedirs(os.path.dirname(self.paths.summary_path), exist_ok=True)

    def _init_tb_writer(self):
        if not self.is_main_process() or self.config.get(
                'no_saving') or not self.paths.has('logs_path'):
            logger = self.logger

            # TODO: maybe we should just raise an exception?
            class DummyWriter:
                def __getattribute__(self, name):
                    dummy_fn = lambda *args, **kwargs: None
                    logger.warn(
                        'Tried to use tensorboard, but tensorboard logs dir was not set. Nothing is written.'
                    )
                    return dummy_fn

            self.writer = DummyWriter()
            self.img_writer = DummyWriter()
        else:
            self.writer = SummaryWriter(self.paths.logs_path,
                                        flush_secs=self.config.get(
                                            'logging.tb_flush_secs', 10))
            self.img_writer = SummaryWriter(self.paths.tb_images_path,
                                            flush_secs=self.config.get(
                                                'logging.tb_flush_secs', 10))

    def _init_callbacks(self):
        self._on_iter_done_callbacks: List[Callable] = []
        self._on_epoch_done_callbacks: List[Callable] = []
        self._on_training_done_callbacks: List[Callable] = []

    def _init_checkpointing_strategy(self):
        if self.config.get('checkpoint'):
            self.checkpoint_freq_iters = self.config.checkpoint.get(
                'freq_iters')
            self.checkpoint_freq_epochs = self.config.checkpoint.get(
                'freq_epochs')

            if len(self.config.get('checkpoint.modules')) == 0:
                self.logger.warn(
                    '`checkpoint` config is specified, but no `modules` are provided. '
                    'No torch modules to checkpoint!')

            if self.config.checkpoint.get('pickle'):
                assert type(self.config.checkpoint.pickle) is tuple
                self.logger.info(
                    f'Will be checkpointing with pickle ' \
                    f'the following modules: {self.config.checkpoint.pickle}')

            assert not (self.checkpoint_freq_iters
                        and self.checkpoint_freq_epochs), """
                Can't save both on iters and epochs.
                Please, remove either freq_iters or freq_epochs
            """
        else:
            # TODO: govnokod :|
            self.checkpoint_freq_iters = None
            self.checkpoint_freq_epochs = None

    def _init_validation_strategy(self):
        self.val_freq_iters = self.config.get('val_freq_iters')
        self.val_freq_epochs = self.config.get('val_freq_epochs')

        assert not (self.val_freq_iters and self.val_freq_epochs), """
            Can't validate on both iters and epochs.
            Please, remove either val_freq_iters or val_freq_epochs
        """

    def _init_stopping_criteria(self):
        self.max_num_epochs = self.config.get('hp.max_num_epochs')
        self.max_num_iters = self.config.get('hp.max_num_iters')
        self.losses = {}

        if not (self.max_num_iters or self.max_num_epochs
                or self.config.has('early_stopping')):
            raise ValueError(
                'You should set either `max_num_iters` or `max_num_epochs`')

    def _init_devices(self):
        assert not self.config.has('device_name'), \
            'FireLab detects and sets `device_name` for you. You influence it via `gpus`.'
        assert not hasattr(
            self, 'device_name'
        ), 'You should not overwrite "device_name" attribute in Trainer.'
        assert not hasattr(
            self,
            'gpus'), 'You should not overwrite "gpus" attribute in Trainer.'

        visible_gpus = list(range(torch.cuda.device_count()))
        self.is_distributed = self.config.get('distributed_training.enabled',
                                              False)

        if self.config.has('gpus'):
            self.gpus = self.config.gpus
        elif self.config.has('firelab.gpus'):
            self.gpus = self.config.firelab.gpus
        else:
            # TODO: maybe we should better take GPUs only when allowed?
            self.gpus = visible_gpus
            if not self.config.get('silent'):
                self.logger.warn(
                    f'Attribute "gpus" was not set in config and '
                    f'{len(visible_gpus)} GPUs were found. I gonna use them.')

        if self.is_distributed:
            import horovod.torch as hvd
            hvd.init()
            torch.cuda.device(hvd.local_rank())
            self.device_name = f'cuda:{hvd.local_rank()}'
            self.logger.info(f'My rank is: {hvd.local_rank()}')
        elif len(self.gpus) > 0:
            self.device_name = f'cuda:{self.gpus[0]}'
            torch.cuda.device(self.gpus[0])
        else:
            self.device_name = 'cpu'