def decrease_lr_in_optim_config(conf: Config, num_tasks_learnt: int) -> Config: """ Creates a new optim config with a decreased LR """ if num_tasks_learnt <= 0 or not conf.has('decrease_lr_coef'): return conf.clone() decrease_coef = conf.decrease_lr_coef**num_tasks_learnt # Updating LR in the main kwargs if conf.kwargs.has('lr'): target_lr = conf.kwargs.lr * decrease_coef conf = conf.overwrite({'kwargs': {'lr': target_lr}}) if conf.kwargs.has('groups'): groups_with_lr = [ g for g in conf.groups[g].keys() if conf.groups[g].has('lr') ] conf = conf.overwrite({ 'groups': { g: conf.groups[g].overwrite({'lr': conf.groups[g].lr}) for g in groups_with_lr } }) return conf
def _init_paths(self): experiment_dir = infer_new_experiment_path( self.config.get('experiment_dir'), self.config.get('exp_series_dir'), self.config.get('exp_name')) self.paths = Config({ 'experiment_dir': experiment_dir, 'checkpoints_path': os.path.join(experiment_dir, 'checkpoints'), 'summary_path': os.path.join(experiment_dir, 'summary.yml'), 'config_path': os.path.join(experiment_dir, 'config.yml'), 'logs_path': os.path.join(experiment_dir, 'logs'), 'tb_images_path': os.path.join(experiment_dir, 'tb_images'), 'custom_data_path': os.path.join(experiment_dir, 'custom_data'), }) if self.config.get('no_saving'): return # Have to create all the paths by ourselves os.makedirs(self.paths.experiment_dir, exist_ok=True) os.makedirs(self.paths.checkpoints_path, exist_ok=True) os.makedirs(self.paths.logs_path, exist_ok=True) os.makedirs(self.paths.tb_images_path, exist_ok=True) os.makedirs(self.paths.custom_data_path, exist_ok=True) os.makedirs(os.path.dirname(self.paths.summary_path), exist_ok=True)
def __init__(self, config: Config): config = config.overwrite(config[config.dataset]) config = config.overwrite(Config.read_from_cli()) config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}' if not config.get('silent'): print(config.hp) self.random = np.random.RandomState(config.random_seed) super().__init__(config)
def __init__(self, config: Config): config = config.overwrite( config.datasets[config.dataset] ) # Overwriting with the dataset-dependent hyperparams config = config.overwrite( Config.read_from_cli()) # Overwriting with the CLI arguments config = config.overwrite(Config({'datasets': None})) # So not to pollute logs super(GANTrainer, self).__init__(config) if self.is_distributed: torch.set_num_threads(4)
def create_res_config(self, block_idx: int) -> Config: increase_conf = self.config.hp.inr.res_increase_scheme num_blocks = self.config.hp.inr.num_blocks resolutions = self.generate_img_sizes(self.config.data.target_img_size) fourier_scale = np.linspace(increase_conf.fourier_scales.min, increase_conf.fourier_scales.max, num_blocks)[block_idx] dim = np.linspace(increase_conf.dims.max, increase_conf.dims.min, num_blocks).astype(int)[block_idx] num_coord_feats = np.linspace(increase_conf.num_coord_feats.max, increase_conf.num_coord_feats.min, num_blocks).astype(int)[block_idx] return Config({ 'resolution': resolutions[block_idx], 'num_learnable_coord_feats': num_coord_feats.item(), 'use_diag_feats': resolutions[block_idx] <= increase_conf.diag_feats_threshold, 'max_num_fixed_coord_feats': 10000 if increase_conf.use_fixed_coord_feats else 0, 'dim': dim.item(), 'fourier_scale': fourier_scale.item(), 'to_rgb': resolutions[block_idx] >= increase_conf.to_rgb_res_threshold, 'n_layers': 1 })
def test_setter(): assert Config({"a": 3}).a == 3 assert Config({"a": 3, "b": {"c": 4}}).b.c == 4 assert Config({"a.b": 3}).a.b == 3 config = Config({"a.b": 3, "a.e": 4}) config.set("a.c.d", 5) config.set("e.e", [1, 2, 3]) assert config.a.b == 3 assert config.a.e == 4 assert config.a.c.d == 5 assert config.e.e == tuple([1, 2, 3])
def __init__(self, config): # TODO: we should somehow say more loudly that we are reserving these properties # Besides, some properties are vital for user to define at he has not idea about it :| # TODO: even I do not know all the options available in config :| if config.has('base_config'): self.config = Config.load(config.base_config) self.config.overwrite(config) else: self.config = config self._init_paths() # Reload config if we continue training if os.path.exists(self.paths.config_path): print( f'Detected existing config: {self.paths.config_path}. Loading it...' ) # A dirty hack that ensures that multiple trainers sync # This is needed for a synced file system # For some reason, portalocker does not work on a shared FS... time.sleep(1) self.config = Config.load(self.paths.config_path) self.config = self.config.overwrite(Config.read_from_cli()) self._init_logger() self._init_devices() if self.is_main_process() and not os.path.exists( self.paths.config_path): self.config.save(self.paths.config_path) if not self.config.get('silent') and self.is_main_process(): self.logger.info( f'Experiment directory: {self.paths.experiment_dir}') self._init_tb_writer() self._init_callbacks() self._init_checkpointing_strategy() self._init_validation_strategy() self._init_stopping_criteria() self.num_iters_done = 0 self.num_epochs_done = 0 self.is_explicitly_stopped = False self.train_dataloader = None self.val_dataloader = None
def run(config_path: str, tb_port: int = None): config = Config.load(config_path) config = config.overwrite(Config.read_from_cli()) if config.get('distributed_training.enabled'): import horovod.torch as hvd hvd.init() fix_random_seed(config.random_seed + hvd.rank()) else: fix_random_seed(config.random_seed) trainer = GANTrainer(config) if not tb_port is None and trainer.is_main_process(): trainer.logger.info(f'Starting tensorboard on port {tb_port}') run_tensorboard(trainer.paths.experiment_dir, tb_port) trainer.start()
def run_validation_sequence(args: argparse.Namespace, config: Config): experiments_vals = generate_experiments_from_hpo_grid( config.validation_sequence.hpo_grid) experiments_vals = [{p.replace('|', '.'): v for p, v in exp.items()} for exp in experiments_vals] configs = [config.overwrite({'hp': Config(hp)}) for hp in experiments_vals] scores = [] print(f'Number of random experiments: {len(configs)}') for i, c in enumerate(configs): print('<==== Running HPs ====>') print(experiments_vals[i]) c = c.overwrite( Config({ 'experiments_dir': f'{config.experiments_dir}-val-seqs', 'lll_setup.num_tasks': c.validation_sequence.num_tasks, 'logging.save_train_logits': False, 'logging.print_accuracy_after_task': False, 'logging.print_unseen_accuracy': False, 'logging.print_forgetting': False, 'exp_name': compute_experiment_name(args, config.hp) })) trainer = LLLTrainer(c) trainer.start() if config.validation_sequence.metric == 'harmonic_mean': score = np.mean(trainer.compute_harmonic_mean_accuracy()) elif config.validation_sequence.metric == 'final_task_wise_acc': score = np.mean(trainer.compute_final_tasks_performance()) else: raise NotImplementedError('Unknown metric') scores.append(score) best_config = configs[np.argmax(scores)] print('Best found setup:', experiments_vals[np.argmax(scores)]) print(best_config) best_config = best_config.overwrite( Config({'start_task': config.validation_sequence.num_tasks})) trainer = LLLTrainer(best_config) trainer.start()
def split_classes_for_tasks(config: Config, random_seed: int) -> List[List[int]]: """ Splits classes into `num_tasks` groups and returns these splits :param num_classes: :param num_tasks: :param num_classes_per_task: :return: """ if config.has('task_sizes'): num_classes_to_use = sum(config.task_sizes) else: num_classes_to_use = config.num_tasks * config.num_classes_per_task if num_classes_to_use > config.num_classes: warnings.warn( f"We'll have duplicated classes: {num_classes_to_use} > {config.num_classes}" ) classes = np.arange(config.num_classes) classes = np.tile(classes, np.ceil(num_classes_to_use / len(classes)).astype(int)) classes = np.random.RandomState( seed=random_seed).permutation(classes)[:num_classes_to_use] # classes = np.array([1,2,4,6,9,10,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,29,31,38,39,40,41,43,44,45,46,47,49,51,53,54,55,56,57,58,59,60,61,62,63,64,66,67,68,69,70,72,73,74,75,76,77,79,80,81,84,86,87,88,89,91,92,93,96,98,99,103,104,105,106,107,108,109,110,112,114,115,116,117,119,121,122,123,124,125,126,127,128,130,131,132,133,135,136,138,139,140,141,142,143,144,145,147,148,149,150,151,152,153,154,156,157,158,159,160,161,163,166,167,168,169,170,171,172,173,174,175,176,177,178,180,181,183,187,188,189,190,191,192,193,194,195,197,198,199,0,3,5,7,8,13,22,28,30,32,33,34,35,36,37,42,48,50,52,65,71,78,82,83,85,90,94,95,97,100,101,102,111,113,118,120,129,134,137,146,155,162,164,165,179,182,184,185,186,196]) if config.has('task_sizes'): steps = flatten([[0], np.cumsum(config.task_sizes[:-1])]) splits = [ classes[c:c + size].tolist() for c, size in zip(steps, config.task_sizes) ] else: splits = classes.reshape(config.num_tasks, config.num_classes_per_task) splits = splits.tolist() return splits
def load_config(args: argparse.Namespace, config_cli_args: List[str]) -> Config: base_config = Config.load('configs/base.yml') curr_config = Config.load(f'configs/{args.config_name}.yml') # Setting properties from the base config config = base_config.all.clone() config = config.overwrite(base_config.get(args.dataset)) # Setting properties from the current config config = config.overwrite(curr_config.all) config = config.overwrite(curr_config.get(args.dataset, Config({}))) # Setting experiment-specific properties config.set('experiments_dir', args.experiments_dir) config.set('random_seed', args.random_seed) # Overwriting with CLI arguments config = config.overwrite(Config.read_from_cli()) config.set('exp_name', compute_experiment_name(args, config.hp)) return config
def construct_optimizer(model: nn.Module, optim_config: Config): name_to_cls = { 'sgd': torch.optim.SGD, 'adam': torch.optim.Adam, 'rms_prop': torch.optim.RMSprop } if False and optim_config.has('groups'): groups = [{'params': getattr(model, g).parameters(), **optim_config.groups.get(g)} for g in sorted(optim_config.groups.keys())] else: groups = [{'params': model.parameters()}] return name_to_cls[optim_config.type](groups, **optim_config.kwargs)
def run_trainer(args: argparse.Namespace, config_args: List[str]): # TODO: read some staff from command line and overwrite config config = Config.load('configs/densepose-rcnn.yml') if not args.local_rank is None: config.set('gpus', [args.local_rank]) else: config.set('gpus', args.gpus) config.set('experiments_dir', args.experiments_dir) config_args = process_cli_config_args(config_args) config = config.overwrite( Config(config_args)) # Overwrite with CLI arguments trainer = DensePoseRCNNTrainer(config) if args.validate_only: print('Running validation only...') trainer.init() trainer.val_dataloader = trainer.train_dataloader trainer.validate() else: trainer.start()
def extract_data(summary_path: os.PathLike, logs_path: os.PathLike) -> Tuple[List[float], Dict, str]: config = Config.load(summary_path).config events_acc = EventAccumulator(logs_path) events_acc.Reload() _, _, val_acc_diffs = zip(*events_acc.Scalars('diff/val/acc')) hp = config.hp.to_dict() hp['n_conv_layers'] = len(config.hp.conv_model_config.conv_sizes) if 'Minimum_test' in events_acc.images.Keys(): image_test = events_acc.Images('Minimum_test')[0].encoded_image_string image_train = events_acc.Images( 'Minimum_train')[0].encoded_image_string else: image_test = None image_train = None return val_acc_diffs, hp, image_test, image_train
def get_transform(config: Config) -> Callable: if config.name == 'mnist': return transforms.Compose([ transforms.Resize(config.target_img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) elif config.name in {'cifar10', 'single_image'}: return transforms.Compose([ transforms.Resize( (config.target_img_size, config.target_img_size)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) elif config.name.startswith('lsun_') or config.name in { 'ffhq_thumbs', 'celeba_thumbs', 'ffhq_256', 'ffhq_1024' }: if config.get('concat_patches.enabled'): return transforms.Compose([ CenterCropToMin(), transforms.RandomHorizontalFlip(), PatchConcatAndResize(config.target_img_size, config.concat_patches.ratio), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) else: return transforms.Compose([ CenterCropToMin(), transforms.RandomHorizontalFlip(), transforms.Resize(config.target_img_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) elif config.name == 'imagenet_vs': return transforms.Compose([ PadToSquare(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) else: raise NotImplementedError(f'Unknown dataset: {config.name}')
def __init__(self, config: Config): super(ConvModel, self).__init__() conv_sizes = config.get('conv_sizes', [1, 8, 32, 64]) dense_sizes = config.get('dense_sizes', [576, 128]) use_bn = config.get('use_bn', False) use_dropout = config.get('use_dropout', False) use_maxpool = config.get('use_maxpool', False) use_skip_connection = config.get('use_skip_connection', False) activation = config.get('activation', 'relu') adaptive_pool_size = config.get('adaptive_pool_size', (4, 4)) if activation == 'relu': self.activation = lambda: nn.ReLU(inplace=True) elif activation == 'selu': self.activation = lambda: nn.SELU(inplace=True) elif activation == 'tanh': self.activation = lambda: nn.Tanh() elif activation == 'sigmoid': self.activation = lambda: nn.Sigmoid() else: raise NotImplementedError( f'Unknown activation function: {activation}') conv_body = nn.Sequential(*[ ConvBlock(conv_sizes[i], conv_sizes[i+1], use_bn, use_skip_connection, \ use_maxpool, self.activation) for i in range(len(conv_sizes) - 1)]) dense_head = nn.Sequential(*[ self._create_dense_block(dense_sizes[i], dense_sizes[ i + 1], use_dropout) for i in range(len(dense_sizes) - 1) ]) self.nn = nn.Sequential(conv_body, nn.AdaptiveAvgPool2d(adaptive_pool_size), Flatten(), dense_head, nn.Linear(dense_sizes[-1], 10))
type=str, default="stylegan2-ffhq-config-f.pt", help="path to the model checkpoint", ) parser.add_argument( "--channel_multiplier", type=int, default=2, help="channel multiplier of the generator. config-f = 2, else = 1", ) args, _ = parser.parse_known_args() args.latent = 512 args.n_mlp = 8 config = Config.load('config.yml') config = config.overwrite(Config.read_from_cli()) g_ema = Generator(config).to(device) checkpoint = torch.load(args.ckpt) g_ema.load_state_dict(checkpoint["g_ema"]) if args.truncation < 1: with torch.no_grad(): mean_latent = g_ema.mean_latent(args.truncation_mean) else: mean_latent = None generate(args, g_ema, device, mean_latent)
def test_overwrite(): assert Config({}).overwrite(Config({"a": 3})).a == 3 assert Config({"a": 2}).overwrite(Config({"a": 3})).a == 3 assert Config({"b": 4}).overwrite(Config({"a": 3})).a == 3 assert Config({"b": 4}).overwrite(Config({"a": 3})).b == 4 assert Config({"a": {"b": 3}}).overwrite(Config({"b": 3})).b == 3 assert Config({"a": {"b": 4}}).overwrite(Config({"b": 3})).a.b == 4 assert Config({"a": {"b": 4}}).overwrite(Config({"a": {"c": 5}})).a.b == 4 assert Config({"a": {"b": 4}}).overwrite(Config({"a": {"c": 5}})).a.c == 5
def test_getter(): assert Config({}).get("a") == None assert Config({"a": 2}).get("a") == 2 assert Config({"a": 2}).get("b", 3) == 3 assert Config({"a": {"b": 4}}).get("a.b") == 4 assert Config({"a": {"b": 4}}).get("a.c", 5) == 5
class BaseTrainer: def __init__(self, config): # TODO: we should somehow say more loudly that we are reserving these properties # Besides, some properties are vital for user to define at he has not idea about it :| # TODO: even I do not know all the options available in config :| if config.has('base_config'): self.config = Config.load(config.base_config) self.config.overwrite(config) else: self.config = config self._init_paths() # Reload config if we continue training if os.path.exists(self.paths.config_path): print( f'Detected existing config: {self.paths.config_path}. Loading it...' ) # A dirty hack that ensures that multiple trainers sync # This is needed for a synced file system # For some reason, portalocker does not work on a shared FS... time.sleep(1) self.config = Config.load(self.paths.config_path) self.config = self.config.overwrite(Config.read_from_cli()) self._init_logger() self._init_devices() if self.is_main_process() and not os.path.exists( self.paths.config_path): self.config.save(self.paths.config_path) if not self.config.get('silent') and self.is_main_process(): self.logger.info( f'Experiment directory: {self.paths.experiment_dir}') self._init_tb_writer() self._init_callbacks() self._init_checkpointing_strategy() self._init_validation_strategy() self._init_stopping_criteria() self.num_iters_done = 0 self.num_epochs_done = 0 self.is_explicitly_stopped = False self.train_dataloader = None self.val_dataloader = None ############################ ### Overwritable methods ### ############################ def init_dataloaders(self): pass def init_models(self): pass def init_criterions(self): pass def init_optimizers(self): pass def train_on_batch(self, batch): pass def on_epoch_done(self): "Callback which is called when epoch has beed done" pass def validate(self): pass def is_main_process(self) -> bool: return is_main_process() ############# ### Hooks ### ############# def before_init_hook(self): pass def after_init_hook(self): pass def before_training_hook(self): pass def after_training_hook(self): pass def get_training_results(self) -> Dict: """ Function which returns training results which are passed to summary generation after training is done """ return {} ###################### ### Public methods ### ###################### def start(self): if len(self.gpus) > 0: with torch.cuda.device(self.gpus[0]): self._start() else: self._start() def stop(self, stopping_reason: str = ''): self.is_explicitly_stopped = True self._explicit_stopping_reason = stopping_reason def write_losses(self, losses: dict, prefix=''): """ Iterates over losses and logs them with self.writer Arguments: - losses: dict of losses; each loss should be a scalar """ for k in losses: self.writer.add_scalar(prefix + k, losses[k], self.num_iters_done) ####################### ### Private methods ### ####################### def init(self): # Initialization self.before_init_hook() self.init_dataloaders() self.init_models() self.init_criterions() self.init_optimizers() self._try_to_load_checkpoint() self.after_init_hook() def _start(self): self.init() # Training self.before_training_hook() self._run_training() self.after_training_hook() self.writer.close() def _run_training(self): try: while not self._should_stop(): if self.config.get('logging.training_progress', True) and self.is_main_process(): batches = tqdm(self.train_dataloader) self.logger.info( 'Running epoch #{}'.format(self.num_epochs_done + 1)) else: batches = self.train_dataloader for batch in batches: self._set_train_mode() if self.config.get('should_ignore_oom_batches', False): safe_oom_call(self.train_on_batch, self.logger, batch, debug=self.config.get('debug_gpu')) else: self.train_on_batch(batch) self.num_iters_done += 1 # Checkpointing the model BEFORE validation, since validation can hault :| self._try_to_checkpoint() if self.config.get('should_ignore_oom_batches', False): safe_oom_call(self._try_to_validate, self.logger, debug=self.config.get('debug_gpu')) else: self._try_to_validate() if self._should_stop(): break self.num_epochs_done += 1 self.on_epoch_done() except Exception as e: self._terminate_experiment(str(e)) raise def _try_to_validate(self): should_validate = False if self.val_freq_iters: should_validate = self.num_iters_done % self.val_freq_iters == 0 elif self.val_freq_epochs: epoch_size = len(self.train_dataloader) was_epoch_just_finished = self.num_iters_done % epoch_size == 0 # TODO: just use different callbacks for val_freq_epochs and val_freq_iters num_epochs_done = ( self.num_epochs_done + 1) if was_epoch_just_finished else self.num_epochs_done is_epoch_appropriate = num_epochs_done % self.val_freq_epochs == 0 should_validate = was_epoch_just_finished and is_epoch_appropriate if should_validate: self._set_eval_mode() # Validating without grad enabled (less memory consumption) with torch.no_grad(): self.validate() def _try_to_checkpoint(self): # Checkpointing in non-main processes lead to subtle erros when loading the weights if not self.is_main_process() or self.config.get('no_saving'): return should_checkpoint = False if self.checkpoint_freq_iters: should_checkpoint = self.num_iters_done % self.checkpoint_freq_iters == 0 elif self.checkpoint_freq_epochs: # TODO: looks like govnokod epoch_size = len(self.train_dataloader) freq = self.checkpoint_freq_epochs * epoch_size should_checkpoint = self.num_iters_done % freq == 0 if not should_checkpoint: return self.checkpoint() self._checkpoint_freq_warning() def checkpoint(self): # TODO: add max_num_checkpoints_to_store argument # We want to checkpoint right now! if not self.paths.has('checkpoints_path'): raise RuntimeError( 'Tried to checkpoint, but no checkpoint path was specified. Cannot checkpoint.'\ 'Provide either `paths.checkpoints_path` or `experiment_dir` in config.') overwrite = not self.config.checkpoint.get('separate_checkpoints') for module_name in self.config.get('checkpoint.modules', []): self._save_module_state(getattr(self, module_name), module_name, overwrite=overwrite) for pickle_attr in self.config.get('checkpoint.pickle', []): self._pickle(getattr(self, pickle_attr), pickle_attr, overwrite=overwrite) self._pickle( { 'num_iters_done': self.num_iters_done, 'num_epochs_done': self.num_epochs_done }, 'training_state', overwrite=overwrite) def _checkpoint_freq_warning(self): """ Prints warning if we write checkpoints too often and they cost too much TODO: wip """ pass def _try_to_load_checkpoint(self): """Loads model state from checkpoint if it is provided""" if not self.is_main_process(): return # We should read and broadcast the checkpoint if not os.path.isdir(self.paths.checkpoints_path): return checkpoints = [ c for c in os.listdir(self.paths.checkpoints_path) if 'training_state' in c ] if len(checkpoints) == 0: return checkpoints_iters = [ int(c[len('training_state-'):-len('-pt')]) for c in checkpoints ] latest_iter = sorted(checkpoints_iters)[-1] try: training_state = self._read_pickle_module('training_state', latest_iter) except FileNotFoundError: print('Could not load training state') return self.num_iters_done = training_state['num_iters_done'] self.num_epochs_done = training_state['num_epochs_done'] print( f'Continuing from iteration: {self.num_iters_done} ({self.num_epochs_done} epochs)' ) if self.config.checkpoint.get('separate_checkpoints'): continue_from_iter = self.num_iters_done else: continue_from_iter = None # Since all of them are overwritten for module_name in self.config.checkpoint.modules: self._load_module_state(getattr(self, module_name), module_name, continue_from_iter) for module_name in self.config.get('checkpoint.pickle', []): self._unpickle(module_name, continue_from_iter) def _should_stop(self) -> bool: "Checks all stopping criteria" if (not self.max_num_iters is None) and (self.num_iters_done >= self.max_num_iters): self._terminate_experiment('Max num iters exceeded') return True if (not self.max_num_epochs is None) and (self.num_epochs_done >= self.max_num_epochs): self._terminate_experiment('Max num epochs exceeded') return True if self._should_early_stop(): self._terminate_experiment('Early stopping') return True if self.is_explicitly_stopped: self._terminate_experiment( f'Stopped explicitly via .stop() method. Reason: {self._explicit_stopping_reason}' ) return True return False def _should_early_stop(self): "Checks early stopping criterion" if self.config.get('early_stopping') is None: return False history = self.losses[self.config.early_stopping.loss_name] n_steps = self.config.early_stopping.history_length should_decrease = self.config.early_stopping.should_decrease return not is_history_improving(history, n_steps, should_decrease) # TODO: we can gather modules automaticall (via "isinstance") def _set_train_mode(self, flag: bool = True): """Switches all models into training mode""" for model_name in self.config.get('modules.models', []): getattr(self, model_name).train(flag) def _set_eval_mode(self): "Switches all models into evaluation mode" self._set_train_mode(False) def _save_module_state(self, module: nn.Module, name: str, overwrite: bool = True): suffix = '' if overwrite else f'-{self.num_iters_done}' file_name = f'{name}{suffix}.pt' module_path = os.path.join(self.paths.checkpoints_path, file_name) torch.save(module.state_dict(), module_path) def _load_module_state(self, module, name, iteration: int = None): suffix = '' if iteration == None else f'-{iteration}' file_name = f'{name}{suffix}.pt' module_path = os.path.join(self.paths.checkpoints_path, file_name) module.load_state_dict(torch.load(module_path)) print(f'Loaded checkpoint: {module_path}') def _pickle(self, module, name, overwrite: bool = True): suffix = '' if overwrite else f'-{self.num_iters_done}' file_name = f'{name}{suffix}.pt' path = os.path.join(self.paths.checkpoints_path, file_name) pickle.dump(module, open(path, 'wb')) def _unpickle(self, name, iteration): setattr(self, name, self._read_pickle_module(name, iteration)) def _read_pickle_module(self, name, iteration: int = None): suffix = '' if iteration == None else f'-{iteration}' file_name = f'{name}{suffix}.pt' path = os.path.join(self.paths.checkpoints_path, file_name) module = pickle.load(open(path, 'rb')) print(f'Loaded pickle module: {path}') return module def _terminate_experiment(self, termination_reason): if not self.is_main_process(): return self.logger.info('Terminating experiment because [%s]' % termination_reason) self._write_summary(termination_reason) def _write_summary(self, termination_reason: str): if not self.is_main_process() or self.config.get('no_saving'): return if not self.paths.has('summary_path'): return summary = { 'name': self.config.get('exp_name', 'unnamed'), 'termination_reason': termination_reason, 'num_iters_done': self.num_iters_done, 'num_epochs_done': self.num_epochs_done, 'config': self.config.to_dict(), 'results': self.get_training_results() } with open(self.paths.summary_path, 'w') as f: yaml.safe_dump(summary, f, default_flow_style=False) ############################## ### Initialization methods ### ############################## def _init_logger(self): if self.config.has('exp_name'): self.logger = logging.getLogger(self.config.exp_name) else: # TODO: is it okay to use class name? self.logger = logging.getLogger(self.__class__.__name__) self.logger.warn('You should provide experiment name (by setting "exp_name" attribute in config) ' \ 'if you want trainer logger to have a specific name.') coloredlogs.install(level=self.config.get('logging.level', 'DEBUG'), logger=self.logger) def _init_paths(self): experiment_dir = infer_new_experiment_path( self.config.get('experiment_dir'), self.config.get('exp_series_dir'), self.config.get('exp_name')) self.paths = Config({ 'experiment_dir': experiment_dir, 'checkpoints_path': os.path.join(experiment_dir, 'checkpoints'), 'summary_path': os.path.join(experiment_dir, 'summary.yml'), 'config_path': os.path.join(experiment_dir, 'config.yml'), 'logs_path': os.path.join(experiment_dir, 'logs'), 'tb_images_path': os.path.join(experiment_dir, 'tb_images'), 'custom_data_path': os.path.join(experiment_dir, 'custom_data'), }) if self.config.get('no_saving'): return # Have to create all the paths by ourselves os.makedirs(self.paths.experiment_dir, exist_ok=True) os.makedirs(self.paths.checkpoints_path, exist_ok=True) os.makedirs(self.paths.logs_path, exist_ok=True) os.makedirs(self.paths.tb_images_path, exist_ok=True) os.makedirs(self.paths.custom_data_path, exist_ok=True) os.makedirs(os.path.dirname(self.paths.summary_path), exist_ok=True) def _init_tb_writer(self): if not self.is_main_process() or self.config.get( 'no_saving') or not self.paths.has('logs_path'): logger = self.logger # TODO: maybe we should just raise an exception? class DummyWriter: def __getattribute__(self, name): dummy_fn = lambda *args, **kwargs: None logger.warn( 'Tried to use tensorboard, but tensorboard logs dir was not set. Nothing is written.' ) return dummy_fn self.writer = DummyWriter() self.img_writer = DummyWriter() else: self.writer = SummaryWriter(self.paths.logs_path, flush_secs=self.config.get( 'logging.tb_flush_secs', 10)) self.img_writer = SummaryWriter(self.paths.tb_images_path, flush_secs=self.config.get( 'logging.tb_flush_secs', 10)) def _init_callbacks(self): self._on_iter_done_callbacks: List[Callable] = [] self._on_epoch_done_callbacks: List[Callable] = [] self._on_training_done_callbacks: List[Callable] = [] def _init_checkpointing_strategy(self): if self.config.get('checkpoint'): self.checkpoint_freq_iters = self.config.checkpoint.get( 'freq_iters') self.checkpoint_freq_epochs = self.config.checkpoint.get( 'freq_epochs') if len(self.config.get('checkpoint.modules')) == 0: self.logger.warn( '`checkpoint` config is specified, but no `modules` are provided. ' 'No torch modules to checkpoint!') if self.config.checkpoint.get('pickle'): assert type(self.config.checkpoint.pickle) is tuple self.logger.info( f'Will be checkpointing with pickle ' \ f'the following modules: {self.config.checkpoint.pickle}') assert not (self.checkpoint_freq_iters and self.checkpoint_freq_epochs), """ Can't save both on iters and epochs. Please, remove either freq_iters or freq_epochs """ else: # TODO: govnokod :| self.checkpoint_freq_iters = None self.checkpoint_freq_epochs = None def _init_validation_strategy(self): self.val_freq_iters = self.config.get('val_freq_iters') self.val_freq_epochs = self.config.get('val_freq_epochs') assert not (self.val_freq_iters and self.val_freq_epochs), """ Can't validate on both iters and epochs. Please, remove either val_freq_iters or val_freq_epochs """ def _init_stopping_criteria(self): self.max_num_epochs = self.config.get('hp.max_num_epochs') self.max_num_iters = self.config.get('hp.max_num_iters') self.losses = {} if not (self.max_num_iters or self.max_num_epochs or self.config.has('early_stopping')): raise ValueError( 'You should set either `max_num_iters` or `max_num_epochs`') def _init_devices(self): assert not self.config.has('device_name'), \ 'FireLab detects and sets `device_name` for you. You influence it via `gpus`.' assert not hasattr( self, 'device_name' ), 'You should not overwrite "device_name" attribute in Trainer.' assert not hasattr( self, 'gpus'), 'You should not overwrite "gpus" attribute in Trainer.' visible_gpus = list(range(torch.cuda.device_count())) self.is_distributed = self.config.get('distributed_training.enabled', False) if self.config.has('gpus'): self.gpus = self.config.gpus elif self.config.has('firelab.gpus'): self.gpus = self.config.firelab.gpus else: # TODO: maybe we should better take GPUs only when allowed? self.gpus = visible_gpus if not self.config.get('silent'): self.logger.warn( f'Attribute "gpus" was not set in config and ' f'{len(visible_gpus)} GPUs were found. I gonna use them.') if self.is_distributed: import horovod.torch as hvd hvd.init() torch.cuda.device(hvd.local_rank()) self.device_name = f'cuda:{hvd.local_rank()}' self.logger.info(f'My rank is: {hvd.local_rank()}') elif len(self.gpus) > 0: self.device_name = f'cuda:{self.gpus[0]}' torch.cuda.device(self.gpus[0]) else: self.device_name = 'cpu'
def load_data( config: Config, img_target_shape: Tuple[int, int] = None ) -> Tuple[ImageDataset, ImageDataset, np.ndarray]: if config.name == 'CUB': ds_train = cub.load_dataset(config.dir, split='train', target_shape=img_target_shape, in_memory=config.get('in_memory', False)) ds_test = cub.load_dataset(config.dir, split='test', target_shape=img_target_shape, in_memory=config.get('in_memory', False)) class_attributes = cub.load_class_attributes(config.dir).astype( np.float32) elif config.name == 'CUB_EMBEDDINGS': ds_train = feats.load_dataset(config.dir, config.input_type, split='train') ds_test = feats.load_dataset(config.dir, config.input_type, split='test') class_attributes = cub.load_class_attributes(config.dir).astype( np.float32) elif config.name == 'AWA': ds_train = awa.load_dataset(config.dir, split='train', target_shape=img_target_shape) ds_test = awa.load_dataset(config.dir, split='test', target_shape=img_target_shape) class_attributes = awa.load_class_attributes(config.dir).astype( np.float32) elif config.name == 'SUN': ds_train = sun.load_dataset(config.dir, split='train', target_shape=img_target_shape) ds_test = sun.load_dataset(config.dir, split='val', target_shape=img_target_shape) class_attributes = sun.load_class_attributes(config.dir).astype( np.float32) elif config.name == 'TinyImageNet': ds_train = tiny_imagenet.load_dataset(config.dir, split='train', target_shape=img_target_shape) ds_test = tiny_imagenet.load_dataset(config.dir, split='val', target_shape=img_target_shape) class_attributes = None elif config.name in SIMPLE_LOADERS.keys(): ds_train = SIMPLE_LOADERS[config.name](config.dir, split='train') ds_test = SIMPLE_LOADERS[config.name](config.dir, split='test') class_attributes = None elif config.name.endswith('EMBEDDINGS'): ds_train = feats.load_dataset(config.dir, config.input_type, split='train') ds_test = feats.load_dataset(config.dir, config.input_type, split='test') class_attributes = None else: raise NotImplementedError(f'Unkown dataset: {config.name}') # if embed_data: # ds_train = extract_resnet_features_for_dataset(ds_train, input_type=18) # ds_test = extract_resnet_features_for_dataset(ds_test, input_type=18) # np.save(f'/tmp/{config.name}_train', ds_train) # np.save(f'/tmp/{config.name}_test', ds_test) # ds_train = np.load(f'/tmp/{config.name}_train.npy', allow_pickle=True) # ds_test = np.load(f'/tmp/{config.name}_test.npy', allow_pickle=True) return ds_train, ds_test, class_attributes
f"experiments/{config.exp_name}/checkpoint/{str(i).zfill(6)}.pt" ) if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser(description="StyleGAN2 trainer") parser.add_argument("--config", type=str, help="Path to the config") parser.add_argument("--local_rank", type=int, default=0, help="local rank for distributed training") args, _ = parser.parse_known_args() config = Config.load(args.config) config = config.overwrite(Config.read_from_cli()) n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 config.training.distributed = n_gpu > 1 if config.training.distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") synchronize() generator = Generator(config).to(device) discriminator = Discriminator(config).to(device) g_ema = Generator(config).to(device) g_ema.eval()