def main_metrics(args): device = utils.get_device_from_arg(args.device) print(f'Using device: {device}') save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' model, tokenizer = utils.get_model_and_tokenizer( model_name=args.model_name, device=device) folder = 'ref' if args.ds_name is None: filename = args.datasplit else: filename = f'{args.ds_name}_{args.datasplit}' ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, args.max_len, args.max_num_data, ds_name=args.ds_name, split=args.datasplit) savefilename = f'{save_directory}/metrics/{folder}/all_{filename}.p' if os.path.isfile(savefilename) and not args.force: print('All metrics already computed. Exiting') return all_sentences = [x[0].numpy().tolist() for x in ds_tokens] is_completed = [True for _ in all_sentences] metrics_all = {} # Distinct-n n_lst = [1, 2, 3, 4, 5, 6] unique_ngram_frac = src.metrics.get_unique_ngram_fraction( all_sentences, n_lst) metrics_all['distinct-n'] = unique_ngram_frac # PPL samples_2 = [ torch.LongTensor(x).view(1, -1).to(device) for x in all_sentences ] ppl = src.metrics.get_perplexity_from_samples(model, samples_2) metrics_all['perplexity'] = ppl # Zipf metrics_all['zipf'] = src.metrics.zipf_coeff(all_sentences) # Repetition metrics_all['repetition'] = src.metrics.get_repetition_fraction( all_sentences) # Non-termination metrics_all[ 'non-termination-ratio'] = src.metrics.get_nontermination_ratio( all_sentences, is_completed) # save with open(savefilename, 'wb') as f: pkl.dump(metrics_all, f) print(f'Done. Saved "{savefilename}". Bye!')
def main(): parser = utils.make_metrics_parser() args = parser.parse_args() print(args) torch.manual_seed(args.seed) device = utils.get_device_from_arg(args.device) save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' if args.use_large_feats: feats_suffix = f'L{args.max_len}' elif args.use_bert_feats: feats_suffix = f'B{args.max_len}' else: feats_suffix = '' if args.use_large_feats: print('---------------Using features from GPT-2 Large!!!! Suffix =', feats_suffix) elif args.use_bert_feats: print('---------------Using features from Roberta Large!!!! Suffix =', feats_suffix) else: print('---------------Using features from model used for generations!!!!') if not os.path.isfile(f'{save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt'): raise FileNotFoundError(f'Generations {save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt do not exist') p_feats = torch.load(f'{save_directory}/generations/ref/feats{feats_suffix}_{args.datasplit}.pt') folder, filename = utils.get_save_filename_from_args(args) algo_name = mauve_metrics.get_discretization_algo_name( discretization_algo=args.discretization, kmeans_num_clusters=args.kmeans_num_clusters, kmeans_explained_var=args.kmeans_explained_var, drmm_num_epochs=args.drmm_num_epochs, drmm_n_layer=args.drmm_n_layer, drmm_n_comp_per_layer=args.drmm_n_component_per_layer, spv_num_epochs=args.spv_num_epochs, device=device, seed=args.seed+1 ) savefilename = f'{save_directory}/metrics/{folder}/mauve_{feats_suffix}_{filename}_{algo_name}.p' if os.path.isfile(savefilename) and not args.force: print('Metrics already exist. Exiting') return if not os.path.isfile(f'{save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt'): raise FileNotFoundError(f'Generations {save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt do not exist') q_feats = torch.load(f'{save_directory}/generations/{folder}/feats{feats_suffix}_{filename}.pt') p_quant, q_quant, metrics = mauve_metrics.compute_mauve_metrics( p_feats, q_feats, discretization_algo=args.discretization, kmeans_num_clusters=args.kmeans_num_clusters, kmeans_explained_var=args.kmeans_explained_var, drmm_num_epochs=args.drmm_num_epochs, drmm_n_layer=args.drmm_n_layer, drmm_n_comp_per_layer=args.drmm_n_component_per_layer, spv_num_epochs=args.spv_num_epochs, device=device, seed=args.seed+1 ) print('Mauve metric:', metrics) # save with open(savefilename, 'wb') as f: pkl.dump([metrics, p_quant, q_quant], f) print(f'Done. Saved "{savefilename}". Bye!')
def main_metrics(args): print(f'device: {args.device}') device = utils.get_device_from_arg(args.device) print(f'Using device: {device}') save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' filename = f'{args.datasplit}_p{args.top_p}_k{args.top_k}_t{args.temp}_seed{args.generate_seed}' folder_name = f'{save_directory}/generations/basic' input_file_name = f'{folder_name}/sample_{filename}.p' if not os.path.isfile(input_file_name): print(f'File {input_file_name} does not exist. Quitting!') return with open(input_file_name, 'rb') as f: all_sentences, is_completed = pkl.load(f)[:2] savefilename = f'{save_directory}/metrics/basic/all_L_{filename}.p' if os.path.isfile(savefilename) and not args.force: print('All metrics already computed. Exiting') return model, tokenizer = utils.get_model_and_tokenizer(model_name='gpt2-large', device=device) metrics_all = {} # Distinct-n n_lst = [1, 2, 3, 4, 5, 6] unique_ngram_frac = src.metrics.get_unique_ngram_fraction( all_sentences, n_lst) metrics_all['distinct-n'] = unique_ngram_frac # PPL samples_2 = [ torch.LongTensor(x).view(1, -1).to(device) for x in all_sentences ] ppl = src.metrics.get_perplexity_from_samples(model, samples_2) metrics_all['perplexity'] = ppl # Zipf metrics_all['zipf'] = src.metrics.zipf_coeff(all_sentences) # Repetition metrics_all['repetition'] = src.metrics.get_repetition_fraction( all_sentences) # Non-termination metrics_all[ 'non-termination-ratio'] = src.metrics.get_nontermination_ratio( all_sentences, is_completed) # save with open(savefilename, 'wb') as f: pkl.dump(metrics_all, f) print(f'Done. Saved "{savefilename}". Bye!')
def main(): parser = make_parser() args = parser.parse_args() print(args) device = utils.get_device_from_arg(args.device) print(f'Using device: {device}') model, tokenizer = utils.get_model_and_tokenizer( model_name=args.model_name, device=device) save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, args.max_len, args.max_num_data, split=args.datasplit) metric_fn_lst = src.metrics.get_probs_metric_fn_lst() metric_fn_names = src.metrics.get_metric_names() print(metric_fn_names) for p in [0.8, 0.9, 0.92, 0.95, 0.99]: # 5 param = (p, 0, 1.0) get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) for k in [1, 5, 10, 50, 100, 500, 1000, 2000, 5000, 10000]: # 10 param = (1.0, k, 1.0) get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) for t in [0.7, 0.8, 0.9, 0.95, 1.0]: # 5 param = (1.0, 0, t) get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory) for t in [0.75, 0.9]: # 4 for k in [10, 100]: param = (1.0, k, t) get_metrics(param, metric_fn_lst, model, ds_tokens, args.datasplit, metric_fn_names, save_directory)
print(args) torch.manual_seed(args.seed) if not args.use_large_feats: raise ValueError('Use large feats!') # check if have to run save_directory = f'./outputs/{utils.get_dataset_name_from_datapath(args.data_dir)}_{utils.get_model_basename(args.model_name)}' if args.ds_name is None: name = args.datasplit else: name = f'{args.ds_name}_{args.datasplit}' folder_name = f'{save_directory}/generations/ref' device = utils.get_device_from_arg(args.device) print(f'Using device: {device}') ###### OLD ## featurize samples # feats = src.model_utils.featurize_sequential(model, ds_tokens) # torch.save(feats, f'{folder_name}/feats_{name}.pt') feats_prefix = '' if args.use_large_feats: model, tokenizer = utils.get_model_and_tokenizer(model_name=args.featurize_model_name, device=device) ds_tokens = utils.load_and_tokenize_data(tokenizer, args.data_dir, args.max_len, args.max_num_generations, ds_name=args.ds_name, split=args.datasplit) for l in {128, 256, 512, args.max_len}: