def softshrink(x, alpha=0.5, inplace=False): r"""Real soft shrink function Parameters ---------- x : tensor The input. alpha : float, optional The threshhold. Returns ------- tensor The output. """ return th.sgn(x) * relu(x.abs() - alpha, inplace=inplace)
def test_sgn(self): # floats a = ht.array([-1, -0.5, 0, 0.5, 1]) signed = ht.sgn(a) comparison = ht.array([-1.0, -1, 0, 1, 1]) self.assertEqual(signed.dtype, comparison.dtype) self.assertEqual(signed.shape, comparison.shape) self.assertEqual(signed.device, a.device) self.assertTrue(ht.equal(signed, comparison)) # complex a = ht.array([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]], split=0) signed = ht.sgn(a) comparison = torch.sgn( torch.tensor([[1 - 2j, -0.5 + 1j], [0 - 3j, 4 + 6j]])) comparison = comparison.to(a.device.torch_device) self.assertEqual(signed.dtype, ht.heat_type_of(comparison)) self.assertEqual(signed.shape, a.shape) self.assertEqual(signed.device, a.device) self.assertTrue(ht.equal(signed, ht.array(comparison, split=0)))
def pointwise_ops(self): a = torch.randn(4) b = torch.randn(4) t = torch.tensor([-1, -2, 3], dtype=torch.int8) r = torch.tensor([0, 1, 10, 0], dtype=torch.int8) t = torch.tensor([-1, -2, 3], dtype=torch.int8) s = torch.tensor([4, 0, 1, 0], dtype=torch.int8) f = torch.zeros(3) g = torch.tensor([-1, 0, 1]) w = torch.tensor([0.3810, 1.2774, -0.2972, -0.3719, 0.4637]) return ( torch.abs(torch.tensor([-1, -2, 3])), torch.absolute(torch.tensor([-1, -2, 3])), torch.acos(a), torch.arccos(a), torch.acosh(a.uniform_(1.0, 2.0)), torch.add(a, 20), torch.add(a, torch.randn(4, 1), alpha=10), torch.addcdiv(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.addcmul(torch.randn(1, 3), torch.randn(3, 1), torch.randn(1, 3), value=0.1), torch.angle(a), torch.asin(a), torch.arcsin(a), torch.asinh(a), torch.arcsinh(a), torch.atan(a), torch.arctan(a), torch.atanh(a.uniform_(-1.0, 1.0)), torch.arctanh(a.uniform_(-1.0, 1.0)), torch.atan2(a, a), torch.bitwise_not(t), torch.bitwise_and(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_or(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.bitwise_xor(t, torch.tensor([1, 0, 3], dtype=torch.int8)), torch.ceil(a), torch.clamp(a, min=-0.5, max=0.5), torch.clamp(a, min=0.5), torch.clamp(a, max=0.5), torch.clip(a, min=-0.5, max=0.5), torch.conj(a), torch.copysign(a, 1), torch.copysign(a, b), torch.cos(a), torch.cosh(a), torch.deg2rad( torch.tensor([[180.0, -180.0], [360.0, -360.0], [90.0, -90.0]])), torch.div(a, b), torch.divide(a, b, rounding_mode="trunc"), torch.divide(a, b, rounding_mode="floor"), torch.digamma(torch.tensor([1.0, 0.5])), torch.erf(torch.tensor([0.0, -1.0, 10.0])), torch.erfc(torch.tensor([0.0, -1.0, 10.0])), torch.erfinv(torch.tensor([0.0, 0.5, -1.0])), torch.exp(torch.tensor([0.0, math.log(2.0)])), torch.exp2(torch.tensor([0.0, math.log(2.0), 3.0, 4.0])), torch.expm1(torch.tensor([0.0, math.log(2.0)])), torch.fake_quantize_per_channel_affine( torch.randn(2, 2, 2), (torch.randn(2) + 1) * 0.05, torch.zeros(2), 1, 0, 255, ), torch.fake_quantize_per_tensor_affine(a, 0.1, 0, 0, 255), torch.float_power(torch.randint(10, (4, )), 2), torch.float_power(torch.arange(1, 5), torch.tensor([2, -3, 4, -5])), torch.floor(a), # torch.floor_divide(torch.tensor([4.0, 3.0]), torch.tensor([2.0, 2.0])), # torch.floor_divide(torch.tensor([4.0, 3.0]), 1.4), torch.fmod(torch.tensor([-3, -2, -1, 1, 2, 3]), 2), torch.fmod(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.frac(torch.tensor([1.0, 2.5, -3.2])), torch.randn(4, dtype=torch.cfloat).imag, torch.ldexp(torch.tensor([1.0]), torch.tensor([1])), torch.ldexp(torch.tensor([1.0]), torch.tensor([1, 2, 3, 4])), torch.lerp(torch.arange(1.0, 5.0), torch.empty(4).fill_(10), 0.5), torch.lerp( torch.arange(1.0, 5.0), torch.empty(4).fill_(10), torch.full_like(torch.arange(1.0, 5.0), 0.5), ), torch.lgamma(torch.arange(0.5, 2, 0.5)), torch.log(torch.arange(5) + 10), torch.log10(torch.rand(5)), torch.log1p(torch.randn(5)), torch.log2(torch.rand(5)), torch.logaddexp(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-1.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([-100.0, -200.0, -300.0]), torch.tensor([-1, -2, -3])), torch.logaddexp2(torch.tensor([1.0, 2000.0, 30000.0]), torch.tensor([-1, -2, -3])), torch.logical_and(r, s), torch.logical_and(r.double(), s.double()), torch.logical_and(r.double(), s), torch.logical_and(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_not(torch.tensor([0, 1, -10], dtype=torch.int8)), torch.logical_not( torch.tensor([0.0, 1.5, -10.0], dtype=torch.double)), torch.logical_not( torch.tensor([0.0, 1.0, -10.0], dtype=torch.double), out=torch.empty(3, dtype=torch.int16), ), torch.logical_or(r, s), torch.logical_or(r.double(), s.double()), torch.logical_or(r.double(), s), torch.logical_or(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logical_xor(r, s), torch.logical_xor(r.double(), s.double()), torch.logical_xor(r.double(), s), torch.logical_xor(r, s, out=torch.empty(4, dtype=torch.bool)), torch.logit(torch.rand(5), eps=1e-6), torch.hypot(torch.tensor([4.0]), torch.tensor([3.0, 4.0, 5.0])), torch.i0(torch.arange(5, dtype=torch.float32)), torch.igamma(a, b), torch.igammac(a, b), torch.mul(torch.randn(3), 100), torch.multiply(torch.randn(4, 1), torch.randn(1, 4)), torch.mvlgamma(torch.empty(2, 3).uniform_(1.0, 2.0), 2), torch.tensor([float("nan"), float("inf"), -float("inf"), 3.14]), torch.nan_to_num(w), torch.nan_to_num(w, nan=2.0), torch.nan_to_num(w, nan=2.0, posinf=1.0), torch.neg(torch.randn(5)), # torch.nextafter(torch.tensor([1, 2]), torch.tensor([2, 1])) == torch.tensor([eps + 1, 2 - eps]), torch.polygamma(1, torch.tensor([1.0, 0.5])), torch.polygamma(2, torch.tensor([1.0, 0.5])), torch.polygamma(3, torch.tensor([1.0, 0.5])), torch.polygamma(4, torch.tensor([1.0, 0.5])), torch.pow(a, 2), torch.pow(torch.arange(1.0, 5.0), torch.arange(1.0, 5.0)), torch.rad2deg( torch.tensor([[3.142, -3.142], [6.283, -6.283], [1.570, -1.570]])), torch.randn(4, dtype=torch.cfloat).real, torch.reciprocal(a), torch.remainder(torch.tensor([-3.0, -2.0]), 2), torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5), torch.round(a), torch.rsqrt(a), torch.sigmoid(a), torch.sign(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sgn(a), torch.signbit(torch.tensor([0.7, -1.2, 0.0, 2.3])), torch.sin(a), torch.sinc(a), torch.sinh(a), torch.sqrt(a), torch.square(a), torch.sub(torch.tensor((1, 2)), torch.tensor((0, 1)), alpha=2), torch.tan(a), torch.tanh(a), torch.trunc(a), torch.xlogy(f, g), torch.xlogy(f, g), torch.xlogy(f, 4), torch.xlogy(2, g), )
def train(config, current_epoch: int, loader, train: bool = True): r""" Train 1 epoch """ if train: __C.MODEL.train() else: __C.MODEL.eval() __C.OPTIMIZER.zero_grad() # forward pass metrics ( total_loss, total_acc, total_oacc, total_real_confidence, total_noise_confidence, total_real_entropy, total_noise_entropy, total_lid, ) = [Counter() for _ in range(8)] # adversarial pass metrics ( adv_total_loss, adv_total_acc, adv_total_oacc, adv_total_real_confidence, adv_total_noise_confidence, adv_total_real_entropy, adv_total_noise_entropy, ) = [Counter() for _ in range(7)] with torch.set_grad_enabled(train): # for i, batch in tqdm(enumerate(loader), total=len(loader)): for i, batch in enumerate(loader): # torch.cuda.empty_cache() batch = process_batch(batch, __C.DEVICE, __C.PARALLEL, __C.DATASET_TYPE) __C.OPTIMIZER.zero_grad() data_size = batch.x[batch.mask].shape[0] # forward pass # with torch.autograd.set_detect_anomaly(True):pip ( loss, x, real_conf, noise_conf, correct, original_correct, real_ent, noise_ent, # lid ) = __C.MODEL(batch, config) lid = __C.MODEL.lid loss = loss.mean() real_conf = real_conf.mean() noise_conf = noise_conf.mean() real_ent = real_ent.mean() noise_ent = noise_ent.mean() acc = correct.sum() / data_size oacc = original_correct.sum() / data_size if train: # backward pass loss.backward() nn.utils.clip_grad_value_(__C.MODEL.parameters(), 1e2) __C.OPTIMIZER.step() __C.SCHEDULER.step() current_lr = __C.OPTIMIZER.param_groups[0]["lr"] __C.OPTIMIZER.zero_grad() if current_epoch % __C.REPORT_ITERATIONS == 0: if train: ic(current_lr) print( colorama.Fore.MAGENTA + "[%d%s/%d]LOSS: %.2e, NoisyACC: %.2f%%, CleanACC: %.2f%%, LID: %.2f, NoisyCONF: %.2f, CleanCONF: %.2f, NoisyENT: %.2f, CleanENT:%.2f" % ( epoch, "t" if train else "e", i, loss.detach().item(), 100.0 * acc.detach().item(), # comparint to noisy labels 100.0 * oacc.detach().item(), # comparing to real labels lid.detach().item(), noise_conf.detach().item(), real_conf.detach().item(), noise_ent.detach().item(), real_ent.detach().item(), )) # adversarial pass, grad required with torch.set_grad_enabled(True): if __C.ADVERSARIAL_NOISE_RATE > 1e-5: if __C.ADVERSARIAL_METHOD == "FGSM": adv_batch = copy_batch(batch) adv_batch.x.requires_grad = True loss, *_ = __C.MODEL(adv_batch, config) loss = loss.mean() loss.backward() adv = adv_batch.x norm = adv.abs().max(dim=-1)[0].mean() # mean max # ic(norm, adv.abs().max()) vadv = (torch.sgn(adv.grad) * norm * __C.ADVERSARIAL_NOISE_RATE) # adv batch adv = vadv + adv # ic(adv_batch.y, adv_batch.y.shape) adv_batch.x = adv with torch.no_grad(): ( adv_loss, _, adv_real_conf, adv_noise_conf, adv_correct, adv_original_correct, adv_real_ent, adv_noise_ent, ) = __C.MODEL(adv_batch, config) adv_loss = adv_loss.mean() adv_real_conf = adv_real_conf.mean() adv_noise_conf = adv_noise_conf.mean() adv_real_ent = adv_real_ent.mean() adv_noise_ent = adv_noise_ent.mean() adv_acc = adv_correct.sum() / data_size adv_oacc = adv_original_correct.sum() / data_size else: raise NotImplementedError __C.OPTIMIZER.zero_grad( ) # clean grads after adversarial pass if current_epoch % __C.REPORT_ITERATIONS == 0: print( colorama.Fore.MAGENTA + "[%d%s/%d]LOSS: %.2e, NoisyACC: %.2f%%, CleanACC: %.2f%%, NoisyCONF: %.2f, CleanCONF: %.2f, NoisyENT: %.2f, CleanENT:%.2f" % ( epoch, "t-adv" if train else "e-adv", i, adv_loss.detach().item(), 100.0 * adv_acc.detach().item(), 100.0 * adv_oacc.detach().item(), adv_noise_conf.detach().item(), adv_real_conf.detach().item(), adv_noise_ent.detach().item(), adv_real_ent.detach().item(), )) else: # fill stats with 0. ( adv_loss, adv_real_conf, adv_noise_conf, adv_acc, adv_oacc, adv_real_ent, adv_noise_ent, ) = torch.zeros([7]) # record these adversarial metrics with torch.no_grad(): adv_total_loss.add(adv_loss) adv_total_acc.add(adv_acc) adv_total_oacc.add(adv_oacc) adv_total_real_confidence.add(adv_real_conf) adv_total_noise_confidence.add(adv_noise_conf) adv_total_real_entropy.add(adv_real_ent) adv_total_noise_entropy.add(adv_noise_ent) with torch.no_grad(): total_loss.add(loss) total_acc.add(acc) total_oacc.add(oacc) total_real_confidence.add(real_conf) total_noise_confidence.add(noise_conf) total_real_entropy.add(real_ent) total_noise_entropy.add(noise_ent) total_lid.add(lid) return ( total_loss.mean, total_acc.mean, total_oacc.mean, total_real_confidence.mean, total_noise_confidence.mean, total_real_entropy.mean, total_noise_entropy.mean, total_lid.mean, adv_total_loss.mean, adv_total_acc.mean, adv_total_oacc.mean, adv_total_real_confidence.mean, adv_total_noise_confidence.mean, adv_total_real_entropy.mean, adv_total_noise_entropy.mean, None if not train else current_lr, )