class Life: def __init__(self, size, n=1): self.grid = pt.zeros(n, 1, size, size).to(device) self.mask = pt.tensor([[[[1, 1, 1], [1, 0, 1], [1, 1, 1]]]], dtype=pt.float).to(device) self.viewer = ImageViewer() self.t = 0 def reset(self): self.grid.fill_(0) self.t = 0 def rand_init(self, seed=0, p=0.2): pt.manual_seed(seed) self.grid = Bernoulli(p).sample(self.grid.size()).to(device) def render(self, invert=False): b, c, h, w = self.grid.shape nrow = int(np.ceil(np.sqrt(b))) grid = make_grid(1 - self.grid, nrow=nrow) grid = grid if invert else 1 - grid grid = 255 * grid.to(pt.uint8).cpu() grid = grid.numpy().transpose((1, 2, 0)) self.viewer.imshow(grid, caption=f't={self.t}') def step(self): padded = F.pad(self.grid, (1, 1, 1, 1), mode='circular') neighbors = F.conv2d(padded, self.mask) mask0 = (neighbors < 2).to(pt.float) * self.grid mask1 = (neighbors > 3).to(pt.float) * self.grid mask2 = (neighbors == 3).to(pt.float) * (1 - self.grid) self.grid[mask0.to(pt.bool)] = 0 self.grid[mask1.to(pt.bool)] = 0 self.grid[mask2.to(pt.bool)] = 1 self.t += 1 def close(self): self.viewer.close()
def forward(self, x): # shape: (bsize, channels, height, width) if not self.training or self.drop_prob == 0.: assert x.dim() == 4, \ "Expected input with 4 dimensions (bsize, channels, height, width)" return x else: y = x[1] assert x[0].shape == y.shape, "Not equal" # get gamma value gamma = self._compute_gamma(x[0]) # sample mask mask = (torch.rand(x[0].shape[0], *x[0].shape[2:]) < gamma).float() # place mask on input device mask = mask.to(x[0].device) # compute block mask block_student_mask, block_teacher_mask = self._compute_block_mask(mask) # apply block mask student_out = x[0] * block_student_mask[:, None, :, :] teacher_out = y * block_teacher_mask[:, None, :, :] student_out_flatten = student_out.view(student_out.size(0), -1) teacher_out_flatten = teacher_out.view(teacher_out.size(0), -1) cos_sim = self.cos(student_out_flatten, teacher_out_flatten) cos_sim = self._compress(cos_sim) cos_sim_bernou = Bernoulli(cos_sim).sample() prob_sim = 1 - cos_sim_bernou.view(cos_sim_bernou.size(0), 1, 1, 1) out = student_out + teacher_out * prob_sim.expand_as(teacher_out) return out
import torch from torch.distributions import Bernoulli mask_sizes = [7, 7] bbb = Bernoulli(torch.tensor(0.9)).sample((3, *mask_sizes)) print(bbb) print(bbb.size())