def model(self, data, transforms=None): output_size = self.encoder.insize decoder = pyro.module("decoder", self.decoder) # decoder takes z and std in the transformed coordinate frame # and the theta # and outputs an upright image with pyro.plate(data.shape[0]): # prior for z z = pyro.sample( "z", D.Normal( torch.zeros(decoder.z_dim, device=data.device), torch.ones(decoder.z_dim, device=data.device), ).to_event(1), ) # given a z, the decoder produces an "image" # this image must be transformed from the self consistent basis # to real world basis # first, z and std for the self consistent basis is outputted # then it is transfomed view = decoder(z) pyro.deterministic("canonical_view", view) # pyro.deterministic # is like pyro.sample but it is deterministic...? # all of this is completely independent of the input # maybe this is the "prior for the transformation" # and hence it looks completely independent of the input # but when the model is run again, these variables are replayed # with the theta generated by the guide # makes sense # so the model replays with theta and mu and sigma generated by # the guide, # taking theta and mu sigma and applying the inverse transform # to get the output image. grid = coordinates.identity_grid( [output_size, output_size], device=data.device) grid = grid.expand(data.shape[0], *grid.shape) transforms = T.TransformSequence(T.Translation(), T.Rotation()) transform = random_pose_transform(transforms) transform_grid = transform(grid) # output from decoder is transormed in do a different coordinate system transformed_view = T.broadcasting_grid_sample(view, transform_grid) # view from decoder outputs an image pyro.sample( "pixels", D.Bernoulli(transformed_view).to_event(3), obs=data)
def evaluate(svi, test_loader, use_cuda=False): # initialize loss accumulator test_loss = 0. transforms = T.TransformSequence(T.Translation(), T.Rotation()) # compute the loss over the entire test set for x in test_loader: # if on GPU put mini-batch into CUDA memory x = x['image'] if use_cuda: x = x.cuda() # compute ELBO estimate and accumulate loss test_loss += svi.evaluate_loss(x, transforms) normalizer_test = len(test_loader.dataset) total_epoch_loss_test = test_loss / normalizer_test return total_epoch_loss_test
def train(svi, train_loader, use_cuda=False): # initialize loss accumulator epoch_loss = 0. # do a training epoch over each mini-batch x returned # by the data loader transforms = T.TransformSequence(T.Translation(), T.Rotation()) for x in train_loader: x = x['image'] # if on GPU put mini-batch into CUDA memory if use_cuda: x = x.cuda() # do ELBO gradient and accumulate loss epoch_loss += svi.step(x, transforms) normalizer_train = len(train_loader.dataset) total_epoch_loss_train = epoch_loss / normalizer_train return total_epoch_loss_train
def train_fs_epoch(vae, vae_optim, vae_loss_fn, classifier, classifier_optim, classifier_loss_fn, use_cuda, train_loader=None, alpha=1): """ FULLY SUPERVISED TRAINING FUNCTION USING POSE VAE train vae encoder and classifier for one epoch returns loss for one epoch """ epoch_loss_vae = 0. epoch_loss_classifier = 0. total_acc = 0. num_steps = 0 for data in train_loader: x = data['image'] y = data['data'] transforms = T.TransformSequence(T.Translation(), T.Rotation()) if use_cuda: x = x.cuda() y = y.cuda() # step of elbo for vae classifier_optim.zero_grad() vae_optim.zero_grad() vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms) out, split = vae.encoder(x) # combined_z = torch.cat((z_loc, z_scale), 1) y_out = classifier.forward(split) classifier_loss = classifier_loss_fn(y_out, y) # step through classifier total_loss = vae_loss + alpha * classifier_loss epoch_loss_vae += vae_loss.item() epoch_loss_classifier += classifier_loss.item() total_loss.backward() vae_optim.step() classifier_optim.step() num_steps += 1 total_acc += torch.sum(torch.eq(y_out.argmax(dim=1), y.argmax(dim=1))) normalizer = len(train_loader.dataset) total_epoch_loss_vae = epoch_loss_vae / normalizer total_epoch_loss_classifier = epoch_loss_classifier / normalizer total_acc_norm = total_acc / normalizer return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps
def train_ss_vae_classifier(vae, vae_optim, vae_loss_fn, classifier, classifier_optim, classifier_loss_fn, train_loader, use_cuda=True): """ train vae and classifier for one epoch returns loss for one epoch in each batch, when the svi takes a step, the optimiser of classifier takes a step """ # classifier is in train mode for dropout classifier.train() epoch_loss_vae = 0. epoch_loss_classifier = 0. total_acc = 0. num_steps = 0 supervised_len = len(train_loader) for data in train_loader: x = data['image'] y = data['data'] if use_cuda: x = x.cuda() y = y.cuda() transforms = T.TransformSequence(T.Translation(), T.Rotation()) classifier_optim.zero_grad() vae_optim.zero_grad() # supervised step vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms) out, split = vae.encoder(x) y_out = classifier.forward(split) classifier_loss = classifier_loss_fn(y_out, y) total_loss = vae_loss + classifier_loss epoch_loss_vae += vae_loss.item() epoch_loss_classifier += classifier_loss.item() total_acc += torch.sum(torch.eq(y_out.argmax(dim=1),y.argmax(dim=1))) total_loss.backward() classifier_optim.step() num_steps +=1 epoch_loss_vae += vae_loss.item() normaliser = len(train_loader.dataset) total_epoch_loss_vae = epoch_loss_vae / 2*normaliser total_epoch_loss_classifier = epoch_loss_classifier / normaliser total_acc_norm = total_acc / normaliser return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps
def sample_img(self, x, use_cuda=False, encoder=False, decoder=False): # encode image x if use_cuda == True: x = x.cuda() batch_shape = x.shape[0] img_shape = x.shape[-1] transforms = T.TransformSequence(T.Translation(), T.Rotation()) if encoder == False: out, split = self.encoder(x) else: out, split = encoder(x) # sample in latent space z = out["z_mu"] # decode the image (note we don't sample in image space) if decoder == False: loc_img = self.decoder(z) else: loc_img = decoder(z) return loc_img.reshape([batch_shape, 1, img_shape, img_shape])
def evaluate(vae, vae_loss_fn, classifier, classifier_loss_fn, test_loader, use_cuda=False, transform=False): """ evaluates for all test data test data is in batches, all batches in test loader tested """ # classifier is in eval mode classifier.eval() epoch_loss_vae = 0. epoch_loss_classifier = 0. total_acc = 0. rms = 0. for data in test_loader: x = data['image'] y = data['data'] if transform is not False: x = transform(x) if use_cuda: x = x.cuda() y = y.cuda() # step of elbo for vae transforms = T.TransformSequence(T.Translation(), T.Rotation()) out, split = vae.encoder(x) vae_loss = vae_loss_fn(vae.model, vae.guide, x, transforms) # combined_z = torch.cat((z_loc, z_scale), 1) # combined_z = combined_z.detach() y_out = classifier.forward(split) classifier_loss = classifier_loss_fn(y_out, y) total_acc += torch.sum(torch.eq(y_out.argmax(dim=1), y.argmax(dim=1))) epoch_loss_vae += vae_loss.item() epoch_loss_classifier += classifier_loss.item() rms += rms_calc(y_out, y) normalizer = len(test_loader.dataset) total_epoch_loss_vae = epoch_loss_vae / normalizer total_epoch_loss_classifier = epoch_loss_classifier / normalizer total_epoch_acc = total_acc / normalizer rms_epoch = rms / normalizer return total_epoch_loss_vae, total_epoch_loss_classifier, total_epoch_acc, rms_epoch
def train_ss_vae_classifier(vae, vae_optim, vae_loss_fn, classifier, classifier_optim, classifier_loss_fn, train_s_loader, train_us_loader, use_cuda=True): """ train vae and classifier for one epoch returns loss for one epoch in each batch, when the svi takes a step, the optimiser of classifier takes a step """ # classifier is in train mode for dropout classifier.train() epoch_loss_vae = 0. epoch_loss_classifier = 0. total_acc = 0. num_steps = 0 supervised_len = len(train_s_loader) unsupervised_len = len(train_us_loader) zip_list = zip(train_s_loader, cycle(train_us_loader)) if len(train_s_loader) > len(train_us_loader) else zip(cycle(train_s_loader), train_us_loader) for data_sup, data_unsup in zip_list: xs = data_sup['image'] ys = data_sup['data'] xus = data_unsup['image'] yus = data_unsup['data'] if use_cuda: xs = xs.cuda() ys = ys.cuda() xus = xus.cuda() transforms = T.TransformSequence(T.Translation(), T.Rotation()) classifier_optim.zero_grad() vae_optim.zero_grad() # supervised step vae_loss = vae_loss_fn(vae.model, vae.guide, xs, transforms) out, split = vae.encoder(xs) y_out = classifier.forward(split) classifier_loss = classifier_loss_fn(y_out, ys) total_loss = vae_loss + classifier_loss epoch_loss_vae += vae_loss.item() epoch_loss_classifier += classifier_loss.item() total_acc += torch.sum(torch.eq(y_out.argmax(dim=1),ys.argmax(dim=1))) total_loss.backward() vae_optim.step() classifier_optim.step() # unsupervised step vae_optim.zero_grad() transforms = T.TransformSequence(T.Translation(), T.Rotation()) vae_loss = vae_loss_fn(vae.model, vae.guide, xus, transforms) vae_loss.backward() vae_optim.step() num_steps +=1 epoch_loss_vae += vae_loss.item() if supervised_len > unsupervised_len: normaliser = len(train_s_loader.dataset) else: normaliser = len(train_us_loader.dataset) total_epoch_loss_vae = epoch_loss_vae / 2*normaliser total_epoch_loss_classifier = epoch_loss_classifier / normaliser total_acc_norm = total_acc / normaliser return total_epoch_loss_vae, total_epoch_loss_classifier, total_acc_norm, num_steps