Exemplo n.º 1
0
def decrease_lr_in_optim_config(conf: Config, num_tasks_learnt: int) -> Config:
    """
    Creates a new optim config with a decreased LR
    """
    if num_tasks_learnt <= 0 or not conf.has('decrease_lr_coef'):
        return conf.clone()

    decrease_coef = conf.decrease_lr_coef**num_tasks_learnt

    # Updating LR in the main kwargs
    if conf.kwargs.has('lr'):
        target_lr = conf.kwargs.lr * decrease_coef
        conf = conf.overwrite({'kwargs': {'lr': target_lr}})

    if conf.kwargs.has('groups'):
        groups_with_lr = [
            g for g in conf.groups[g].keys() if conf.groups[g].has('lr')
        ]
        conf = conf.overwrite({
            'groups': {
                g: conf.groups[g].overwrite({'lr': conf.groups[g].lr})
                for g in groups_with_lr
            }
        })

    return conf
Exemplo n.º 2
0
    def _init_paths(self):
        experiment_dir = infer_new_experiment_path(
            self.config.get('experiment_dir'),
            self.config.get('exp_series_dir'), self.config.get('exp_name'))

        self.paths = Config({
            'experiment_dir':
            experiment_dir,
            'checkpoints_path':
            os.path.join(experiment_dir, 'checkpoints'),
            'summary_path':
            os.path.join(experiment_dir, 'summary.yml'),
            'config_path':
            os.path.join(experiment_dir, 'config.yml'),
            'logs_path':
            os.path.join(experiment_dir, 'logs'),
            'tb_images_path':
            os.path.join(experiment_dir, 'tb_images'),
            'custom_data_path':
            os.path.join(experiment_dir, 'custom_data'),
        })

        if self.config.get('no_saving'): return

        # Have to create all the paths by ourselves
        os.makedirs(self.paths.experiment_dir, exist_ok=True)
        os.makedirs(self.paths.checkpoints_path, exist_ok=True)
        os.makedirs(self.paths.logs_path, exist_ok=True)
        os.makedirs(self.paths.tb_images_path, exist_ok=True)
        os.makedirs(self.paths.custom_data_path, exist_ok=True)
        os.makedirs(os.path.dirname(self.paths.summary_path), exist_ok=True)
Exemplo n.º 3
0
 def __init__(self, config: Config):
     config = config.overwrite(config[config.dataset])
     config = config.overwrite(Config.read_from_cli())
     config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}'
     if not config.get('silent'):
         print(config.hp)
     self.random = np.random.RandomState(config.random_seed)
     super().__init__(config)
Exemplo n.º 4
0
    def __init__(self, config: Config):
        config = config.overwrite(
            config.datasets[config.dataset]
        )  # Overwriting with the dataset-dependent hyperparams
        config = config.overwrite(
            Config.read_from_cli())  # Overwriting with the CLI arguments
        config = config.overwrite(Config({'datasets':
                                          None}))  # So not to pollute logs

        super(GANTrainer, self).__init__(config)

        if self.is_distributed:
            torch.set_num_threads(4)
Exemplo n.º 5
0
    def create_res_config(self, block_idx: int) -> Config:
        increase_conf = self.config.hp.inr.res_increase_scheme
        num_blocks = self.config.hp.inr.num_blocks
        resolutions = self.generate_img_sizes(self.config.data.target_img_size)
        fourier_scale = np.linspace(increase_conf.fourier_scales.min,
                                    increase_conf.fourier_scales.max,
                                    num_blocks)[block_idx]
        dim = np.linspace(increase_conf.dims.max, increase_conf.dims.min,
                          num_blocks).astype(int)[block_idx]
        num_coord_feats = np.linspace(increase_conf.num_coord_feats.max,
                                      increase_conf.num_coord_feats.min,
                                      num_blocks).astype(int)[block_idx]

        return Config({
            'resolution':
            resolutions[block_idx],
            'num_learnable_coord_feats':
            num_coord_feats.item(),
            'use_diag_feats':
            resolutions[block_idx] <= increase_conf.diag_feats_threshold,
            'max_num_fixed_coord_feats':
            10000 if increase_conf.use_fixed_coord_feats else 0,
            'dim':
            dim.item(),
            'fourier_scale':
            fourier_scale.item(),
            'to_rgb':
            resolutions[block_idx] >= increase_conf.to_rgb_res_threshold,
            'n_layers':
            1
        })
Exemplo n.º 6
0
def test_setter():
    assert Config({"a": 3}).a == 3
    assert Config({"a": 3, "b": {"c": 4}}).b.c == 4
    assert Config({"a.b": 3}).a.b == 3

    config = Config({"a.b": 3, "a.e": 4})
    config.set("a.c.d", 5)
    config.set("e.e", [1, 2, 3])

    assert config.a.b == 3
    assert config.a.e == 4
    assert config.a.c.d == 5
    assert config.e.e == tuple([1, 2, 3])
Exemplo n.º 7
0
    def __init__(self, config):
        # TODO: we should somehow say more loudly that we are reserving these properties
        # Besides, some properties are vital for user to define at he has not idea about it :|
        # TODO: even I do not know all the options available in config :|
        if config.has('base_config'):
            self.config = Config.load(config.base_config)
            self.config.overwrite(config)
        else:
            self.config = config

        self._init_paths()

        # Reload config if we continue training
        if os.path.exists(self.paths.config_path):
            print(
                f'Detected existing config: {self.paths.config_path}. Loading it...'
            )
            # A dirty hack that ensures that multiple trainers sync
            # This is needed for a synced file system
            # For some reason, portalocker does not work on a shared FS...
            time.sleep(1)
            self.config = Config.load(self.paths.config_path)
            self.config = self.config.overwrite(Config.read_from_cli())

        self._init_logger()
        self._init_devices()

        if self.is_main_process() and not os.path.exists(
                self.paths.config_path):
            self.config.save(self.paths.config_path)

        if not self.config.get('silent') and self.is_main_process():
            self.logger.info(
                f'Experiment directory: {self.paths.experiment_dir}')

        self._init_tb_writer()
        self._init_callbacks()
        self._init_checkpointing_strategy()
        self._init_validation_strategy()
        self._init_stopping_criteria()

        self.num_iters_done = 0
        self.num_epochs_done = 0
        self.is_explicitly_stopped = False
        self.train_dataloader = None
        self.val_dataloader = None
Exemplo n.º 8
0
def run(config_path: str, tb_port: int = None):
    config = Config.load(config_path)
    config = config.overwrite(Config.read_from_cli())

    if config.get('distributed_training.enabled'):
        import horovod.torch as hvd
        hvd.init()
        fix_random_seed(config.random_seed + hvd.rank())
    else:
        fix_random_seed(config.random_seed)

    trainer = GANTrainer(config)

    if not tb_port is None and trainer.is_main_process():
        trainer.logger.info(f'Starting tensorboard on port {tb_port}')
        run_tensorboard(trainer.paths.experiment_dir, tb_port)

    trainer.start()
Exemplo n.º 9
0
def run_validation_sequence(args: argparse.Namespace, config: Config):
    experiments_vals = generate_experiments_from_hpo_grid(
        config.validation_sequence.hpo_grid)
    experiments_vals = [{p.replace('|', '.'): v
                         for p, v in exp.items()} for exp in experiments_vals]
    configs = [config.overwrite({'hp': Config(hp)}) for hp in experiments_vals]
    scores = []

    print(f'Number of random experiments: {len(configs)}')

    for i, c in enumerate(configs):
        print('<==== Running HPs ====>')
        print(experiments_vals[i])

        c = c.overwrite(
            Config({
                'experiments_dir': f'{config.experiments_dir}-val-seqs',
                'lll_setup.num_tasks': c.validation_sequence.num_tasks,
                'logging.save_train_logits': False,
                'logging.print_accuracy_after_task': False,
                'logging.print_unseen_accuracy': False,
                'logging.print_forgetting': False,
                'exp_name': compute_experiment_name(args, config.hp)
            }))
        trainer = LLLTrainer(c)
        trainer.start()

        if config.validation_sequence.metric == 'harmonic_mean':
            score = np.mean(trainer.compute_harmonic_mean_accuracy())
        elif config.validation_sequence.metric == 'final_task_wise_acc':
            score = np.mean(trainer.compute_final_tasks_performance())
        else:
            raise NotImplementedError('Unknown metric')

        scores.append(score)

    best_config = configs[np.argmax(scores)]
    print('Best found setup:', experiments_vals[np.argmax(scores)])
    print(best_config)

    best_config = best_config.overwrite(
        Config({'start_task': config.validation_sequence.num_tasks}))
    trainer = LLLTrainer(best_config)
    trainer.start()
Exemplo n.º 10
0
def split_classes_for_tasks(config: Config,
                            random_seed: int) -> List[List[int]]:
    """
    Splits classes into `num_tasks` groups and returns these splits

    :param num_classes:
    :param num_tasks:
    :param num_classes_per_task:
    :return:
    """

    if config.has('task_sizes'):
        num_classes_to_use = sum(config.task_sizes)
    else:
        num_classes_to_use = config.num_tasks * config.num_classes_per_task

    if num_classes_to_use > config.num_classes:
        warnings.warn(
            f"We'll have duplicated classes: {num_classes_to_use} > {config.num_classes}"
        )

    classes = np.arange(config.num_classes)
    classes = np.tile(classes,
                      np.ceil(num_classes_to_use / len(classes)).astype(int))
    classes = np.random.RandomState(
        seed=random_seed).permutation(classes)[:num_classes_to_use]

    # classes = np.array([1,2,4,6,9,10,11,12,14,15,16,17,18,19,20,21,23,24,25,26,27,29,31,38,39,40,41,43,44,45,46,47,49,51,53,54,55,56,57,58,59,60,61,62,63,64,66,67,68,69,70,72,73,74,75,76,77,79,80,81,84,86,87,88,89,91,92,93,96,98,99,103,104,105,106,107,108,109,110,112,114,115,116,117,119,121,122,123,124,125,126,127,128,130,131,132,133,135,136,138,139,140,141,142,143,144,145,147,148,149,150,151,152,153,154,156,157,158,159,160,161,163,166,167,168,169,170,171,172,173,174,175,176,177,178,180,181,183,187,188,189,190,191,192,193,194,195,197,198,199,0,3,5,7,8,13,22,28,30,32,33,34,35,36,37,42,48,50,52,65,71,78,82,83,85,90,94,95,97,100,101,102,111,113,118,120,129,134,137,146,155,162,164,165,179,182,184,185,186,196])

    if config.has('task_sizes'):
        steps = flatten([[0], np.cumsum(config.task_sizes[:-1])])
        splits = [
            classes[c:c + size].tolist()
            for c, size in zip(steps, config.task_sizes)
        ]
    else:
        splits = classes.reshape(config.num_tasks, config.num_classes_per_task)
        splits = splits.tolist()

    return splits
