def test_relu(self): # *************************************************************** # Test ReLU Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.ReLU(), tnn.ReLU()) # *************************************************************** # Test PReLU Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.PReLU(), tnn.PReLU()) check_equal(arr, jnn.PReLU(10, 99.9), tnn.PReLU(10, 99.9)) check_equal(arr, jnn.PReLU(10, 2), tnn.PReLU(10, 2)) check_equal(arr, jnn.PReLU(10, -0.2), tnn.PReLU(10, -0.2)) # *************************************************************** # Test ReLU6 Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.ReLU6(), tnn.ReLU6()) # *************************************************************** # Test LeakyReLU Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU()) check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) # *************************************************************** # Test ELU Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.ELU(), tnn.ELU()) check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3)) check_equal(arr, jnn.ELU(2), tnn.ELU(2)) check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9)) # *************************************************************** # Test GELU Layer # *************************************************************** if hasattr(tnn, "GELU"): arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.GELU(), tnn.GELU()) # *************************************************************** # Test Softplus Layer # *************************************************************** arr = np.random.randn(16, 10, 224, 224) check_equal(arr, jnn.Softplus(), tnn.Softplus()) check_equal(arr, jnn.Softplus(2), tnn.Softplus(2)) check_equal(arr, jnn.Softplus(2, 99.9), tnn.Softplus(2, 99.9))
def __init__(self, cin, cout, zdim=128, nf=64): super(ConfNet, self).__init__() network = [ nn.Conv(cin, nf, 4, stride=2, padding=1, bias=False), nn.GroupNorm(16, nf), nn.LeakyReLU(scale=0.2), nn.Conv(nf, (nf * 2), 4, stride=2, padding=1, bias=False), nn.GroupNorm((16 * 2), (nf * 2)), nn.LeakyReLU(scale=0.2), nn.Conv((nf * 2), (nf * 4), 4, stride=2, padding=1, bias=False), nn.GroupNorm((16 * 4), (nf * 4)), nn.LeakyReLU(scale=0.2), nn.Conv((nf * 4), (nf * 8), 4, stride=2, padding=1, bias=False), nn.LeakyReLU(scale=0.2), nn.Conv((nf * 8), zdim, 4, stride=1, padding=0, bias=False), nn.ReLU() ] network += [ nn.ConvTranspose(zdim, (nf * 8), 4, padding=0, bias=False), nn.ReLU(), nn.ConvTranspose((nf * 8), (nf * 4), 4, stride=2, padding=1, bias=False), nn.GroupNorm((16 * 4), (nf * 4)), nn.ReLU(), nn.ConvTranspose((nf * 4), (nf * 2), 4, stride=2, padding=1, bias=False), nn.GroupNorm((16 * 2), (nf * 2)), nn.ReLU() ] self.network = nn.Sequential(*network) out_net1 = [ nn.ConvTranspose((nf * 2), nf, 4, stride=2, padding=1, bias=False), nn.GroupNorm(16, nf), nn.ReLU(), nn.ConvTranspose(nf, nf, 4, stride=2, padding=1, bias=False), nn.GroupNorm(16, nf), nn.ReLU(), nn.Conv(nf, 2, 5, stride=1, padding=2, bias=False), nn.Softplus() ] self.out_net1 = nn.Sequential(*out_net1) out_net2 = [ nn.Conv((nf * 2), 2, 3, stride=1, padding=1, bias=False), nn.Softplus() ] self.out_net2 = nn.Sequential(*out_net2)
def dis_loss(self, real_samps, fake_samps, height, alpha, r1_gamma=10.0): # Obtain predictions r_preds = self.dis(real_samps, height, alpha) f_preds = self.dis(fake_samps, height, alpha) # loss = torch.mean(nn.Softplus()(f_preds)) + torch.mean(nn.Softplus()(-r_preds)) loss = jt.mean(nn.Softplus()(f_preds)) + jt.mean( nn.Softplus()(-r_preds)) if r1_gamma != 0.0: r1_penalty = self.R1Penalty(real_samps.detach(), height, alpha) * (r1_gamma * 0.5) loss += r1_penalty return loss
def gen_loss(self, _, fake_samps, height, alpha): f_preds = self.dis(fake_samps, height, alpha) # print(f_preds.is_stop_grad()) # return torch.mean(nn.Softplus()(-f_preds)) return jt.mean(nn.Softplus()(-f_preds))