# the goal of this script is to download and save wandb stats
# stolen straight from https://docs.wandb.ai/library/public-api-guide
import wandb
import numpy as np
import pandas as pd
import os
import shutil
import torch
from modelhandling import load_model_from_disk
import contextlib
from tempfile import mkdtemp

__api__ = wandb.Api()
__runs__ = __api__.runs("sebaseliens/explainable-asag")


def get_run_ids(*groups):
    return [
        run.id for run in __runs__
        if run.config['group'] in groups or not groups
    ]


def get_runs(*groups):
    return [
        run for run in __runs__ if run.config['group'] in groups or not groups
    ]


def as_run(run):
    if isinstance(run, str):
def fetch_history(run_id: str) -> List[Dict]:
    return wandb.Api().run('sash-a/cdn_test/' + run_id).history()
Beispiel #3
0
def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(7)

    args = parse_args()

    Path(args.save_path).mkdir(parents=True, exist_ok=True)
    entity = "demiurge"
    project = "melgan"
    load_from_run_id = args.load_from_run_id
    resume_run_id = args.resume_run_id
    restore_run_id = load_from_run_id or resume_run_id
    batch_size = args.batch_size

    # Getting initial run steps and epoch
    # if restore run, replace args
    steps = None
    if restore_run_id:
        api = wandb.Api()
        previous_run = api.run(f"{entity}/{project}/{restore_run_id}")
        steps = previous_run.lastHistoryStep
        prev_args = argparse.Namespace(**previous_run.config)
        args = vars(args)
        args.update(vars(prev_args))
        args = Namespace(**args)
        args.batch_size = batch_size

    load_initial_weights = bool(restore_run_id)
    sampling_rate = args.sampling_rate
    ratios = args.ratios
    if isinstance(ratios, str):
        ratios = ratios.replace(" ", "")
        ratios = ratios.strip("][").split(",")
        ratios = [int(i) for i in ratios]
        ratios = np.array(ratios)

    if load_from_run_id and resume_run_id:
        raise RuntimeError("Specify either --load_from_id or --resume_run_id.")

    if resume_run_id:
        print(f"Resuming run ID {resume_run_id}.")
    elif load_from_run_id:
        print(
            f"Starting new run with initial weights from run ID {load_from_run_id}."
        )
    else:
        print("Starting new run from scratch.")

    # read 1 line in train files to log dataset location
    train_files = Path(args.data_path) / "train_files.txt"
    with open(train_files, encoding="utf-8", mode="r") as f:
        file = f.readline()
    args.train_file_sample = str(file)

    wandb.init(
        entity=entity,
        project=project,
        id=resume_run_id,
        config=args,
        resume=True if resume_run_id else False,
        save_code=True,
        dir=args.save_path,
        notes=args.notes,
    )

    print("run id: " + str(wandb.run.id))
    print("run name: " + str(wandb.run.name))

    root = Path(wandb.run.dir)
    root.mkdir(parents=True, exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    wandb.save("args.yml")

    ###############################################
    # The file modules.py is needed by the unagan #
    ###############################################
    wandb.save(mel2wav.modules.__file__, base_path=".")

    #######################
    # Load PyTorch Models #
    #######################

    netG = Generator(args.n_mel_channels,
                     args.ngf,
                     args.n_residual_layers,
                     ratios=ratios).to(device)
    netD = Discriminator(args.num_D, args.ndf, args.n_layers_D,
                         args.downsamp_factor).to(device)
    fft = Audio2Mel(
        n_mel_channels=args.n_mel_channels,
        pad_mode=args.pad_mode,
        sampling_rate=sampling_rate,
    ).to(device)

    for model in [netG, netD, fft]:
        wandb.watch(model)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(),
                            lr=args.learning_rate,
                            betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(),
                            lr=args.learning_rate,
                            betas=(0.5, 0.9))

    if load_initial_weights:

        for model, filenames in [
            (netG, ["netG.pt", "netG_prev.pt"]),
            (optG, ["optG.pt", "optG_prev.pt"]),
            (netD, ["netD.pt", "netD_prev.pt"]),
            (optD, ["optD.pt", "optD_prev.pt"]),
        ]:
            recover_model = False
            filepath = None
            for filename in filenames:
                try:
                    run_path = f"{entity}/{project}/{restore_run_id}"
                    print(f"Restoring {filename} from run path {run_path}")
                    restored_file = wandb.restore(filename, run_path=run_path)
                    filepath = restored_file.name
                    model = load_state_dict_handleDP(model, filepath)
                    recover_model = True
                    break
                except RuntimeError as e:
                    print("RuntimeError", e)
                    print(f"recover model weight file: '{filename}'' failed")
            if not recover_model:
                raise RuntimeError(
                    f"Cannot load model weight files for component {filenames[0]}."
                )
            else:
                # store successfully recovered model weight file ("***_prev.pt")
                path_parent = Path(filepath).parent
                newfilepath = str(path_parent / filenames[1])
                os.rename(filepath, newfilepath)
                wandb.save(newfilepath)
    if torch.cuda.device_count() > 1:
        netG = DP(netG).to(device)
        netD = DP(netD).to(device)
        fft = DP(fft).to(device)
        print(f"We have {torch.cuda.device_count()} gpus. Use data parallel.")
    else:
        print(f"We have {torch.cuda.device_count()} gpu.")

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(
        Path(args.data_path) / "train_files.txt",
        args.seq_len,
        sampling_rate=sampling_rate,
    )
    test_set = AudioDataset(
        Path(args.data_path) / "test_files.txt",
        sampling_rate * 4,
        sampling_rate=sampling_rate,
        augment=False,
    )
    wandb.save(str(Path(args.data_path) / "train_files.txt"))
    wandb.save(str(Path(args.data_path) / "test_files.txt"))

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=4)
    test_loader = DataLoader(test_set, batch_size=1)

    if len(train_loader) == 0:
        raise RuntimeError("Train dataset is empty.")

    if len(test_loader) == 0:
        raise RuntimeError("Test dataset is empty.")

    if not restore_run_id:
        steps = wandb.run.step
    start_epoch = steps // len(train_loader)
    print(f"Starting with epoch {start_epoch} and step {steps}.")

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    samples = []
    melImages = []
    num_fix_samples = args.n_test_samples - (args.n_test_samples // 2)
    cmap = cm.get_cmap("inferno")
    for i, x_t in enumerate(test_loader):
        x_t = x_t.to(device)
        s_t = fft(x_t).detach()

        test_voc.append(s_t.to(device))
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i), sampling_rate, audio)
        samples.append(
            wandb.Audio(audio,
                        caption=f"sample {i}",
                        sample_rate=sampling_rate))
        melImage = s_t.squeeze().detach().cpu().numpy()
        melImage = (melImage - np.amin(melImage)) / (np.amax(melImage) -
                                                     np.amin(melImage))
        # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255)
        # melImage = melImage.resize((melImage.width * 4, melImage.height * 4))
        melImages.append(wandb.Image(cmap(melImage), caption=f"sample {i}"))

        if i == num_fix_samples - 1:
            break

    # if not resume_run_id:
    wandb.log({"audio/original": samples}, step=start_epoch)
    wandb.log({"mel/original": melImages}, step=start_epoch)
    # else:
    #     print("We are resuming, skipping logging of original audio.")

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000

    for epoch in range(start_epoch, start_epoch + args.epochs + 1):
        for iterno, x_t in enumerate(train_loader):
            x_t = x_t.to(device)
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.to(device))

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.to(device).detach())
            D_real = netD(x_t.to(device))

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.to(device))

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j],
                                                D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            costs.append(
                [loss_D.item(),
                 loss_G.item(),
                 loss_feat.item(), s_error])

            wandb.log(
                {
                    "loss/discriminator": costs[-1][0],
                    "loss/generator": costs[-1][1],
                    "loss/feature_matching": costs[-1][2],
                    "loss/mel_reconstruction": costs[-1][3],
                },
                step=steps,
            )
            steps += 1

            if steps % args.save_interval == 0:
                st = time.time()
                with torch.no_grad():
                    samples = []
                    melImages = []
                    # fix samples
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i),
                                    sampling_rate, pred_audio)
                        samples.append(
                            wandb.Audio(
                                pred_audio,
                                caption=f"sample {i}",
                                sample_rate=sampling_rate,
                            ))
                        melImage = voc.squeeze().detach().cpu().numpy()
                        melImage = (melImage - np.amin(melImage)) / (
                            np.amax(melImage) - np.amin(melImage))
                        # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255)
                        # melImage = melImage.resize(
                        #     (melImage.width * 4, melImage.height * 4)
                        # )
                        melImages.append(
                            wandb.Image(cmap(melImage), caption=f"sample {i}"))
                    wandb.log(
                        {
                            "audio/generated": samples,
                            "mel/generated": melImages,
                            "epoch": epoch,
                        },
                        step=steps,
                    )

                    # var samples
                    source = []
                    pred = []
                    pred_mel = []
                    num_var_samples = args.n_test_samples - num_fix_samples
                    for i, x_t in enumerate(test_loader):
                        # source
                        x_t = x_t.to(device)
                        audio = x_t.squeeze().cpu()
                        source.append(
                            wandb.Audio(audio,
                                        caption=f"sample {i}",
                                        sample_rate=sampling_rate))
                        # pred
                        s_t = fft(x_t).detach()
                        voc = s_t.to(device)
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        pred.append(
                            wandb.Audio(
                                pred_audio,
                                caption=f"sample {i}",
                                sample_rate=sampling_rate,
                            ))
                        melImage = voc.squeeze().detach().cpu().numpy()
                        melImage = (melImage - np.amin(melImage)) / (
                            np.amax(melImage) - np.amin(melImage))
                        # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255)
                        # melImage = melImage.resize(
                        #     (melImage.width * 4, melImage.height * 4)
                        # )
                        pred_mel.append(
                            wandb.Image(cmap(melImage), caption=f"sample {i}"))

                        # stop when reach log sample
                        if i == num_var_samples - 1:
                            break

                    wandb.log(
                        {
                            "audio/var_original": source,
                            "audio/var_generated": pred,
                            "mel/var_generated": pred_mel,
                        },
                        step=steps,
                    )

                print("Saving models ...")
                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")
                wandb.save(str(root / "netG.pt"))
                wandb.save(str(root / "optG.pt"))

                torch.save(netD.state_dict(), root / "netD.pt")
                torch.save(optD.state_dict(), root / "optD.pt")
                wandb.save(str(root / "netD.pt"))
                wandb.save(str(root / "optD.pt"))

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD.state_dict(), root / "best_netD.pt")
                    torch.save(netG.state_dict(), root / "best_netG.pt")
                    wandb.save(str(root / "best_netD.pt"))
                    wandb.save(str(root / "best_netG.pt"))

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args.log_interval == 0:
                print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                      format(
                          epoch,
                          iterno,
                          len(train_loader),
                          1000 * (time.time() - start) / args.log_interval,
                          np.asarray(costs).mean(0),
                      ))
                costs = []
                start = time.time()
