def train(param):
    if not isinstance(param, dict):
        args = vars(param)
    else:
        args = param

    framework = get_class_by_name('conditioned_separation', args['model'])
    if args['spec_type'] != 'magnitude':
        args['input_channels'] = 4

    if args['resume_from_checkpoint'] is None:
        if args['seed'] is not None:
            seed_everything(args['seed'])

    model = framework(**args)

    if args['last_activation'] != 'identity' and args[
            'spec_est_mode'] != 'masking':
        warn(
            'Please check if you really want to use a mapping-based spectrogram estimation method '
            'with a final activation function. ')
    ##########################################################

    # -- checkpoint
    ckpt_path = Path(args['ckpt_root_path'])
    mkdir_if_not_exists(ckpt_path)
    ckpt_path = ckpt_path.joinpath(args['model'])
    mkdir_if_not_exists(ckpt_path)
    run_id = args['run_id']
    ckpt_path = ckpt_path.joinpath(run_id)
    mkdir_if_not_exists(ckpt_path)
    save_top_k = args['save_top_k']

    checkpoint_callback = ModelCheckpoint(
        filepath=ckpt_path,
        save_top_k=save_top_k,
        verbose=False,
        monitor='val_loss',
        save_last=False,
        save_weights_only=args['save_weights_only'])
    args['checkpoint_callback'] = checkpoint_callback

    # -- early stop
    patience = args['patience']
    early_stop_callback = EarlyStopping(monitor='val_loss',
                                        min_delta=0.0,
                                        patience=patience,
                                        verbose=False)
    args['early_stop_callback'] = early_stop_callback

    if args['resume_from_checkpoint'] is not None:
        run_id = run_id + "_resume_" + args['resume_from_checkpoint']
        args['resume_from_checkpoint'] = Path(args['ckpt_root_path']).joinpath(
            args['model']).joinpath(args['run_id']).joinpath(
                args['resume_from_checkpoint'])
        args['resume_from_checkpoint'] = str(args['resume_from_checkpoint'])

    model_name = model.spec2spec.__class__.__name__

    # -- logger setting
    log = args['log']
    if log == 'False':
        args['logger'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft_exp',
                                     tags=[model_name],
                                     offline=False,
                                     name=run_id)
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_save_path = 'etc/lightning_logs'
        mkdir_if_not_exists(default_save_path)

    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = dict(
        (name, args[name]) for name in valid_kwargs if name in args)

    # Trainer
    trainer = Trainer(**trainer_kwargs)
    dataset_args = {
        'musdb_root': args['musdb_root'],
        'batch_size': args['batch_size'],
        'num_workers': args['num_workers'],
        'pin_memory': args['pin_memory'],
        'num_frame': args['num_frame'],
        'hop_length': args['hop_length'],
        'n_fft': args['n_fft']
    }

    dp = DataProvider(**dataset_args)
    train_dataset, training_dataloader = dp.get_training_dataset_and_loader()
    valid_dataset, validation_dataloader = dp.get_validation_dataset_and_loader(
    )

    for key in sorted(args.keys()):
        print('{}:{}'.format(key, args[key]))

    if args['auto_lr_find']:
        lr_find = trainer.tuner.lr_find(model,
                                        training_dataloader,
                                        validation_dataloader,
                                        early_stop_threshold=None,
                                        min_lr=1e-5)

        print(f"Found lr: {lr_find.suggestion()}")
        return None

    if args['resume_from_checkpoint'] is not None:
        print('resume from the checkpoint')

    trainer.fit(model, training_dataloader, validation_dataloader)

    return None
def eval(param):

    if not isinstance(param, dict):
        args = vars(param)
    else:
        args = param

    for key in args.keys():
        if args[key] == 'None':
            args[key] = None

    if args['gpu_index'] is not None:
        args['gpus'] = str(args['gpu_index'])

    # MODEL
    ##########################################################
    # # # get framework
    framework = get_class_by_name('conditioned_separation', args['model'])
    if args['spec_type'] != 'magnitude':
        args['input_channels'] = 4
    # # # Model instantiation
    from copy import deepcopy as c
    model_args = c(args)
    model = framework(**model_args)
    ##########################################################

    # Trainer Definition

    # -- checkpoint
    ckpt_path = Path(args['ckpt_root_path']).joinpath(args['model']).joinpath(
        args['run_id'])
    ckpt_path = '{}/{}'.format(str(ckpt_path), args['epoch'])

    # -- logger setting
    log = args['log']
    if log == 'False':
        args['logger'] = False
        args['checkpoint_callback'] = False
        args['early_stop_callback'] = False
    elif log == 'wandb':
        args['logger'] = WandbLogger(project='lasaft_exp',
                                     tags=args['model'],
                                     offline=False,
                                     name=args['run_id'] + '_eval_' +
                                     args['epoch'].replace('=', '_'))
        args['logger'].log_hyperparams(model.hparams)
        args['logger'].watch(model, log='all')
    elif log == 'tensorboard':
        raise NotImplementedError
    else:
        args['logger'] = True  # default
        default_save_path = 'etc/lightning_logs'
        mkdir_if_not_exists(default_save_path)

    # Trainer
    if isinstance(args['gpus'], int):
        if args['gpus'] > 1:
            warn(
                '# gpu and num_workers should be 1, Not implemented: museval for distributed parallel'
            )
            args['gpus'] = 1
            args['distributed_backend'] = None

    valid_kwargs = inspect.signature(Trainer.__init__).parameters
    trainer_kwargs = dict(
        (name, args[name]) for name in valid_kwargs if name in args)

    # DATASET
    ##########################################################
    dataset_args = {
        'musdb_root': args['musdb_root'],
        'batch_size': args['batch_size'],
        'num_workers': args['num_workers'],
        'pin_memory': args['pin_memory'],
        'num_frame': args['num_frame'],
        'hop_length': args['hop_length'],
        'n_fft': args['n_fft']
    }
    dp = DataProvider(**dataset_args)
    ##########################################################

    trainer_kwargs['precision'] = 32
    trainer = Trainer(**trainer_kwargs)
    _, test_data_loader = dp.get_test_dataset_and_loader()
    model = model.load_from_checkpoint(ckpt_path)

    trainer.test(model, test_data_loader)

    return None

if __name__ == '__main__':

    parser = ArgumentParser()
    parser.add_argument('--problem_name', type=str)
    parser.add_argument('--model', type=str)
    parser.add_argument('--mode', type=str, default='train')
    temp_args, _ = parser.parse_known_args()

    # Problem
    problem_name = temp_args.problem_name
    assert problem_name in ['conditioned_separation']

    # Model
    model = get_class_by_name(problem_name, temp_args.model)
    parser = model.add_model_specific_args(parser)

    # Dataset
    parser = DataProvider.add_data_provider_args(parser)

    mode = temp_args.mode

    # Environment Setup
    mkdir_if_not_exists('etc')
    mkdir_if_not_exists('etc/checkpoints')

    parser.add_argument('--ckpt_root_path',
                        type=str,
                        default='etc/checkpoints')
    parser.add_argument('--log', type=str, default=True)