Example #1
0
def main(args, update_params_dict):
    train_fn = train_ae.train_ae
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    with open(os.path.join('model_save', 'train_params.json'), 'r') as f:
        training_params = json.load(f)

    training_params['unlabeled'].update(update_params_dict)
    param_str = dict_to_str(update_params_dict)
    training_params.update({
        'device': device,
        'es_flag': False,
        'retrain_flag': args.retrain_flag
    })

    if args.omics != 'both':
        training_params.update({
            'device':
            device,
            'model_save_folder':
            os.path.join('model_save', 'ae', args.omics, param_str),
            'es_flag':
            False,
            'retrain_flag':
            args.retrain_flag
        })
        safe_make_dir(training_params['model_save_folder'])

    task_save_folder = os.path.join('model_save', 'ae', args.omics, param_str)
    safe_make_dir(task_save_folder)

    random.seed(2020)

    data_provider = DataProvider(
        batch_size=training_params['unlabeled']['batch_size'],
        target=args.measurement)

    # if args.omics == 'both':
    #     training_params.update(
    #         {
    #             'input_dim': sum([data_provider.shape_dict[k] for k in data_provider.shape_dict if k != 'target']),
    #             'output_dim': data_provider.shape_dict['target']
    #         }
    #     )
    # else:
    training_params.update({'output_dim': data_provider.shape_dict['target']})
    if args.omics != 'both':
        training_params.update({
            'input_dim':
            data_provider.shape_dict[args.omics],
        })

    # start unlabeled training
    if args.omics == 'gex':
        encoder, historys = train_fn(
            dataloader=data_provider.get_unlabeled_gex_dataloader(),
            **wrap_training_params(training_params, type='unlabeled'))
        with open(
                os.path.join(training_params['model_save_folder'],
                             f'unlabel_train_history.pickle'), 'wb') as f:
            for history in historys:
                pickle.dump(dict(history), f)

    elif args.omics == 'mut':
        encoder, historys = train_fn(
            dataloader=data_provider.get_unlabeld_mut_dataloader(match=False),
            **wrap_training_params(training_params, type='unlabeled'))
        with open(
                os.path.join(training_params['model_save_folder'],
                             f'unlabel_train_history.pickle'), 'wb') as f:
            for history in historys:
                pickle.dump(dict(history), f)
    else:
        training_params.update({
            'model_save_folder':
            os.path.join('model_save', 'ae', 'gex', param_str),
            'input_dim':
            data_provider.shape_dict['gex'],
        })
        safe_make_dir(training_params['model_save_folder'])

        gex_encoder, gex_historys = train_fn(
            dataloader=data_provider.get_unlabeled_gex_dataloader(),
            **wrap_training_params(training_params, type='unlabeled'))

        training_params.update({
            'model_save_folder':
            os.path.join('model_save', 'ae', 'mut', param_str),
            'input_dim':
            data_provider.shape_dict['mut'],
        })
        safe_make_dir(training_params['model_save_folder'])

        mut_encoder, mut_historys = train_fn(
            dataloader=data_provider.get_unlabeld_mut_dataloader(match=False),
            **wrap_training_params(training_params, type='unlabeled'))

    ft_evaluation_metrics = defaultdict(list)
    fold_count = 0
    if args.omics == 'gex':
        labeled_dataloader_generator = data_provider.get_labeled_data_generator(
            omics='gex')
        for train_labeled_dataloader, val_labeled_dataloader in labeled_dataloader_generator:
            ft_encoder = deepcopy(encoder)
            target_regressor, ft_historys = fine_tuning.fine_tune_encoder(
                encoder=ft_encoder,
                train_dataloader=train_labeled_dataloader,
                val_dataloader=val_labeled_dataloader,
                test_dataloader=val_labeled_dataloader,
                seed=fold_count,
                metric_name=args.metric,
                task_save_folder=task_save_folder,
                **wrap_training_params(training_params, type='labeled'))
            for metric in [
                    'dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr',
                    'cspearmanr', 'crmse'
            ]:
                ft_evaluation_metrics[metric].append(
                    ft_historys[-2][metric][ft_historys[-2]['best_index']])
            fold_count += 1
    elif args.omics == 'mut':
        labeled_dataloader_generator = data_provider.get_labeled_data_generator(
            omics='mut')
        test_ft_evaluation_metrics = defaultdict(list)
        for train_labeled_dataloader, val_labeled_dataloader, test_labeled_dataloader in labeled_dataloader_generator:
            ft_encoder = deepcopy(encoder)
            target_regressor, ft_historys = fine_tuning.fine_tune_encoder(
                encoder=ft_encoder,
                train_dataloader=train_labeled_dataloader,
                val_dataloader=val_labeled_dataloader,
                test_dataloader=test_labeled_dataloader,
                seed=fold_count,
                metric_name=args.metric,
                task_save_folder=task_save_folder,
                **wrap_training_params(training_params, type='labeled'))
            for metric in [
                    'dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr',
                    'cspearmanr', 'crmse'
            ]:
                ft_evaluation_metrics[metric].append(
                    ft_historys[-2][metric][ft_historys[-2]['best_index']])
                test_ft_evaluation_metrics[metric].append(
                    ft_historys[-1][metric][ft_historys[-2]['best_index']])
            fold_count += 1
        with open(
                os.path.join(task_save_folder,
                             f'{param_str}_test_ft_evaluation_results.json'),
                'w') as f:
            json.dump(test_ft_evaluation_metrics, f)

    else:
        labeled_dataloader_generator = data_provider.get_labeled_data_generator(
            omics='both')
        for train_labeled_dataloader, val_labeled_dataloader in labeled_dataloader_generator:
            ft_gex_encoder = deepcopy(gex_encoder)
            ft_mut_encoder = deepcopy(mut_encoder)
            target_regressor, ft_historys = fine_tuning_mul.fine_tune_encoder(
                encoders=[ft_gex_encoder, ft_mut_encoder],
                train_dataloader=train_labeled_dataloader,
                val_dataloader=val_labeled_dataloader,
                test_dataloader=val_labeled_dataloader,
                seed=fold_count,
                metric_name=args.metric,
                task_save_folder=task_save_folder,
                **wrap_training_params(training_params, type='labeled'))
            for metric in [
                    'dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr',
                    'cspearmanr', 'crmse'
            ]:
                ft_evaluation_metrics[metric].append(
                    ft_historys[-2][metric][ft_historys[-2]['best_index']])
            fold_count += 1

    with open(
            os.path.join(task_save_folder,
                         f'{param_str}_ft_evaluation_results.json'), 'w') as f:
        json.dump(ft_evaluation_metrics, f)
