def forward(self, z, y): # If hierarchical, concatenate zs and ys if self.hier: zs = torch.split(z, self.z_chunk_size, 1) z = zs[0] ys = [torch.cat([y, item], 1) for item in zs[1:]] else: ys = [y] * len(self.blocks) # First linear layer h = torch.Tensor(self.linear(z)) # Reshape h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) # Loop over blocks for index, blocklist in enumerate(self.blocks): # Second inner loop in case block has multiple layers for block in blocklist: h = block(h, torch.Tensor(ys[index])) # Apply batchnorm-relu-conv-tanh at output return torch.tanh(self.output_layer(h))
def forward(self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, split_D=False): # If training G, enable grad tape if train_G: self.G.train() else: self.G.eval() # Get Generator output given noise G_z = self.G(z, self.G.shared(gy)) # Cast as necessary # Split_D means to run D once with real data and once with fake, # rather than concatenating along the batch dimension. if split_D: D_fake = self.D(G_z, gy) if x is not None: D_real = self.D(x, dy) return D_fake, D_real else: if return_G_z: return D_fake, G_z else: return D_fake # If real data is provided, concatenate it with the Generator's output # along the batch dimension for improved efficiency. else: if x is not None and x.shape[-1] != G_z.shape[-1]: x = F.interpolate(x, size=G_z.shape[-2:]) D_input = torch.cat([G_z, x], 0) if x is not None else G_z D_class = torch.cat([gy, dy], 0) if dy is not None else gy # Get Discriminator output D_out = self.D(D_input, D_class) if x is not None: return torch.split( D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real else: if return_G_z: return D_out, G_z else: return D_out
p_dict = {} for p in params: p_dict.update(p) M = p_dict['M'] N = p_dict['N'] parts = p_dict['parts'] inputs = {"input": torch.rand(M, N), "split_size": int(M * N / parts)} print(p_dict) import torch as th th_input = th.from_numpy(inputs['input'].numpy()) torch_out = [x.shape for x in th.split(th_input, inputs['split_size'])] print("torch:", [x.shape for x in th.split(th_input, inputs['split_size'])]) paddle_out = [ x.shape for x in torch.split(inputs['input'], inputs['split_size']) ] print("paddorch:", paddle_out) assert torch_out == paddle_out M = 100 N = 20 parts = 40 inputs = {"input": torch.rand(M, N), "split_size": int(M * N / parts)} import torch as th th_input = th.from_numpy(inputs['input'].numpy()) torch_out = [x.shape for x in th.split(th_input, inputs['split_size'])] print("torch:", [x.shape for x in th.split(th_input, inputs['split_size'])]) paddle_out = [ x.shape for x in torch.split(inputs['input'], inputs['split_size'])
def train(x, y): G.optim.zero_grad() D.optim.zero_grad() # How many chunks to split x and y into? x = torch.split(x, config['batch_size']) y = torch.split(y, config['batch_size']) counter = 0 # Optionally toggle D and G's "require_grad" if config['toggle_grads']: utils.toggle_grad(D, True) utils.toggle_grad(G, False) for step_index in range(config['num_D_steps']): # If accumulating gradients, loop multiple times before an optimizer step D.optim.zero_grad() D_loss_total = 0 for accumulation_index in range(config['num_D_accumulations']): z_.sample_() y_.sample_() D_fake, D_real = GD(z_[:config['batch_size']], y_[:config['batch_size']], x[counter], y[counter], train_G=True, split_D=config['split_D']) # Compute components of D's loss, average them, and divide by # the number of gradient accumulations D_loss_real, D_loss_fake = losses.discriminator_loss( D_fake, D_real) D_loss = (D_loss_real + D_loss_fake) / float( config['num_D_accumulations']) D_loss_total += D_loss counter += 1 # Optionally apply ortho reg in D if config['D_ortho'] > 0.0: # Debug print to indicate we're using ortho reg in D. print('using modified ortho reg in D') utils.ortho(D, config['D_ortho']) D_loss_total.backward() D.optim.minimize(D_loss_total) # Optionally toggle "requires_grad" if config['toggle_grads']: utils.toggle_grad(D, False) utils.toggle_grad(G, True) # Zero G's gradients by default before training G, for safety G.optim.zero_grad() # If accumulating gradients, loop multiple times G_loss_total = 0 for accumulation_index in range(config['num_G_accumulations']): z_.sample_() y_.sample_() D_fake = GD(z_, y_, train_G=True, split_D=config['split_D']) G_loss = losses.generator_loss(D_fake) / float( config['num_G_accumulations']) G_loss_total += G_loss # Optionally apply modified ortho reg in G if config['G_ortho'] > 0.0: print('using modified ortho reg in G' ) # Debug print to indicate we're using ortho reg in G # Don't ortho reg shared, it makes no sense. Really we should blacklist any embeddings for this utils.ortho(G, config['G_ortho'], blacklist=[param for param in G.shared.parameters()]) G_loss_total.backward() G.optim.minimize(G_loss_total) # If we have an ema, update it, regardless of if we test with it or not if config['ema']: ema.update(state_dict['itr']) out = { 'G_loss': float(torch.Tensor(G_loss).item()), 'D_loss_real': float(torch.Tensor(D_loss_real).item()), 'D_loss_fake': float(torch.Tensor(D_loss_fake).item()) } # Return G's loss and the components of D's loss. return out
def split(x, num_or_sections, dim=0): return torch.split(x, num_or_sections, dim)