Ejemplo n.º 1
0
def load_model(weights_fpath: Path, multi_gpu=False, device=None):
    """
    Loads the model in memory. If this function is not explicitely called, it will be run on the 
    first call to embed_frames() with the default weights file.
    
    :param weights_fpath: the path to saved model weights.
    :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The 
    model will be loaded and will run on this device. Outputs will however always be on the cpu. 
    If None, will default to your GPU if it"s available, otherwise your CPU.
    """
    # TODO: I think the slow loading of the encoder might have something to do with the device it
    #   was saved on. Worth investigating.
    global _model, _device

    if device is None:
        _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif isinstance(device, str):
        _device = torch.device(device)
    else:
        _device = device

    checkpoint = torch.load(weights_fpath, _device)
    _model = Encoder(_device, _device)

    if multi_gpu:
        if torch.cuda.device_count() <= 1:
            raise "multi_gpu cannot be enabled"

        _model = torch.nn.DataParallel(_model)
        # load params
        _model.load_state_dict(checkpoint["model_state"])
    else:
        # https://discuss.pytorch.org/t/solved-keyerror-unexpected-key-module-encoder-embedding-weight-in-state-dict/1686/3
        state_dict = checkpoint['model_state']
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            if 'module' in k:
                name = k[7:] # remove `module.` because of DataParallel
                new_state_dict[name] = v
            else:
                new_state_dict[k] = v

        # load params
        _model.load_state_dict(new_state_dict)

    _model = _model.to(_device)
    _model.eval()
Ejemplo n.º 2
0
def train(run_id: str, data_dir: str, validate_data_dir: str, models_dir: Path,
          umap_every: int, save_every: int, backup_every: int, vis_every: int,
          validate_every: int, force_restart: bool, visdom_server: str,
          port: str, no_visdom: bool):
    # Create a dataset and a dataloader
    train_dataset = LandmarkDataset(data_dir, img_per_cls, train=True)
    train_loader = LandmarkDataLoader(
        train_dataset,
        cls_per_batch,
        img_per_cls,
        num_workers=6,
    )

    validate_dataset = LandmarkDataset(validate_data_dir,
                                       v_img_per_cls,
                                       train=False)
    validate_loader = LandmarkDataLoader(
        validate_dataset,
        v_cls_per_batch,
        v_img_per_cls,
        num_workers=4,
    )

    validate_iter = iter(validate_loader)

    criterion = torch.nn.CrossEntropyLoss()

    # Setup the device on which to run the forward pass and the loss. These can be different,
    # because the forward pass is faster on the GPU whereas the loss is often (depending on your
    # hyperparameters) faster on the CPU.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # FIXME: currently, the gradient is None if loss_device is cuda
    # loss_device = torch.device("cpu")
    # fixed by https://github.com/CorentinJ/Real-Time-Voice-Cloning/issues/237
    loss_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create the model and the optimizer
    model = Encoder(device, loss_device)
    arc_face = ArcFace(model_embedding_size,
                       num_class,
                       scale=30,
                       m=0.35,
                       device=device)

    multi_gpu = False
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if torch.cuda.device_count() > 1:
        multi_gpu = True
        model = torch.nn.DataParallel(model)
        arc_face = torch.nn.DataParallel(arc_face)
    model.to(device)
    arc_face.to(device)

    optimizer = torch.optim.SGD([{
        'params': model.parameters()
    }, {
        'params': arc_face.parameters()
    }],
                                lr=learning_rate_init,
                                momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                                step_size=25000,
                                                gamma=0.5)

    init_step = 1

    # Configure file path for the model
    state_fpath = models_dir.joinpath(run_id + ".pt")
    pretrained_path = state_fpath

    backup_dir = models_dir.joinpath(run_id + "_backups")

    # Load any existing model
    if not force_restart:
        if state_fpath.exists():
            print(
                "Found existing model \"%s\", loading it and resuming training."
                % run_id)
            checkpoint = torch.load(pretrained_path)
            init_step = checkpoint["step"]
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            optimizer.param_groups[0]["lr"] = learning_rate_init
        else:
            print("No model \"%s\" found, starting training from scratch." %
                  run_id)
    else:
        print("Starting the training from scratch.")
    model.train()

    # Initialize the visualization environment
    vis = Visualizations(run_id,
                         vis_every,
                         server=visdom_server,
                         port=port,
                         disabled=no_visdom)
    vis.log_dataset(train_dataset)
    vis.log_params()
    device_name = str(
        torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
    vis.log_implementation({"Device": device_name})

    # Training loop
    profiler = Profiler(summarize_every=500, disabled=False)
    for step, cls_batch in enumerate(train_loader, init_step):
        profiler.tick("Blocking, waiting for batch (threaded)")

        # Forward pass
        inputs = torch.from_numpy(cls_batch.data).float().to(device)
        labels = torch.from_numpy(cls_batch.labels).long().to(device)
        sync(device)
        profiler.tick("Data to %s" % device)

        embeds = model(inputs)
        sync(device)
        profiler.tick("Forward pass")

        output = arc_face(embeds, labels)
        loss = criterion(output, labels)
        sync(device)
        profiler.tick("Loss")

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        profiler.tick("Backward pass")

        optimizer.step()
        scheduler.step()
        profiler.tick("Parameter update")

        acc = get_acc(output, labels)
        # Update visualizations
        # learning_rate = optimizer.param_groups[0]["lr"]
        vis.update(loss.item(), acc, step)

        print("step {}, loss: {}, acc: {}".format(step, loss.item(), acc))

        # Draw projections and save them to the backup folder
        if umap_every != 0 and step % umap_every == 0:
            print("Drawing and saving projections (step %d)" % step)
            projection_dir = backup_dir / 'projections'
            projection_dir.mkdir(exist_ok=True, parents=True)
            projection_fpath = projection_dir.joinpath("%s_umap_%d.png" %
                                                       (run_id, step))
            embeds = embeds.detach()
            embeds = (embeds /
                      torch.norm(embeds, dim=1, keepdim=True)).cpu().numpy()
            vis.draw_projections(embeds, img_per_cls, step, projection_fpath)
            vis.save()

        # Overwrite the latest version of the model
        if save_every != 0 and step % save_every == 0:
            print("Saving the model (step %d)" % step)
            torch.save(
                {
                    "step": step + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                }, state_fpath)

        # Make a backup
        if backup_every != 0 and step % backup_every == 0:
            if step > 4000:  # don't save until 4k steps
                print("Making a backup (step %d)" % step)

                ckpt_dir = backup_dir / 'ckpt'
                ckpt_dir.mkdir(exist_ok=True, parents=True)
                backup_fpath = ckpt_dir.joinpath("%s_%d.pt" % (run_id, step))
                torch.save(
                    {
                        "step": step + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                    }, backup_fpath)

        # Do validation
        if validate_every != 0 and step % validate_every == 0:
            # validation loss, acc
            model.eval()
            for i in range(num_validate):
                with torch.no_grad():
                    validate_cls_batch = next(validate_iter)
                    validate_inputs = torch.from_numpy(
                        validate_cls_batch.data).float().to(device)
                    validat_labels = torch.from_numpy(
                        validate_cls_batch.labels).long().to(device)
                    validate_embeds = model(validate_inputs)
                    validate_output = arc_face(validate_embeds, validat_labels)
                    validate_loss = criterion(validate_output, validat_labels)
                    validate_acc = get_acc(validate_output, validat_labels)

                vis.update_validate(validate_loss.item(), validate_acc, step,
                                    num_validate)

            # take the last one for drawing projection
            projection_dir = backup_dir / 'v_projections'
            projection_dir.mkdir(exist_ok=True, parents=True)
            projection_fpath = projection_dir.joinpath("%s_umap_%d.png" %
                                                       (run_id, step))
            validate_embeds = validate_embeds.detach()
            validate_embeds = (validate_embeds / torch.norm(
                validate_embeds, dim=1, keepdim=True)).cpu().numpy()
            vis.draw_projections(validate_embeds,
                                 v_img_per_cls,
                                 step,
                                 projection_fpath,
                                 is_validate=True)
            vis.save()

            model.train()

        profiler.tick("Extras (visualizations, saving)")
Ejemplo n.º 3
0
PORT = 8000


class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        self.send_response(200)
        self.end_headers()
        self.wfile.write(b'Hello, world!')

    def do_POST(self):
        content_length = int(self.headers['Content-Length'])
        body = self.rfile.read(content_length)
        self.send_response(200)
        self.end_headers()

        img = Image.open(BytesIO(body))
        img = np.array(img)
        img = img / 255

        desc = encoder.predict(np.expand_dims(img, axis=0))[0]
        desc = desc.tolist()

        self.wfile.write(json.dumps(desc).encode("utf-8"))


encoder = Encoder()
encoder.load_weights("encoder.h5")

httpd = HTTPServer(('localhost', PORT), SimpleHTTPRequestHandler)
httpd.serve_forever()
Ejemplo n.º 4
0
from matplotlib import pyplot as plt
import tensorflow as tf
import numpy as np
import random
from encoder.model import Encoder

model = Encoder()
model.load_weights("encoder.h5")

X_train = np.load("X_train.npy")

diffs = []
for pack in X_train:
  d = model.predict(np.float32(pack))
  diffs.append((d, pack))

SAMPLES = 10
TOP_N = 7

fig = plt.figure(figsize=(20, 20))
axarr = fig.subplots(SAMPLES, TOP_N + 1)
for i in range(SAMPLES):
  i1 = random.randint(0, len(X_train) - 1)
  d1 = random.choice(model.predict(np.float32(X_train[i1])))

  diff = sorted(diffs, key=lambda x: np.average(np.sqrt(np.average(np.square(d1 - x[0]), axis=-1))))

  axarr[i, 0].imshow(X_train[i1][0])
  j = 1
  for d, pack in diff[:TOP_N]:
    axarr[i, j].imshow(pack[-1])
Ejemplo n.º 5
0
            d = model(x)
            ds.append(d)
            mds.append(tf.expand_dims(tf.reduce_mean(d, axis=-2), axis=0))

        # print(ds[0])
        # print(d2)

        loss_tensor, std1, std = loss(ds, mds)
        print("{} {} {} {}".format(ep, loss_tensor.numpy(), std1, std))

        gradients = tape.gradient(loss_tensor, model.trainable_variables)
        # print(gradients)
        if np.isnan(gradients[0][0][0][0][0].numpy()):
            return

        optimizer.apply_gradients(zip(gradients, model.trainable_variables))


model = Encoder()
X_train = np.load("X_train.npy")

try:
    model.load_weights("encoder.h5")
except Exception as e:
    print("Failed to load weights")

for i in range(1000):
    train_step(X_train[random.sample(range(len(X_train)), 2)], i)
    if i % 10 == 0:
        model.save("encoder.h5")
        print("Saved model")