def print_metrics(model_name, runs_all, validate_hcp=False):
    metrics_ukb = {
        'f1': [],
        'acc': [],
        'auc': [],
        'sensitivity': [],
        'specificity': []
    }
    metrics_hcp = {
        'f1': [],
        'acc': [],
        'auc': [],
        'sensitivity': [],
        'specificity': []
    }
    for fold_num, run_info in runs_all.items():
        run_id = run_info['run_id']
        # print('Args are', run_id, device_run, dropout, weight_d)

        api = wandb.Api()
        best_run = api.run(f'/st-team/spatio-temporal-brain/runs/{run_id}')
        # w_config = best_run.config
        for metric in metrics_ukb.keys():
            metrics_ukb[metric].append(
                best_run.summary[f'values_test_{metric}'])

        # Running for HCP
        w_config = best_run.config

        w_config['analysis_type'] = AnalysisType(w_config['analysis_type'])
        w_config['dataset_type'] = DatasetType(w_config['dataset_type'])
        w_config['device_run'] = DEVICE_RUN
        w_config['param_lr'] = w_config['lr']
        if 'lr' not in run_info.keys():
            w_config['param_lr'] = w_config['lr']
        else:
            w_config['param_lr'] = float(run_info['lr'])
        w_config['model_with_sigmoid'] = True
        w_config['param_activation'] = w_config['activation']
        w_config['param_channels_conv'] = w_config['channels_conv']
        w_config['param_conn_type'] = ConnType(w_config['conn_type'])
        w_config['param_conv_strategy'] = ConvStrategy(
            w_config['conv_strategy'])
        if 'dropout' not in run_info.keys():
            w_config['param_dropout'] = w_config['dropout']
        else:
            w_config['param_dropout'] = float(run_info['dropout'])
        w_config['param_encoding_strategy'] = EncodingStrategy(
            w_config['encoding_strategy'])
        w_config['param_normalisation'] = Normalisation(
            w_config['normalisation'])
        w_config['param_num_gnn_layers'] = w_config['num_gnn_layers']
        w_config['param_pooling'] = PoolingStrategy(w_config['pooling'])
        if 'weight_d' not in run_info.keys():
            w_config['param_weight_decay'] = w_config['weight_decay']
        else:
            w_config['param_weight_decay'] = float(run_info['weight_d'])

        w_config['sweep_type'] = SweepType(w_config['sweep_type'])
        w_config['param_gat_heads'] = 0
        if w_config['sweep_type'] == SweepType.GAT:
            w_config['param_gat_heads'] = w_config.gat_heads

        if w_config['analysis_type'] == AnalysisType.ST_MULTIMODAL:
            w_config['multimodal_size'] = 10
        elif w_config['analysis_type'] == AnalysisType.ST_UNIMODAL:
            w_config['multimodal_size'] = 0

        if w_config['target_var'] in ['age', 'bmi']:
            w_config['model_with_sigmoid'] = False

        # Getting best model
        inner_fold_for_val: int = 1
        model: SpatioTemporalModel = generate_st_model(w_config, for_test=True)
        if 'model_v' in run_info.keys():
            model.VERSION = run_info['model_v']
        model_saving_path: str = create_name_for_model(
            target_var=w_config['target_var'],
            model=model,
            outer_split_num=w_config['fold_num'],
            inner_split_num=inner_fold_for_val,
            n_epochs=w_config['num_epochs'],
            threshold=w_config['threshold'],
            batch_size=w_config['batch_size'],
            num_nodes=w_config['num_nodes'],
            conn_type=w_config['param_conn_type'],
            normalisation=w_config['param_normalisation'],
            analysis_type=w_config['analysis_type'],
            metric_evaluated='loss',
            dataset_type=w_config['dataset_type'],
            lr=w_config['param_lr'],
            weight_decay=w_config['param_weight_decay'],
            edge_weights=w_config['edge_weights'])
        if 'model_v' in run_info.keys():
            # We know the very specific "old" cases
            if w_config['param_pooling'] == PoolingStrategy.DIFFPOOL:
                model_saving_path = model_saving_path.replace(
                    'T_difW_F', 'GC_FGA_F')
            elif w_config['param_pooling'] == PoolingStrategy.MEAN:
                model_saving_path = model_saving_path.replace(
                    'T_no_W_F', 'GC_FGA_F')
        model.load_state_dict(
            torch.load(model_saving_path, map_location=w_config['device_run']))
        model.eval()
        if not validate_hcp:
            continue
        else:
            # Getting HCP Data
            name_dataset = create_name_for_brain_dataset(
                num_nodes=68,
                time_length=1200,
                target_var='gender',
                threshold=w_config['threshold'],
                normalisation=w_config['param_normalisation'],
                connectivity_type=w_config['param_conn_type'],
                analysis_type=w_config['analysis_type'],
                encoding_strategy=w_config['param_encoding_strategy'],
                dataset_type=DatasetType('hcp'),
                edge_weights=w_config['edge_weights'])
            print('Going with', name_dataset)
            dataset = HCPDataset(
                root=name_dataset,
                target_var='gender',
                num_nodes=68,
                threshold=w_config['threshold'],
                connectivity_type=w_config['param_conn_type'],
                normalisation=w_config['param_normalisation'],
                analysis_type=w_config['analysis_type'],
                encoding_strategy=w_config['param_encoding_strategy'],
                time_length=1200,
                edge_weights=w_config['edge_weights'])

            # dataset.data is private, might change in future versions of pyg...
            dataset.data.x = dataset.data.x[:, :490]

            test_out_loader = DataLoader(dataset,
                                         batch_size=w_config['batch_size'],
                                         shuffle=False)
            test_metrics = evaluate_model(model, test_out_loader,
                                          w_config['param_pooling'],
                                          w_config['device_run'])
            for metric in metrics_hcp.keys():
                metrics_hcp[metric].append(test_metrics[metric])

    # print('UKB:')
    print(model_name, end=' & ')
    print(
        f'{round(np.mean(metrics_ukb["auc"]), 2)} ({round(np.std(metrics_ukb["auc"]), 3)}) & '
        f'{round(np.mean(metrics_ukb["acc"]), 2)} ({round(np.std(metrics_ukb["acc"]), 3)}) & '
        f'{round(np.mean(metrics_ukb["sensitivity"]), 2)} ({round(np.std(metrics_ukb["sensitivity"]), 3)}) & '
        f'{round(np.mean(metrics_ukb["specificity"]), 2)} ({round(np.std(metrics_ukb["specificity"]), 3)})'
    )

    if validate_hcp:
        print('HCP:')
        print(
            f'{round(np.mean(metrics_hcp["auc"]), 2)} ({round(np.std(metrics_hcp["auc"]), 3)}) & '
            f'{round(np.mean(metrics_hcp["acc"]), 2)} ({round(np.std(metrics_hcp["acc"]), 3)}) & '
            f'{round(np.mean(metrics_hcp["sensitivity"]), 2)} ({round(np.std(metrics_hcp["sensitivity"]), 3)}) & '
            f'{round(np.mean(metrics_hcp["specificity"]), 2)} ({round(np.std(metrics_hcp["specificity"]), 3)})'
        )
Beispiel #5
0
def _get_run(run_id):
    run_path = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{run_id}"
    api = wandb.Api()
    return api.run(run_path)
Beispiel #6
0
def check_run(api: Api) -> bool:
    print("Checking logged metrics, saving and downloading a file".ljust(
        72, "."),
          end="")
    failed_test_strings = []

    # set up config
    n_epochs = 4
    string_test = "A test config"
    dict_test = {"config_val": 2, "config_string": "config string"}
    list_test = [0, "one", "2"]
    config = {
        "epochs": n_epochs,
        "stringTest": string_test,
        "dictTest": dict_test,
        "listTest": list_test,
    }
    # create a file to save
    filepath = "./test with_special-characters.txt"
    f = open(filepath, "w")
    f.write("test")
    f.close()

    with wandb.init(reinit=True, config=config, project=PROJECT_NAME) as run:
        run_id = run.id
        entity = run.entity
        logged = True
        try:
            for i in range(1, 11):
                run.log({"loss": 1.0 / i}, step=i)
            log_dict = {"val1": 1.0, "val2": 2}
            run.log({"dict": log_dict}, step=i + 1)
        except Exception:
            logged = False
            failed_test_strings.append(
                "Failed to log values to run. Contact W&B for support.")

        try:
            run.log(
                {"HT%3ML ": wandb.Html('<a href="https://mysite">Link</a>')})
        except Exception:
            failed_test_strings.append(
                "Failed to log to media. Contact W&B for support.")

        wandb.save(filepath)
    public_api = wandb.Api()
    prev_run = public_api.run("{}/{}/{}".format(entity, PROJECT_NAME, run_id))
    if prev_run is None:
        failed_test_strings.append(
            "Failed to access run through API. Contact W&B for support.")
        print_results(failed_test_strings, False)
        return False
    for key, value in prev_run.config.items():
        if config[key] != value:
            failed_test_strings.append(
                "Read config values don't match run config. Contact W&B for support."
            )
            break
    if logged and (
            prev_run.history_keys["keys"]["loss"]["previousValue"] != 0.1
            or prev_run.history_keys["lastStep"] != 11 or
            prev_run.history_keys["keys"]["dict.val1"]["previousValue"] != 1.0
            or
            prev_run.history_keys["keys"]["dict.val2"]["previousValue"] != 2):
        failed_test_strings.append(
            "History metrics don't match logged values. Check database encoding."
        )

    if logged and prev_run.summary["loss"] != 1.0 / 10:
        failed_test_strings.append(
            "Read summary values don't match expected value. Check database encoding, or contact W&B for support."
        )
    # TODO: (kdg) refactor this so it doesn't rely on an exception handler
    try:
        read_file = retry_fn(partial(prev_run.file, filepath))
        read_file = read_file.download(replace=True)
    except Exception:
        with wandb.init(reinit=True,
                        project=PROJECT_NAME,
                        config={"test": "test direct saving"}) as run:
            saved, status_code, _ = try_manual_save(api, filepath, run.id,
                                                    run.entity)
            if saved:
                failed_test_strings.append(
                    "Unable to download file. Check SQS configuration, topic configuration and bucket permissions."
                )
            else:
                failed_test_strings.append(
                    "Unable to save file with status code: {}. Check SQS configuration and bucket permissions."
                    .format(status_code))

            print_results(failed_test_strings, False)
        return False
    contents = read_file.read()
    if contents != "test":
        failed_test_strings.append(
            "Contents of downloaded file do not match uploaded contents. Contact W&B for support."
        )
    print_results(failed_test_strings, False)
    return len(failed_test_strings) == 0