Exemplo n.º 11
0
def load_config(args: argparse.Namespace,
                config_cli_args: List[str]) -> Config:
    base_config = Config.load('configs/base.yml')
    curr_config = Config.load(f'configs/{args.config_name}.yml')

    # Setting properties from the base config
    config = base_config.all.clone()
    config = config.overwrite(base_config.get(args.dataset))

    # Setting properties from the current config
    config = config.overwrite(curr_config.all)
    config = config.overwrite(curr_config.get(args.dataset, Config({})))

    # Setting experiment-specific properties
    config.set('experiments_dir', args.experiments_dir)
    config.set('random_seed', args.random_seed)

    # Overwriting with CLI arguments
    config = config.overwrite(Config.read_from_cli())
    config.set('exp_name', compute_experiment_name(args, config.hp))

    return config
Exemplo n.º 12
0
def construct_optimizer(model: nn.Module, optim_config: Config):
    name_to_cls = {
        'sgd': torch.optim.SGD,
        'adam': torch.optim.Adam,
        'rms_prop': torch.optim.RMSprop
    }

    if False and  optim_config.has('groups'):
        groups = [{'params': getattr(model, g).parameters(), **optim_config.groups.get(g)} for g in sorted(optim_config.groups.keys())]
    else:
        groups = [{'params': model.parameters()}]

    return name_to_cls[optim_config.type](groups, **optim_config.kwargs)
Exemplo n.º 13
0
def run_trainer(args: argparse.Namespace, config_args: List[str]):
    # TODO: read some staff from command line and overwrite config
    config = Config.load('configs/densepose-rcnn.yml')

    if not args.local_rank is None:
        config.set('gpus', [args.local_rank])
    else:
        config.set('gpus', args.gpus)

    config.set('experiments_dir', args.experiments_dir)

    config_args = process_cli_config_args(config_args)
    config = config.overwrite(
        Config(config_args))  # Overwrite with CLI arguments

    trainer = DensePoseRCNNTrainer(config)

    if args.validate_only:
        print('Running validation only...')
        trainer.init()
        trainer.val_dataloader = trainer.train_dataloader
        trainer.validate()
    else:
        trainer.start()
Exemplo n.º 14
0
def extract_data(summary_path: os.PathLike,
                 logs_path: os.PathLike) -> Tuple[List[float], Dict, str]:
    config = Config.load(summary_path).config
    events_acc = EventAccumulator(logs_path)
    events_acc.Reload()

    _, _, val_acc_diffs = zip(*events_acc.Scalars('diff/val/acc'))
    hp = config.hp.to_dict()
    hp['n_conv_layers'] = len(config.hp.conv_model_config.conv_sizes)

    if 'Minimum_test' in events_acc.images.Keys():
        image_test = events_acc.Images('Minimum_test')[0].encoded_image_string
        image_train = events_acc.Images(
            'Minimum_train')[0].encoded_image_string
    else:
        image_test = None
        image_train = None

    return val_acc_diffs, hp, image_test, image_train
Exemplo n.º 15
0
def get_transform(config: Config) -> Callable:
    if config.name == 'mnist':
        return transforms.Compose([
            transforms.Resize(config.target_img_size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    elif config.name in {'cifar10', 'single_image'}:
        return transforms.Compose([
            transforms.Resize(
                (config.target_img_size, config.target_img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    elif config.name.startswith('lsun_') or config.name in {
            'ffhq_thumbs', 'celeba_thumbs', 'ffhq_256', 'ffhq_1024'
    }:
        if config.get('concat_patches.enabled'):
            return transforms.Compose([
                CenterCropToMin(),
                transforms.RandomHorizontalFlip(),
                PatchConcatAndResize(config.target_img_size,
                                     config.concat_patches.ratio),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ])
        else:
            return transforms.Compose([
                CenterCropToMin(),
                transforms.RandomHorizontalFlip(),
                transforms.Resize(config.target_img_size),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ])
    elif config.name == 'imagenet_vs':
        return transforms.Compose([
            PadToSquare(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5),
                                 inplace=True),
        ])
    else:
        raise NotImplementedError(f'Unknown dataset: {config.name}')
Exemplo n.º 16
0
    def __init__(self, config: Config):
        super(ConvModel, self).__init__()

        conv_sizes = config.get('conv_sizes', [1, 8, 32, 64])
        dense_sizes = config.get('dense_sizes', [576, 128])
        use_bn = config.get('use_bn', False)
        use_dropout = config.get('use_dropout', False)
        use_maxpool = config.get('use_maxpool', False)
        use_skip_connection = config.get('use_skip_connection', False)
        activation = config.get('activation', 'relu')
        adaptive_pool_size = config.get('adaptive_pool_size', (4, 4))

        if activation == 'relu':
            self.activation = lambda: nn.ReLU(inplace=True)
        elif activation == 'selu':
            self.activation = lambda: nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = lambda: nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = lambda: nn.Sigmoid()
        else:
            raise NotImplementedError(
                f'Unknown activation function: {activation}')

        conv_body = nn.Sequential(*[
            ConvBlock(conv_sizes[i], conv_sizes[i+1], use_bn, use_skip_connection, \
                      use_maxpool, self.activation) for i in range(len(conv_sizes) - 1)])
        dense_head = nn.Sequential(*[
            self._create_dense_block(dense_sizes[i], dense_sizes[
                i + 1], use_dropout) for i in range(len(dense_sizes) - 1)
        ])

        self.nn = nn.Sequential(conv_body,
                                nn.AdaptiveAvgPool2d(adaptive_pool_size),
                                Flatten(), dense_head,
                                nn.Linear(dense_sizes[-1], 10))
Exemplo n.º 17
0
        type=str,
        default="stylegan2-ffhq-config-f.pt",
        help="path to the model checkpoint",
    )
    parser.add_argument(
        "--channel_multiplier",
        type=int,
        default=2,
        help="channel multiplier of the generator. config-f = 2, else = 1",
    )

    args, _ = parser.parse_known_args()

    args.latent = 512
    args.n_mlp = 8
    config = Config.load('config.yml')
    config = config.overwrite(Config.read_from_cli())

    g_ema = Generator(config).to(device)
    checkpoint = torch.load(args.ckpt)

    g_ema.load_state_dict(checkpoint["g_ema"])

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    generate(args, g_ema, device, mean_latent)
Exemplo n.º 18
0
def test_overwrite():
    assert Config({}).overwrite(Config({"a": 3})).a == 3
    assert Config({"a": 2}).overwrite(Config({"a": 3})).a == 3
    assert Config({"b": 4}).overwrite(Config({"a": 3})).a == 3
    assert Config({"b": 4}).overwrite(Config({"a": 3})).b == 4
    assert Config({"a": {"b": 3}}).overwrite(Config({"b": 3})).b == 3
    assert Config({"a": {"b": 4}}).overwrite(Config({"b": 3})).a.b == 4
    assert Config({"a": {"b": 4}}).overwrite(Config({"a": {"c": 5}})).a.b == 4
    assert Config({"a": {"b": 4}}).overwrite(Config({"a": {"c": 5}})).a.c == 5
Exemplo n.º 19
0
def test_getter():
    assert Config({}).get("a") == None
    assert Config({"a": 2}).get("a") == 2
    assert Config({"a": 2}).get("b", 3) == 3
    assert Config({"a": {"b": 4}}).get("a.b") == 4
    assert Config({"a": {"b": 4}}).get("a.c", 5) == 5
Exemplo n.º 20
0
class BaseTrainer:
    def __init__(self, config):
        # TODO: we should somehow say more loudly that we are reserving these properties
        # Besides, some properties are vital for user to define at he has not idea about it :|
        # TODO: even I do not know all the options available in config :|
        if config.has('base_config'):
            self.config = Config.load(config.base_config)
            self.config.overwrite(config)
        else:
            self.config = config

        self._init_paths()

        # Reload config if we continue training
        if os.path.exists(self.paths.config_path):
            print(
                f'Detected existing config: {self.paths.config_path}. Loading it...'
            )
            # A dirty hack that ensures that multiple trainers sync
            # This is needed for a synced file system
            # For some reason, portalocker does not work on a shared FS...
            time.sleep(1)
            self.config = Config.load(self.paths.config_path)
            self.config = self.config.overwrite(Config.read_from_cli())

        self._init_logger()
        self._init_devices()

        if self.is_main_process() and not os.path.exists(
                self.paths.config_path):
            self.config.save(self.paths.config_path)

        if not self.config.get('silent') and self.is_main_process():
            self.logger.info(
                f'Experiment directory: {self.paths.experiment_dir}')

        self._init_tb_writer()
        self._init_callbacks()
        self._init_checkpointing_strategy()
        self._init_validation_strategy()
        self._init_stopping_criteria()

        self.num_iters_done = 0
        self.num_epochs_done = 0
        self.is_explicitly_stopped = False
        self.train_dataloader = None
        self.val_dataloader = None

    ############################
    ### Overwritable methods ###
    ############################
    def init_dataloaders(self):
        pass

    def init_models(self):
        pass

    def init_criterions(self):
        pass

    def init_optimizers(self):
        pass

    def train_on_batch(self, batch):
        pass

    def on_epoch_done(self):
        "Callback which is called when epoch has beed done"
        pass

    def validate(self):
        pass

    def is_main_process(self) -> bool:
        return is_main_process()

    #############
    ### Hooks ###
    #############
    def before_init_hook(self):
        pass

    def after_init_hook(self):
        pass

    def before_training_hook(self):
        pass

    def after_training_hook(self):
        pass

    def get_training_results(self) -> Dict:
        """
        Function which returns training results which
        are passed to summary generation after training is done
        """
        return {}

    ######################
    ### Public methods ###
    ######################
    def start(self):
        if len(self.gpus) > 0:
            with torch.cuda.device(self.gpus[0]):
                self._start()
        else:
            self._start()

    def stop(self, stopping_reason: str = ''):
        self.is_explicitly_stopped = True
        self._explicit_stopping_reason = stopping_reason

    def write_losses(self, losses: dict, prefix=''):
        """
        Iterates over losses and logs them with self.writer
        Arguments:
            - losses: dict of losses; each loss should be a scalar
        """
        for k in losses:
            self.writer.add_scalar(prefix + k, losses[k], self.num_iters_done)

    #######################
    ### Private methods ###
    #######################
    def init(self):
        # Initialization
        self.before_init_hook()
        self.init_dataloaders()
        self.init_models()
        self.init_criterions()
        self.init_optimizers()
        self._try_to_load_checkpoint()
        self.after_init_hook()

    def _start(self):
        self.init()

        # Training
        self.before_training_hook()
        self._run_training()
        self.after_training_hook()
        self.writer.close()

    def _run_training(self):
        try:
            while not self._should_stop():
                if self.config.get('logging.training_progress',
                                   True) and self.is_main_process():
                    batches = tqdm(self.train_dataloader)

                    self.logger.info(
                        'Running epoch #{}'.format(self.num_epochs_done + 1))
                else:
                    batches = self.train_dataloader

                for batch in batches:
                    self._set_train_mode()

                    if self.config.get('should_ignore_oom_batches', False):
                        safe_oom_call(self.train_on_batch,
                                      self.logger,
                                      batch,
                                      debug=self.config.get('debug_gpu'))
                    else:
                        self.train_on_batch(batch)

                    self.num_iters_done += 1

                    # Checkpointing the model BEFORE validation, since validation can hault :|
                    self._try_to_checkpoint()

                    if self.config.get('should_ignore_oom_batches', False):
                        safe_oom_call(self._try_to_validate,
                                      self.logger,
                                      debug=self.config.get('debug_gpu'))
                    else:
                        self._try_to_validate()

                    if self._should_stop():
                        break

                self.num_epochs_done += 1
                self.on_epoch_done()
        except Exception as e:
            self._terminate_experiment(str(e))
            raise

    def _try_to_validate(self):
        should_validate = False

        if self.val_freq_iters:
            should_validate = self.num_iters_done % self.val_freq_iters == 0
        elif self.val_freq_epochs:
            epoch_size = len(self.train_dataloader)
            was_epoch_just_finished = self.num_iters_done % epoch_size == 0
            # TODO: just use different callbacks for val_freq_epochs and val_freq_iters
            num_epochs_done = (
                self.num_epochs_done +
                1) if was_epoch_just_finished else self.num_epochs_done
            is_epoch_appropriate = num_epochs_done % self.val_freq_epochs == 0
            should_validate = was_epoch_just_finished and is_epoch_appropriate

        if should_validate:
            self._set_eval_mode()

            # Validating without grad enabled (less memory consumption)
            with torch.no_grad():
                self.validate()

    def _try_to_checkpoint(self):
        # Checkpointing in non-main processes lead to subtle erros when loading the weights
        if not self.is_main_process() or self.config.get('no_saving'): return

        should_checkpoint = False

        if self.checkpoint_freq_iters:
            should_checkpoint = self.num_iters_done % self.checkpoint_freq_iters == 0
        elif self.checkpoint_freq_epochs:
            # TODO: looks like govnokod
            epoch_size = len(self.train_dataloader)
            freq = self.checkpoint_freq_epochs * epoch_size
            should_checkpoint = self.num_iters_done % freq == 0

        if not should_checkpoint:
            return

        self.checkpoint()
        self._checkpoint_freq_warning()

    def checkpoint(self):
        # TODO: add max_num_checkpoints_to_store argument
        # We want to checkpoint right now!
        if not self.paths.has('checkpoints_path'):
            raise RuntimeError(
                'Tried to checkpoint, but no checkpoint path was specified. Cannot checkpoint.'\
                'Provide either `paths.checkpoints_path` or `experiment_dir` in config.')

        overwrite = not self.config.checkpoint.get('separate_checkpoints')

        for module_name in self.config.get('checkpoint.modules', []):
            self._save_module_state(getattr(self, module_name),
                                    module_name,
                                    overwrite=overwrite)

        for pickle_attr in self.config.get('checkpoint.pickle', []):
            self._pickle(getattr(self, pickle_attr),
                         pickle_attr,
                         overwrite=overwrite)

        self._pickle(
            {
                'num_iters_done': self.num_iters_done,
                'num_epochs_done': self.num_epochs_done
            },
            'training_state',
            overwrite=overwrite)

    def _checkpoint_freq_warning(self):
        """
        Prints warning if we write checkpoints too often and they cost too much
        TODO: wip
        """
        pass

    def _try_to_load_checkpoint(self):
        """Loads model state from checkpoint if it is provided"""
        if not self.is_main_process():
            return  # We should read and broadcast the checkpoint
        if not os.path.isdir(self.paths.checkpoints_path): return

        checkpoints = [
            c for c in os.listdir(self.paths.checkpoints_path)
            if 'training_state' in c
        ]
        if len(checkpoints) == 0:
            return
        checkpoints_iters = [
            int(c[len('training_state-'):-len('-pt')]) for c in checkpoints
        ]
        latest_iter = sorted(checkpoints_iters)[-1]

        try:
            training_state = self._read_pickle_module('training_state',
                                                      latest_iter)
        except FileNotFoundError:
            print('Could not load training state')
            return

        self.num_iters_done = training_state['num_iters_done']
        self.num_epochs_done = training_state['num_epochs_done']

        print(
            f'Continuing from iteration: {self.num_iters_done} ({self.num_epochs_done} epochs)'
        )

        if self.config.checkpoint.get('separate_checkpoints'):
            continue_from_iter = self.num_iters_done
        else:
            continue_from_iter = None  # Since all of them are overwritten

        for module_name in self.config.checkpoint.modules:
            self._load_module_state(getattr(self, module_name), module_name,
                                    continue_from_iter)

        for module_name in self.config.get('checkpoint.pickle', []):
            self._unpickle(module_name, continue_from_iter)

    def _should_stop(self) -> bool:
        "Checks all stopping criteria"
        if (not self.max_num_iters is None) and (self.num_iters_done >=
                                                 self.max_num_iters):
            self._terminate_experiment('Max num iters exceeded')
            return True

        if (not self.max_num_epochs is None) and (self.num_epochs_done >=
                                                  self.max_num_epochs):
            self._terminate_experiment('Max num epochs exceeded')
            return True

        if self._should_early_stop():
            self._terminate_experiment('Early stopping')
            return True

        if self.is_explicitly_stopped:
            self._terminate_experiment(
                f'Stopped explicitly via .stop() method. Reason: {self._explicit_stopping_reason}'
            )
            return True

        return False

    def _should_early_stop(self):
        "Checks early stopping criterion"
        if self.config.get('early_stopping') is None: return False

        history = self.losses[self.config.early_stopping.loss_name]
        n_steps = self.config.early_stopping.history_length
        should_decrease = self.config.early_stopping.should_decrease

        return not is_history_improving(history, n_steps, should_decrease)

    # TODO: we can gather modules automaticall (via "isinstance")
    def _set_train_mode(self, flag: bool = True):
        """Switches all models into training mode"""

        for model_name in self.config.get('modules.models', []):
            getattr(self, model_name).train(flag)

    def _set_eval_mode(self):
        "Switches all models into evaluation mode"
        self._set_train_mode(False)

    def _save_module_state(self,
                           module: nn.Module,
                           name: str,
                           overwrite: bool = True):
        suffix = '' if overwrite else f'-{self.num_iters_done}'
        file_name = f'{name}{suffix}.pt'
        module_path = os.path.join(self.paths.checkpoints_path, file_name)
        torch.save(module.state_dict(), module_path)

    def _load_module_state(self, module, name, iteration: int = None):
        suffix = '' if iteration == None else f'-{iteration}'
        file_name = f'{name}{suffix}.pt'
        module_path = os.path.join(self.paths.checkpoints_path, file_name)
        module.load_state_dict(torch.load(module_path))
        print(f'Loaded checkpoint: {module_path}')

    def _pickle(self, module, name, overwrite: bool = True):
        suffix = '' if overwrite else f'-{self.num_iters_done}'
        file_name = f'{name}{suffix}.pt'
        path = os.path.join(self.paths.checkpoints_path, file_name)
        pickle.dump(module, open(path, 'wb'))

    def _unpickle(self, name, iteration):
        setattr(self, name, self._read_pickle_module(name, iteration))

    def _read_pickle_module(self, name, iteration: int = None):
        suffix = '' if iteration == None else f'-{iteration}'
        file_name = f'{name}{suffix}.pt'
        path = os.path.join(self.paths.checkpoints_path, file_name)
        module = pickle.load(open(path, 'rb'))

        print(f'Loaded pickle module: {path}')

        return module

    def _terminate_experiment(self, termination_reason):
        if not self.is_main_process(): return
        self.logger.info('Terminating experiment because [%s]' %
                         termination_reason)
        self._write_summary(termination_reason)

    def _write_summary(self, termination_reason: str):
        if not self.is_main_process() or self.config.get('no_saving'): return
        if not self.paths.has('summary_path'): return

        summary = {
            'name': self.config.get('exp_name', 'unnamed'),
            'termination_reason': termination_reason,
            'num_iters_done': self.num_iters_done,
            'num_epochs_done': self.num_epochs_done,
            'config': self.config.to_dict(),
            'results': self.get_training_results()
        }

        with open(self.paths.summary_path, 'w') as f:
            yaml.safe_dump(summary, f, default_flow_style=False)

    ##############################
    ### Initialization methods ###
    ##############################
    def _init_logger(self):
        if self.config.has('exp_name'):
            self.logger = logging.getLogger(self.config.exp_name)
        else:
            # TODO: is it okay to use class name?
            self.logger = logging.getLogger(self.__class__.__name__)
            self.logger.warn('You should provide experiment name (by setting "exp_name" attribute in config) ' \
                             'if you want trainer logger to have a specific name.')

        coloredlogs.install(level=self.config.get('logging.level', 'DEBUG'),
                            logger=self.logger)

    def _init_paths(self):
        experiment_dir = infer_new_experiment_path(
            self.config.get('experiment_dir'),
            self.config.get('exp_series_dir'), self.config.get('exp_name'))

        self.paths = Config({
            'experiment_dir':
            experiment_dir,
            'checkpoints_path':
            os.path.join(experiment_dir, 'checkpoints'),
            'summary_path':
            os.path.join(experiment_dir, 'summary.yml'),
            'config_path':
            os.path.join(experiment_dir, 'config.yml'),
            'logs_path':
            os.path.join(experiment_dir, 'logs'),
            'tb_images_path':
            os.path.join(experiment_dir, 'tb_images'),
            'custom_data_path':
            os.path.join(experiment_dir, 'custom_data'),
        })

        if self.config.get('no_saving'): return

        # Have to create all the paths by ourselves
        os.makedirs(self.paths.experiment_dir, exist_ok=True)
        os.makedirs(self.paths.checkpoints_path, exist_ok=True)
        os.makedirs(self.paths.logs_path, exist_ok=True)
        os.makedirs(self.paths.tb_images_path, exist_ok=True)
        os.makedirs(self.paths.custom_data_path, exist_ok=True)
        os.makedirs(os.path.dirname(self.paths.summary_path), exist_ok=True)

    def _init_tb_writer(self):
        if not self.is_main_process() or self.config.get(
                'no_saving') or not self.paths.has('logs_path'):
            logger = self.logger

            # TODO: maybe we should just raise an exception?
            class DummyWriter:
                def __getattribute__(self, name):
                    dummy_fn = lambda *args, **kwargs: None
                    logger.warn(
                        'Tried to use tensorboard, but tensorboard logs dir was not set. Nothing is written.'
                    )
                    return dummy_fn

            self.writer = DummyWriter()
            self.img_writer = DummyWriter()
        else:
            self.writer = SummaryWriter(self.paths.logs_path,
                                        flush_secs=self.config.get(
                                            'logging.tb_flush_secs', 10))
            self.img_writer = SummaryWriter(self.paths.tb_images_path,
                                            flush_secs=self.config.get(
                                                'logging.tb_flush_secs', 10))

    def _init_callbacks(self):
        self._on_iter_done_callbacks: List[Callable] = []
        self._on_epoch_done_callbacks: List[Callable] = []
        self._on_training_done_callbacks: List[Callable] = []

    def _init_checkpointing_strategy(self):
        if self.config.get('checkpoint'):
            self.checkpoint_freq_iters = self.config.checkpoint.get(
                'freq_iters')
            self.checkpoint_freq_epochs = self.config.checkpoint.get(
                'freq_epochs')

            if len(self.config.get('checkpoint.modules')) == 0:
                self.logger.warn(
                    '`checkpoint` config is specified, but no `modules` are provided. '
                    'No torch modules to checkpoint!')

            if self.config.checkpoint.get('pickle'):
                assert type(self.config.checkpoint.pickle) is tuple
                self.logger.info(
                    f'Will be checkpointing with pickle ' \
                    f'the following modules: {self.config.checkpoint.pickle}')

            assert not (self.checkpoint_freq_iters
                        and self.checkpoint_freq_epochs), """
                Can't save both on iters and epochs.
                Please, remove either freq_iters or freq_epochs
            """
        else:
            # TODO: govnokod :|
            self.checkpoint_freq_iters = None
            self.checkpoint_freq_epochs = None

    def _init_validation_strategy(self):
        self.val_freq_iters = self.config.get('val_freq_iters')
        self.val_freq_epochs = self.config.get('val_freq_epochs')

        assert not (self.val_freq_iters and self.val_freq_epochs), """
            Can't validate on both iters and epochs.
            Please, remove either val_freq_iters or val_freq_epochs
        """

    def _init_stopping_criteria(self):
        self.max_num_epochs = self.config.get('hp.max_num_epochs')
        self.max_num_iters = self.config.get('hp.max_num_iters')
        self.losses = {}

        if not (self.max_num_iters or self.max_num_epochs
                or self.config.has('early_stopping')):
            raise ValueError(
                'You should set either `max_num_iters` or `max_num_epochs`')

    def _init_devices(self):
        assert not self.config.has('device_name'), \
            'FireLab detects and sets `device_name` for you. You influence it via `gpus`.'
        assert not hasattr(
            self, 'device_name'
        ), 'You should not overwrite "device_name" attribute in Trainer.'
        assert not hasattr(
            self,
            'gpus'), 'You should not overwrite "gpus" attribute in Trainer.'

        visible_gpus = list(range(torch.cuda.device_count()))
        self.is_distributed = self.config.get('distributed_training.enabled',
                                              False)

        if self.config.has('gpus'):
            self.gpus = self.config.gpus
        elif self.config.has('firelab.gpus'):
            self.gpus = self.config.firelab.gpus
        else:
            # TODO: maybe we should better take GPUs only when allowed?
            self.gpus = visible_gpus
            if not self.config.get('silent'):
                self.logger.warn(
                    f'Attribute "gpus" was not set in config and '
                    f'{len(visible_gpus)} GPUs were found. I gonna use them.')

        if self.is_distributed:
            import horovod.torch as hvd
            hvd.init()
            torch.cuda.device(hvd.local_rank())
            self.device_name = f'cuda:{hvd.local_rank()}'
            self.logger.info(f'My rank is: {hvd.local_rank()}')
        elif len(self.gpus) > 0:
            self.device_name = f'cuda:{self.gpus[0]}'
            torch.cuda.device(self.gpus[0])
        else:
            self.device_name = 'cpu'
Exemplo n.º 21
0
def load_data(
    config: Config,
    img_target_shape: Tuple[int, int] = None
) -> Tuple[ImageDataset, ImageDataset, np.ndarray]:
    if config.name == 'CUB':
        ds_train = cub.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape,
                                    in_memory=config.get('in_memory', False))
        ds_test = cub.load_dataset(config.dir,
                                   split='test',
                                   target_shape=img_target_shape,
                                   in_memory=config.get('in_memory', False))
        class_attributes = cub.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'CUB_EMBEDDINGS':
        ds_train = feats.load_dataset(config.dir,
                                      config.input_type,
                                      split='train')
        ds_test = feats.load_dataset(config.dir,
                                     config.input_type,
                                     split='test')
        class_attributes = cub.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'AWA':
        ds_train = awa.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape)
        ds_test = awa.load_dataset(config.dir,
                                   split='test',
                                   target_shape=img_target_shape)
        class_attributes = awa.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'SUN':
        ds_train = sun.load_dataset(config.dir,
                                    split='train',
                                    target_shape=img_target_shape)
        ds_test = sun.load_dataset(config.dir,
                                   split='val',
                                   target_shape=img_target_shape)
        class_attributes = sun.load_class_attributes(config.dir).astype(
            np.float32)
    elif config.name == 'TinyImageNet':
        ds_train = tiny_imagenet.load_dataset(config.dir,
                                              split='train',
                                              target_shape=img_target_shape)
        ds_test = tiny_imagenet.load_dataset(config.dir,
                                             split='val',
                                             target_shape=img_target_shape)
        class_attributes = None
    elif config.name in SIMPLE_LOADERS.keys():
        ds_train = SIMPLE_LOADERS[config.name](config.dir, split='train')
        ds_test = SIMPLE_LOADERS[config.name](config.dir, split='test')
        class_attributes = None
    elif config.name.endswith('EMBEDDINGS'):
        ds_train = feats.load_dataset(config.dir,
                                      config.input_type,
                                      split='train')
        ds_test = feats.load_dataset(config.dir,
                                     config.input_type,
                                     split='test')
        class_attributes = None
    else:
        raise NotImplementedError(f'Unkown dataset: {config.name}')

    # if embed_data:
    #     ds_train = extract_resnet_features_for_dataset(ds_train, input_type=18)
    #     ds_test = extract_resnet_features_for_dataset(ds_test, input_type=18)

    # np.save(f'/tmp/{config.name}_train', ds_train)
    # np.save(f'/tmp/{config.name}_test', ds_test)
    # ds_train = np.load(f'/tmp/{config.name}_train.npy', allow_pickle=True)
    # ds_test = np.load(f'/tmp/{config.name}_test.npy', allow_pickle=True)

    return ds_train, ds_test, class_attributes
Exemplo n.º 22
0
                    f"experiments/{config.exp_name}/checkpoint/{str(i).zfill(6)}.pt"
                )


if __name__ == "__main__":
    device = "cuda"

    parser = argparse.ArgumentParser(description="StyleGAN2 trainer")

    parser.add_argument("--config", type=str, help="Path to the config")
    parser.add_argument("--local_rank",
                        type=int,
                        default=0,
                        help="local rank for distributed training")
    args, _ = parser.parse_known_args()
    config = Config.load(args.config)
    config = config.overwrite(Config.read_from_cli())

    n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    config.training.distributed = n_gpu > 1

    if config.training.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")
        synchronize()

    generator = Generator(config).to(device)
    discriminator = Discriminator(config).to(device)
    g_ema = Generator(config).to(device)
    g_ema.eval()