MSE_Loss = nn.MSELoss().to(device) MSE_Loss_sum = nn.MSELoss(reduction='sum').to(device) L1_Loss = nn.L1Loss().to(device) alpha_weight = float(2. / args.step_iteration) for i in range(args.max_iter - iteration): if iteration >= args.step_iteration: if step < 3: alpha = 0 iteration = 0 step += 1 alpha = min(1, alpha_weight * iteration) try: dat = dataloader.__next__() x2_target_image, x4_target_image, x8_target_image, input_image = dat except (OSError, StopIteration): dat = dataloader.__next__() x2_target_image, x4_target_image, x8_target_image, input_image = dat iteration += 1 input_image = input_image.to(device) if step == 1: target_image = x2_target_image.to(device) elif step == 2: target_image = x4_target_image.to(device) elif step == 3: target_image = x8_target_image.to(device) train(generator, discriminator, face_align_net, g_optim, d_optim, input_image, target_image, step, iteration, alpha)
class RecDataLoader: def __init__(self, dataset, batch_size, shuffle, num_workers, **kwargs): self.dataset = dataset self.process = dataset.process self.len_thresh = self.dataset._find_max_length() // 2 self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers self.iteration = 0 self.dataiter = None self.queue_1 = list() self.queue_2 = list() def __len__(self): return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \ else len(self.dataset) // self.batch_size + 1 def __iter__(self): return self def pack(self, batch_data): batch = {'img': [], 'label': []} # img tensor current shape: B,H,W,C all_same_height_images = [ self.process.resize_with_specific_height(_['img'][0].numpy()) for _ in batch_data ] max_img_w = max({m_img.shape[1] for m_img in all_same_height_images}) # make sure max_img_w is integral multiple of 8 max_img_w = int(np.ceil(max_img_w / 8) * 8) for i in range(len(batch_data)): _label = batch_data[i]['label'][0] img = self.process.normalize_img( self.process.width_pad_img(all_same_height_images[i], max_img_w)) img = img.transpose([2, 0, 1]) batch['img'].append(torch.FloatTensor(img)) batch['label'].append(_label) batch['img'] = torch.stack(batch['img']) return batch def build(self): self.dataiter = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle, num_workers=self.num_workers).__iter__() def __next__(self): if self.dataiter == None: self.build() if self.iteration == len(self.dataset) and len(self.queue_2): batch_data = self.queue_2 self.queue_2 = list() return self.pack(batch_data) if not len(self.queue_2) and not len( self.queue_1) and self.iteration == len(self.dataset): self.iteration = 0 self.dataiter = None raise StopIteration # start iteration try: while True: # get data from origin dataloader temp = self.dataiter.__next__() self.iteration += 1 # to different queue if len(temp['label'][0]) <= self.len_thresh: self.queue_1.append(temp) else: self.queue_2.append(temp) # to store batch data batch_data = None # queue_1 full, push to batch_data if len(self.queue_1) == self.batch_size: batch_data = self.queue_1 self.queue_1 = list() # or queue_2 full, push to batch_data elif len(self.queue_2) == self.batch_size: batch_data = self.queue_2 self.queue_2 = list() # start to process batch if batch_data is not None: return self.pack(batch_data) # deal with last batch except StopIteration: batch_data = self.queue_1 self.queue_1 = list() return self.pack(batch_data)
class RecDataLoader: def __init__(self, dataset, config): self.dataset = dataset self.process = RecDataProcess(config) self.len_thresh = self.dataset._find_max_length() // 2 self.batch_size = config.batch_size self.shuffle = config.shuffle self.num_workers = config.num_workers self.iteration = 0 self.dataiter = None self.queue_1 = list() self.queue_2 = list() def __len__(self): return len(self.dataset) // self.batch_size if len(self.dataset) % self.batch_size == 0 \ else len(self.dataset) // self.batch_size + 1 def __iter__(self): return self def pack(self, batch_data): batch = [[], [], []] max_length = max({it[2].item() for it in batch_data}) # img tensor current shape: C, H, W max_img_w = max({it[0].shape[-1] for it in batch_data}) # make sure max_img_w is integral multiple of 8 max_img_w = max_img_w + (8 - max_img_w % 8) if max_img_w % 8 != 0 else max_img_w for i in range(len(batch_data)): _img, _label, _length = batch_data[i] # trans to np array, roll back axis _img = _img.numpy().transpose([1, 2, 0]) img = self.process.resize_normalize(_img, max_img_w) label = _label.tolist()[0] + [0] * (max_length - len(_label.tolist()[0])) batch[0].append(torch.FloatTensor(img)) batch[1].append(torch.IntTensor(label)) batch[2].append(torch.IntTensor([max_length])) return [torch.stack(batch[0]), torch.stack(batch[1]), torch.cat(batch[2])] def build(self): self.dataiter = DataLoader(self.dataset, batch_size=1, shuffle=self.shuffle, num_workers=self.num_workers).__iter__() def __next__(self): if self.dataiter == None: self.build() if self.iteration == len(self.dataset) and len(self.queue_2): batch_data = self.queue_2 self.queue_2 = list() return self.pack(batch_data) if not len(self.queue_2) and not len(self.queue_1) and self.iteration == len(self.dataset): self.iteration = 0 self.dataiter = None raise StopIteration # start iteration try: while True: # get data from origin dataloader temp = self.dataiter.__next__() self.iteration += 1 # to different queue if temp[2].item() <= self.len_thresh: self.queue_1.append(temp) else: self.queue_2.append(temp) # to store batch data batch_data = None # queue_1 full, push to batch_data if len(self.queue_1) == self.batch_size: batch_data = self.queue_1 self.queue_1 = list() # or queue_2 full, push to batch_data elif len(self.queue_2) == self.batch_size: batch_data = self.queue_2 self.queue_2 = list() # start to process batch if batch_data is not None: return self.pack(batch_data) # deal with last batch except StopIteration: batch_data = self.queue_1 self.queue_1 = list() return self.pack(batch_data)