Example #2
0
def main(args, update_params_dict):
    if args.method == 'cleitm':
        train_fn = train_cleitm.train_cleitm
    elif args.method == 'cleita':
        train_fn = train_cleita.train_cleita
    elif args.method == 'cleitc':
        train_fn = train_cleitc.train_cleitc
    elif args.method == 'dsn':
        train_fn = train_dsn.train_dsn
    else:
        train_fn = train_cleit.train_cleit

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    with open(os.path.join('model_save', 'train_params.json'), 'r') as f:
        training_params = json.load(f)

    training_params['unlabeled'].update(update_params_dict)
    param_str = dict_to_str(update_params_dict)

    training_params.update({
        'device':
        device,
        'model_save_folder':
        os.path.join('model_save', args.method, param_str),
        'es_flag':
        False,
        'retrain_flag':
        args.retrain_flag
    })
    task_save_folder = os.path.join('model_save', args.method,
                                    args.measurement)

    safe_make_dir(training_params['model_save_folder'])
    safe_make_dir(task_save_folder)

    data_provider = DataProvider(
        batch_size=training_params['unlabeled']['batch_size'],
        target=args.measurement)
    training_params.update({
        'input_dim': data_provider.shape_dict['gex'],
        'output_dim': data_provider.shape_dict['target']
    })

    random.seed(2020)
    labeled_dataloader_generator = data_provider.get_labeled_data_generator(
        omics='mut')
    ft_evaluation_metrics = defaultdict(list)
    test_ft_evaluation_metrics = defaultdict(list)
    fold_count = 0
    # start unlabeled training
    if 'cleit' not in args.method:
        encoder, historys = train_fn(
            s_dataloaders=data_provider.get_unlabeld_mut_dataloader(
                match=True),
            t_dataloaders=data_provider.get_unlabeled_gex_dataloader(),
            **wrap_training_params(training_params, type='unlabeled'))
        with open(
                os.path.join(training_params['model_save_folder'],
                             f'unlabel_train_history.pickle'), 'wb') as f:
            for history in historys:
                pickle.dump(dict(history), f)

        for train_labeled_dataloader, val_labeled_dataloader, test_labeled_dataloader in labeled_dataloader_generator:
            ft_encoder = deepcopy(encoder)
            target_regressor, ft_historys = fine_tuning.fine_tune_encoder(
                encoder=ft_encoder,
                train_dataloader=train_labeled_dataloader,
                val_dataloader=val_labeled_dataloader,
                test_dataloader=test_labeled_dataloader,
                seed=fold_count,
                metric_name=args.metric,
                task_save_folder=task_save_folder,
                **wrap_training_params(training_params, type='labeled'))
            with open(
                    os.path.join(training_params['model_save_folder'],
                                 f'ft_train_history_{fold_count}.pickle'),
                    'wb') as f:
                for history in ft_historys:
                    pickle.dump(dict(history), f)
            for metric in [
                    'dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr',
                    'cspearmanr', 'crmse'
            ]:
                ft_evaluation_metrics[metric].append(
                    ft_historys[-2][metric][ft_historys[-2]['best_index']])
                test_ft_evaluation_metrics[metric].append(
                    ft_historys[-1][metric][ft_historys[-2]['best_index']])
            fold_count += 1
    else:
        for train_labeled_dataloader, val_labeled_dataloader, test_labeled_dataloader in labeled_dataloader_generator:
            encoder, historys = train_fn(
                dataloader=data_provider.get_unlabeld_mut_dataloader(
                    match=True),
                seed=fold_count,
                **wrap_training_params(training_params, type='unlabeled'))
            ft_encoder = deepcopy(encoder)
            target_regressor, ft_historys = fine_tuning.fine_tune_encoder(
                encoder=ft_encoder,
                train_dataloader=train_labeled_dataloader,
                val_dataloader=val_labeled_dataloader,
                test_dataloader=test_labeled_dataloader,
                seed=fold_count,
                metric_name=args.metric,
                task_save_folder=task_save_folder,
                **wrap_training_params(training_params, type='labeled'))
            for metric in [
                    'dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr',
                    'cspearmanr', 'crmse'
            ]:
                ft_evaluation_metrics[metric].append(
                    ft_historys[-2][metric][ft_historys[-2]['best_index']])
                test_ft_evaluation_metrics[metric].append(
                    ft_historys[-1][metric][ft_historys[-2]['best_index']])
            fold_count += 1
            with open(
                    os.path.join(training_params['model_save_folder'],
                                 f'ft_train_history_{fold_count}.pickle'),
                    'wb') as f:
                for history in ft_historys:
                    pickle.dump(dict(history), f)
        with open(
                os.path.join(task_save_folder,
                             f'{param_str}_test_ft_evaluation_results.json'),
                'w') as f:
            json.dump(test_ft_evaluation_metrics, f)
        with open(
                os.path.join(task_save_folder,
                             f'{param_str}_ft_evaluation_results.json'),
                'w') as f:
            json.dump(ft_evaluation_metrics, f)