Ejemplo n.º 1
0
 def __init__(self, config: Config):
     config = config.overwrite(config[config.dataset])
     config = config.overwrite(Config.read_from_cli())
     config.exp_name = f'zsl_{config.dataset}_{config.hp.compute_hash()}_{config.random_seed}'
     if not config.get('silent'):
         print(config.hp)
     self.random = np.random.RandomState(config.random_seed)
     super().__init__(config)
Ejemplo n.º 2
0
    def __init__(self, config: Config):
        config = config.overwrite(
            config.datasets[config.dataset]
        )  # Overwriting with the dataset-dependent hyperparams
        config = config.overwrite(
            Config.read_from_cli())  # Overwriting with the CLI arguments
        config = config.overwrite(Config({'datasets':
                                          None}))  # So not to pollute logs

        super(GANTrainer, self).__init__(config)

        if self.is_distributed:
            torch.set_num_threads(4)
Ejemplo n.º 3
0
    def __init__(self, config):
        # TODO: we should somehow say more loudly that we are reserving these properties
        # Besides, some properties are vital for user to define at he has not idea about it :|
        # TODO: even I do not know all the options available in config :|
        if config.has('base_config'):
            self.config = Config.load(config.base_config)
            self.config.overwrite(config)
        else:
            self.config = config

        self._init_paths()

        # Reload config if we continue training
        if os.path.exists(self.paths.config_path):
            print(
                f'Detected existing config: {self.paths.config_path}. Loading it...'
            )
            # A dirty hack that ensures that multiple trainers sync
            # This is needed for a synced file system
            # For some reason, portalocker does not work on a shared FS...
            time.sleep(1)
            self.config = Config.load(self.paths.config_path)
            self.config = self.config.overwrite(Config.read_from_cli())

        self._init_logger()
        self._init_devices()

        if self.is_main_process() and not os.path.exists(
                self.paths.config_path):
            self.config.save(self.paths.config_path)

        if not self.config.get('silent') and self.is_main_process():
            self.logger.info(
                f'Experiment directory: {self.paths.experiment_dir}')

        self._init_tb_writer()
        self._init_callbacks()
        self._init_checkpointing_strategy()
        self._init_validation_strategy()
        self._init_stopping_criteria()

        self.num_iters_done = 0
        self.num_epochs_done = 0
        self.is_explicitly_stopped = False
        self.train_dataloader = None
        self.val_dataloader = None
Ejemplo n.º 4
0
def run(config_path: str, tb_port: int = None):
    config = Config.load(config_path)
    config = config.overwrite(Config.read_from_cli())

    if config.get('distributed_training.enabled'):
        import horovod.torch as hvd
        hvd.init()
        fix_random_seed(config.random_seed + hvd.rank())
    else:
        fix_random_seed(config.random_seed)

    trainer = GANTrainer(config)

    if not tb_port is None and trainer.is_main_process():
        trainer.logger.info(f'Starting tensorboard on port {tb_port}')
        run_tensorboard(trainer.paths.experiment_dir, tb_port)

    trainer.start()
Ejemplo n.º 5
0
def load_config(args: argparse.Namespace,
                config_cli_args: List[str]) -> Config:
    base_config = Config.load('configs/base.yml')
    curr_config = Config.load(f'configs/{args.config_name}.yml')

    # Setting properties from the base config
    config = base_config.all.clone()
    config = config.overwrite(base_config.get(args.dataset))

    # Setting properties from the current config
    config = config.overwrite(curr_config.all)
    config = config.overwrite(curr_config.get(args.dataset, Config({})))

    # Setting experiment-specific properties
    config.set('experiments_dir', args.experiments_dir)
    config.set('random_seed', args.random_seed)

    # Overwriting with CLI arguments
    config = config.overwrite(Config.read_from_cli())
    config.set('exp_name', compute_experiment_name(args, config.hp))

    return config
Ejemplo n.º 6
0
        type=str,
        default="stylegan2-ffhq-config-f.pt",
        help="path to the model checkpoint",
    )
    parser.add_argument(
        "--channel_multiplier",
        type=int,
        default=2,
        help="channel multiplier of the generator. config-f = 2, else = 1",
    )

    args, _ = parser.parse_known_args()

    args.latent = 512
    args.n_mlp = 8
    config = Config.load('config.yml')
    config = config.overwrite(Config.read_from_cli())

    g_ema = Generator(config).to(device)
    checkpoint = torch.load(args.ckpt)

    g_ema.load_state_dict(checkpoint["g_ema"])

    if args.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.mean_latent(args.truncation_mean)
    else:
        mean_latent = None

    generate(args, g_ema, device, mean_latent)