Example #1
0
 def mixcod_grad_norm(mixcod, x1, x2, phase, alpha):
     x1_hat = Variable(x1, requires_grad=True)
     # x2_hat      = Variable(x2, requires_grad=True)
     _, fake_x1_hat, z1_struct, z1_hat_struct = mixcod(
         x1_hat, x2, phase, alpha)
     err_x_hat = utils.mismatch(fake_x1_hat, x1_hat, args.match_x_metric)
     err_z_hat = utils.mismatch(z1_struct, z1_hat_struct,
                                args.match_z_metric)
     loss = err_x_hat.sum() + err_z_hat.sum()
     grad_x_hat = grad(outputs=loss, inputs=x1_hat, create_graph=True)[0]
     grad_norm = grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1)
     return grad_norm
Example #2
0
    def mix_loss(mixcod, crt1, crt2, x1, x2, phase, alpha):
        x1_x2, x1_x1, z1_struct, z1_hat_struct = mixcod(x1, x2, phase, alpha)
        loss_x = utils.mismatch(x1_x1, x1, args.match_x_metric)
        loss_z = utils.mismatch(z1_struct, z1_hat_struct, args.match_z_metric)

        # compare cyclic x1 reconstruction with original x1
        crt1_real = crt1(x1, x1, phase, alpha)
        crt1_fake = crt1(x1_x1, x1, phase, alpha)
        loss_crt1 = real_fake_loss(crt1_real, crt1_fake)
        # compare fake x1 translation with real x2
        crt2_real = crt2(x2, x2, phase, alpha)
        crt2_fake = crt2(x1_x2, x1_x2, phase, alpha)
        loss_crt2 = real_fake_loss(crt2_real, crt2_fake)
        return [loss_x, loss_z, alpha * loss_crt1,
                alpha * loss_crt2], x1_x1.detach(), x1_x2.detach()
Example #3
0
 def autoenc_grad_norm(autoenc, x, phase, alpha):
     x_hat = Variable(x, requires_grad=True)
     fake_x_hat = autoenc(x_hat, phase, alpha)
     err_x_hat = utils.mismatch(fake_x_hat, x_hat, args.match_x_metric)
     grad_x_hat = grad(outputs=err_x_hat.sum(),
                       inputs=x_hat,
                       create_graph=True)[0]
     grad_norm = grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1)
     return grad_norm
Example #4
0
 def autoenc_loss(autoenc, crt, x, phase, alpha):
     x_fake = autoenc(x, phase, alpha)
     loss_x = utils.mismatch(x_fake, x, args.match_x_metric)
     crt_real = crt(x, x, phase, alpha)
     crt_fake = crt(x_fake, x, phase, alpha)
     loss_crt = real_fake_loss(crt_real, crt_fake)
     return [
         args.autoenc_weight * loss_x, alpha * args.autoenc_weight**loss_crt
     ], x_fake.detach()
Example #5
0
def updateImages(input, session):
    encoder, generator, critic = session.encoder, session.generator, session.critic
    # batch,alpha,phase = session.cur_batch(), session.alpha, session.phase
    phase = 5
    alpha = 1.0

    stats = {}
    utils.requires_grad(encoder, False)
    utils.requires_grad(generator, False)
    utils.requires_grad(critic, False)

    smoothing = GaussianSmoothing(1, 11.0, 10.0).cuda()
    x = smoothing(input)
    x = torch.abs(x - x[[1]])  #+ 0.5*torch.rand_like(input)
    x = Variable(x, requires_grad=True)

    optimizer = optim.Adam([x], 0.001, betas=(0.0, 0.99))

    while True:
        for i in range(1):
            losses = []
            optimizer.zero_grad()

            real_z = encoder(x, phase, alpha)
            fake_x = generator(real_z, phase, alpha)

            err_x = utils.mismatch(fake_x, x, args.match_x_metric)
            losses.append(err_x)
            stats['x_err'] = err_x.data

            cls_fake = critic(fake_x, x, session.phase, session.alpha)
            # measure loss only where real score is highre than fake score
            cls_loss = -cls_fake.mean()
            stats['cls_loss'] = cls_loss.data
            # warm up critic loss to kick in with alpha
            losses.append(cls_loss)

            # Propagate gradients for encoder and decoder
            loss = sum(losses)
            loss.backward()

            g = x.grad.cpu().data

            # Apply encoder and decoder gradients
            optimizer.step()

        idx = 0
        imshow(x[idx, 0].cpu().data)
        imshow(fake_x[idx, 0].cpu().data)
        # imshow(input[idx,0].cpu().data)
        # imshow(g[0,0].cpu().data)

        clf()

    return stats
import seaborn as sns
import matplotlib.gridspec as gridspec
from itertools import product

pos = list(product([i for i in range(20)], repeat=2))

# VEGFA_SpCas9
best_x_1 = [-0.48330546, -0.27027048, 1.71309145, 1.77855314]
# VEGFA_xCas9
best_x_2 = [-1.28784322, -0.45575753, 4.94635742, 0.09800039]
# HEK_site1_SpCas9
best_x_3 = [-1.52845086, -0.54307945, 3.02779871, 1.03082501]
# HEK_site_2_xCas9
best_x_4 = [-4.30187549, -0.5218599, 6.3321199, 0.49449049]

out = mismatch(pos)
pred_1 = predict(out, best_x_1)
pred_2 = predict(out, best_x_2)
pred_3 = predict(out, best_x_3)
pred_4 = predict(out, best_x_4)

fig = plt.figure(figsize=(10, 8))
gs = gridspec.GridSpec(2, 2)
cbar_ax = fig.add_axes([.91, .3, .03, .4])
j = 0
for g, pred, title in zip(
    [gs[i] for i in range(4)], [pred_1, pred_2, pred_3, pred_4],
    ['VEGFA SpCas9', 'VEGFA xCas9', 'HEK site1 SpCas9', 'HEK site1 xCas9']):
    mat = np.zeros((20, 20))
    for i, (index, value) in enumerate(zip(pos, pred)):
        mat[index] = pred[i]
Example #7
0
    def update(self, batch, phase, alpha):
        encoder, generator, critic = self.encoder, self.generator, self.critic
        stats, losses = {}, []
        utils.requires_grad(encoder, True)
        utils.requires_grad(generator, True)
        utils.requires_grad(critic, False)
        encoder.zero_grad()
        generator.zero_grad()

        x = batch[0]
        batch_size = x.shape[0]

        real_z = encoder(x, phase, alpha)
        fake_x = generator(real_z, phase, alpha)

        # use no gradient propagation if no x metric is required
        if args.use_x_metric:
            # match x: E_x||g(e(x)) - x|| -> min_e
            err_x = utils.mismatch(fake_x, x, args.match_x_metric)
            losses.append(err_x)
        else:
            with torch.no_grad():
                err_x = utils.mismatch(fake_x, x, args.match_x_metric)
        stats['x_err'] = err_x

        if args.use_z_metric:
            # cyclic match z E_x||e(g(e(x))) - e(x)||^2
            fake_z = encoder(fake_x, phase, alpha)
            err_z = utils.mismatch(real_z, fake_z, args.match_z_metric)
            losses.append(err_z)
        else:
            with torch.no_grad():
                fake_z = encoder(fake_x, phase, alpha)
                err_z = utils.mismatch(real_z, fake_z, args.match_z_metric)
        stats['z_err'] = err_z

        cls_fake = critic(fake_x, x, phase, alpha)

        cls_real = critic(x, x, phase, alpha)

        # measure loss only where real score is highre than fake score
        G_loss = -(cls_fake *
                   (cls_real.detach() > cls_fake.detach()).float()).mean()

        # Gloss      = -torch.log(cls_fake).mean()
        stats['G_loss'] = G_loss
        # warm up critic loss to kick in with alpha
        losses.append(alpha * G_loss)

        # Propagate gradients for encoder and decoder
        loss = sum(losses)
        loss.backward()

        # Apply encoder and decoder gradients
        self.optimizerE.step()
        self.optimizerG.step()

        ###### Critic ########
        losses = []
        utils.requires_grad(critic, True)
        utils.requires_grad(encoder, False)
        utils.requires_grad(generator, False)
        critic.zero_grad()
        # Use fake_x, as fixed data here
        fake_x = fake_x.detach()

        cls_fake = critic(fake_x, x, phase, alpha)
        cls_real = critic(x, x, phase, alpha)

        cf, cr = cls_fake.mean(), cls_real.mean()
        C_loss = cf - cr + torch.abs(cf + cr)

        grad_norm = autoenc_grad_norm(encoder, generator, x, phase,
                                      alpha).mean()
        grad_loss = critic_grad_penalty(critic, x, fake_x, batch_size, phase,
                                        alpha, grad_norm)
        stats['grad_loss'] = grad_loss
        losses.append(grad_loss)

        # C_loss      = -torch.log(1.0 - cls_fake).mean() - torch.log(cls_real).mean()

        stats['cls_fake'] = cls_fake.mean()
        stats['cls_real'] = cls_real.mean()
        stats['C_loss'] = C_loss.data

        # Propagate critic losses
        losses.append(C_loss)
        loss = sum(losses)
        loss.backward()

        # Apply critic gradient
        self.optimizerC.step()
        return stats
    gen_log = []
    criterion = torch.nn.BCELoss()

    for step_i in range(1, steps + 1):
        # Create real and fake labels (0/1)
        real_label = torch.ones(batch_size).to(device)
        fake_label = torch.zeros(batch_size).to(device)
        soft_label = torch.Tensor(batch_size).uniform_(smooth, 1).to(device)

        ########## Training the Discriminator  ################################

        real_img, hair_class, eye_class = shuffler.get_batch()
        real_img = real_img.to(device)
        hair_class, eye_class = hair_class.to(device), eye_class.to(device)
        correct_class = torch.cat((hair_class, eye_class), 1)
        wrong_hair = utils.mismatch(hair_class).to(device)
        wrong_eye = utils.mismatch(eye_class).to(device)
        wrong_class = torch.cat((wrong_hair, wrong_eye), 1)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_img = G(z, correct_class).to(device)

        real_img_correct_class = D(real_img, correct_class)
        real_img_wrong_class = D(real_img, wrong_class)
        fake_img_correct_class = D(fake_img, correct_class)

        discrim_loss = (criterion(real_img_correct_class, real_label) +
                        (criterion(real_img_wrong_class, fake_label) +
                         criterion(fake_img_correct_class, fake_label)) * 0.5)
        D_optim.zero_grad()
        discrim_loss.backward()
import matplotlib.pyplot as plt
import numpy as np
from utils import mismatch, predict_progress
import mpl_toolkits.axisartist as axisartist

plt.style.use('seaborn-white')

# HEK_site1_SpCas9
best_x_3 = [-1.52845086,-0.54307945,3.02779871,1.03082501]
# HEK_site_2_xCas9
best_x_4 = [-4.30187549,-0.5218599,6.3321199,0.49449049]

data = mismatch([[],[1,12]])
energy_3 = predict_progress(data,parms=best_x_3)
energy_4 = predict_progress(data,parms=best_x_4)

# plt.plot(np.arange(23),energy_3[0],'o--',label='HEK_site1_SpCas9_target',c='tab:blue',ms=5)
# plt.plot(np.arange(23),energy_4[0],'o-',label='HEK_site1_xCas9_target',c='tab:red',ms=5)
fig = plt.figure(figsize=(8,4))
ax = axisartist.Subplot(fig, 111)
fig.add_axes(ax)
for i in range(23):
    plt.axvline(x=i,c="white",ls="-",lw=1)
ax.plot(np.arange(23),energy_3[1],'o-',label='HEK site1 SpCas9',c='darkcyan')
ax.plot(np.arange(23),energy_4[1],'o-.',label='HEK site1 xCas9',c='tab:purple')

# ax.margins(0) # remove default margins (matplotlib verision 2+)
ax.axvspan(0, 1, facecolor='red', alpha=0.15)
ax.axvspan(21, 22, facecolor='green', alpha=0.15)
ax.axvspan(1, 21, facecolor='gray', alpha=0.15)
ax.axis["bottom"].set_axisline_style("-|>", size = 1.5)