def main(): logger = logging.getLogger(__name__) start_time = datetime.datetime.now() model_args, training_args = load_or_parse_args((ModelArgs, TrainingArgs), verbose=True, json_path=CONFIG_PATH) train_orig_df, label_enc = load_train_dataframe( training_args.data_train, min_class_samples=training_args.min_class_samples) # assert training_args.test_size % training_args.batch_size == 0, "Test size should be multiple of batch size" # TODO: split DFs once and keep those on the disk. Reload label_enc from disk on resume. train_df, valid_df = split_dataframe_train_test( train_orig_df, test_size=training_args.test_size, stratify=train_orig_df.landmark_id, random_state=SEED) num_classes = train_df.landmark_id.nunique( ) if training_args.min_class_samples is None else len(label_enc.classes_) logger.info(f'Num classes train: {num_classes}') logger.info(f'Num classes valid: {valid_df.landmark_id.nunique()}') logger.info('Initializing the model') model = LandmarkModel(model_name=model_args.model_name, n_classes=num_classes, loss_module=model_args.loss_module, pooling_name=model_args.pooling_name, args_pooling=model_args.args_pooling, normalize=model_args.normalize, use_fc=model_args.use_fc, fc_dim=model_args.fc_dim, dropout=model_args.dropout) logger.info("Model params:") logger.info(pformat(model_args)) # save checkpoints training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True) joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename) logger.info( f'Persisted LabelEncoder to {training_args.label_encoder_filename}') save_config_checkpoint(training_args.checkpoints_dir, json_path=CONFIG_PATH) # Stage 1 - train full model with low resolution stage1_start_time = datetime.datetime.now() lit_module = LandmarksPLBaseModule(hparams={ **model_args.__dict__, **training_args.__dict__ }, model=model, optimizer=training_args.optimizer, loss=model_args.loss_module) # init data dm = LandmarksDataModule( train_df, valid_df, hparams=training_args, image_dir=training_args.data_path, batch_size=training_args.batch_size, num_workers=training_args.num_workers, use_weighted_sampler=training_args.use_weighted_sampler) # train dt_str = datetime.datetime.now().strftime("%y%m%d_%H-%M") wandb_logger = WandbLogger( name=f'{model_args.model_name.capitalize()}_GeM_ArcFace_{dt_str}', save_dir='logs/', project='landmarks', tags=['TPU'], ) checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=2, save_last=True, verbose=True) # hack around to change only filename, not provide the full path (which is generated by W&B) checkpoint_callback.filename = '{epoch}-{val_acc:.3f}' early_stopping_callback = EarlyStopping('val_acc', verbose=True, mode='max') trainer = pl.Trainer( gpus=training_args.gpus, tpu_cores=training_args.tpu_cores, logger=wandb_logger, max_epochs=training_args.n_epochs, val_check_interval=training_args.val_check_interval, checkpoint_callback=checkpoint_callback, progress_bar_refresh_rate=100, resume_from_checkpoint=training_args.resume_checkpoint, gradient_clip_val=training_args.gradient_clip_val, accumulate_grad_batches=training_args.accumulate_grad_batches, early_stop_callback=early_stopping_callback, fast_dev_run=DEBUG_ENABLED, limit_train_batches=3, limit_val_batches=2) trainer.fit(lit_module, datamodule=dm) try: training_args.checkpoints_dir = get_wandb_logger_checkpoints_path( wandb_logger) logger.info( f'Saving checkpoints to the current directory: {training_args.checkpoints_dir}' ) except (NotADirectoryError, FileNotFoundError) as e: logger.warning( f'Unable to get current checkpoints directory, using default one: ' f'{training_args.checkpoints_dir}') logger.debug(exc_info=e) except Exception as e: logger.warning('Unknown error', exc_info=e) # save checkpoints (saved twice - in default directory above and in wandb current run folder) training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True) joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename) logger.info( f'Persisted LabelEncoder to {training_args.label_encoder_filename}') save_config_checkpoint(training_args.checkpoints_dir, json_path=CONFIG_PATH) stage1_end_time = datetime.datetime.now() logger.info('Stage 1 duration: {}'.format(stage1_end_time - stage1_start_time)) # Stage 2: Fine-tuning with frozen backbone on higher resolution # Change: # lr=0.01, image_size=512/crop_size=448 (-> DataLoader), margin=0.3, freeze_backbone model_args.margin = 0.3 model_args.freeze_backbone = True training_args.data_path = "data/orig" training_args.lr = 0.01 training_args.image_size = 512 training_args.crop_size = 448 lit_module = LandmarksPLBaseModule(hparams={ **model_args.__dict__, **training_args.__dict__ }, model=model, optimizer=training_args.optimizer, loss=model_args.loss_module) dm = LandmarksDataModule( train_df, valid_df, hparams=training_args, image_dir=training_args.data_path, batch_size=training_args.batch_size, num_workers=training_args.num_workers, use_weighted_sampler=training_args.use_weighted_sampler) trainer.fit(lit_module, datamodule=dm) # Wrap-up end_time = datetime.datetime.now() logger.info('Training duration: {}'.format(end_time - start_time))
def main(): """ Use this class if anything in trainer checkpoint changed and only model weights are required to be preloaded. """ logger = logging.getLogger(__name__) start_time = datetime.datetime.now() model_args, training_args = load_or_parse_args((ModelArgs, TrainingArgs), verbose=True) train_orig_df, label_enc = load_train_dataframe(training_args.data_train, min_class_samples=training_args.min_class_samples) # assert training_args.test_size % training_args.batch_size == 0, "Test size should be multiple of batch size" # TODO: split DFs once and keep those on the disk. Reload label_enc from disk on resume. train_df, valid_df = split_dataframe_train_test(train_orig_df, test_size=training_args.test_size, stratify=train_orig_df.landmark_id, random_state=SEED) num_classes = train_df.landmark_id.nunique() if training_args.min_class_samples is None else len(label_enc.classes_) logger.info(f'Num classes train: {num_classes}') logger.info(f'Num classes valid: {valid_df.landmark_id.nunique()}') # save checkpoints training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True) joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename) logger.info(f'Persisted LabelEncoder to {training_args.label_encoder_filename}') save_config_checkpoint(training_args.checkpoints_dir) logger.info('Initializing the model') model = LandmarkModel(model_name=model_args.model_name, n_classes=num_classes, loss_module=model_args.loss_module, pooling_name=model_args.pooling_name, args_pooling=model_args.args_pooling, normalize=model_args.normalize, use_fc=model_args.use_fc, fc_dim=model_args.fc_dim, dropout=model_args.dropout ) logger.info("Model params:") logger.info(pformat(model_args)) model = load_model_state_from_checkpoint(net=model, checkpoint_path=training_args.resume_checkpoint) lit_module = LandmarksPLBaseModule(hparams=training_args.__dict__, model=model, optimizer=training_args.optimizer, loss=model_args.loss_module) # init data dm = LandmarksDataModule(train_df, valid_df, hparams=training_args, image_dir=training_args.data_path, batch_size=training_args.batch_size, num_workers=training_args.num_workers, use_weighted_sampler=training_args.use_weighted_sampler ) # train dt_str = datetime.datetime.now().strftime("%y%m%d_%H-%M") wandb_logger = WandbLogger(name=f'{model_args.model_name.capitalize()}_GeM_ArcFace_{dt_str}', save_dir='logs/', project='landmarks') checkpoint_callback = ModelCheckpoint(monitor='val_acc', mode='max', save_top_k=2, save_last=True, verbose=True) # hack around to change only filename, not provide the full path (which is generated by W&B) checkpoint_callback.filename = '{epoch}-{val_acc:.3f}' early_stopping_callback = EarlyStopping('val_acc', verbose=True, mode='max') trainer = pl.Trainer(gpus=training_args.gpus, logger=wandb_logger, max_epochs=training_args.n_epochs, val_check_interval=training_args.val_check_interval, checkpoint_callback=checkpoint_callback, progress_bar_refresh_rate=100, gradient_clip_val=training_args.gradient_clip_val, accumulate_grad_batches=training_args.accumulate_grad_batches, early_stop_callback=early_stopping_callback # fast_dev_run=True, # limit_train_batches=5, # limit_val_batches=5 ) trainer.fit(lit_module, datamodule=dm) try: training_args.checkpoints_dir = get_wandb_logger_checkpoints_path(wandb_logger) logger.info(f'Saving checkpoints to the current directory: {training_args.checkpoints_dir}') except: logger.warning(f'Unable to get current checkpoints directory, using default one: ' f'{training_args.checkpoints_dir}') # save checkpoints (saved twice - in default directory above and in wandb current run folder) training_args.checkpoints_dir.mkdir(exist_ok=True, parents=True) joblib.dump(label_enc, filename=training_args.checkpoints_dir / training_args.label_encoder_filename) logger.info(f'Persisted LabelEncoder to {training_args.label_encoder_filename}') save_config_checkpoint(training_args.checkpoints_dir) end_time = datetime.datetime.now() logger.info('Duration: {}'.format(end_time - start_time))