Esempio n. 1
0
def train(run_id: str, data_dir: str, validate_data_dir: str, models_dir: Path,
          umap_every: int, save_every: int, backup_every: int, vis_every: int,
          validate_every: int, force_restart: bool, visdom_server: str,
          port: str, no_visdom: bool):
    # Create a dataset and a dataloader
    train_dataset = LandmarkDataset(data_dir, img_per_cls, train=True)
    train_loader = LandmarkDataLoader(
        train_dataset,
        cls_per_batch,
        img_per_cls,
        num_workers=6,
    )

    validate_dataset = LandmarkDataset(validate_data_dir,
                                       v_img_per_cls,
                                       train=False)
    validate_loader = LandmarkDataLoader(
        validate_dataset,
        v_cls_per_batch,
        v_img_per_cls,
        num_workers=4,
    )

    validate_iter = iter(validate_loader)

    criterion = torch.nn.CrossEntropyLoss()

    # Setup the device on which to run the forward pass and the loss. These can be different,
    # because the forward pass is faster on the GPU whereas the loss is often (depending on your
    # hyperparameters) faster on the CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # FIXME: currently, the gradient is None if loss_device is cuda
    # loss_device = torch.device("cpu")
    # fixed by https://github.com/CorentinJ/Real-Time-Voice-Cloning/issues/237
    loss_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the model and the optimizer
    model = Encoder(device, loss_device)
    arc_face = ArcFace(model_embedding_size,
                       num_class,
                       scale=30,
                       m=0.35,
                       device=device)

    multi_gpu = False
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        multi_gpu = True
        model = torch.nn.DataParallel(model)
        arc_face = torch.nn.DataParallel(arc_face)
    model.to(device)
    arc_face.to(device)

    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': arc_face.parameters()
    }],
                                lr=learning_rate_init,
                                momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                step_size=25000,
                                                gamma=0.5)

    init_step = 1

    # Configure file path for the model
    state_fpath = models_dir.joinpath(run_id + ".pt")
    pretrained_path = state_fpath

    backup_dir = models_dir.joinpath(run_id + "_backups")

    # Load any existing model
    if not force_restart:
        if state_fpath.exists():
            print(
                "Found existing model \"%s\", loading it and resuming training."
                % run_id)
            checkpoint = torch.load(pretrained_path)
            init_step = checkpoint["step"]
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            optimizer.param_groups[0]["lr"] = learning_rate_init
        else:
            print("No model \"%s\" found, starting training from scratch." %
                  run_id)
    else:
        print("Starting the training from scratch.")
    model.train()

    # Initialize the visualization environment
    vis = Visualizations(run_id,
                         vis_every,
                         server=visdom_server,
                         port=port,
                         disabled=no_visdom)
    vis.log_dataset(train_dataset)
    vis.log_params()
    device_name = str(
        torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
    vis.log_implementation({"Device": device_name})

    # Training loop
    profiler = Profiler(summarize_every=500, disabled=False)
    for step, cls_batch in enumerate(train_loader, init_step):
        profiler.tick("Blocking, waiting for batch (threaded)")

        # Forward pass
        inputs = torch.from_numpy(cls_batch.data).float().to(device)
        labels = torch.from_numpy(cls_batch.labels).long().to(device)
        sync(device)
        profiler.tick("Data to %s" % device)

        embeds = model(inputs)
        sync(device)
        profiler.tick("Forward pass")

        output = arc_face(embeds, labels)
        loss = criterion(output, labels)
        sync(device)
        profiler.tick("Loss")

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        profiler.tick("Backward pass")

        optimizer.step()
        scheduler.step()
        profiler.tick("Parameter update")

        acc = get_acc(output, labels)
        # Update visualizations
        # learning_rate = optimizer.param_groups[0]["lr"]
        vis.update(loss.item(), acc, step)

        print("step {}, loss: {}, acc: {}".format(step, loss.item(), acc))

        # Draw projections and save them to the backup folder
        if umap_every != 0 and step % umap_every == 0:
            print("Drawing and saving projections (step %d)" % step)
            projection_dir = backup_dir / 'projections'
            projection_dir.mkdir(exist_ok=True, parents=True)
            projection_fpath = projection_dir.joinpath("%s_umap_%d.png" %
                                                       (run_id, step))
            embeds = embeds.detach()
            embeds = (embeds /
                      torch.norm(embeds, dim=1, keepdim=True)).cpu().numpy()
            vis.draw_projections(embeds, img_per_cls, step, projection_fpath)
            vis.save()

        # Overwrite the latest version of the model
        if save_every != 0 and step % save_every == 0:
            print("Saving the model (step %d)" % step)
            torch.save(
                {
                    "step": step + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                }, state_fpath)

        # Make a backup
        if backup_every != 0 and step % backup_every == 0:
            if step > 4000:  # don't save until 4k steps
                print("Making a backup (step %d)" % step)

                ckpt_dir = backup_dir / 'ckpt'
                ckpt_dir.mkdir(exist_ok=True, parents=True)
                backup_fpath = ckpt_dir.joinpath("%s_%d.pt" % (run_id, step))
                torch.save(
                    {
                        "step": step + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                    }, backup_fpath)

        # Do validation
        if validate_every != 0 and step % validate_every == 0:
            # validation loss, acc
            model.eval()
            for i in range(num_validate):
                with torch.no_grad():
                    validate_cls_batch = next(validate_iter)
                    validate_inputs = torch.from_numpy(
                        validate_cls_batch.data).float().to(device)
                    validat_labels = torch.from_numpy(
                        validate_cls_batch.labels).long().to(device)
                    validate_embeds = model(validate_inputs)
                    validate_output = arc_face(validate_embeds, validat_labels)
                    validate_loss = criterion(validate_output, validat_labels)
                    validate_acc = get_acc(validate_output, validat_labels)

                vis.update_validate(validate_loss.item(), validate_acc, step,
                                    num_validate)

            # take the last one for drawing projection
            projection_dir = backup_dir / 'v_projections'
            projection_dir.mkdir(exist_ok=True, parents=True)
            projection_fpath = projection_dir.joinpath("%s_umap_%d.png" %
                                                       (run_id, step))
            validate_embeds = validate_embeds.detach()
            validate_embeds = (validate_embeds / torch.norm(
                validate_embeds, dim=1, keepdim=True)).cpu().numpy()
            vis.draw_projections(validate_embeds,
                                 v_img_per_cls,
                                 step,
                                 projection_fpath,
                                 is_validate=True)
            vis.save()

            model.train()

        profiler.tick("Extras (visualizations, saving)")
Esempio n. 2
0
def train(run_id: str, clean_data_root: Path, models_dir: Path,
          umap_every: int, save_every: int, backup_every: int, vis_every: int,
          force_restart: bool, visdom_server: str, no_visdom: bool):
    # Create a dataset and a dataloader
    dataset = SpeakerVerificationDataset(clean_data_root)
    loader = SpeakerVerificationDataLoader(
        dataset,
        speakers_per_batch,
        utterances_per_speaker,
        num_workers=8,
    )

    # Setup the device on which to run the forward pass and the loss. These can be different,
    # because the forward pass is faster on the GPU whereas the loss is often (depending on your
    # hyperparameters) faster on the CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # FIXME: currently, the gradient is None if loss_device is cuda
    loss_device = torch.device("cpu")

    # Create the model and the optimizer
    model = SpeakerEncoder(device, loss_device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
    init_step = 1

    # Configure file path for the model
    state_fpath = models_dir.joinpath(run_id + ".pt")
    backup_dir = models_dir.joinpath(run_id + "_backups")

    # Load any existing model
    if not force_restart:
        if state_fpath.exists():
            print(
                "Found existing model \"%s\", loading it and resuming training."
                % run_id)
            checkpoint = torch.load(state_fpath)
            init_step = checkpoint["step"]
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            optimizer.param_groups[0]["lr"] = learning_rate_init
        else:
            print("No model \"%s\" found, starting training from scratch." %
                  run_id)
    else:
        print("Starting the training from scratch.")
    model.train()

    # Initialize the visualization environment
    vis = Visualizations(run_id,
                         vis_every,
                         server=visdom_server,
                         disabled=no_visdom)
    vis.log_dataset(dataset)
    vis.log_params()
    device_name = str(
        torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
    vis.log_implementation({"Device": device_name})

    # Training loop
    profiler = Profiler(summarize_every=10, disabled=False)
    for step, speaker_batch in enumerate(loader, init_step):
        profiler.tick("Blocking, waiting for batch (threaded)")

        # Forward pass
        inputs = torch.from_numpy(speaker_batch.data).to(device)
        sync(device)
        profiler.tick("Data to %s" % device)
        embeds = model(inputs)
        sync(device)
        profiler.tick("Forward pass")
        embeds_loss = embeds.view(
            (speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
        loss, eer = model.loss(embeds_loss)
        sync(loss_device)
        profiler.tick("Loss")

        # Backward pass
        model.zero_grad()
        loss.backward()
        profiler.tick("Backward pass")
        model.do_gradient_ops()
        optimizer.step()
        profiler.tick("Parameter update")

        # Update visualizations
        # learning_rate = optimizer.param_groups[0]["lr"]
        vis.update(loss.item(), eer, step)

        # Draw projections and save them to the backup folder
        if umap_every != 0 and step % umap_every == 0:
            print("Drawing and saving projections (step %d)" % step)
            backup_dir.mkdir(exist_ok=True)
            projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" %
                                                   (run_id, step))
            embeds = embeds.detach().cpu().numpy()
            vis.draw_projections(embeds, utterances_per_speaker, step,
                                 projection_fpath)
            vis.save()

        # Overwrite the latest version of the model
        if save_every != 0 and step % save_every == 0:
            print("Saving the model (step %d)" % step)
            torch.save(
                {
                    "step": step + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                }, state_fpath)

        # Make a backup
        if backup_every != 0 and step % backup_every == 0:
            print("Making a backup (step %d)" % step)
            backup_dir.mkdir(exist_ok=True)
            backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" %
                                               (run_id, step))
            torch.save(
                {
                    "step": step + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                }, backup_fpath)

        profiler.tick("Extras (visualizations, saving)")