Exemplo n.º 1
0
def get_transformations(transform_spec):
    transformer = transformers.TransformerSequence(*[
        TRANSFORMER_STRING_MAP[trans_name](networks.EquivariantPosePredictor,
                                           1, 32)
        for trans_name in transform_spec
    ])

    transforms = T.TransformSequence(
        *[TRANSFORM_STRING_MAP[trans_name]() for trans_name in transform_spec])
    return transforms, transformer
Exemplo n.º 2
0
    def model(self,
              data,
              transforms=None):

        output_size = self.encoder.insize
        decoder = pyro.module("decoder", self.decoder)
        # decoder takes z and std in the transformed coordinate frame
        # and the theta
        # and outputs an upright image
        with pyro.plate(data.shape[0]):
            # prior for z
            z = pyro.sample(
                "z",
                D.Normal(
                    torch.zeros(decoder.z_dim, device=data.device),
                    torch.ones(decoder.z_dim, device=data.device),
                ).to_event(1),
            )

            # given a z, the decoder produces an "image"
            # this image must be transformed from the self consistent basis
            # to real world basis
            # first, z and std for the self consistent basis is outputted
            # then it is transfomed
            view = decoder(z)

            pyro.deterministic("canonical_view", view)
            # pyro.deterministic
            # is like pyro.sample but it is deterministic...?

            # all of this is completely independent of the input
            # maybe this is the "prior for the transformation"
            # and hence it looks completely independent of the input
            # but when the model is run again, these variables are replayed
            # with the theta generated by the guide
            # makes sense
            # so the model replays with theta and mu and sigma generated by
            # the guide,
            # taking theta and mu sigma and applying the inverse transform
            # to get the output image.
            grid = coordinates.identity_grid(
                [output_size, output_size], device=data.device)
            grid = grid.expand(data.shape[0], *grid.shape)
            transforms = T.TransformSequence(T.Translation(), T.Rotation())
            transform = random_pose_transform(transforms)

            transform_grid = transform(grid)

            # output from decoder is transormed in do a different coordinate system

            transformed_view = T.broadcasting_grid_sample(view, transform_grid)

            # view from decoder outputs an image
            pyro.sample(
                "pixels", D.Bernoulli(transformed_view).to_event(3), obs=data)
Exemplo n.º 3
0
def evaluate(svi, test_loader, use_cuda=False):
    # initialize loss accumulator
    test_loss = 0.
    transforms = T.TransformSequence(T.Translation(), T.Rotation())
    # compute the loss over the entire test set
    for x in test_loader:
        # if on GPU put mini-batch into CUDA memory
        x = x['image']
        if use_cuda:
            x = x.cuda()
        # compute ELBO estimate and accumulate loss
        test_loss += svi.evaluate_loss(x, transforms)
    normalizer_test = len(test_loader.dataset)
    total_epoch_loss_test = test_loss / normalizer_test
    return total_epoch_loss_test
Exemplo n.º 4
0
def train(svi, train_loader, use_cuda=False):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch x returned
    # by the data loader
    transforms = T.TransformSequence(T.Translation(), T.Rotation())
    for x in train_loader:
        x = x['image']
        # if on GPU put mini-batch into CUDA memory
        if use_cuda:
            x = x.cuda()
        # do ELBO gradient and accumulate loss
        epoch_loss += svi.step(x, transforms)
    normalizer_train = len(train_loader.dataset)
    total_epoch_loss_train = epoch_loss / normalizer_train
    return total_epoch_loss_train
