attribute_path = base_path + attribute_name

                experiment_path = utils.create_experiment_directory()

                model_params = run_utils.create_model_params(experiment_path, epoch, lf, beta, int(epoch / 4), expanded_user_item, mixup,
                                                             no_generative_factors, epoch, is_hessian_penalty_activated, used_data)

                args.max_epochs = epoch

                wandb_logger = WandbLogger(project='recommender-xai', tags=['morpho', train_tag], name=wandb_name)
                trainer = pl.Trainer.from_argparse_args(args,
                                                        logger=wandb_logger, #False
                                                        gpus=0,
                                                        weights_summary='full',
                                                        checkpoint_callback = False,
                                                        callbacks = [ProgressBar(), EarlyStopping(monitor='train_loss')]
                )

                if(train):
                    print('<---------------------------------- VAE Training ---------------------------------->')
                    print("Running with the following configuration: \n{}".format(args))
                    if (synthetic_data):
                        model_params['synthetic_data'], model_params['syn_y'] = data_utils.create_synthetic_data(no_generative_factors,
                                                                                                                 experiment_path,
                                                                                                                 expanded_user_item,
                                                                                                                 continous_data,
                                                                                                                 normalvariate,
                                                                                                                 noise = False)
                        generate_distribution_df()

                    model = VAE(model_params)
def test_v1_7_0_progress_bar():
    with pytest.deprecated_call(match="has been deprecated in v1.5 and will be removed in v1.7."):
        _ = ProgressBar()
Пример #3
0
                        wandb_logger = WandbLogger(project='recommender-xai',
                                                   tags=['vae', train_tag],
                                                   name=wandb_name)
                        trainer = pl.Trainer.from_argparse_args(
                            args,
                            # limit_test_batches=0.1,
                            # precision =16,
                            logger=wandb_logger,  # False
                            gradient_clip_val=0.5,
                            # accumulate_grad_batches=0,
                            gpus=0,
                            weights_summary='full',
                            checkpoint_callback=False,
                            callbacks=[
                                ProgressBar(),
                                EarlyStopping(monitor='train_loss')
                            ])

                        if (train):
                            print(
                                '<---------------------------------- VAE Training ---------------------------------->'
                            )
                            print(
                                "Running with the following configuration: \n{}"
                                .format(args))
                            if (synthetic_data):
                                model_params['synthetic_data'], model_params[
                                    'syn_y'] = data_utils.create_synthetic_data(
                                        no_generative_factors, experiment_path,
                                        expanded_user_item, continous_data,