def main(args): if args.method == 'dann': train_fn = train_dann.train_dann elif args.method == 'adda': train_fn = train_adda.train_adda elif args.method == 'dcc': train_fn = train_dcc.train_dcc elif args.method == 'cleitcs': train_fn = train_cleitcs.train_cleitcs else: train_fn = train_coral.train_coral device = 'cuda' if torch.cuda.is_available() else 'cpu' with open(os.path.join('model_save', 'train_params.json'), 'r') as f: training_params = json.load(f) training_params.update( { 'device': device, 'model_save_folder': os.path.join('model_save', args.method, 'labeled'), }) safe_make_dir(training_params['model_save_folder']) data_provider = DataProvider(batch_size=training_params['labeled']['batch_size'], target=args.measurement) training_params.update( { 'input_dim': data_provider.shape_dict['gex'], 'output_dim': data_provider.shape_dict['target'] } ) s_labeled_dataloader = data_provider.get_labeled_gex_dataloader() labeled_dataloader_generator = data_provider.get_drug_labeled_mut_dataloader() s_ft_evaluation_metrics = defaultdict(list) t_ft_evaluation_metrics = defaultdict(list) val_ft_evaluation_metrics = defaultdict(list) test_ft_evaluation_metrics = defaultdict(list) fold_count = 0 for train_labeled_dataloader, val_labeled_dataloader, test_labeled_dataloader in labeled_dataloader_generator: target_regressor, train_historys = train_fn(s_dataloaders=s_labeled_dataloader, t_dataloaders=train_labeled_dataloader, val_dataloader=val_labeled_dataloader, test_dataloader=test_labeled_dataloader, metric_name=args.metric, seed = fold_count, **wrap_training_params(training_params, type='labeled')) for metric in ['dpearsonr', 'dspearmanr', 'drmse', 'cpearsonr', 'cspearmanr', 'crmse']: val_ft_evaluation_metrics[metric].append(train_historys[-2][metric][train_historys[-2]['best_index']]) test_ft_evaluation_metrics[metric].append(train_historys[-1][metric][train_historys[-2]['best_index']]) fold_count += 1 with open(os.path.join(training_params['model_save_folder'], f'test_ft_evaluation_results.json'), 'w') as f: json.dump(test_ft_evaluation_metrics, f) with open(os.path.join(training_params['model_save_folder'], f'ft_evaluation_results.json'), 'w') as f: json.dump(val_ft_evaluation_metrics, f)
def main(args): device = 'cuda' if torch.cuda.is_available() else 'cpu' with open(os.path.join('model_save', 'train_params.json'), 'r') as f: training_params = json.load(f) training_params.update( { 'device': device, 'model_save_folder': os.path.join('model_save', 'mlp', args.omics), }) data_provider = DataProvider(batch_size=training_params['unlabeled']['batch_size'], target=args.measurement) training_params.update( { 'input_dim': data_provider.shape_dict[args.omics], 'output_dim': data_provider.shape_dict['target'] } ) ft_evaluation_metrics = defaultdict(list) if args.omics == 'gex': labeled_dataloader_generator = data_provider.get_drug_labeled_gex_dataloader() fold_count = 0 for train_labeled_dataloader, val_labeled_dataloader in labeled_dataloader_generator: target_regressor, ft_historys = fine_tune_encoder( train_dataloader=train_labeled_dataloader, val_dataloader=val_labeled_dataloader, test_dataloader=val_labeled_dataloader, seed=fold_count, metric_name=args.metric, **wrap_training_params(training_params, type='labeled') ) for metric in ['dpearsonr', 'dspearmanr','drmse', 'cpearsonr', 'cspearmanr','crmse']: ft_evaluation_metrics[metric].append(ft_historys[-2][metric][ft_historys[-2]['best_index']]) fold_count += 1 else: labeled_dataloader_generator = data_provider.get_drug_labeled_mut_dataloader() fold_count = 0 test_ft_evaluation_metrics = defaultdict(list) for train_labeled_dataloader, val_labeled_dataloader, test_labeled_dataloader in labeled_dataloader_generator: target_regressor, ft_historys = fine_tune_encoder( train_dataloader=train_labeled_dataloader, val_dataloader=val_labeled_dataloader, test_dataloader=test_labeled_dataloader, seed=fold_count, metric_name=args.metric, **wrap_training_params(training_params, type='labeled') ) for metric in ['dpearsonr', 'dspearmanr','drmse', 'cpearsonr', 'cspearmanr','crmse']: ft_evaluation_metrics[metric].append(ft_historys[-2][metric][ft_historys[-2]['best_index']]) test_ft_evaluation_metrics[metric].append(ft_historys[-1][metric][ft_historys[-2]['best_index']]) fold_count += 1 with open(os.path.join(training_params['model_save_folder'], f'test_ft_evaluation_results.json'), 'w') as f: json.dump(test_ft_evaluation_metrics, f) with open(os.path.join(training_params['model_save_folder'], f'ft_evaluation_results.json'), 'w') as f: json.dump(ft_evaluation_metrics, f)