Exemplo n.º 5
0
def train_fs_epoch(vae,
                   vae_optim,
                   vae_loss_fn,
                   classifier,
                   classifier_optim,
                   classifier_loss_fn,
                   use_cuda,
                   train_loader=None,
                   alpha=1):
    """
    FULLY SUPERVISED TRAINING FUNCTION USING POSE VAE
    train vae encoder and classifier for one epoch
    returns loss for one epoch
    """
    epoch_loss_vae = 0.
    epoch_loss_classifier = 0.
    total_acc = 0.
    num_steps = 0
    for data in train_loader:
        x = data['image']
        y = data['data']
        transforms = T.TransformSequence(T.Translation(), T.Rotation())
        if use_cuda:
            x = x.cuda()
            y = y.cuda()
        # step of elbo for vae
        classifier_optim.zero_grad()
        vae_optim.zero_grad()
        vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms)
        out, split = vae.encoder(x)
        # combined_z = torch.cat((z_loc, z_scale), 1)
        y_out = classifier.forward(split)
        classifier_loss = classifier_loss_fn(y_out, y)
        # step through classifier
        total_loss = vae_loss + alpha * classifier_loss
        epoch_loss_vae += vae_loss.item()
        epoch_loss_classifier += classifier_loss.item()
        total_loss.backward()
        vae_optim.step()
        classifier_optim.step()
        num_steps += 1
        total_acc += torch.sum(torch.eq(y_out.argmax(dim=1), y.argmax(dim=1)))
    normalizer = len(train_loader.dataset)
    total_epoch_loss_vae = epoch_loss_vae / normalizer
    total_epoch_loss_classifier = epoch_loss_classifier / normalizer
    total_acc_norm = total_acc / normalizer
    return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps
def train_ss_vae_classifier(vae, vae_optim, vae_loss_fn, classifier, classifier_optim, classifier_loss_fn, train_loader, use_cuda=True):
    """
    train vae and classifier for one epoch
    returns loss for one epoch
    in each batch, when the svi takes a step, the optimiser of classifier takes a step
    """
    # classifier is in train mode for dropout
    classifier.train() 
    
    epoch_loss_vae = 0.
    epoch_loss_classifier = 0.
    total_acc = 0.
    num_steps = 0
    supervised_len = len(train_loader)

    for data in train_loader:
        x = data['image']
        y = data['data']
        if use_cuda:
            x = x.cuda()
            y = y.cuda()

        transforms = T.TransformSequence(T.Translation(), T.Rotation())  
        classifier_optim.zero_grad()
        vae_optim.zero_grad()
        # supervised step
        vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms)
        out, split = vae.encoder(x)
        y_out = classifier.forward(split)
        
        classifier_loss = classifier_loss_fn(y_out, y)

        total_loss = vae_loss + classifier_loss
        epoch_loss_vae += vae_loss.item()
        epoch_loss_classifier += classifier_loss.item()
        total_acc += torch.sum(torch.eq(y_out.argmax(dim=1),y.argmax(dim=1)))
        total_loss.backward()
        classifier_optim.step()

        num_steps +=1
        epoch_loss_vae += vae_loss.item()
    normaliser = len(train_loader.dataset)
    total_epoch_loss_vae = epoch_loss_vae / 2*normaliser
    total_epoch_loss_classifier = epoch_loss_classifier / normaliser
    total_acc_norm = total_acc / normaliser
    return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps
Exemplo n.º 7
0
 def sample_img(self, x, use_cuda=False, encoder=False, decoder=False):
     # encode image x
     if use_cuda == True:
         x = x.cuda()
     batch_shape = x.shape[0]
     img_shape = x.shape[-1]
     transforms = T.TransformSequence(T.Translation(), T.Rotation())  
     if encoder == False:
         out, split = self.encoder(x)
     else:
         out, split = encoder(x)
     # sample in latent space
     z = out["z_mu"]
     # decode the image (note we don't sample in image space)
     if decoder == False:
         loc_img = self.decoder(z)
     else:
         loc_img = decoder(z)
     return loc_img.reshape([batch_shape, 1, img_shape, img_shape])