Beispiel #7
0
def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed_everything(7)

    args = parse_args()

    Path(args.save_path).mkdir(parents=True, exist_ok=True)
    entity = "materialvision"
    project = "melganmv"
    load_from_run_id = args.load_from_run_id
    resume_run_id = args.resume_run_id
    restore_run_id = load_from_run_id or resume_run_id
    load_initial_weights = bool(restore_run_id)
    sampling_rate = args.sampling_rate

    if load_from_run_id and resume_run_id:
        raise RuntimeError("Specify either --load_from_id or --resume_run_id.")

    if resume_run_id:
        print(f"Resuming run ID {resume_run_id}.")
    elif load_from_run_id:
        print(
            f"Starting new run with initial weights from run ID {load_from_run_id}."
        )
    else:
        print("Starting new run from scratch.")

    wandb.init(
        entity=entity,
        project=project,
        id=resume_run_id,
        config=args,
        resume=True if resume_run_id else False,
        save_code=True,
        dir=args.save_path,
        notes=args.notes,
    )

    print("run id: " + str(wandb.run.id))
    print("run name: " + str(wandb.run.name))

    root = Path(wandb.run.dir)
    root.mkdir(parents=True, exist_ok=True)

    ####################################
    # Dump arguments and create logger #
    ####################################
    with open(root / "args.yml", "w") as f:
        yaml.dump(args, f)
    wandb.save("args.yml")

    ###############################################
    # The file modules.py is needed by the unagan #
    ###############################################
    wandb.save(mel2wav.modules.__file__, base_path=".")

    #######################
    # Load PyTorch Models #
    #######################
    netG = Generator(args.n_mel_channels, args.ngf,
                     args.n_residual_layers).to(device)
    netD = Discriminator(args.num_D, args.ndf, args.n_layers_D,
                         args.downsamp_factor).to(device)
    fft = Audio2Mel(
        n_mel_channels=args.n_mel_channels,
        pad_mode=args.pad_mode,
        sampling_rate=sampling_rate,
    ).to(device)

    for model in [netG, netD, fft]:
        wandb.watch(model)

    #####################
    # Create optimizers #
    #####################
    optG = torch.optim.Adam(netG.parameters(),
                            lr=args.learning_rate,
                            betas=(0.5, 0.9))
    optD = torch.optim.Adam(netD.parameters(),
                            lr=args.learning_rate,
                            betas=(0.5, 0.9))

    if load_initial_weights:

        for obj, filename in [
            (netG, "netG.pt"),
            (optG, "optG.pt"),
            (netD, "netD.pt"),
            (optD, "optD.pt"),
        ]:
            run_path = f"{entity}/{project}/{restore_run_id}"
            print(f"Restoring {filename} from run path {run_path}")
            restored_file = wandb.restore(filename, run_path=run_path)
            obj.load_state_dict(torch.load(restored_file.name))

    #######################
    # Create data loaders #
    #######################
    train_set = AudioDataset(
        Path(args.data_path) / "train_files.txt",
        args.seq_len,
        sampling_rate=sampling_rate,
    )
    test_set = AudioDataset(
        Path(args.data_path) / "test_files.txt",
        sampling_rate * 4,
        sampling_rate=sampling_rate,
        augment=False,
    )
    wandb.save(str(Path(args.data_path) / "train_files.txt"))
    wandb.save(str(Path(args.data_path) / "test_files.txt"))

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=4)
    test_loader = DataLoader(test_set, batch_size=1)

    if len(train_loader) == 0:
        raise RuntimeError("Train dataset is empty.")

    if len(test_loader) == 0:
        raise RuntimeError("Test dataset is empty.")

    # Getting initial run steps and epoch

    if load_from_run_id:
        api = wandb.Api()
        previous_run = api.run(f"{entity}/{project}/{restore_run_id}")
        steps = previous_run.lastHistoryStep
    else:
        steps = wandb.run.step

    start_epoch = steps // len(train_loader)
    print(f"Starting with epoch {start_epoch} and step {steps}.")

    ##########################
    # Dumping original audio #
    ##########################
    test_voc = []
    test_audio = []
    samples = []
    for i, x_t in enumerate(test_loader):
        x_t = x_t.to(device)
        s_t = fft(x_t).detach()

        test_voc.append(s_t.to(device))
        test_audio.append(x_t)

        audio = x_t.squeeze().cpu()
        save_sample(root / ("original_%d.wav" % i), sampling_rate, audio)
        samples.append(
            wandb.Audio(audio,
                        caption=f"sample {i}",
                        sample_rate=sampling_rate))

        if i == args.n_test_samples - 1:
            break

    if not resume_run_id:
        wandb.log({"audio/original": samples}, step=0)
    else:
        print("We are resuming, skipping logging of original audio.")

    costs = []
    start = time.time()

    # enable cudnn autotuner to speed up training
    torch.backends.cudnn.benchmark = True

    best_mel_reconst = 1000000

    for epoch in range(start_epoch, start_epoch + args.epochs + 1):
        for iterno, x_t in enumerate(train_loader):
            x_t = x_t.to(device)
            s_t = fft(x_t).detach()
            x_pred_t = netG(s_t.to(device))

            with torch.no_grad():
                s_pred_t = fft(x_pred_t.detach())
                s_error = F.l1_loss(s_t, s_pred_t).item()

            #######################
            # Train Discriminator #
            #######################
            D_fake_det = netD(x_pred_t.to(device).detach())
            D_real = netD(x_t.to(device))

            loss_D = 0
            for scale in D_fake_det:
                loss_D += F.relu(1 + scale[-1]).mean()

            for scale in D_real:
                loss_D += F.relu(1 - scale[-1]).mean()

            netD.zero_grad()
            loss_D.backward()
            optD.step()

            ###################
            # Train Generator #
            ###################
            D_fake = netD(x_pred_t.to(device))

            loss_G = 0
            for scale in D_fake:
                loss_G += -scale[-1].mean()

            loss_feat = 0
            feat_weights = 4.0 / (args.n_layers_D + 1)
            D_weights = 1.0 / args.num_D
            wt = D_weights * feat_weights
            for i in range(args.num_D):
                for j in range(len(D_fake[i]) - 1):
                    loss_feat += wt * F.l1_loss(D_fake[i][j],
                                                D_real[i][j].detach())

            netG.zero_grad()
            (loss_G + args.lambda_feat * loss_feat).backward()
            optG.step()

            costs.append(
                [loss_D.item(),
                 loss_G.item(),
                 loss_feat.item(), s_error])

            wandb.log(
                {
                    "loss/discriminator": costs[-1][0],
                    "loss/generator": costs[-1][1],
                    "loss/feature_matching": costs[-1][2],
                    "loss/mel_reconstruction": costs[-1][3],
                },
                step=steps,
            )
            steps += 1

            if steps % args.save_interval == 0:
                st = time.time()
                with torch.no_grad():
                    samples = []
                    for i, (voc, _) in enumerate(zip(test_voc, test_audio)):
                        pred_audio = netG(voc)
                        pred_audio = pred_audio.squeeze().cpu()
                        save_sample(root / ("generated_%d.wav" % i),
                                    sampling_rate, pred_audio)
                        samples.append(
                            wandb.Audio(
                                pred_audio,
                                caption=f"sample {i}",
                                sample_rate=sampling_rate,
                            ))
                    wandb.log(
                        {
                            "audio/generated": samples,
                            "epoch": epoch,
                        },
                        step=steps,
                    )

                print("Saving models ...")
                torch.save(netG.state_dict(), root / "netG.pt")
                torch.save(optG.state_dict(), root / "optG.pt")
                wandb.save(str(root / "netG.pt"))
                wandb.save(str(root / "optG.pt"))

                torch.save(netD.state_dict(), root / "netD.pt")
                torch.save(optD.state_dict(), root / "optD.pt")
                wandb.save(str(root / "netD.pt"))
                wandb.save(str(root / "optD.pt"))

                if np.asarray(costs).mean(0)[-1] < best_mel_reconst:
                    best_mel_reconst = np.asarray(costs).mean(0)[-1]
                    torch.save(netD.state_dict(), root / "best_netD.pt")
                    torch.save(netG.state_dict(), root / "best_netG.pt")
                    wandb.save(str(root / "best_netD.pt"))
                    wandb.save(str(root / "best_netG.pt"))

                print("Took %5.4fs to generate samples" % (time.time() - st))
                print("-" * 100)

            if steps % args.log_interval == 0:
                print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}".
                      format(
                          epoch,
                          iterno,
                          len(train_loader),
                          1000 * (time.time() - start) / args.log_interval,
                          np.asarray(costs).mean(0),
                      ))
                costs = []
                start = time.time()
def fetch_all_wandb_run_ids(wandb_project, wandb_entity, wandb_api=None):
    if wandb_api is None:
        wandb_api = wandb.Api()
    wandb_path = f'{wandb_entity}/{wandb_project}/*'
    runs = wandb_api.runs(wandb_path)
    return [run.id for run in runs]
Beispiel #9
0
def trainer(args):

    config_keys = [
        "batch_size",
        "soft_label",
        "adv_weight",
        "d_thresh",
        "z_dim",
        "z_dis",
        "model_save_step",
        "g_lr",
        "d_lr",
        "beta",
        "cube_len",
        "leak_value",
        "bias",
    ]

    # check new run or resume run
    if args.resume_id:
        api = wandb.Api()
        previous_run = api.run(f"bugan/simple-pytorch-3dgan/{args.resume_id}")
        config = previous_run.config
        pprint.pprint(config)

        run = wandb.init(
            project="simple-pytorch-3dgan",
            id=args.resume_id,
            entity="bugan",
            config=config,
            resume=True,
        )
    else:
        config = {
            **args.__dict__,
            **{k: getattr(params, k) for k in config_keys},
        }
        pprint.pprint(config)

        run = wandb.init(
            entity="bugan", project="simple-pytorch-3dgan", config=config, resume=True
        )

    # convert config dict to Namespace
    config = Namespace(**config)
    # added for output dir
    save_file_path = params.output_dir + "/" + config.model_name
    print(save_file_path)  # ../outputs/dcgan
    if not os.path.exists(save_file_path):
        os.makedirs(save_file_path)

    # for using tensorboard
    if config.logs:
        model_uid = datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
        writer = SummaryWriter(
            params.output_dir
            + "/"
            + config.model_name
            + "/logs_"
            + model_uid
            + "_"
            + config.logs
            + "/"
        )

    # datset define
    # dsets_path = args.input_dir + args.data_dir + "train/"
    dsets_path = config.data_dir
    # if params.cube_len == 64:
    #     dsets_path = params.data_dir + params.model_dir + "30/train64/"

    print(dsets_path)  # ../volumetric_data/chair/30/train/

    if config.rotate:
        train_dsets = AugmentDataset(dsets_path, config, "train", res=config.res)
    else:
        train_dsets = ShapeNetDataset(dsets_path, config, "train", res=config.res)
    # val_dsets = ShapeNetDataset(dsets_path, args, "val")

    train_dset_loaders = torch.utils.data.DataLoader(
        train_dsets,
        batch_size=params.batch_size,
        shuffle=True,
        num_workers=24,
        pin_memory=True,
    )
    # val_dset_loaders = torch.utils.data.DataLoader(val_dsets, batch_size=args.batch_size, shuffle=True, num_workers=1)

    dset_len = {"train": len(train_dsets)}
    dset_loaders = {"train": train_dset_loaders}
    # print (dset_len["train"])

    # model define
    D = net_D(config)
    # summary(net_D, input_size=(32, 32, 32))

    G = net_G(config)
    # print(G)
    # print(D)

    # load state dict if resume
    if args.resume_id:
        G, D = load_model(run, G, D)

    wandb.watch(G)
    wandb.watch(D)

    # summary(net_G, input_size=(params.z_dim,))

    # print total number of parameters in a model
    # x = sum(p.numel() for p in G.parameters() if p.requires_grad)
    # print (x)
    # x = sum(p.numel() for p in D.parameters() if p.requires_grad)
    # print (x)

    D_solver = optim.Adam(D.parameters(), lr=params.d_lr, betas=params.beta)
    # D_solver = optim.SGD(D.parameters(), lr=params.d_lr * 100, momentum=0.9)
    G_solver = optim.Adam(G.parameters(), lr=params.g_lr, betas=params.beta)

    D.to(params.device)
    G.to(params.device)

    # criterion_D = nn.BCELoss()
    criterion_D = nn.MSELoss()
    criterion_G = nn.L1Loss()

    itr_val = -1
    itr_train = -1

    for epoch in range(config.epochs):

        start = time.time()

        for phase in ["train"]:
            if phase == "train":
                # if args.lrsh:
                #     D_scheduler.step()
                D.train()
                G.train()
            else:
                D.eval()
                G.eval()

            running_loss_G = 0.0
            running_loss_D = 0.0
            running_loss_adv_G = 0.0

            for i, X in enumerate(tqdm(dset_loaders[phase])):

                # if phase == 'val':
                #     itr_val += 1

                if phase == "train":
                    itr_train += 1

                X = X.to(params.device)
                # print (X)
                # print (X.size())

                batch = X.size()[0]
                # print (batch)

                Z = generateZ(config, batch)
                # print (Z.size())

                # ============= Train the discriminator =============#
                d_real = D(X)

                fake = G(Z)

                if i == 0 and epoch % config.generate_every == 0:
                    image_saved_path = Path(params.images_dir) / config.model_name
                    image_saved_path.mkdir(parents=True, exist_ok=True)

                    samples = fake.cpu().data[:5].squeeze().numpy()

                    fnames = []
                    for i, samp in enumerate(samples):
                        # print(i, samp)
                        try:
                            mesh = trimesh.voxel.VoxelGrid(
                                trimesh.voxel.encoding.DenseEncoding(samp >= 0.5)
                            ).marching_cubes
                        except ValueError as exc:
                            print(f"Marching cubes failed: {exc}")
                            continue
                        fname = Path(image_saved_path) / f"{epoch:04}_{i}.obj"
                        mesh.export(fname)
                        fnames.append(fname)

                    wandb.log(
                        {
                            "generated_tree_samples": [
                                wandb.Object3D(open(fname)) for fname in fnames
                            ],
                            "epoch": epoch,
                        },
                        step=itr_train,
                    )

                d_fake = D(fake)

                real_labels = torch.ones_like(d_real).to(params.device)
                fake_labels = torch.zeros_like(d_fake).to(params.device)
                # print (d_fake.size(), fake_labels.size())

                if params.soft_label:
                    real_labels = (
                        torch.Tensor(batch).uniform_(0.7, 1.2).to(params.device)
                    )
                    fake_labels = torch.Tensor(batch).uniform_(0, 0.3).to(params.device)

                # print (d_real.size(), real_labels.size())
                d_real_loss = criterion_D(d_real, real_labels)

                d_fake_loss = criterion_D(d_fake, fake_labels)

                d_loss = d_real_loss + d_fake_loss

                # no deleted
                d_real_acu = torch.ge(d_real.squeeze(), 0.5).float()
                d_fake_acu = torch.le(d_fake.squeeze(), 0.5).float()
                d_total_acu = torch.mean(torch.cat((d_real_acu, d_fake_acu), 0))

                if d_total_acu < params.d_thresh:
                    D.zero_grad()
                    d_loss.backward()
                    D_solver.step()

                # =============== Train the generator ===============#

                Z = generateZ(config, batch)

                # print (X)
                fake = G(Z)  # generated fake: 0-1, X: 0/1
                d_fake = D(fake)

                adv_g_loss = criterion_D(d_fake, real_labels)
                # print (fake.size(), X.size())

                # recon_g_loss = criterion_D(fake, X)
                recon_g_loss = criterion_G(fake, X)
                # g_loss = recon_g_loss + params.adv_weight * adv_g_loss
                g_loss = adv_g_loss

                if config.local_test:
                    # print('Iteration-{} , D(x) : {:.4} , G(x) : {:.4} , D(G(x)) : {:.4}'.format(itr_train, d_loss.item(), recon_g_loss.item(), adv_g_loss.item()))
                    print(
                        "Iteration-{} , D(x) : {:.4}, D(G(x)) : {:.4}".format(
                            itr_train, d_loss.item(), adv_g_loss.item()
                        )
                    )

                D.zero_grad()
                G.zero_grad()
                g_loss.backward()
                G_solver.step()

                # =============== logging each 10 iterations ===============#

                running_loss_G += recon_g_loss.item() * X.size(0)
                running_loss_D += d_loss.item() * X.size(0)
                running_loss_adv_G += adv_g_loss.item() * X.size(0)

                if config.logs:
                    loss_G = {
                        "adv_loss_G": adv_g_loss,
                        "recon_loss_G": recon_g_loss,
                    }

                    loss_D = {
                        "adv_real_loss_D": d_real_loss,
                        "adv_fake_loss_D": d_fake_loss,
                        "d_real_acu": d_real_acu.mean(),
                        "d_fake_acu": d_fake_acu.mean(),
                        "d_total_acu": d_total_acu,
                    }

                    # if itr_val % 10 == 0 and phase == 'val':
                    #     save_val_log(writer, loss_D, loss_G, itr_val)

                    if itr_train % 10 == 0 and phase == "train":
                        save_train_log(writer, loss_D, loss_G, itr_train)

                        wandb.log(
                            {"G": loss_G, "D": loss_D, "epoch": epoch}, step=itr_train
                        )

            # =============== each epoch save model or save image ===============#
            epoch_loss_G = running_loss_G / dset_len[phase]
            epoch_loss_D = running_loss_D / dset_len[phase]
            epoch_loss_adv_G = running_loss_adv_G / dset_len[phase]

            end = time.time()
            epoch_time = end - start

            print(
                "Epochs-{} ({}) , D(x) : {:.4}, D(G(x)) : {:.4}".format(
                    epoch, phase, epoch_loss_D, epoch_loss_adv_G
                )
            )
            print("Elapsed Time: {:.4} min".format(epoch_time / 60.0))

            if (epoch + 1) % params.model_save_step == 0:

                print("model_saved, images_saved...")
                save_model(run, config.model_name, G, D)
Beispiel #10
0
def main(
    output_folder=None,
    duration=10,
    num_samples=5,
    gid=1,
    seed=123,
    melgan_run_id=None,
    unagan_run_id=None,
    hifigan_run_id=None,
    wandb_code=None,
):
    if wandb_code:
        wandb.login(key=wandb_code, relogin=True)

    if melgan_run_id and hifigan_run_id:
        raise Exception("Can only set one of [melgan_run_id, hifigan_run_id], not both")

    if not unagan_run_id:
        raise Exception("unagan_run_id should not be empty")

    download_weights.main(
        model_dir=Path("models/custom"),
        melgan_run_id=melgan_run_id,
        unagan_run_id=unagan_run_id,
        hifigan_run_id=hifigan_run_id,
    )
    # ### Data type ###
    # assert data_type in ["singing", "speech", "piano", "violin"]

    # ### Architecture type ###
    # if data_type == "singing":
    #     assert arch_type in ["nh", "h", "hc"]
    # elif data_type == "speech":
    #     assert arch_type in ["h", "hc"]
    # elif data_type == "piano":
    #     assert arch_type in ["hc"]
    # elif data_type == "violin":
    #     assert arch_type in ["hc"]

    # if arch_type == "nh":
    #     arch_type = "nonhierarchical"
    # elif arch_type == "h":
    #     arch_type = "hierarchical"
    # elif arch_type == "hc":
    #     arch_type = "hierarchical_with_cycle"

    data_type = "custom"
    arch_type = "hierarchical_with_cycle"

    # ### Model type ###
    model_type = f"{data_type}.{arch_type}"

    # ### Model info ###
    if output_folder is None:
        output_folder = Path("generated_samples") / model_type
    output_folder = Path(output_folder)

    output_folder.mkdir(parents=True, exist_ok=True)

    # also save to all_generated_audio_dir is the folder exists,
    # but do not save if in unagan training (both run id None)
    if not unagan_run_id:
        all_generated_audio_dir = None
    else:
        try:
            all_generated_audio_dir = Path(
                "/content/drive/My Drive/PUBLICATIONS/The Replicant/AUDIO DATABASE/UNAGAN OUTPUT/AUDIOS/"
            )
            all_generated_audio_dir.mkdir(parents=True, exist_ok=True)
            print(
                "generated audio files will also saved to:",
                str(all_generated_audio_dir),
            )
        except:
            all_generated_audio_dir = None
            print(
                "the path '",
                str(all_generated_audio_dir),
                "' not exists. Only save audio files to:",
                str(output_folder),
            )

    api = wandb.Api()
    previous_run = api.run(f"demiurge/unagan/{unagan_run_id}")
    unagan_config = Namespace(**previous_run.config)

    ################# unagan config parameters ##################
    z_dim = unagan_config.z_dim
    z_scale_factors = unagan_config.z_scale_factors
    z_total_scale_factor = np.prod(z_scale_factors)
    feat_dim = unagan_config.feat_dim
    ##################

    param_fp = f"models/{data_type}/params.generator.{arch_type}.pt"

    mean_fp = f"models/{data_type}/mean.mel.npy"
    std_fp = f"models/{data_type}/std.mel.npy"

    mean = torch.from_numpy(np.load(mean_fp)).float().view(1, feat_dim, 1)
    std = torch.from_numpy(np.load(std_fp)).float().view(1, feat_dim, 1)
    if gid >= 0:
        mean = mean.cuda(gid)
        std = std.cuda(gid)

    ###############################################################
    ### Vocoder info ###

    ### MELGAN ###
    if melgan_run_id:
        api = wandb.Api()
        previous_run = api.run(f"demiurge/melgan/{melgan_run_id}")
        melgan_config = Namespace(**previous_run.config)

        ################# melgan config parameters ##################
        # melgan only parameters
        n_mel_channels = 80
        ngf = 32
        n_residual_layers = 3

        # also applied to unagan generate
        sampling_rate = 44100
        hop_length = 256

        if n_mel_channels != feat_dim:
            print(
                f"Warning!!! melgan n_mel_channels {n_mel_channels} != unagan feat_dim {feat_dim}"
            )

        if hasattr(melgan_config, "hop_length"):
            hop_length = melgan_config.hop_length
        if hasattr(melgan_config, "sampling_rate"):
            sampling_rate = melgan_config.sampling_rate
        if hasattr(melgan_config, "n_mel_channels"):
            n_mel_channels = melgan_config.n_mel_channels
        if hasattr(melgan_config, "ngf"):
            ngf = melgan_config.ngf
        if hasattr(melgan_config, "n_residual_layers"):
            n_residual_layers = melgan_config.n_residual_layers

        ########################

        # ### Vocoder Model ###
        vocoder_model_dir = Path("models") / data_type / "vocoder"

        if data_type == "speech":
            vocoder_name = "OriginalGenerator"
        else:
            vocoder_name = "GRUGenerator"
        MelGAN = getattr(melgan_models, vocoder_name)
        vocoder = MelGAN(n_mel_channels, ngf, n_residual_layers)
        vocoder.eval()

        vocoder_param_fp = vocoder_model_dir / "params.pt"
        vocoder_state_dict = torch.load(vocoder_param_fp)
        try:
            vocoder.load_state_dict(vocoder_state_dict)
        except RuntimeError as e:
            print(e)
            print("Fixing model by removing .module prefix")
            vocoder_state_dict = OrderedDict(
                (k.split(".", 1)[1], v) for k, v in vocoder_state_dict.items()
            )
            vocoder.load_state_dict(vocoder_state_dict)

        if gid >= 0:
            vocoder = vocoder.cuda(gid)

    ### HIFI-GAN ###

    if hifigan_run_id:
        api = wandb.Api()
        previous_run = api.run(f"demiurge/hifi-gan/{hifigan_run_id}")
        hifigan_config = Namespace(**previous_run.config)

        # parameters applied to unagan generate
        sampling_rate = 44100
        hop_length = 256
        if hasattr(hifigan_config, "hop_size"):
            hop_length = hifigan_config.hop_size
        if hasattr(hifigan_config, "sampling_rate"):
            sampling_rate = hifigan_config.sampling_rate

        vocoder_model_dir = Path("models") / data_type / "vocoder"
        vocoder = hifi_models.Generator(hifigan_config)
        vocoder.eval()

        vocoder_state_dict = torch.load(vocoder_model_dir / "g")
        vocoder.load_state_dict(vocoder_state_dict["generator"])

        if gid >= 0:
            vocoder = vocoder.cuda(gid)

    ###################################################################

    # ### Generator ###
    if arch_type == "nonhierarchical":
        generator = NonHierarchicalGenerator(feat_dim, z_dim)
    elif arch_type.startswith("hierarchical"):
        generator = HierarchicalGenerator(feat_dim, z_dim, z_scale_factors)

    generator.eval()
    for p in generator.parameters():
        p.requires_grad = False

    manager.load_model(param_fp, generator, device_id="cpu")

    if gid >= 0:
        generator = generator.cuda(gid)

    # ### Process ###
    torch.manual_seed(seed)
    # information for filename
    filename_base = datetime.utcnow().strftime("%Y-%m-%d_%H-%M")

    if melgan_run_id:
        filename_base += "_mel-" + melgan_run_id

    if unagan_run_id:
        filename_base += "_una-" + unagan_run_id

    if hifigan_run_id:
        filename_base += "_hifi-" + hifigan_run_id

    num_frames = int(np.ceil(duration * (sampling_rate / hop_length)))

    audio_array = []
    for ii in range(num_samples):
        out_fp_wav = Path(output_folder) / f"{filename_base}_sample{ii}.wav"
        print(f"Generating {out_fp_wav}")

        if arch_type == "nonhierarchical":
            z = torch.zeros((1, z_dim, num_frames)).normal_(0, 1).float()
        elif arch_type.startswith("hierarchical"):
            z = (
                torch.zeros((1, z_dim, int(np.ceil(num_frames / z_total_scale_factor))))
                .normal_(0, 1)
                .float()
            )

        if gid >= 0:
            z = z.cuda(gid)

        with torch.set_grad_enabled(False):
            with torch.cuda.device(gid):
                # Generator
                melspec_voc = generator(z)
                melspec_voc = (melspec_voc * std) + mean

                # Vocoder
                audio = vocoder(melspec_voc)
                audio = audio.squeeze().cpu().numpy()

        # keep generated audio as array to log to wandb
        if not unagan_run_id:
            audio_array.append(audio)
        else:
            # Save to wav
            sf.write(out_fp_wav, audio, sampling_rate)
            audio_array.append(out_fp_wav)
            # Save also to all_generated_audio_dir
            if all_generated_audio_dir:
                out2_fp_wav = (
                    Path(all_generated_audio_dir) / f"{filename_base}_sample{ii}.wav"
                )
                sf.write(out2_fp_wav, audio, sampling_rate)
    return audio_array, sampling_rate
Beispiel #11
0
from scipy import stats
from scipy.special import factorial
from scipy.stats import binom_test, wilcoxon

import os
import pickle
from datetime import datetime
import tabulate
import wandb
from collections import namedtuple, defaultdict, OrderedDict
import json
from ipypb import ipb

# from metalearning import cnnmlp

API = wandb.Api()
MAX_HISTORY_SAMPLES = 4000

DATASET_CORESET_SIZE = 22500
ACCURACY_THRESHOLD = 0.95
TASK_ACC_COLS = [f'Test Accuracy, Query #{i}' for i in range(1, 11)]

QUERY_NAMES = [
    'blue', 'brown', 'cyan', 'gray', 'green', 'orange', 'pink', 'purple',
    'red', 'yellow', 'cone', 'cube', 'cylinder', 'dodecahedron', 'ellipsoid',
    'octahedron', 'pyramid', 'rectangle', 'sphere', 'torus', 'chain_mail',
    'marble', 'maze', 'metal', 'metal_weave', 'polka', 'rubber', 'rug',
    'tiles', 'wood_plank'
]

COLOR = 'color'
def list_run_files(run_id: str,
                   project: str = "flowers",
                   entity: str = "jeremytjordan"):
    api = wandb.Api()
    run = api.run(f"{entity}/{project}/{run_id}")
    return [f.name for f in run.files()]
def list_runs(project: str = "flowers", entity: str = "jeremytjordan"):
    api = wandb.Api()
    runs = api.runs(f"{entity}/{project}")
    return [r.id for r in runs]
Beispiel #14
0
 def __init__(self, project_name=None):
     self.api = wandb.Api()
     if project_name is not None:
         self.set_project(project_name)
def main():
    n_qubits = 8
    n_layers_list = [32, 64, 80, 96]
    project = 'IsingModel'
    target_cfgs = {
        'config.n_qubits': n_qubits,
        'config.n_layers': {
            "$in": n_layers_list
        },
        'config.g': 2,
        'config.h': 0,
        'config.lr': 0.05,
        'config.seed': 96,
        'config.scheduler_name': 'exponential_decay',
    }
    print(f'Downloading experiment results from {project}')
    print(f'| Target constraints: {target_cfgs}')

    api = wandb.Api()
    runs = api.runs(project, filters=target_cfgs)

    history = {}
    for run in runs:
        if run.state == 'finished':
            print(run.name)
            n_layers = run.config['n_layers']
            h = run.history()
            # Theoretically E(\theta) >= E_0 and fidelity <= 1.
            # If it is negative, it must be a precision error.
            h['loss'] = h['loss'].clip(lower=0.)
            h['fidelity/ground'] = h['fidelity/ground'].clip(upper=1.)
            history[n_layers] = h
    print('Download done')
    assert set(history.keys()) == set(n_layers_list)

    linestyles = ['-', '-.', '--', ':']
    linewidths = [1.2, 1.2, 1.3, 1.4]

    xlim = 0, 500

    plt.subplot(211)
    for i, n_layers in enumerate(n_layers_list):
        h = history[n_layers]
        plt.plot(h._step,
                 h.loss,
                 linestyles[i],
                 color=color_list[i],
                 linewidth=linewidths[i],
                 alpha=1.,
                 markersize=5,
                 label=f'L={n_layers}')
    plt.xlim(*xlim)
    plt.yscale('log')
    plt.ylabel(r'$E(\mathbf{\theta}) - E_0$', fontsize=13)
    plt.grid(True, c='0.5', ls=':', lw=0.5)
    # plt.legend(loc='upper right')

    plt.subplot(212)
    for i, n_layers in enumerate(n_layers_list):
        h = history[n_layers]
        plt.plot(h._step,
                 h['fidelity/ground'],
                 linestyles[i],
                 color=color_list[i],
                 linewidth=linewidths[i],
                 alpha=1.,
                 markersize=5,
                 label=f'L={n_layers}')

    plt.xlim(*xlim)
    plt.xlabel('Optimization Steps', fontsize=13)
    plt.ylabel(
        r'$|\,\langle \psi(\mathbf{\theta^*})\, |\, \phi \rangle\, |^2$',
        fontsize=13)
    plt.grid(True, c='0.5', ls=':', lw=0.5)
    plt.legend(loc='lower right')

    plt.tight_layout()
    plt.savefig('fig/ising_optimization_ed.pdf', bbox_inches='tight')
    plt.show()
Beispiel #16
0
def get_wandb_dataframes(run_list=None, project=None):
    api = wandb.Api()
    delta_dataframes = []
    for run_key in run_list:
        delta_dataframes.append(api.run(run_key).history())
    return delta_dataframes
Beispiel #17
0
def _get_model_candidates_from_wb(project, model_use_case_id):
    api = wandb.Api({"project": project})
    versions = api.artifact_versions(
        "model", "{}_model_candidates".format(model_use_case_id))
    return versions
Beispiel #18
0
def delete_wandb_run(run_name):
    api = wandb.Api()
    run = api.run(run_name)
    run.delete()
    logging.info(f"run {run_name} had been deleted with success")
Beispiel #19
0
 def __init__(self, path=None, opts=None):
     self.path = path
     self.api = wandb.Api()
     self.opts = opts or {}
     self.displayed = False
     self.height = self.opts.get("height", 420)
def sync_crashed(sweep_name: Optional[str]):
    wandb_key = get_wandb_env()
    assert wandb_key, "W&B API key is needed for staring a W&B swype"

    project = config.get("wandb", {}).get("project")
    api = wandb.Api()

    if sweep_name is not None:
        sweep_map = get_sweep_table(api, project)
        name_to_id, repeats = invert_sweep_id_table(sweep_map)

        if sweep_name in name_to_id:
            sweep_name = name_to_id[sweep_name]
        elif sweep_name in repeats:
            print(f"ERROR: ambigous sweep name: {sweep_name}")
            return

    relpath = get_relative_path()

    runs = get_runs_in_sweep(api, project, sweep_name, {"state": "crashed"})
    print(
        f"Sweep {sweep_name}: found {len(runs)} crashed runs. Trying to synchronize..."
    )
    for r in runs:
        hostname = get_run_host(api, project, r.id)
        dir = None
        found = []

        cmd = f"find ./wandb -iname '*{r.id}'"
        if hostname is None:
            res = run_multiple_hosts(config["hosts"], cmd)
            for hn, (res, retcode) in res.items():
                res = res.strip()
                if retcode == 0 and res:
                    found.append((hn, res))
        else:
            res, retcode = run_multiple_hosts([hostname], cmd)[hostname]
            res = res.strip()
            if retcode == 0 and res:
                found = [(hostname, res)]

        if len(found) != 1:
            print(f"WARNING: Failed to identify run {r.id}")
            continue

        hostname, dir = found[0]

        if len(dir.split("\n")) != 1:
            print(f"WARNING: Failed to identify run {r.id}")
            continue

        print(f"Found run {r.id} at {hostname} in dir {dir}. Syncing...")

        cd = config.get_command(hostname, "cd")
        wandb_cmd = config.get_command(hostname, "wandb", "~/.local/bin/wandb")

        cmd = f"{cd} {relpath}; {wandb_key} {wandb_cmd} sync {dir}"
        _, errcode = remote_run(hostname, cmd + " 2>/dev/null")

        if errcode != 0:
            print("Sync failed :(")
            continue
Beispiel #21
0
def main():
    parser = ArgumentParser()

    group = parser.add_mutually_exclusive_group()
    group.add_argument('--sweep', help="Select runs from the given sweep.")
    group.add_argument('--tag', help="Select runs with the given tag.")

    parser.add_argument('--project', help="Path of the project, in the form entity_id/project_id.")
    parser.add_argument('--dry-run', action='store_true',
                        help="Describe the changes without actually performing them.")
    args = parser.parse_args()

    wandb.init(job_type='update_metrics', project=args.project)

    overrides = {}

    if args.project:
        overrides['project'] = args.project

    api = wandb.Api(overrides)

    if args.tag:
        runs = api.runs(args.project, filters={
            'tags': args.tag
        })

        plots_dir = os.path.join('update_metrics', args.tag)
    elif args.sweep:
        sweep: wandb_api.Sweep = api.sweep(args.sweep)

        print(f"Processing sweep {sweep.url}")

        runs = sweep.runs
        plots_dir = os.path.join('update_metrics', sweep.id)
    else:
        raise ValueError("One of --tag or --sweep must be provided.")

    run: wandb_api.Run
    for run in tqdm(runs):
        version = run.config.get('metrics_version', 0)

        if version == CURRENT_VERSION:
            continue

        tqdm.write(f"Run {run.name}:")
        tqdm.write(f" - URL {run.url}")
        tqdm.write(f" - current metrics version v{version}")

        if version < 1:
            tqdm.write(f" - adding entropy discrimination ROC curve")
            add_entropy_roc(run, plots_dir)

        if version < 2:
            tqdm.write(f" - adding accuracy / AUC combined score")
            add_combined_score(run)

        if version < 3:
            tqdm.write(f" - adding default approach config key")
            add_default_approach(run)

        if version < 4:
            tqdm.write(f" - adding checkpoint artifact")
            add_checkpoint_artifact(run, api, args.dry_run)

        run.config['metrics_version'] = CURRENT_VERSION

        if not args.dry_run:
            run.update()
Beispiel #22
0
def get_policy(env_name: str, pre_trained: int = 1):
    """
    Retrieves policies for the environment with the pre-trained quality marker.

    :param env_name:  name of the environment
    :param pre_trained: pre_trained level . It should be between 1 and 5 ,
                        where 1 indicates best model and 5 indicates worst level.

    Example:
        >>> import policybazaar
        >>> policybazaar.get_policy('d4rl:maze2d-open-v0',pre_trained=1)
    """

    assert MIN_PRE_TRAINED_LEVEL <= pre_trained <= MAX_PRE_TRAINED_LEVEL, \
        'pre_trained marker should be between [{},{}] where {} indicates the best model' \
        ' and {} indicates the worst model'.format(MIN_PRE_TRAINED_LEVEL, MAX_PRE_TRAINED_LEVEL, MIN_PRE_TRAINED_LEVEL,
                                                   MAX_PRE_TRAINED_LEVEL)
    assert env_name in ENV_IDS or env_name in CHILD_PARENT_ENVS, \
        '`{}` not found. It should be among following: {}'.format(env_name,
                                                                  list(ENV_IDS.keys()) + list(CHILD_PARENT_ENVS.keys()))

    if env_name not in ENV_IDS:
        env_name = CHILD_PARENT_ENVS[env_name]
    if env_name in ENV_PERFORMANCE_STATS and pre_trained in ENV_PERFORMANCE_STATS[
            env_name]:
        info = ENV_PERFORMANCE_STATS[env_name][pre_trained]
    else:
        info = {}

    run_path = ENV_IDS[env_name]['wandb_run_path']
    run = wandb.Api().run(run_path)
    env_root = os.path.join(env_name, POLICY_BAZAAR_DIR, env_name,
                            'pre_trained_{}'.format(pre_trained), 'models')
    os.makedirs(env_root, exist_ok=True)

    if 'cassie' in env_name:
        # retrieve model
        model_name = '{}.p'.format(ENV_IDS[env_name]['model_name'])
        from .cassie_model import ActorCriticNetwork
        model = ActorCriticNetwork(**run.config['model_kwargs'])
        wandb.restore(name=model_name,
                      run_path=run_path,
                      replace=True,
                      root=env_root)

        model_data = torch.load(os.path.join(env_root, model_name),
                                map_location=torch.device('cpu'))
        model.load_state_dict(model_data['state_dict'])
        model.actor.obs_std = model_data["act_obs_std"]
        model.actor.obs_mean = model_data["act_obs_mean"]
        model.critic.obs_std = model_data["critic_obs_std"]
        model.critic.obs_mean = model_data["critic_obs_mean"]

    else:
        # retrieve model
        model_name = '{}_{}.0.p'.format(
            ENV_IDS[env_name]['model_name'],
            ENV_IDS[env_name]['models'][pre_trained])

        from .model import ActorCriticNetwork
        model = ActorCriticNetwork(run.config['observation_size'],
                                   run.config['action_size'],
                                   hidden_dim=64,
                                   action_std=0.5)

        wandb.restore(name=model_name,
                      run_path=run_path,
                      replace=True,
                      root=env_root)
        model.load_state_dict(
            torch.load(os.path.join(env_root, model_name),
                       map_location=torch.device('cpu')))
    return model, info
Beispiel #23
0
    parser.add_argument('--custom_metric',
                        action="store_true",
                        default=False,
                        help='Custom non Test metric')
    parser.add_argument('--graph_gen',
                        action="store_true",
                        default=False,
                        help='Report all graph gen Test metric')
    parser.add_argument('--top_k',
                        type=int,
                        default=5,
                        help='Return only top K runs')
    parser.add_argument('--config_key', nargs='*', type=str, default=[])
    parser.add_argument('--config_val', nargs='*', default=[])
    parser.add_argument('--dataset', type=str, default='bdp')
    parser.add_argument(
        '--eval_set',
        default="test",
        help=
        "Whether to evaluate model on test set (default) or validation set.")
    parser.add_argument('--get_step', nargs='+', default=5, type=int)
    args = parser.parse_args()

    with open('../settings.json') as f:
        data = json.load(f)
    args.wandb_apikey = data.get("wandbapikey")
    os.environ['WANDB_API_KEY'] = args.wandb_apikey
    args.api = wandb.Api()

    main(args)
Beispiel #24
0
def setup_wandb(flags):
    flags.wandb_name = f'{flags.xpid}-{flags.seed}'

    for wandb_key in ('WANDB_RESUME', 'WANDB_RUN_ID'):
        if wandb_key in os.environ:
            del os.environ[wandb_key]

    if flags.wandb_resume:
        api = wandb.Api()

        original_run_id = None
        resume_step = None
        resume_checkpoint = None

        existing_runs = api.runs(f'{flags.wandb_entity}/{flags.wandb_project}', {'$and': [{'config.id': str(flags.xpid)},
                                                                                          {'config.seed': int(
                                                                                              flags.seed)}]})

        if len(existing_runs) > 1:
            raise ValueError(
                f'Found more than one matching run to id {flags.xpid} and seed {flags.seed}: {[r.id for r in existing_runs]}. Aborting... ')

        elif len(existing_runs) == 1:
            existing_run = existing_runs[0]
            original_run_id = existing_run.id

            history = existing_run.history(pandas=True, samples=1000)

            # Verify there's actually a run to resume
            if len(history) > 0:
                checkpoint_index = -1
                while np.isnan(history['steps'].iat[checkpoint_index]):
                    checkpoint_index -= 1

                resume_step = int(history['steps'].iat[checkpoint_index])

                if resume_step >= flags.total_steps:
                    print(
                        f'resume_step ({resume_step}) is greater than or equal to total steps ({flags.total_steps}), nothing to do here...')
                    sys.exit(0)

                # Now that we now that resume_step is, we can load from there.
                try:
                    resume_checkpoint = existing_run.file(f'model-{resume_step}.tar')
                    resume_checkpoint.download(replace=True)
                except (AttributeError, wandb.CommError) as e:
                    print('Failed to download most recent checkpoint, will not resume')

        if original_run_id is None:
            print(f'Failed to find run to resume for seed {flags.seed}, running from scratch')

        elif resume_step is None:
            print(f'Failed to find the correct resume timestamp for seed {flags.seed}, running from scratch')

        elif resume_checkpoint is None:
            print(f'Failed to find checkpoint to resume for seed {flags.seed}, running from scratch')

        else:
            os.environ['WANDB_RESUME'] = 'must'
            os.environ['WANDB_RUN_ID'] = original_run_id

            if resume_step is not None:
                flags.current_step = resume_step

            flags.resume_checkpoint_path = resume_checkpoint.name

    for key in os.environ:
        if 'WANDB' in key:
            print(key, os.environ[key])

    wandb.init(entity=flags.wandb_entity, project=flags.wandb_project, name=flags.wandb_name,
               dir=flags.wandb_dir, config=vars(flags))
    wandb.save(os.path.join(wandb.run.dir, '*.pth'))
    wandb.save(os.path.join(wandb.run.dir, '*.tar'))
Beispiel #25
0
def fetch_run(run_id: str = '', run_path: str = '') -> wandb.wandb_run:
    if not run_path:
        run_path = 'codeepneat/cdn/' + run_id
    return wandb.Api().run(run_path)
Beispiel #26
0
def misc(cfg):
    api = wandb.Api()
    test_metric = {'int': "test_int_acc", 'tag': "test_tag conlleval f1"}
    y_title = {'int': "Test int (acc)", 'tag': "Test tag (f1)"}

    title = {'atis': f"Performance on ATIS", 'snips': f"Performance on SNIPS"}

    batch = {'atis': {'ce': 16, 'vat': 64}, 'snips': {'ce': 64, 'vat': 64}}
    xticks = {
        'atis': [0.1, 0.2, 0.4],
        # 'snips' : [0.01,0.02,0.03,0.04,0.05,0.06,0.07,0.08,0.09,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0],
        'snips': [0.02, 0.04, 0.06, 0.08, 0.1, 0.2, 0.4]
    }

    runs = api.runs("cetosignis/viraal-rerank-full", {
        '$and': [{
            'config.training.dataset': cfg.dataset
        }, {
            'config.training.task': cfg.task
        }]
    },
                    per_page=1000)
    configs = ometrics.Metrics()
    summaries = ometrics.Metrics()
    print(len(runs))
    for run in runs:
        if test_metric[cfg.test_task] in run.summary:
            configs.append(run.config)
            summaries.append(run.summary)
    df = pd.DataFrame()

    x = "Labeled part"
    y = y_title[cfg.test_task]
    filename = f'{cfg.dataset}_{cfg.task}_{cfg.test_task}'

    df["Task"] = configs["training/task"]
    df["Dataset"] = cfg.dataset
    df["Loss"] = configs["training/loss"]
    df["Batch Size"] = configs["training/iterator/params/batch_size"]
    df["Criteria"] = [i[0] for i in configs["rerank/criteria"]]
    df["Loss+Criteria"] = df["Loss"] + "+" + df["Criteria"]
    df[x] = configs["training/unlabeler/params/labeled_part"]
    df[y] = summaries[test_metric[cfg.test_task]]
    yticks = np.arange(
        np.floor(df[y].min() * 100) / 100,
        np.ceil(df[y].max() * 100) / 100, 0.01)
    yticks = yticks if len(yticks) < 30 else np.arange(
        np.floor(df[y].min() * 100) / 100,
        np.ceil(df[y].max() * 100) / 100, 0.05)
    cols = ['Dataset', 'Task', 'Loss', 'Batch Size', 'Criteria', x]
    df.groupby(cols).mean().to_csv(f'{filename}.csv')
    # df[y] = summaries["test_int_acc"]

    plt.figure()
    sns.set()
    sns.set_context("paper")
    sns.set_style("whitegrid")
    fig = sns.barplot(x=x,
                      y=y,
                      hue="Loss+Criteria",
                      data=df,
                      palette="Blues_d")
    plt.ylim(yticks[0], yticks[-1])
    plt.title(title[cfg.dataset])
    sns.despine(left=True, bottom=True)
    plt.tight_layout()
    plt.savefig(f"{filename}.png", dpi=300)
def api(runner):
    return wandb.Api()
Beispiel #28
0
def main(argv):
    args = parser.parse_args()
    print('Load test starting')

    project_name = args.project
    if project_name is None:
        project_name = 'artifacts-load-test-%s' % str(datetime.now()).replace(
            ' ', '-').replace(':', '-').replace('.', '-')

    env_project = os.environ.get('WANDB_PROJECT')

    sweep_id = os.environ.get('WANDB_SWEEP_ID')
    if sweep_id:
        del os.environ['WANDB_SWEEP_ID']
    wandb_config_paths = os.environ.get('WANDB_CONFIG_PATHS')
    if wandb_config_paths:
        del os.environ['WANDB_CONFIG_PATHS']
    wandb_run_id = os.environ.get('WANDB_RUN_ID')
    if wandb_run_id:
        del os.environ['WANDB_RUN_ID']

    # set global entity and project before chdir'ing
    from wandb.apis import InternalApi
    api = InternalApi()
    settings_entity = api.settings('entity')
    settings_base_url = api.settings('base_url')
    os.environ['WANDB_ENTITY'] = (os.environ.get('LOAD_TEST_ENTITY')
                                  or settings_entity)
    os.environ['WANDB_PROJECT'] = project_name
    os.environ['WANDB_BASE_URL'] = (os.environ.get('LOAD_TEST_BASE_URL')
                                    or settings_base_url)

    # Change dir to avoid litering code directory
    pwd = os.getcwd()
    tempdir = tempfile.TemporaryDirectory()
    os.chdir(tempdir.name)

    artifact_name = 'load-artifact-' + ''.join(
        random.choices(string.ascii_lowercase + string.digits, k=10))

    print('Generating source data')
    source_file_names = gen_files(args.gen_n_files, args.gen_max_small_size,
                                  args.gen_max_large_size)
    print('Done generating source data')

    procs = []
    stop_queue = multiprocessing.Queue()
    stats_queue = multiprocessing.Queue()

    # start all processes

    # writers
    for i in range(args.num_writers):
        file_names = source_file_names
        if args.non_overlapping_writers:
            chunk_size = int(len(source_file_names) / args.num_writers)
            file_names = source_file_names[i * chunk_size:(i + 1) * chunk_size]
        p = multiprocessing.Process(
            target=proc_version_writer,
            args=(stop_queue, stats_queue, project_name, file_names,
                  artifact_name, args.files_per_version_min,
                  args.files_per_version_max))
        p.start()
        procs.append(p)

    # readers
    for i in range(args.num_readers):
        p = multiprocessing.Process(target=proc_version_reader,
                                    args=(stop_queue, stats_queue,
                                          project_name, artifact_name, i))
        p.start()
        procs.append(p)

    # deleters
    for _ in range(args.num_deleters):
        p = multiprocessing.Process(
            target=proc_version_deleter,
            args=(stop_queue, stats_queue, artifact_name,
                  args.min_versions_before_delete, args.delete_period_max))
        p.start()
        procs.append(p)

    # cache garbage collector
    if args.cache_gc_period_max is None:
        print('Testing cache GC process not enabled!')
    else:
        p = multiprocessing.Process(target=proc_cache_garbage_collector,
                                    args=(stop_queue,
                                          args.cache_gc_period_max))
        p.start()
        procs.append(p)

    # reset environment
    os.environ['WANDB_ENTITY'] = settings_entity
    os.environ['WANDB_BASE_URL'] = settings_base_url
    os.environ
    if env_project is None:
        del os.environ['WANDB_PROJECT']
    else:
        os.environ['WANDB_PROJECT'] = env_project
    if sweep_id:
        os.environ['WANDB_SWEEP_ID'] = sweep_id
    if wandb_config_paths:
        os.environ['WANDB_CONFIG_PATHS'] = wandb_config_paths
    if wandb_run_id:
        os.environ['WANDB_RUN_ID'] = wandb_run_id
    # go back to original dir
    os.chdir(pwd)

    # test phase
    start_time = time.time()
    stats = defaultdict(int)

    run = wandb.init(job_type='main-test-phase')
    run.config.update(args)
    while time.time() - start_time < args.test_phase_seconds:
        stat_update = None
        try:
            stat_update = stats_queue.get(True, 5000)
        except queue.Empty:
            pass
        print('** Test time: %s' % (time.time() - start_time))
        if stat_update:
            for k, v in stat_update.items():
                stats[k] += v
        wandb.log(stats)

    print('Test phase time expired')
    # stop all processes and wait til all are done
    for _ in procs:
        stop_queue.put(True)
    print('Waiting for processes to stop')
    fail = False
    for proc in procs:
        proc.join()
        if proc.exitcode != 0:
            print('FAIL! Test phase failed')
            fail = True
            sys.exit(1)

    # drain remaining stats
    while True:
        try:
            stat_update = stats_queue.get_nowait()
        except queue.Empty:
            break
        for k, v in stat_update.items():
            stats[k] += v

    print('Stats')
    import pprint
    pprint.pprint(dict(stats))

    if fail:
        print('FAIL! Test phase failed')
        sys.exit(1)
    else:
        print('Test phase successfully completed')

    print('Starting verification phase')

    os.environ['WANDB_ENTITY'] = (os.environ.get('LOAD_TEST_ENTITY')
                                  or settings_entity)
    os.environ['WANDB_PROJECT'] = project_name
    os.environ['WANDB_BASE_URL'] = (os.environ.get('LOAD_TEST_BASE_URL')
                                    or settings_base_url)
    data_api = wandb.Api()
    # we need list artifacts by walking runs, accessing via
    # project.artifactType.artifacts only returns committed artifacts
    for run in data_api.runs('%s/%s' % (api.settings('entity'), project_name)):
        for v in run.logged_artifacts():
            # TODO: allow deleted once we build deletion support
            if v.state not in ['COMMITTED', 'DELETED']:
                print('FAIL! Artifact version not committed or deleted: %s' %
                      v)
                sys.exit(1)

    print('Verification succeeded')
                                   (envs.action_space.nvec.sum(), )).to(device)
# TRY NOT TO MODIFY: start the game
global_step = 0
start_time = time.time()
# Note how `next_obs` and `next_done` are used; their usage is equivalent to
# https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail/blob/84a7582477fb0d5c82ad6d850fe476829dddd2e1/a2c_ppo_acktr/storage.py#L60
next_obs = envs.reset()
next_done = torch.zeros(args.num_envs).to(device)
num_updates = args.total_timesteps // args.batch_size
## CRASH AND RESUME LOGIC:
starting_update = 1
if args.prod_mode and wandb.run.resumed:
    print("previous run.summary", run.summary)
    starting_update = run.summary['charts/update'] + 1
    global_step = starting_update * args.batch_size
    api = wandb.Api()
    run = api.run(run.get_url()[len("https://app.wandb.ai/"):])
    model = run.file('agent.pt')
    model.download(f"models/{experiment_name}/")
    agent.load_state_dict(torch.load(f"models/{experiment_name}/agent.pt"))
    agent.eval()
    print(f"resumed at update {starting_update}")
for update in range(starting_update, num_updates + 1):
    # Annealing the rate if instructed to do so.
    if args.anneal_lr:
        frac = 1.0 - (update - 1.0) / num_updates
        lrnow = lr(frac)
        optimizer.param_groups[0]['lr'] = lrnow

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(0, args.num_steps):
Beispiel #30
0
def main():
    parser = argparse.ArgumentParser(description='Vehicle orientation')
    parser.add_argument('-u','--user', help='username', default='corner')
    parser.add_argument('-p','--project', help='project name', default='cityai2020Orientation')
    parser.add_argument('-r','--run_id', help='run id', default='pe5y029c')
    traindata = False
    is_synthetic = False
    is_track = True

    args = parser.parse_args()

    num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    print('wandb api')
    api = wandb.Api()
    run = api.run(args.user + '/' + args.project + '/' + args.run_id)
    print('copy wandb configs')
    cfg = copy.deepcopy(run.config)

    if cfg['MODEL.DEVICE'] == "cuda":
        os.environ['CUDA_VISIBLE_DEVICES'] = cfg['MODEL.DEVICE_ID']
    cfg['DATASETS.TRACKS_FILE'] = '/net/merkur/storage/deeplearning/users/eckvik/WorkspaceNeu/VReID/ImageFolderStatistics/Track3Files'
    cudnn.benchmark = True
    print('load dataset')
    if is_synthetic and traindata:
        cfg['DATASETS.SYNTHETIC'] = True
        cfg['DATASETS.SYNTHETIC_LOADER'] = 0
        cfg['DATASETS.SYNTHETIC_DIR'] = 'ai_city_challenge/2020/Track2/AIC20_track2_reid_simulation/AIC20_track2/AIC20_ReID_Simulation'
        dataset = init_dataset('AI_CITY2020_TEST_VAL', cfg=cfg,fold=1,eval_mode=False)
    else:
        if is_track:
            dataset = init_dataset('AI_CITY2020_TRACKS', cfg=cfg,fold=1,eval_mode=False)
        else:
            dataset = init_dataset('AI_CITY2020_TEST_VAL', cfg=cfg,fold=1,eval_mode=False)

    if traindata:
        if is_synthetic and traindata:
            dataset = [item[0] for item in dataset.train]
            dataset = dataset[36935:]
            dataset.sort()
            dataset = [[item, 0,0] for item in dataset]
            val_set = ImageDatasetOrientation(dataset, cfg, is_train=False, test=True)
        else:
            val_set = ImageDatasetOrientation(dataset.train, cfg, is_train=False, test=True)
    else:
        val_set = ImageDatasetOrientation(dataset.query+dataset.gallery, cfg, is_train=False, test=True)
    #
    val_loader = DataLoader(
        val_set, batch_size=cfg['TEST.IMS_PER_BATCH'], shuffle=False, num_workers = cfg['DATALOADER.NUM_WORKERS'],
        collate_fn=val_collate_fn
    )
    print('build model')
    model = build_regression_model(cfg)
    print('get last epoch')
    epoch_best = 10#run.summary['epoch']
    weights_path = os.path.join(cfg['OUTPUT_DIR'],cfg['MODEL.NAME']+'_model_'+str(epoch_best)+'.pth')
    print('load pretrained weights')
    model.load_param(weights_path)
    model.eval()

    evaluator = create_supervised_evaluator(model, metrics={'score_feat': Score_feats()}, device = cfg['MODEL.DEVICE'])
    print('run')
    evaluator.run(val_loader)
    scores,feats,pids,camids = evaluator.state.metrics['score_feat']
    feats = np.array(feats)
    scores = np.array(scores)
    print('save')
    if traindata:

        if is_track:
            feats_mean = []
            for item in dataset.train_tracks_vID:
                indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]])
                feats_mean.append(np.mean(feats[indis], axis=0))
            feats_mean = np.array(feats_mean)

            scores_mean = []
            for item in dataset.train_tracks_vID:
                indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]])
                scores_mean.append(np.mean(scores[indis], axis=0))

            np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train_track.npy'), np.array(feats_mean))  # .npy extension is added if not given
            np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train_track.npy'), np.array(scores_mean))  # .npy extension is added if not given
        else:
            if is_synthetic and traindata:
                np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train_synthetic.npy'), np.array(feats))  # .npy extension is added if not given
                np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train_synthetic.npy'), np.array(scores))  # .npy extension is added if not given
            else:
                np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_train.npy'), np.array(feats))  # .npy extension is added if not given
                np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_train.npy'), np.array(scores))  # .npy extension is added if not given
    else:
        if is_track:
            feats_mean = []
            for feat in feats[:1052]:
                feats_mean.append(feat)
            for item in dataset.test_tracks_vID:
                indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]])
                feats_mean.append(np.mean(feats[1052:][indis], axis=0))
            feats_mean = np.array(feats_mean)

            scores_mean = []
            for score in scores[:1052]:
                scores_mean.append(score)
            for item in dataset.test_tracks_vID:
                indis = np.array([int(jitem[:6]) - 1 for jitem in item[0]])
                scores_mean.append(np.mean(scores[1052:][indis], axis=0))
            np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_query_gal_track.npy'), np.array(feats_mean))  # .npy extension is added if not given
            np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_query_gal_track.npy'), np.array(scores_mean))  # .npy extension is added if not given
        else:
            np.save(os.path.join(cfg['OUTPUT_DIR'], 'feats_query_gal.npy'), np.array(feats))  # .npy extension is added if not given
            np.save(os.path.join(cfg['OUTPUT_DIR'], 'scores_query_gal.npy'), np.array(scores))  # .npy extension is added if not given
        print(cfg['OUTPUT_DIR'])
        print()
        txt_dir='dist_orient'
        num_query = 1052
        all_mAP = np.zeros(num_query)


        statistic_name ='feats'
        feats = torch.from_numpy(feats).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)

        statistic_name ='xyz'
        feats = torch.from_numpy(scores).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)

        statistic_name ='xy'
        scores_curr = scores[:,0:2]
        feats = torch.from_numpy(scores_curr).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)


        statistic_name ='x'
        scores_curr = scores[:, 0:1]
        feats = torch.from_numpy(scores_curr).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)
        statistic_name ='y'
        scores_curr = scores[:, 1:2]
        feats = torch.from_numpy(scores_curr).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)
        statistic_name ='z'
        scores_curr = scores[:, 2:3]
        feats = torch.from_numpy(scores_curr).float().to('cuda')
        feats_normed = torch.nn.functional.normalize(feats, dim=1, p=2)
        # query
        qf = feats_normed[:num_query]
        q_pids = np.asarray(pids[:num_query])
        q_camids = np.asarray(camids[:num_query])
        # gallery
        gf = feats_normed[num_query:]
        g_pids = np.asarray(pids[num_query:])
        g_camids = np.asarray(camids[num_query:])
        m, n = qf.shape[0], gf.shape[0]
        distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
                  torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
        distmat.addmm_(1, -2, qf, gf.t())
        distmat = distmat.cpu().numpy()
        g_camids = np.ones_like(g_camids)
        g_pids = np.ones_like(g_pids)
        generate_image_dir_and_txt(cfg, dataset, txt_dir, distmat, g_pids, q_pids, g_camids, q_camids, all_mAP,
                                   statistic_name=statistic_name, max_rank=100)