# the goal of this script is to download and save wandb stats # stolen straight from https://docs.wandb.ai/library/public-api-guide import wandb import numpy as np import pandas as pd import os import shutil import torch from modelhandling import load_model_from_disk import contextlib from tempfile import mkdtemp __api__ = wandb.Api() __runs__ = __api__.runs("sebaseliens/explainable-asag") def get_run_ids(*groups): return [ run.id for run in __runs__ if run.config['group'] in groups or not groups ] def get_runs(*groups): return [ run for run in __runs__ if run.config['group'] in groups or not groups ] def as_run(run): if isinstance(run, str):
def fetch_history(run_id: str) -> List[Dict]: return wandb.Api().run('sash-a/cdn_test/' + run_id).history()
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "demiurge" project = "melgan" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id batch_size = args.batch_size # Getting initial run steps and epoch # if restore run, replace args steps = None if restore_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep prev_args = argparse.Namespace(**previous_run.config) args = vars(args) args.update(vars(prev_args)) args = Namespace(**args) args.batch_size = batch_size load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate ratios = args.ratios if isinstance(ratios, str): ratios = ratios.replace(" ", "") ratios = ratios.strip("][").split(",") ratios = [int(i) for i in ratios] ratios = np.array(ratios) if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") # read 1 line in train files to log dataset location train_files = Path(args.data_path) / "train_files.txt" with open(train_files, encoding="utf-8", mode="r") as f: file = f.readline() args.train_file_sample = str(file) wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for model, filenames in [ (netG, ["netG.pt", "netG_prev.pt"]), (optG, ["optG.pt", "optG_prev.pt"]), (netD, ["netD.pt", "netD_prev.pt"]), (optD, ["optD.pt", "optD_prev.pt"]), ]: recover_model = False filepath = None for filename in filenames: try: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) filepath = restored_file.name model = load_state_dict_handleDP(model, filepath) recover_model = True break except RuntimeError as e: print("RuntimeError", e) print(f"recover model weight file: '{filename}'' failed") if not recover_model: raise RuntimeError( f"Cannot load model weight files for component {filenames[0]}." ) else: # store successfully recovered model weight file ("***_prev.pt") path_parent = Path(filepath).parent newfilepath = str(path_parent / filenames[1]) os.rename(filepath, newfilepath) wandb.save(newfilepath) if torch.cuda.device_count() > 1: netG = DP(netG).to(device) netD = DP(netD).to(device) fft = DP(fft).to(device) print(f"We have {torch.cuda.device_count()} gpus. Use data parallel.") else: print(f"We have {torch.cuda.device_count()} gpu.") ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") if not restore_run_id: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] melImages = [] num_fix_samples = args.n_test_samples - (args.n_test_samples // 2) cmap = cm.get_cmap("inferno") for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) melImage = s_t.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / (np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize((melImage.width * 4, melImage.height * 4)) melImages.append(wandb.Image(cmap(melImage), caption=f"sample {i}")) if i == num_fix_samples - 1: break # if not resume_run_id: wandb.log({"audio/original": samples}, step=start_epoch) wandb.log({"mel/original": melImages}, step=start_epoch) # else: # print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] melImages = [] # fix samples for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) melImages.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) wandb.log( { "audio/generated": samples, "mel/generated": melImages, "epoch": epoch, }, step=steps, ) # var samples source = [] pred = [] pred_mel = [] num_var_samples = args.n_test_samples - num_fix_samples for i, x_t in enumerate(test_loader): # source x_t = x_t.to(device) audio = x_t.squeeze().cpu() source.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) # pred s_t = fft(x_t).detach() voc = s_t.to(device) pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() pred.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) pred_mel.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) # stop when reach log sample if i == num_var_samples - 1: break wandb.log( { "audio/var_original": source, "audio/var_generated": pred, "mel/var_generated": pred_mel, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def print_metrics(model_name, runs_all, validate_hcp=False): metrics_ukb = { 'f1': [], 'acc': [], 'auc': [], 'sensitivity': [], 'specificity': [] } metrics_hcp = { 'f1': [], 'acc': [], 'auc': [], 'sensitivity': [], 'specificity': [] } for fold_num, run_info in runs_all.items(): run_id = run_info['run_id'] # print('Args are', run_id, device_run, dropout, weight_d) api = wandb.Api() best_run = api.run(f'/st-team/spatio-temporal-brain/runs/{run_id}') # w_config = best_run.config for metric in metrics_ukb.keys(): metrics_ukb[metric].append( best_run.summary[f'values_test_{metric}']) # Running for HCP w_config = best_run.config w_config['analysis_type'] = AnalysisType(w_config['analysis_type']) w_config['dataset_type'] = DatasetType(w_config['dataset_type']) w_config['device_run'] = DEVICE_RUN w_config['param_lr'] = w_config['lr'] if 'lr' not in run_info.keys(): w_config['param_lr'] = w_config['lr'] else: w_config['param_lr'] = float(run_info['lr']) w_config['model_with_sigmoid'] = True w_config['param_activation'] = w_config['activation'] w_config['param_channels_conv'] = w_config['channels_conv'] w_config['param_conn_type'] = ConnType(w_config['conn_type']) w_config['param_conv_strategy'] = ConvStrategy( w_config['conv_strategy']) if 'dropout' not in run_info.keys(): w_config['param_dropout'] = w_config['dropout'] else: w_config['param_dropout'] = float(run_info['dropout']) w_config['param_encoding_strategy'] = EncodingStrategy( w_config['encoding_strategy']) w_config['param_normalisation'] = Normalisation( w_config['normalisation']) w_config['param_num_gnn_layers'] = w_config['num_gnn_layers'] w_config['param_pooling'] = PoolingStrategy(w_config['pooling']) if 'weight_d' not in run_info.keys(): w_config['param_weight_decay'] = w_config['weight_decay'] else: w_config['param_weight_decay'] = float(run_info['weight_d']) w_config['sweep_type'] = SweepType(w_config['sweep_type']) w_config['param_gat_heads'] = 0 if w_config['sweep_type'] == SweepType.GAT: w_config['param_gat_heads'] = w_config.gat_heads if w_config['analysis_type'] == AnalysisType.ST_MULTIMODAL: w_config['multimodal_size'] = 10 elif w_config['analysis_type'] == AnalysisType.ST_UNIMODAL: w_config['multimodal_size'] = 0 if w_config['target_var'] in ['age', 'bmi']: w_config['model_with_sigmoid'] = False # Getting best model inner_fold_for_val: int = 1 model: SpatioTemporalModel = generate_st_model(w_config, for_test=True) if 'model_v' in run_info.keys(): model.VERSION = run_info['model_v'] model_saving_path: str = create_name_for_model( target_var=w_config['target_var'], model=model, outer_split_num=w_config['fold_num'], inner_split_num=inner_fold_for_val, n_epochs=w_config['num_epochs'], threshold=w_config['threshold'], batch_size=w_config['batch_size'], num_nodes=w_config['num_nodes'], conn_type=w_config['param_conn_type'], normalisation=w_config['param_normalisation'], analysis_type=w_config['analysis_type'], metric_evaluated='loss', dataset_type=w_config['dataset_type'], lr=w_config['param_lr'], weight_decay=w_config['param_weight_decay'], edge_weights=w_config['edge_weights']) if 'model_v' in run_info.keys(): # We know the very specific "old" cases if w_config['param_pooling'] == PoolingStrategy.DIFFPOOL: model_saving_path = model_saving_path.replace( 'T_difW_F', 'GC_FGA_F') elif w_config['param_pooling'] == PoolingStrategy.MEAN: model_saving_path = model_saving_path.replace( 'T_no_W_F', 'GC_FGA_F') model.load_state_dict( torch.load(model_saving_path, map_location=w_config['device_run'])) model.eval() if not validate_hcp: continue else: # Getting HCP Data name_dataset = create_name_for_brain_dataset( num_nodes=68, time_length=1200, target_var='gender', threshold=w_config['threshold'], normalisation=w_config['param_normalisation'], connectivity_type=w_config['param_conn_type'], analysis_type=w_config['analysis_type'], encoding_strategy=w_config['param_encoding_strategy'], dataset_type=DatasetType('hcp'), edge_weights=w_config['edge_weights']) print('Going with', name_dataset) dataset = HCPDataset( root=name_dataset, target_var='gender', num_nodes=68, threshold=w_config['threshold'], connectivity_type=w_config['param_conn_type'], normalisation=w_config['param_normalisation'], analysis_type=w_config['analysis_type'], encoding_strategy=w_config['param_encoding_strategy'], time_length=1200, edge_weights=w_config['edge_weights']) # dataset.data is private, might change in future versions of pyg... dataset.data.x = dataset.data.x[:, :490] test_out_loader = DataLoader(dataset, batch_size=w_config['batch_size'], shuffle=False) test_metrics = evaluate_model(model, test_out_loader, w_config['param_pooling'], w_config['device_run']) for metric in metrics_hcp.keys(): metrics_hcp[metric].append(test_metrics[metric]) # print('UKB:') print(model_name, end=' & ') print( f'{round(np.mean(metrics_ukb["auc"]), 2)} ({round(np.std(metrics_ukb["auc"]), 3)}) & ' f'{round(np.mean(metrics_ukb["acc"]), 2)} ({round(np.std(metrics_ukb["acc"]), 3)}) & ' f'{round(np.mean(metrics_ukb["sensitivity"]), 2)} ({round(np.std(metrics_ukb["sensitivity"]), 3)}) & ' f'{round(np.mean(metrics_ukb["specificity"]), 2)} ({round(np.std(metrics_ukb["specificity"]), 3)})' ) if validate_hcp: print('HCP:') print( f'{round(np.mean(metrics_hcp["auc"]), 2)} ({round(np.std(metrics_hcp["auc"]), 3)}) & ' f'{round(np.mean(metrics_hcp["acc"]), 2)} ({round(np.std(metrics_hcp["acc"]), 3)}) & ' f'{round(np.mean(metrics_hcp["sensitivity"]), 2)} ({round(np.std(metrics_hcp["sensitivity"]), 3)}) & ' f'{round(np.mean(metrics_hcp["specificity"]), 2)} ({round(np.std(metrics_hcp["specificity"]), 3)})' )
def _get_run(run_id): run_path = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{run_id}" api = wandb.Api() return api.run(run_path)
def check_run(api: Api) -> bool: print("Checking logged metrics, saving and downloading a file".ljust( 72, "."), end="") failed_test_strings = [] # set up config n_epochs = 4 string_test = "A test config" dict_test = {"config_val": 2, "config_string": "config string"} list_test = [0, "one", "2"] config = { "epochs": n_epochs, "stringTest": string_test, "dictTest": dict_test, "listTest": list_test, } # create a file to save filepath = "./test with_special-characters.txt" f = open(filepath, "w") f.write("test") f.close() with wandb.init(reinit=True, config=config, project=PROJECT_NAME) as run: run_id = run.id entity = run.entity logged = True try: for i in range(1, 11): run.log({"loss": 1.0 / i}, step=i) log_dict = {"val1": 1.0, "val2": 2} run.log({"dict": log_dict}, step=i + 1) except Exception: logged = False failed_test_strings.append( "Failed to log values to run. Contact W&B for support.") try: run.log( {"HT%3ML ": wandb.Html('<a href="https://mysite">Link</a>')}) except Exception: failed_test_strings.append( "Failed to log to media. Contact W&B for support.") wandb.save(filepath) public_api = wandb.Api() prev_run = public_api.run("{}/{}/{}".format(entity, PROJECT_NAME, run_id)) if prev_run is None: failed_test_strings.append( "Failed to access run through API. Contact W&B for support.") print_results(failed_test_strings, False) return False for key, value in prev_run.config.items(): if config[key] != value: failed_test_strings.append( "Read config values don't match run config. Contact W&B for support." ) break if logged and ( prev_run.history_keys["keys"]["loss"]["previousValue"] != 0.1 or prev_run.history_keys["lastStep"] != 11 or prev_run.history_keys["keys"]["dict.val1"]["previousValue"] != 1.0 or prev_run.history_keys["keys"]["dict.val2"]["previousValue"] != 2): failed_test_strings.append( "History metrics don't match logged values. Check database encoding." ) if logged and prev_run.summary["loss"] != 1.0 / 10: failed_test_strings.append( "Read summary values don't match expected value. Check database encoding, or contact W&B for support." ) # TODO: (kdg) refactor this so it doesn't rely on an exception handler try: read_file = retry_fn(partial(prev_run.file, filepath)) read_file = read_file.download(replace=True) except Exception: with wandb.init(reinit=True, project=PROJECT_NAME, config={"test": "test direct saving"}) as run: saved, status_code, _ = try_manual_save(api, filepath, run.id, run.entity) if saved: failed_test_strings.append( "Unable to download file. Check SQS configuration, topic configuration and bucket permissions." ) else: failed_test_strings.append( "Unable to save file with status code: {}. Check SQS configuration and bucket permissions." .format(status_code)) print_results(failed_test_strings, False) return False contents = read_file.read() if contents != "test": failed_test_strings.append( "Contents of downloaded file do not match uploaded contents. Contact W&B for support." ) print_results(failed_test_strings, False) return len(failed_test_strings) == 0
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "materialvision" project = "melganmv" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for obj, filename in [ (netG, "netG.pt"), (optG, "optG.pt"), (netD, "netD.pt"), (optD, "optD.pt"), ]: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) obj.load_state_dict(torch.load(restored_file.name)) ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") # Getting initial run steps and epoch if load_from_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep else: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) if i == args.n_test_samples - 1: break if not resume_run_id: wandb.log({"audio/original": samples}, step=0) else: print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) wandb.log( { "audio/generated": samples, "epoch": epoch, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def fetch_all_wandb_run_ids(wandb_project, wandb_entity, wandb_api=None): if wandb_api is None: wandb_api = wandb.Api() wandb_path = f'{wandb_entity}/{wandb_project}/*' runs = wandb_api.runs(wandb_path) return [run.id for run in runs]
def trainer(args): config_keys = [ "batch_size", "soft_label", "adv_weight", "d_thresh", "z_dim", "z_dis", "model_save_step", "g_lr", "d_lr", "beta", "cube_len", "leak_value", "bias", ] # check new run or resume run if args.resume_id: api = wandb.Api() previous_run = api.run(f"bugan/simple-pytorch-3dgan/{args.resume_id}") config = previous_run.config pprint.pprint(config) run = wandb.init( project="simple-pytorch-3dgan", id=args.resume_id, entity="bugan", config=config, resume=True, ) else: config = { **args.__dict__, **{k: getattr(params, k) for k in config_keys}, } pprint.pprint(config) run = wandb.init( entity="bugan", project="simple-pytorch-3dgan", config=config, resume=True ) # convert config dict to Namespace config = Namespace(**config) # added for output dir save_file_path = params.output_dir + "/" + config.model_name print(save_file_path) # ../outputs/dcgan if not os.path.exists(save_file_path): os.makedirs(save_file_path) # for using tensorboard if config.logs: model_uid = datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S") writer = SummaryWriter( params.output_dir + "/" + config.model_name + "/logs_" + model_uid + "_" + config.logs + "/" ) # datset define # dsets_path = args.input_dir + args.data_dir + "train/" dsets_path = config.data_dir # if params.cube_len == 64: # dsets_path = params.data_dir + params.model_dir + "30/train64/" print(dsets_path) # ../volumetric_data/chair/30/train/ if config.rotate: train_dsets = AugmentDataset(dsets_path, config, "train", res=config.res) else: train_dsets = ShapeNetDataset(dsets_path, config, "train", res=config.res) # val_dsets = ShapeNetDataset(dsets_path, args, "val") train_dset_loaders = torch.utils.data.DataLoader( train_dsets, batch_size=params.batch_size, shuffle=True, num_workers=24, pin_memory=True, ) # val_dset_loaders = torch.utils.data.DataLoader(val_dsets, batch_size=args.batch_size, shuffle=True, num_workers=1) dset_len = {"train": len(train_dsets)} dset_loaders = {"train": train_dset_loaders} # print (dset_len["train"]) # model define D = net_D(config) # summary(net_D, input_size=(32, 32, 32)) G = net_G(config) # print(G) # print(D) # load state dict if resume if args.resume_id: G, D = load_model(run, G, D) wandb.watch(G) wandb.watch(D) # summary(net_G, input_size=(params.z_dim,)) # print total number of parameters in a model # x = sum(p.numel() for p in G.parameters() if p.requires_grad) # print (x) # x = sum(p.numel() for p in D.parameters() if p.requires_grad) # print (x) D_solver = optim.Adam(D.parameters(), lr=params.d_lr, betas=params.beta) # D_solver = optim.SGD(D.parameters(), lr=params.d_lr * 100, momentum=0.9) G_solver = optim.Adam(G.parameters(), lr=params.g_lr, betas=params.beta) D.to(params.device) G.to(params.device) # criterion_D = nn.BCELoss() criterion_D = nn.MSELoss() criterion_G = nn.L1Loss() itr_val = -1 itr_train = -1 for epoch in range(config.epochs): start = time.time() for phase in ["train"]: if phase == "train": # if args.lrsh: # D_scheduler.step() D.train() G.train() else: D.eval() G.eval() running_loss_G = 0.0 running_loss_D = 0.0 running_loss_adv_G = 0.0 for i, X in enumerate(tqdm(dset_loaders[phase])): # if phase == 'val': # itr_val += 1 if phase == "train": itr_train += 1 X = X.to(params.device) # print (X) # print (X.size()) batch = X.size()[0] # print (batch) Z = generateZ(config, batch) # print (Z.size()) # ============= Train the discriminator =============# d_real = D(X) fake = G(Z) if i == 0 and epoch % config.generate_every == 0: image_saved_path = Path(params.images_dir) / config.model_name image_saved_path.mkdir(parents=True, exist_ok=True) samples = fake.cpu().data[:5].squeeze().numpy() fnames = [] for i, samp in enumerate(samples): # print(i, samp) try: mesh = trimesh.voxel.VoxelGrid( trimesh.voxel.encoding.DenseEncoding(samp >= 0.5) ).marching_cubes except ValueError as exc: print(f"Marching cubes failed: {exc}") continue fname = Path(image_saved_path) / f"{epoch:04}_{i}.obj" mesh.export(fname) fnames.append(fname) wandb.log( { "generated_tree_samples": [ wandb.Object3D(open(fname)) for fname in fnames ], "epoch": epoch, }, step=itr_train, ) d_fake = D(fake) real_labels = torch.ones_like(d_real).to(params.device) fake_labels = torch.zeros_like(d_fake).to(params.device) # print (d_fake.size(), fake_labels.size()) if params.soft_label: real_labels = ( torch.Tensor(batch).uniform_(0.7, 1.2).to(params.device) ) fake_labels = torch.Tensor(batch).uniform_(0, 0.3).to(params.device) # print (d_real.size(), real_labels.size()) d_real_loss = criterion_D(d_real, real_labels) d_fake_loss = criterion_D(d_fake, fake_labels) d_loss = d_real_loss + d_fake_loss # no deleted d_real_acu = torch.ge(d_real.squeeze(), 0.5).float() d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float() d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0)) if d_total_acu < params.d_thresh: D.zero_grad() d_loss.backward() D_solver.step() # =============== Train the generator ===============# Z = generateZ(config, batch) # print (X) fake = G(Z) # generated fake: 0-1, X: 0/1 d_fake = D(fake) adv_g_loss = criterion_D(d_fake, real_labels) # print (fake.size(), X.size()) # recon_g_loss = criterion_D(fake, X) recon_g_loss = criterion_G(fake, X) # g_loss = recon_g_loss + params.adv_weight * adv_g_loss g_loss = adv_g_loss if config.local_test: # print('Iteration-{} , D(x) : {:.4} , G(x) : {:.4} , D(G(x)) : {:.4}'.format(itr_train, d_loss.item(), recon_g_loss.item(), adv_g_loss.item())) print( "Iteration-{} , D(x) : {:.4}, D(G(x)) : {:.4}".format( itr_train, d_loss.item(), adv_g_loss.item() ) ) D.zero_grad() G.zero_grad() g_loss.backward() G_solver.step() # =============== logging each 10 iterations ===============# running_loss_G += recon_g_loss.item() * X.size(0) running_loss_D += d_loss.item() * X.size(0) running_loss_adv_G += adv_g_loss.item() * X.size(0) if config.logs: loss_G = { "adv_loss_G": adv_g_loss, "recon_loss_G": recon_g_loss, } loss_D = { "adv_real_loss_D": d_real_loss, "adv_fake_loss_D": d_fake_loss, "d_real_acu": d_real_acu.mean(), "d_fake_acu": d_fake_acu.mean(), "d_total_acu": d_total_acu, } # if itr_val % 10 == 0 and phase == 'val': # save_val_log(writer, loss_D, loss_G, itr_val) if itr_train % 10 == 0 and phase == "train": save_train_log(writer, loss_D, loss_G, itr_train) wandb.log( {"G": loss_G, "D": loss_D, "epoch": epoch}, step=itr_train ) # =============== each epoch save model or save image ===============# epoch_loss_G = running_loss_G / dset_len[phase] epoch_loss_D = running_loss_D / dset_len[phase] epoch_loss_adv_G = running_loss_adv_G / dset_len[phase] end = time.time() epoch_time = end - start print( "Epochs-{} ({}) , D(x) : {:.4}, D(G(x)) : {:.4}".format( epoch, phase, epoch_loss_D, epoch_loss_adv_G ) ) print("Elapsed Time: {:.4} min".format(epoch_time / 60.0)) if (epoch + 1) % params.model_save_step == 0: print("model_saved, images_saved...") save_model(run, config.model_name, G, D)
def main( output_folder=None, duration=10, num_samples=5, gid=1, seed=123, melgan_run_id=None, unagan_run_id=None, hifigan_run_id=None, wandb_code=None, ): if wandb_code: wandb.login(key=wandb_code, relogin=True) if melgan_run_id and hifigan_run_id: raise Exception("Can only set one of [melgan_run_id, hifigan_run_id], not both") if not unagan_run_id: raise Exception("unagan_run_id should not be empty") download_weights.main( model_dir=Path("models/custom"), melgan_run_id=melgan_run_id, unagan_run_id=unagan_run_id, hifigan_run_id=hifigan_run_id, ) # ### Data type ### # assert data_type in ["singing", "speech", "piano", "violin"] # ### Architecture type ### # if data_type == "singing": # assert arch_type in ["nh", "h", "hc"] # elif data_type == "speech": # assert arch_type in ["h", "hc"] # elif data_type == "piano": # assert arch_type in ["hc"] # elif data_type == "violin": # assert arch_type in ["hc"] # if arch_type == "nh": # arch_type = "nonhierarchical" # elif arch_type == "h": # arch_type = "hierarchical" # elif arch_type == "hc": # arch_type = "hierarchical_with_cycle" data_type = "custom" arch_type = "hierarchical_with_cycle" # ### Model type ### model_type = f"{data_type}.{arch_type}" # ### Model info ### if output_folder is None: output_folder = Path("generated_samples") / model_type output_folder = Path(output_folder) output_folder.mkdir(parents=True, exist_ok=True) # also save to all_generated_audio_dir is the folder exists, # but do not save if in unagan training (both run id None) if not unagan_run_id: all_generated_audio_dir = None else: try: all_generated_audio_dir = Path( "/content/drive/My Drive/PUBLICATIONS/The Replicant/AUDIO DATABASE/UNAGAN OUTPUT/AUDIOS/" ) all_generated_audio_dir.mkdir(parents=True, exist_ok=True) print( "generated audio files will also saved to:", str(all_generated_audio_dir), ) except: all_generated_audio_dir = None print( "the path '", str(all_generated_audio_dir), "' not exists. Only save audio files to:", str(output_folder), ) api = wandb.Api() previous_run = api.run(f"demiurge/unagan/{unagan_run_id}") unagan_config = Namespace(**previous_run.config) ################# unagan config parameters ################## z_dim = unagan_config.z_dim z_scale_factors = unagan_config.z_scale_factors z_total_scale_factor = np.prod(z_scale_factors) feat_dim = unagan_config.feat_dim ################## param_fp = f"models/{data_type}/params.generator.{arch_type}.pt" mean_fp = f"models/{data_type}/mean.mel.npy" std_fp = f"models/{data_type}/std.mel.npy" mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1) std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1) if gid >= 0: mean = mean.cuda(gid) std = std.cuda(gid) ############################################################### ### Vocoder info ### ### MELGAN ### if melgan_run_id: api = wandb.Api() previous_run = api.run(f"demiurge/melgan/{melgan_run_id}") melgan_config = Namespace(**previous_run.config) ################# melgan config parameters ################## # melgan only parameters n_mel_channels = 80 ngf = 32 n_residual_layers = 3 # also applied to unagan generate sampling_rate = 44100 hop_length = 256 if n_mel_channels != feat_dim: print( f"Warning!!! melgan n_mel_channels {n_mel_channels} != unagan feat_dim {feat_dim}" ) if hasattr(melgan_config, "hop_length"): hop_length = melgan_config.hop_length if hasattr(melgan_config, "sampling_rate"): sampling_rate = melgan_config.sampling_rate if hasattr(melgan_config, "n_mel_channels"): n_mel_channels = melgan_config.n_mel_channels if hasattr(melgan_config, "ngf"): ngf = melgan_config.ngf if hasattr(melgan_config, "n_residual_layers"): n_residual_layers = melgan_config.n_residual_layers ######################## # ### Vocoder Model ### vocoder_model_dir = Path("models") / data_type / "vocoder" if data_type == "speech": vocoder_name = "OriginalGenerator" else: vocoder_name = "GRUGenerator" MelGAN = getattr(melgan_models, vocoder_name) vocoder = MelGAN(n_mel_channels, ngf, n_residual_layers) vocoder.eval() vocoder_param_fp = vocoder_model_dir / "params.pt" vocoder_state_dict = torch.load(vocoder_param_fp) try: vocoder.load_state_dict(vocoder_state_dict) except RuntimeError as e: print(e) print("Fixing model by removing .module prefix") vocoder_state_dict = OrderedDict( (k.split(".", 1)[1], v) for k, v in vocoder_state_dict.items() ) vocoder.load_state_dict(vocoder_state_dict) if gid >= 0: vocoder = vocoder.cuda(gid) ### HIFI-GAN ### if hifigan_run_id: api = wandb.Api() previous_run = api.run(f"demiurge/hifi-gan/{hifigan_run_id}") hifigan_config = Namespace(**previous_run.config) # parameters applied to unagan generate sampling_rate = 44100 hop_length = 256 if hasattr(hifigan_config, "hop_size"): hop_length = hifigan_config.hop_size if hasattr(hifigan_config, "sampling_rate"): sampling_rate = hifigan_config.sampling_rate vocoder_model_dir = Path("models") / data_type / "vocoder" vocoder = hifi_models.Generator(hifigan_config) vocoder.eval() vocoder_state_dict = torch.load(vocoder_model_dir / "g") vocoder.load_state_dict(vocoder_state_dict["generator"]) if gid >= 0: vocoder = vocoder.cuda(gid) ################################################################### # ### Generator ### if arch_type == "nonhierarchical": generator = NonHierarchicalGenerator(feat_dim, z_dim) elif arch_type.startswith("hierarchical"): generator = HierarchicalGenerator(feat_dim, z_dim, z_scale_factors) generator.eval() for p in generator.parameters(): p.requires_grad = False manager.load_model(param_fp, generator, device_id="cpu") if gid >= 0: generator = generator.cuda(gid) # ### Process ### torch.manual_seed(seed) # information for filename filename_base = datetime.utcnow().strftime("%Y-%m-%d_%H-%M") if melgan_run_id: filename_base += "_mel-" + melgan_run_id if unagan_run_id: filename_base += "_una-" + unagan_run_id if hifigan_run_id: filename_base += "_hifi-" + hifigan_run_id num_frames = int(np.ceil(duration * (sampling_rate / hop_length))) audio_array = [] for ii in range(num_samples): out_fp_wav = Path(output_folder) / f"{filename_base}_sample{ii}.wav" print(f"Generating {out_fp_wav}") if arch_type == "nonhierarchical": z = torch.zeros((1, z_dim, num_frames)).normal_(0, 1).float() elif arch_type.startswith("hierarchical"): z = ( torch.zeros((1, z_dim, int(np.ceil(num_frames / z_total_scale_factor)))) .normal_(0, 1) .float() ) if gid >= 0: z = z.cuda(gid) with torch.set_grad_enabled(False): with torch.cuda.device(gid): # Generator melspec_voc = generator(z) melspec_voc = (melspec_voc * std) + mean # Vocoder audio = vocoder(melspec_voc) audio = audio.squeeze().cpu().numpy() # keep generated audio as array to log to wandb if not unagan_run_id: audio_array.append(audio) else: # Save to wav sf.write(out_fp_wav, audio, sampling_rate) audio_array.append(out_fp_wav) # Save also to all_generated_audio_dir if all_generated_audio_dir: out2_fp_wav = ( Path(all_generated_audio_dir) / f"{filename_base}_sample{ii}.wav" ) sf.write(out2_fp_wav, audio, sampling_rate) return audio_array, sampling_rate
from scipy import stats from scipy.special import factorial from scipy.stats import binom_test, wilcoxon import os import pickle from datetime import datetime import tabulate import wandb from collections import namedtuple, defaultdict, OrderedDict import json from ipypb import ipb # from metalearning import cnnmlp API = wandb.Api() MAX_HISTORY_SAMPLES = 4000 DATASET_CORESET_SIZE = 22500 ACCURACY_THRESHOLD = 0.95 TASK_ACC_COLS = [f'Test Accuracy, Query #{i}' for i in range(1, 11)] QUERY_NAMES = [ 'blue', 'brown', 'cyan', 'gray', 'green', 'orange', 'pink', 'purple', 'red', 'yellow', 'cone', 'cube', 'cylinder', 'dodecahedron', 'ellipsoid', 'octahedron', 'pyramid', 'rectangle', 'sphere', 'torus', 'chain_mail', 'marble', 'maze', 'metal', 'metal_weave', 'polka', 'rubber', 'rug', 'tiles', 'wood_plank' ] COLOR = 'color'
def list_run_files(run_id: str, project: str = "flowers", entity: str = "jeremytjordan"): api = wandb.Api() run = api.run(f"{entity}/{project}/{run_id}") return [f.name for f in run.files()]
def list_runs(project: str = "flowers", entity: str = "jeremytjordan"): api = wandb.Api() runs = api.runs(f"{entity}/{project}") return [r.id for r in runs]
def __init__(self, project_name=None): self.api = wandb.Api() if project_name is not None: self.set_project(project_name)
def main(): n_qubits = 8 n_layers_list = [32, 64, 80, 96] project = 'IsingModel' target_cfgs = { 'config.n_qubits': n_qubits, 'config.n_layers': { "$in": n_layers_list }, 'config.g': 2, 'config.h': 0, 'config.lr': 0.05, 'config.seed': 96, 'config.scheduler_name': 'exponential_decay', } print(f'Downloading experiment results from {project}') print(f'| Target constraints: {target_cfgs}') api = wandb.Api() runs = api.runs(project, filters=target_cfgs) history = {} for run in runs: if run.state == 'finished': print(run.name) n_layers = run.config['n_layers'] h = run.history() # Theoretically E(\theta) >= E_0 and fidelity <= 1. # If it is negative, it must be a precision error. h['loss'] = h['loss'].clip(lower=0.) h['fidelity/ground'] = h['fidelity/ground'].clip(upper=1.) history[n_layers] = h print('Download done') assert set(history.keys()) == set(n_layers_list) linestyles = ['-', '-.', '--', ':'] linewidths = [1.2, 1.2, 1.3, 1.4] xlim = 0, 500 plt.subplot(211) for i, n_layers in enumerate(n_layers_list): h = history[n_layers] plt.plot(h._step, h.loss, linestyles[i], color=color_list[i], linewidth=linewidths[i], alpha=1., markersize=5, label=f'L={n_layers}') plt.xlim(*xlim) plt.yscale('log') plt.ylabel(r'$E(\mathbf{\theta}) - E_0$', fontsize=13) plt.grid(True, c='0.5', ls=':', lw=0.5) # plt.legend(loc='upper right') plt.subplot(212) for i, n_layers in enumerate(n_layers_list): h = history[n_layers] plt.plot(h._step, h['fidelity/ground'], linestyles[i], color=color_list[i], linewidth=linewidths[i], alpha=1., markersize=5, label=f'L={n_layers}') plt.xlim(*xlim) plt.xlabel('Optimization Steps', fontsize=13) plt.ylabel( r'$|\,\langle \psi(\mathbf{\theta^*})\, |\, \phi \rangle\, |^2$', fontsize=13) plt.grid(True, c='0.5', ls=':', lw=0.5) plt.legend(loc='lower right') plt.tight_layout() plt.savefig('fig/ising_optimization_ed.pdf', bbox_inches='tight') plt.show()
def get_wandb_dataframes(run_list=None, project=None): api = wandb.Api() delta_dataframes = [] for run_key in run_list: delta_dataframes.append(api.run(run_key).history()) return delta_dataframes
def _get_model_candidates_from_wb(project, model_use_case_id): api = wandb.Api({"project": project}) versions = api.artifact_versions( "model", "{}_model_candidates".format(model_use_case_id)) return versions
def delete_wandb_run(run_name): api = wandb.Api() run = api.run(run_name) run.delete() logging.info(f"run {run_name} had been deleted with success")
def __init__(self, path=None, opts=None): self.path = path self.api = wandb.Api() self.opts = opts or {} self.displayed = False self.height = self.opts.get("height", 420)
def sync_crashed(sweep_name: Optional[str]): wandb_key = get_wandb_env() assert wandb_key, "W&B API key is needed for staring a W&B swype" project = config.get("wandb", {}).get("project") api = wandb.Api() if sweep_name is not None: sweep_map = get_sweep_table(api, project) name_to_id, repeats = invert_sweep_id_table(sweep_map) if sweep_name in name_to_id: sweep_name = name_to_id[sweep_name] elif sweep_name in repeats: print(f"ERROR: ambigous sweep name: {sweep_name}") return relpath = get_relative_path() runs = get_runs_in_sweep(api, project, sweep_name, {"state": "crashed"}) print( f"Sweep {sweep_name}: found {len(runs)} crashed runs. Trying to synchronize..." ) for r in runs: hostname = get_run_host(api, project, r.id) dir = None found = [] cmd = f"find ./wandb -iname '*{r.id}'" if hostname is None: res = run_multiple_hosts(config["hosts"], cmd) for hn, (res, retcode) in res.items(): res = res.strip() if retcode == 0 and res: found.append((hn, res)) else: res, retcode = run_multiple_hosts([hostname], cmd)[hostname] res = res.strip() if retcode == 0 and res: found = [(hostname, res)] if len(found) != 1: print(f"WARNING: Failed to identify run {r.id}") continue hostname, dir = found[0] if len(dir.split("\n")) != 1: print(f"WARNING: Failed to identify run {r.id}") continue print(f"Found run {r.id} at {hostname} in dir {dir}. Syncing...") cd = config.get_command(hostname, "cd") wandb_cmd = config.get_command(hostname, "wandb", "~/.local/bin/wandb") cmd = f"{cd} {relpath}; {wandb_key} {wandb_cmd} sync {dir}" _, errcode = remote_run(hostname, cmd + " 2>/dev/null") if errcode != 0: print("Sync failed :(") continue
def main(): parser = ArgumentParser() group = parser.add_mutually_exclusive_group() group.add_argument('--sweep', help="Select runs from the given sweep.") group.add_argument('--tag', help="Select runs with the given tag.") parser.add_argument('--project', help="Path of the project, in the form entity_id/project_id.") parser.add_argument('--dry-run', action='store_true', help="Describe the changes without actually performing them.") args = parser.parse_args() wandb.init(job_type='update_metrics', project=args.project) overrides = {} if args.project: overrides['project'] = args.project api = wandb.Api(overrides) if args.tag: runs = api.runs(args.project, filters={ 'tags': args.tag }) plots_dir = os.path.join('update_metrics', args.tag) elif args.sweep: sweep: wandb_api.Sweep = api.sweep(args.sweep) print(f"Processing sweep {sweep.url}") runs = sweep.runs plots_dir = os.path.join('update_metrics', sweep.id) else: raise ValueError("One of --tag or --sweep must be provided.") run: wandb_api.Run for run in tqdm(runs): version = run.config.get('metrics_version', 0) if version == CURRENT_VERSION: continue tqdm.write(f"Run {run.name}:") tqdm.write(f" - URL {run.url}") tqdm.write(f" - current metrics version v{version}") if version < 1: tqdm.write(f" - adding entropy discrimination ROC curve") add_entropy_roc(run, plots_dir) if version < 2: tqdm.write(f" - adding accuracy / AUC combined score") add_combined_score(run) if version < 3: tqdm.write(f" - adding default approach config key") add_default_approach(run) if version < 4: tqdm.write(f" - adding checkpoint artifact") add_checkpoint_artifact(run, api, args.dry_run) run.config['metrics_version'] = CURRENT_VERSION if not args.dry_run: run.update()
def get_policy(env_name: str, pre_trained: int = 1): """ Retrieves policies for the environment with the pre-trained quality marker. :param env_name: name of the environment :param pre_trained: pre_trained level . It should be between 1 and 5 , where 1 indicates best model and 5 indicates worst level. Example: >>> import policybazaar >>> policybazaar.get_policy('d4rl:maze2d-open-v0',pre_trained=1) """ assert MIN_PRE_TRAINED_LEVEL <= pre_trained <= MAX_PRE_TRAINED_LEVEL, \ 'pre_trained marker should be between [{},{}] where {} indicates the best model' \ ' and {} indicates the worst model'.format(MIN_PRE_TRAINED_LEVEL, MAX_PRE_TRAINED_LEVEL, MIN_PRE_TRAINED_LEVEL, MAX_PRE_TRAINED_LEVEL) assert env_name in ENV_IDS or env_name in CHILD_PARENT_ENVS, \ '`{}` not found. It should be among following: {}'.format(env_name, list(ENV_IDS.keys()) + list(CHILD_PARENT_ENVS.keys())) if env_name not in ENV_IDS: env_name = CHILD_PARENT_ENVS[env_name] if env_name in ENV_PERFORMANCE_STATS and pre_trained in ENV_PERFORMANCE_STATS[ env_name]: info = ENV_PERFORMANCE_STATS[env_name][pre_trained] else: info = {} run_path = ENV_IDS[env_name]['wandb_run_path'] run = wandb.Api().run(run_path) env_root = os.path.join(env_name, POLICY_BAZAAR_DIR, env_name, 'pre_trained_{}'.format(pre_trained), 'models') os.makedirs(env_root, exist_ok=True) if 'cassie' in env_name: # retrieve model model_name = '{}.p'.format(ENV_IDS[env_name]['model_name']) from .cassie_model import ActorCriticNetwork model = ActorCriticNetwork(**run.config['model_kwargs']) wandb.restore(name=model_name, run_path=run_path, replace=True, root=env_root) model_data = torch.load(os.path.join(env_root, model_name), map_location=torch.device('cpu')) model.load_state_dict(model_data['state_dict']) model.actor.obs_std = model_data["act_obs_std"] model.actor.obs_mean = model_data["act_obs_mean"] model.critic.obs_std = model_data["critic_obs_std"] model.critic.obs_mean = model_data["critic_obs_mean"] else: # retrieve model model_name = '{}_{}.0.p'.format( ENV_IDS[env_name]['model_name'], ENV_IDS[env_name]['models'][pre_trained]) from .model import ActorCriticNetwork model = ActorCriticNetwork(run.config['observation_size'], run.config['action_size'], hidden_dim=64, action_std=0.5) wandb.restore(name=model_name, run_path=run_path, replace=True, root=env_root) model.load_state_dict( torch.load(os.path.join(env_root, model_name), map_location=torch.device('cpu'))) return model, info
parser.add_argument('--custom_metric', action="store_true", default=False, help='Custom non Test metric') parser.add_argument('--graph_gen', action="store_true", default=False, help='Report all graph gen Test metric') parser.add_argument('--top_k', type=int, default=5, help='Return only top K runs') parser.add_argument('--config_key', nargs='*', type=str, default=[]) parser.add_argument('--config_val', nargs='*', default=[]) parser.add_argument('--dataset', type=str, default='bdp') parser.add_argument( '--eval_set', default="test", help= "Whether to evaluate model on test set (default) or validation set.") parser.add_argument('--get_step', nargs='+', default=5, type=int) args = parser.parse_args() with open('../settings.json') as f: data = json.load(f) args.wandb_apikey = data.get("wandbapikey") os.environ['WANDB_API_KEY'] = args.wandb_apikey args.api = wandb.Api() main(args)
def setup_wandb(flags): flags.wandb_name = f'{flags.xpid}-{flags.seed}' for wandb_key in ('WANDB_RESUME', 'WANDB_RUN_ID'): if wandb_key in os.environ: del os.environ[wandb_key] if flags.wandb_resume: api = wandb.Api() original_run_id = None resume_step = None resume_checkpoint = None existing_runs = api.runs(f'{flags.wandb_entity}/{flags.wandb_project}', {'$and': [{'config.id': str(flags.xpid)}, {'config.seed': int( flags.seed)}]}) if len(existing_runs) > 1: raise ValueError( f'Found more than one matching run to id {flags.xpid} and seed {flags.seed}: {[r.id for r in existing_runs]}. Aborting... ') elif len(existing_runs) == 1: existing_run = existing_runs[0] original_run_id = existing_run.id history = existing_run.history(pandas=True, samples=1000) # Verify there's actually a run to resume if len(history) > 0: checkpoint_index = -1 while np.isnan(history['steps'].iat[checkpoint_index]): checkpoint_index -= 1 resume_step = int(history['steps'].iat[checkpoint_index]) if resume_step >= flags.total_steps: print( f'resume_step ({resume_step}) is greater than or equal to total steps ({flags.total_steps}), nothing to do here...') sys.exit(0) # Now that we now that resume_step is, we can load from there. try: resume_checkpoint = existing_run.file(f'model-{resume_step}.tar') resume_checkpoint.download(replace=True) except (AttributeError, wandb.CommError) as e: print('Failed to download most recent checkpoint, will not resume') if original_run_id is None: print(f'Failed to find run to resume for seed {flags.seed}, running from scratch') elif resume_step is None: print(f'Failed to find the correct resume timestamp for seed {flags.seed}, running from scratch') elif resume_checkpoint is None: print(f'Failed to find checkpoint to resume for seed {flags.seed}, running from scratch') else: os.environ['WANDB_RESUME'] = 'must' os.environ['WANDB_RUN_ID'] = original_run_id if resume_step is not None: flags.current_step = resume_step flags.resume_checkpoint_path = resume_checkpoint.name for key in os.environ: if 'WANDB' in key: print(key, os.environ[key]) wandb.init(entity=flags.wandb_entity, project=flags.wandb_project, name=flags.wandb_name, dir=flags.wandb_dir, config=vars(flags)) wandb.save(os.path.join(wandb.run.dir, '*.pth')) wandb.save(os.path.join(wandb.run.dir, '*.tar'))
def fetch_run(run_id: str = '', run_path: str = '') -> wandb.wandb_run: if not run_path: run_path = 'codeepneat/cdn/' + run_id return wandb.Api().run(run_path)
def misc(cfg): api = wandb.Api() test_metric = {'int': "test_int_acc", 'tag': "test_tag conlleval f1"} y_title = {'int': "Test int (acc)", 'tag': "Test tag (f1)"} title = {'atis': f"Performance on ATIS", 'snips': f"Performance on SNIPS"} batch = {'atis': {'ce': 16, 'vat': 64}, 'snips': {'ce': 64, 'vat': 64}} xticks = { 'atis': [0.1, 0.2, 0.4], # 'snips' : [0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0], 'snips': [0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.4] } runs = api.runs("cetosignis/viraal-rerank-full", { '$and': [{ 'config.training.dataset': cfg.dataset }, { 'config.training.task': cfg.task }] }, per_page=1000) configs = ometrics.Metrics() summaries = ometrics.Metrics() print(len(runs)) for run in runs: if test_metric[cfg.test_task] in run.summary: configs.append(run.config) summaries.append(run.summary) df = pd.DataFrame() x = "Labeled part" y = y_title[cfg.test_task] filename = f'{cfg.dataset}_{cfg.task}_{cfg.test_task}' df["Task"] = configs["training/task"] df["Dataset"] = cfg.dataset df["Loss"] = configs["training/loss"] df["Batch Size"] = configs["training/iterator/params/batch_size"] df["Criteria"] = [i[0] for i in configs["rerank/criteria"]] df["Loss+Criteria"] = df["Loss"] + "+" + df["Criteria"] df[x] = configs["training/unlabeler/params/labeled_part"] df[y] = summaries[test_metric[cfg.test_task]] yticks = np.arange( np.floor(df[y].min() * 100) / 100, np.ceil(df[y].max() * 100) / 100, 0.01) yticks = yticks if len(yticks) < 30 else np.arange( np.floor(df[y].min() * 100) / 100, np.ceil(df[y].max() * 100) / 100, 0.05) cols = ['Dataset', 'Task', 'Loss', 'Batch Size', 'Criteria', x] df.groupby(cols).mean().to_csv(f'{filename}.csv') # df[y] = summaries["test_int_acc"] plt.figure() sns.set() sns.set_context("paper") sns.set_style("whitegrid") fig = sns.barplot(x=x, y=y, hue="Loss+Criteria", data=df, palette="Blues_d") plt.ylim(yticks[0], yticks[-1]) plt.title(title[cfg.dataset]) sns.despine(left=True, bottom=True) plt.tight_layout() plt.savefig(f"{filename}.png", dpi=300)
def api(runner): return wandb.Api()
def main(argv): args = parser.parse_args() print('Load test starting') project_name = args.project if project_name is None: project_name = 'artifacts-load-test-%s' % str(datetime.now()).replace( ' ', '-').replace(':', '-').replace('.', '-') env_project = os.environ.get('WANDB_PROJECT') sweep_id = os.environ.get('WANDB_SWEEP_ID') if sweep_id: del os.environ['WANDB_SWEEP_ID'] wandb_config_paths = os.environ.get('WANDB_CONFIG_PATHS') if wandb_config_paths: del os.environ['WANDB_CONFIG_PATHS'] wandb_run_id = os.environ.get('WANDB_RUN_ID') if wandb_run_id: del os.environ['WANDB_RUN_ID'] # set global entity and project before chdir'ing from wandb.apis import InternalApi api = InternalApi() settings_entity = api.settings('entity') settings_base_url = api.settings('base_url') os.environ['WANDB_ENTITY'] = (os.environ.get('LOAD_TEST_ENTITY') or settings_entity) os.environ['WANDB_PROJECT'] = project_name os.environ['WANDB_BASE_URL'] = (os.environ.get('LOAD_TEST_BASE_URL') or settings_base_url) # Change dir to avoid litering code directory pwd = os.getcwd() tempdir = tempfile.TemporaryDirectory() os.chdir(tempdir.name) artifact_name = 'load-artifact-' + ''.join( random.choices(string.ascii_lowercase + string.digits, k=10)) print('Generating source data') source_file_names = gen_files(args.gen_n_files, args.gen_max_small_size, args.gen_max_large_size) print('Done generating source data') procs = [] stop_queue = multiprocessing.Queue() stats_queue = multiprocessing.Queue() # start all processes # writers for i in range(args.num_writers): file_names = source_file_names if args.non_overlapping_writers: chunk_size = int(len(source_file_names) / args.num_writers) file_names = source_file_names[i * chunk_size:(i + 1) * chunk_size] p = multiprocessing.Process( target=proc_version_writer, args=(stop_queue, stats_queue, project_name, file_names, artifact_name, args.files_per_version_min, args.files_per_version_max)) p.start() procs.append(p) # readers for i in range(args.num_readers): p = multiprocessing.Process(target=proc_version_reader, args=(stop_queue, stats_queue, project_name, artifact_name, i)) p.start() procs.append(p) # deleters for _ in range(args.num_deleters): p = multiprocessing.Process( target=proc_version_deleter, args=(stop_queue, stats_queue, artifact_name, args.min_versions_before_delete, args.delete_period_max)) p.start() procs.append(p) # cache garbage collector if args.cache_gc_period_max is None: print('Testing cache GC process not enabled!') else: p = multiprocessing.Process(target=proc_cache_garbage_collector, args=(stop_queue, args.cache_gc_period_max)) p.start() procs.append(p) # reset environment os.environ['WANDB_ENTITY'] = settings_entity os.environ['WANDB_BASE_URL'] = settings_base_url os.environ if env_project is None: del os.environ['WANDB_PROJECT'] else: os.environ['WANDB_PROJECT'] = env_project if sweep_id: os.environ['WANDB_SWEEP_ID'] = sweep_id if wandb_config_paths: os.environ['WANDB_CONFIG_PATHS'] = wandb_config_paths if wandb_run_id: os.environ['WANDB_RUN_ID'] = wandb_run_id # go back to original dir os.chdir(pwd) # test phase start_time = time.time() stats = defaultdict(int) run = wandb.init(job_type='main-test-phase') run.config.update(args) while time.time() - start_time < args.test_phase_seconds: stat_update = None try: stat_update = stats_queue.get(True, 5000) except queue.Empty: pass print('** Test time: %s' % (time.time() - start_time)) if stat_update: for k, v in stat_update.items(): stats[k] += v wandb.log(stats) print('Test phase time expired') # stop all processes and wait til all are done for _ in procs: stop_queue.put(True) print('Waiting for processes to stop') fail = False for proc in procs: proc.join() if proc.exitcode != 0: print('FAIL! Test phase failed') fail = True sys.exit(1) # drain remaining stats while True: try: stat_update = stats_queue.get_nowait() except queue.Empty: break for k, v in stat_update.items(): stats[k] += v print('Stats') import pprint pprint.pprint(dict(stats)) if fail: print('FAIL! Test phase failed') sys.exit(1) else: print('Test phase successfully completed') print('Starting verification phase') os.environ['WANDB_ENTITY'] = (os.environ.get('LOAD_TEST_ENTITY') or settings_entity) os.environ['WANDB_PROJECT'] = project_name os.environ['WANDB_BASE_URL'] = (os.environ.get('LOAD_TEST_BASE_URL') or settings_base_url) data_api = wandb.Api() # we need list artifacts by walking runs, accessing via # project.artifactType.artifacts only returns committed artifacts for run in data_api.runs('%s/%s' % (api.settings('entity'), project_name)): for v in run.logged_artifacts(): # TODO: allow deleted once we build deletion support if v.state not in ['COMMITTED', 'DELETED']: print('FAIL! Artifact version not committed or deleted: %s' % v) sys.exit(1) print('Verification succeeded')
(envs.action_space.nvec.sum(), )).to(device) # TRY NOT TO MODIFY: start the game global_step = 0 start_time = time.time() # Note how `next_obs` and `next_done` are used; their usage is equivalent to # https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60 next_obs = envs.reset() next_done = torch.zeros(args.num_envs).to(device) num_updates = args.total_timesteps // args.batch_size ## CRASH AND RESUME LOGIC: starting_update = 1 if args.prod_mode and wandb.run.resumed: print("previous run.summary", run.summary) starting_update = run.summary['charts/update'] + 1 global_step = starting_update * args.batch_size api = wandb.Api() run = api.run(run.get_url()[len("https://app.wandb.ai/"):]) model = run.file('agent.pt') model.download(f"models/{experiment_name}/") agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt")) agent.eval() print(f"resumed at update {starting_update}") for update in range(starting_update, num_updates + 1): # Annealing the rate if instructed to do so. if args.anneal_lr: frac = 1.0 - (update - 1.0) / num_updates lrnow = lr(frac) optimizer.param_groups[0]['lr'] = lrnow # TRY NOT TO MODIFY: prepare the execution of the game. for step in range(0, args.num_steps):
def main(): parser = argparse.ArgumentParser(description='Vehicle orientation') parser.add_argument('-u','--user', help='username', default='corner') parser.add_argument('-p','--project', help='project name', default='cityai2020Orientation') parser.add_argument('-r','--run_id', help='run id', default='pe5y029c') traindata = False is_synthetic = False is_track = True args = parser.parse_args() num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 print('wandb api') api = wandb.Api() run = api.run(args.user + '/' + args.project + '/' + args.run_id) print('copy wandb configs') cfg = copy.deepcopy(run.config) if cfg['MODEL.DEVICE'] == "cuda": os.environ['CUDA_VISIBLE_DEVICES'] = cfg['MODEL.DEVICE_ID'] cfg['DATASETS.TRACKS_FILE'] = '/net/merkur/storage/deeplearning/users/eckvik/WorkspaceNeu/VReID/ImageFolderStatistics/Track3Files' cudnn.benchmark = True print('load dataset') if is_synthetic and traindata: cfg['DATASETS.SYNTHETIC'] = True cfg['DATASETS.SYNTHETIC_LOADER'] = 0 cfg['DATASETS.SYNTHETIC_DIR'] = 'ai_city_challenge/2020/Track2/AIC20_track2_reid_simulation/AIC20_track2/AIC20_ReID_Simulation' dataset = init_dataset('AI_CITY2020_TEST_VAL', cfg=cfg,fold=1,eval_mode=False) else: if is_track: dataset = init_dataset('AI_CITY2020_TRACKS', cfg=cfg,fold=1,eval_mode=False) else: dataset = init_dataset('AI_CITY2020_TEST_VAL', cfg=cfg,fold=1,eval_mode=False) if traindata: if is_synthetic and traindata: dataset = [item[0] for item in dataset.train] dataset = dataset[36935:] dataset.sort() dataset = [[item, 0,0] for item in dataset] val_set = ImageDatasetOrientation(dataset, cfg, is_train=False, test=True) else: val_set = ImageDatasetOrientation(dataset.train, cfg, is_train=False, test=True) else: val_set = ImageDatasetOrientation(dataset.query+dataset.gallery, cfg, is_train=False, test=True) # val_loader = DataLoader( val_set, batch_size=cfg['TEST.IMS_PER_BATCH'], shuffle=False, num_workers = cfg['DATALOADER.NUM_WORKERS'], collate_fn=val_collate_fn ) print('build model') model = build_regression_model(cfg) print('get last epoch') epoch_best = 10#run.summary['epoch'] weights_path = os.path.join(cfg['OUTPUT_DIR'],cfg['MODEL.NAME']+'_model_'+str(epoch_best)+'.pth') print('load pretrained weights') model.load_param(weights_path) model.eval() evaluator = create_supervised_evaluator(model, metrics={'score_feat': Score_feats()}, device = cfg['MODEL.DEVICE']) print('run') evaluator.run(val_loader) scores,feats,pids,camids = evaluator.state.metrics['score_feat'] feats = np.array(feats) scores = np.array(scores) print('save') if traindata: if is_track: feats_mean = [] for item in dataset.train_tracks_vID: indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]]) feats_mean.append(np.mean(feats[indis], axis=0)) feats_mean = np.array(feats_mean) scores_mean = [] for item in dataset.train_tracks_vID: indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]]) scores_mean.append(np.mean(scores[indis], axis=0)) np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train_track.npy'), np.array(feats_mean)) # .npy extension is added if not given np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train_track.npy'), np.array(scores_mean)) # .npy extension is added if not given else: if is_synthetic and traindata: np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train_synthetic.npy'), np.array(feats)) # .npy extension is added if not given np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train_synthetic.npy'), np.array(scores)) # .npy extension is added if not given else: np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train.npy'), np.array(feats)) # .npy extension is added if not given np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train.npy'), np.array(scores)) # .npy extension is added if not given else: if is_track: feats_mean = [] for feat in feats[:1052]: feats_mean.append(feat) for item in dataset.test_tracks_vID: indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]]) feats_mean.append(np.mean(feats[1052:][indis], axis=0)) feats_mean = np.array(feats_mean) scores_mean = [] for score in scores[:1052]: scores_mean.append(score) for item in dataset.test_tracks_vID: indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]]) scores_mean.append(np.mean(scores[1052:][indis], axis=0)) np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_query_gal_track.npy'), np.array(feats_mean)) # .npy extension is added if not given np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_query_gal_track.npy'), np.array(scores_mean)) # .npy extension is added if not given else: np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_query_gal.npy'), np.array(feats)) # .npy extension is added if not given np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_query_gal.npy'), np.array(scores)) # .npy extension is added if not given print(cfg['OUTPUT_DIR']) print() txt_dir='dist_orient' num_query = 1052 all_mAP = np.zeros(num_query) statistic_name ='feats' feats = torch.from_numpy(feats).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100) statistic_name ='xyz' feats = torch.from_numpy(scores).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100) statistic_name ='xy' scores_curr = scores[:,0:2] feats = torch.from_numpy(scores_curr).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100) statistic_name ='x' scores_curr = scores[:, 0:1] feats = torch.from_numpy(scores_curr).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100) statistic_name ='y' scores_curr = scores[:, 1:2] feats = torch.from_numpy(scores_curr).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100) statistic_name ='z' scores_curr = scores[:, 2:3] feats = torch.from_numpy(scores_curr).float().to('cuda') feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2) # query qf = feats_normed[:num_query] q_pids = np.asarray(pids[:num_query]) q_camids = np.asarray(camids[:num_query]) # gallery gf = feats_normed[num_query:] g_pids = np.asarray(pids[num_query:]) g_camids = np.asarray(camids[num_query:]) m, n = qf.shape[0], gf.shape[0] distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat.addmm_(1, -2, qf, gf.t()) distmat = distmat.cpu().numpy() g_camids = np.ones_like(g_camids) g_pids = np.ones_like(g_pids) generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP, statistic_name=statistic_name, max_rank=100)