Exemplo n.º 8
0
def evaluate(vae,
             vae_loss_fn,
             classifier,
             classifier_loss_fn,
             test_loader,
             use_cuda=False,
             transform=False):
    """
    evaluates for all test data
    test data is in batches, all batches in test loader tested
    """
    # classifier is in eval mode
    classifier.eval()
    epoch_loss_vae = 0.
    epoch_loss_classifier = 0.
    total_acc = 0.
    rms = 0.
    for data in test_loader:
        x = data['image']
        y = data['data']
        if transform is not False:
            x = transform(x)
        if use_cuda:
            x = x.cuda()
            y = y.cuda()
        # step of elbo for vae
        transforms = T.TransformSequence(T.Translation(), T.Rotation())
        out, split = vae.encoder(x)
        vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms)
        # combined_z = torch.cat((z_loc, z_scale), 1)
        # combined_z = combined_z.detach()
        y_out = classifier.forward(split)
        classifier_loss = classifier_loss_fn(y_out, y)
        total_acc += torch.sum(torch.eq(y_out.argmax(dim=1), y.argmax(dim=1)))
        epoch_loss_vae += vae_loss.item()
        epoch_loss_classifier += classifier_loss.item()
        rms += rms_calc(y_out, y)
    normalizer = len(test_loader.dataset)
    total_epoch_loss_vae = epoch_loss_vae / normalizer
    total_epoch_loss_classifier = epoch_loss_classifier / normalizer
    total_epoch_acc = total_acc / normalizer
    rms_epoch = rms / normalizer
    return total_epoch_loss_vae, total_epoch_loss_classifier, total_epoch_acc, rms_epoch
def train_ss_vae_classifier(vae, vae_optim, vae_loss_fn, classifier, classifier_optim, classifier_loss_fn,
                         train_s_loader, train_us_loader, use_cuda=True):
    """
    train vae and classifier for one epoch
    returns loss for one epoch
    in each batch, when the svi takes a step, the optimiser of classifier takes a step
    """
    # classifier is in train mode for dropout
    classifier.train() 
    
    epoch_loss_vae = 0.
    epoch_loss_classifier = 0.
    total_acc = 0.
    num_steps = 0
    supervised_len = len(train_s_loader)
    unsupervised_len = len(train_us_loader)
    zip_list = zip(train_s_loader, cycle(train_us_loader)) if len(train_s_loader) > len(train_us_loader) else zip(cycle(train_s_loader), train_us_loader)
    for data_sup, data_unsup in zip_list:
        xs = data_sup['image']
        ys = data_sup['data']
        xus = data_unsup['image']
        yus = data_unsup['data']
        if use_cuda:
            xs = xs.cuda()
            ys = ys.cuda()
            xus = xus.cuda()
        transforms = T.TransformSequence(T.Translation(), T.Rotation())  
        classifier_optim.zero_grad()
        vae_optim.zero_grad()
        # supervised step
        vae_loss = vae_loss_fn(vae.model, vae.guide, xs, transforms)
        out, split = vae.encoder(xs)
        y_out = classifier.forward(split)
        
        classifier_loss = classifier_loss_fn(y_out, ys)

        total_loss = vae_loss + classifier_loss
        epoch_loss_vae += vae_loss.item()
        epoch_loss_classifier += classifier_loss.item()
        total_acc += torch.sum(torch.eq(y_out.argmax(dim=1),ys.argmax(dim=1)))
        total_loss.backward()
        
        vae_optim.step()
        classifier_optim.step()

        # unsupervised step
        vae_optim.zero_grad()
        transforms = T.TransformSequence(T.Translation(), T.Rotation())  
        vae_loss = vae_loss_fn(vae.model, vae.guide, xus, transforms)
        vae_loss.backward()
        vae_optim.step()
        num_steps +=1
        epoch_loss_vae += vae_loss.item()
    if supervised_len > unsupervised_len:
        normaliser = len(train_s_loader.dataset)
    else:
        normaliser = len(train_us_loader.dataset)
    total_epoch_loss_vae = epoch_loss_vae / 2*normaliser
    total_epoch_loss_classifier = epoch_loss_classifier / normaliser
    total_acc_norm = total_acc / normaliser
    return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps