def test_dummylogger_noop_method_calls():
    """Test that the DummyLogger methods can be called with arbitrary arguments."""
    logger = DummyLogger()
    logger.log_hyperparams("1", 2, three="three")
    logger.log_metrics("1", 2, three="three")
Beispiel #2
0
def train(**args):
    params = EasyDict(args)
    params.gpu = int(params.gpu)

    config = ConfigParser()
    config.read('config.ini')
    if params.datasets == ['all']:
        params.datasets = ['imdb', 'amazon', 'yelp', 'rottentomatoes', 'hotel']

    is_tokenizer_length_dataset_specific = Models(params.model) == Models.distilbert and (
            params.tokenizer_length is None or params.tokenizer_length)
    is_number_prototypes_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_number_of_prototypes is None or params.pc_number_of_prototypes == -1)
    is_sep_loss_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_sep_loss_weight is None or params.pc_sep_loss_weight == -1)
    if_ce_loss_dataset_specific = Models(params.model) == Models.protoconv and (
            params.pc_ce_loss_weight is None or params.pc_ce_loss_weight == -1)

    for dataset in params.datasets:
        params.data_set = dataset
        seed_everything(params.seed)

        if is_tokenizer_length_dataset_specific:
            params.tokenizer_length = dataset_tokens_length[params.data_set]

        if is_number_prototypes_dataset_specific:
            params.pc_number_of_prototypes = dataset_to_number_of_prototypes[params.data_set]

        if is_sep_loss_dataset_specific:
            params.pc_sep_loss_weight = dataset_to_separation_loss[params.data_set]

        if if_ce_loss_dataset_specific:
            weight = 1 - (params.pc_cls_loss_weight + params.pc_sep_loss_weight + params.pc_l1_loss_weight)
            assert weight > 0, f'Weight {weight} of cross entropy loss cannot be less or equal to 0'
            params.pc_ce_loss_weight = weight

        logger = DummyLogger()
        if params.logger:
            comet_config = EasyDict(config['cometml'])
            project_name = params.project_name if params.project_name else comet_config.projectname
            logger = CometLogger(api_key=comet_config.apikey, project_name=project_name,
                                 workspace=comet_config.workspace)

        # logger.experiment.log_code(folder='src')
        logger.log_hyperparams(params)
        base_callbacks = [LearningRateMonitor(logging_interval='epoch')]

        df_dataset = pd.read_csv(f'data/{params.data_set}/tokenized_data.csv')
        n_splits = get_n_splits(dataset=df_dataset, x_label='text', y_label='label', folds=params.fold)
        log_splits(n_splits, logger)

        embeddings = GloVe('42B', cache=params.cache) if Models(params.model) != Models.distilbert else None

        best_models_scores, number_of_prototypes = [], []
        for fold_id, (train_index, val_index, test_index) in enumerate(n_splits):
            i = str(fold_id)

            model_checkpoint = ModelCheckpoint(
                filepath='checkpoints/fold_' + i + '_{epoch:02d}-{val_loss_' + i + ':.4f}-{val_acc_' + i + ':.4f}',
                save_weights_only=True, save_top_k=1, monitor='val_acc_' + i,
                period=params.pc_project_prototypes_every_n
            )
            early_stop = EarlyStopping(monitor=f'val_loss_{i}', patience=10, verbose=True, mode='min', min_delta=0.005)
            callbacks = deepcopy(base_callbacks) + [model_checkpoint, early_stop]

            lit_module = model_to_litmodule[params.model]
            train_df, valid_df = df_dataset.iloc[train_index + val_index], df_dataset.iloc[test_index]
            model, train_loader, val_loader, *utils = lit_module.from_params_and_dataset(train_df, valid_df, params,
                                                                                         fold_id, embeddings)
            trainer = Trainer(auto_lr_find=params.find_lr, logger=logger, max_epochs=params.epoch, callbacks=callbacks,
                              gpus=params.gpu, deterministic=True, fast_dev_run=params.fast_dev_run,
                              num_sanity_val_steps=0)

            trainer.tune(model, train_dataloader=train_loader, val_dataloaders=val_loader)
            trainer.fit(model, train_dataloader=train_loader, val_dataloaders=val_loader)

            for absolute_path in model_checkpoint.best_k_models.keys():
                logger.experiment.log_model(Path(absolute_path).name, absolute_path)

            if model_checkpoint.best_model_score:
                best_models_scores.append(model_checkpoint.best_model_score.tolist())
                logger.log_metrics({'best_model_score_' + i: model_checkpoint.best_model_score.tolist()}, step=0)

            if Models(params.model) == Models.protoconv and model_checkpoint.best_model_path:
                best_model = lit_module.load_from_checkpoint(model_checkpoint.best_model_path)
                saved_number_of_prototypes = sum(best_model.enabled_prototypes_mask.tolist())
                number_of_prototypes.append(saved_number_of_prototypes)
                logger.log_hyperparams({
                    f'saved_prototypes_{fold_id}': saved_number_of_prototypes,
                    f'best_model_path_{fold_id}': str(Path(model_checkpoint.best_model_path).name)
                })

                if params.pc_visualize:
                    data_visualizer = DataVisualizer(best_model)
                    logger.experiment.log_html(f'<h1>Split {fold_id}</h1><br> <h3>Prototypes:</h3><br>'
                                               f'{data_visualizer.visualize_prototypes()}<br>')
                    logger.experiment.log_figure(f'Prototypes similarity_{fold_id}',
                                                 data_visualizer.visualize_similarity().figure)
                    logger.experiment.log_html(f'<h3>Random prediction explanations:</h3><br>'
                                               f'{data_visualizer.visualize_random_predictions(val_loader, n=15)}')

        if len(best_models_scores) >= 1:
            avg_best, std_best = float(np.mean(np.array(best_models_scores))), float(
                np.std(np.array(best_models_scores)))
            table_entry = f'{avg_best:.3f} ($\pm${std_best:.3f})'

            logger.log_hyperparams({
                'avg_best_scores': avg_best,
                'std_best_scores': std_best,
                'table_entry': table_entry
            })

        if len(number_of_prototypes) >= 1:
            logger.log_hyperparams({'avg_saved_prototypes': float(np.mean(np.array(number_of_prototypes)))})

        logger.experiment.end()