def main():
    logger = logging.getLogger(__name__)
    start_time = datetime.datetime.now()
    model_args, training_args = load_or_parse_args((ModelArgs, TrainingArgs),
                                                   verbose=True,
                                                   json_path=CONFIG_PATH)
    train_orig_df, label_enc = load_train_dataframe(
        training_args.data_train,
        min_class_samples=training_args.min_class_samples)

    # assert training_args.test_size % training_args.batch_size == 0, "Test size should be multiple of batch size"

    # TODO: split DFs once and keep those on the disk. Reload label_enc from disk on resume.
    train_df, valid_df = split_dataframe_train_test(
        train_orig_df,
        test_size=training_args.test_size,
        stratify=train_orig_df.landmark_id,
        random_state=SEED)
    num_classes = train_df.landmark_id.nunique(
    ) if training_args.min_class_samples is None else len(label_enc.classes_)
    logger.info(f'Num classes train: {num_classes}')
    logger.info(f'Num classes valid: {valid_df.landmark_id.nunique()}')

    logger.info('Initializing the model')
    model = LandmarkModel(model_name=model_args.model_name,
                          n_classes=num_classes,
                          loss_module=model_args.loss_module,
                          pooling_name=model_args.pooling_name,
                          args_pooling=model_args.args_pooling,
                          normalize=model_args.normalize,
                          use_fc=model_args.use_fc,
                          fc_dim=model_args.fc_dim,
                          dropout=model_args.dropout)
    logger.info("Model params:")
    logger.info(pformat(model_args))

    # save checkpoints
    training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True)
    joblib.dump(label_enc,
                filename=training_args.checkpoints_dir /
                training_args.label_encoder_filename)
    logger.info(
        f'Persisted LabelEncoder to {training_args.label_encoder_filename}')
    save_config_checkpoint(training_args.checkpoints_dir,
                           json_path=CONFIG_PATH)

    # Stage 1 - train full model with low resolution
    stage1_start_time = datetime.datetime.now()

    lit_module = LandmarksPLBaseModule(hparams={
        **model_args.__dict__,
        **training_args.__dict__
    },
                                       model=model,
                                       optimizer=training_args.optimizer,
                                       loss=model_args.loss_module)
    # init data
    dm = LandmarksDataModule(
        train_df,
        valid_df,
        hparams=training_args,
        image_dir=training_args.data_path,
        batch_size=training_args.batch_size,
        num_workers=training_args.num_workers,
        use_weighted_sampler=training_args.use_weighted_sampler)
    # train
    dt_str = datetime.datetime.now().strftime("%y%m%d_%H-%M")
    wandb_logger = WandbLogger(
        name=f'{model_args.model_name.capitalize()}_GeM_ArcFace_{dt_str}',
        save_dir='logs/',
        project='landmarks',
        tags=['TPU'],
    )
    checkpoint_callback = ModelCheckpoint(monitor='val_acc',
                                          mode='max',
                                          save_top_k=2,
                                          save_last=True,
                                          verbose=True)
    # hack around to change only filename, not provide the full path (which is generated by W&B)
    checkpoint_callback.filename = '{epoch}-{val_acc:.3f}'
    early_stopping_callback = EarlyStopping('val_acc',
                                            verbose=True,
                                            mode='max')
    trainer = pl.Trainer(
        gpus=training_args.gpus,
        tpu_cores=training_args.tpu_cores,
        logger=wandb_logger,
        max_epochs=training_args.n_epochs,
        val_check_interval=training_args.val_check_interval,
        checkpoint_callback=checkpoint_callback,
        progress_bar_refresh_rate=100,
        resume_from_checkpoint=training_args.resume_checkpoint,
        gradient_clip_val=training_args.gradient_clip_val,
        accumulate_grad_batches=training_args.accumulate_grad_batches,
        early_stop_callback=early_stopping_callback,
        fast_dev_run=DEBUG_ENABLED,
        limit_train_batches=3,
        limit_val_batches=2)
    trainer.fit(lit_module, datamodule=dm)
    try:
        training_args.checkpoints_dir = get_wandb_logger_checkpoints_path(
            wandb_logger)
        logger.info(
            f'Saving checkpoints to the current directory: {training_args.checkpoints_dir}'
        )
    except (NotADirectoryError, FileNotFoundError) as e:
        logger.warning(
            f'Unable to get current checkpoints directory, using default one: '
            f'{training_args.checkpoints_dir}')
        logger.debug(exc_info=e)
    except Exception as e:
        logger.warning('Unknown error', exc_info=e)
    # save checkpoints (saved twice - in default directory above and in wandb current run folder)
    training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True)
    joblib.dump(label_enc,
                filename=training_args.checkpoints_dir /
                training_args.label_encoder_filename)
    logger.info(
        f'Persisted LabelEncoder to {training_args.label_encoder_filename}')
    save_config_checkpoint(training_args.checkpoints_dir,
                           json_path=CONFIG_PATH)
    stage1_end_time = datetime.datetime.now()
    logger.info('Stage 1 duration: {}'.format(stage1_end_time -
                                              stage1_start_time))

    # Stage 2: Fine-tuning with frozen backbone on higher resolution
    # Change:
    # lr=0.01, image_size=512/crop_size=448 (-> DataLoader), margin=0.3, freeze_backbone

    model_args.margin = 0.3
    model_args.freeze_backbone = True
    training_args.data_path = "data/orig"
    training_args.lr = 0.01
    training_args.image_size = 512
    training_args.crop_size = 448

    lit_module = LandmarksPLBaseModule(hparams={
        **model_args.__dict__,
        **training_args.__dict__
    },
                                       model=model,
                                       optimizer=training_args.optimizer,
                                       loss=model_args.loss_module)

    dm = LandmarksDataModule(
        train_df,
        valid_df,
        hparams=training_args,
        image_dir=training_args.data_path,
        batch_size=training_args.batch_size,
        num_workers=training_args.num_workers,
        use_weighted_sampler=training_args.use_weighted_sampler)

    trainer.fit(lit_module, datamodule=dm)

    # Wrap-up
    end_time = datetime.datetime.now()
    logger.info('Training duration: {}'.format(end_time - start_time))
Beispiel #2
0
def main():
    """
    Use this class if anything in trainer checkpoint changed and only model weights are required to be preloaded.
    """
    logger = logging.getLogger(__name__)
    start_time = datetime.datetime.now()
    model_args, training_args = load_or_parse_args((ModelArgs, TrainingArgs), verbose=True)
    train_orig_df, label_enc = load_train_dataframe(training_args.data_train,
                                                    min_class_samples=training_args.min_class_samples)

    # assert training_args.test_size % training_args.batch_size == 0, "Test size should be multiple of batch size"

    # TODO: split DFs once and keep those on the disk. Reload label_enc from disk on resume.
    train_df, valid_df = split_dataframe_train_test(train_orig_df, test_size=training_args.test_size,
                                                    stratify=train_orig_df.landmark_id, random_state=SEED)
    num_classes = train_df.landmark_id.nunique() if training_args.min_class_samples is None else len(label_enc.classes_)
    logger.info(f'Num classes train: {num_classes}')
    logger.info(f'Num classes valid: {valid_df.landmark_id.nunique()}')

    # save checkpoints
    training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True)
    joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename)
    logger.info(f'Persisted LabelEncoder to {training_args.label_encoder_filename}')
    save_config_checkpoint(training_args.checkpoints_dir)

    logger.info('Initializing the model')
    model = LandmarkModel(model_name=model_args.model_name,
                          n_classes=num_classes,
                          loss_module=model_args.loss_module,
                          pooling_name=model_args.pooling_name,
                          args_pooling=model_args.args_pooling,
                          normalize=model_args.normalize,
                          use_fc=model_args.use_fc,
                          fc_dim=model_args.fc_dim,
                          dropout=model_args.dropout
                          )
    logger.info("Model params:")
    logger.info(pformat(model_args))
    model = load_model_state_from_checkpoint(net=model, checkpoint_path=training_args.resume_checkpoint)

    lit_module = LandmarksPLBaseModule(hparams=training_args.__dict__,
                                       model=model,
                                       optimizer=training_args.optimizer,
                                       loss=model_args.loss_module)
    # init data
    dm = LandmarksDataModule(train_df, valid_df,
                             hparams=training_args,
                             image_dir=training_args.data_path,
                             batch_size=training_args.batch_size,
                             num_workers=training_args.num_workers,
                             use_weighted_sampler=training_args.use_weighted_sampler
                             )
    # train
    dt_str = datetime.datetime.now().strftime("%y%m%d_%H-%M")
    wandb_logger = WandbLogger(name=f'{model_args.model_name.capitalize()}_GeM_ArcFace_{dt_str}',
                               save_dir='logs/',
                               project='landmarks')
    checkpoint_callback = ModelCheckpoint(monitor='val_acc',
                                          mode='max',
                                          save_top_k=2,
                                          save_last=True,
                                          verbose=True)
    # hack around to change only filename, not provide the full path (which is generated by W&B)
    checkpoint_callback.filename = '{epoch}-{val_acc:.3f}'

    early_stopping_callback = EarlyStopping('val_acc', verbose=True, mode='max')

    trainer = pl.Trainer(gpus=training_args.gpus,
                         logger=wandb_logger,
                         max_epochs=training_args.n_epochs,
                         val_check_interval=training_args.val_check_interval,
                         checkpoint_callback=checkpoint_callback,
                         progress_bar_refresh_rate=100,
                         gradient_clip_val=training_args.gradient_clip_val,
                         accumulate_grad_batches=training_args.accumulate_grad_batches,
                         early_stop_callback=early_stopping_callback
                         # fast_dev_run=True,
                         # limit_train_batches=5,
                         # limit_val_batches=5
                         )
    trainer.fit(lit_module, datamodule=dm)

    try:
        training_args.checkpoints_dir = get_wandb_logger_checkpoints_path(wandb_logger)
        logger.info(f'Saving checkpoints to the current directory: {training_args.checkpoints_dir}')
    except:
        logger.warning(f'Unable to get current checkpoints directory, using default one: '
                       f'{training_args.checkpoints_dir}')
    # save checkpoints (saved twice - in default directory above and in wandb current run folder)
    training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True)
    joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename)
    logger.info(f'Persisted LabelEncoder to {training_args.label_encoder_filename}')
    save_config_checkpoint(training_args.checkpoints_dir)

    end_time = datetime.datetime.now()
    logger.info('Duration: {}'.format(end_time - start_time))