def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.hint_B = input['hint_B'].to(self.device) self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
def set_fusion_input(self, input, box_info): AtoB = self.opt.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) self.full_real_B = input['B' if AtoB else 'A'].to(self.device) self.full_hint_B = input['hint_B'].to(self.device) self.full_mask_B = input['mask_B'].to(self.device) self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt) self.box_info_list = box_info
def set_input(self, input): if (self.half): for key in input.keys(): input[key] = input[key].half() AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) # self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.hint_B = input['hint_B'].to(self.device) self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
def set_forward_without_box(self, input): AtoB = self.opt.which_direction == 'AtoB' self.full_real_A = input['A' if AtoB else 'B'].to(self.device) self.full_real_B = input['B' if AtoB else 'A'].to(self.device) # self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.full_hint_B = input['hint_B'].to(self.device) self.full_mask_B = input['mask_B'].to(self.device) self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent self.full_real_B_enc = util.encode_ab_ind( self.full_real_B[:, :, ::4, ::4], self.opt) (_, self.comp_B_reg) = self.netGComp(self.full_real_A, self.full_hint_B, self.full_mask_B) self.fake_B_reg = self.comp_B_reg
def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' # pdb.set_trace() # pp self.opt.which_direction, 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) # self.image_paths = input['A_paths' if AtoB else 'B_paths'] self.hint_B = input['hint_B'].to(self.device) self.mask_B = input['mask_B'].to(self.device) self.mask_B_nc = self.mask_B + self.opt.mask_cent self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
])) lens = len(train_datasets) print('train datasets is [{}] '.format(lens)) dataloader = torch.utils.data.DataLoader(train_datasets, batch_size=batch_size, shuffle=True) for epoch in range(epoches): for index, data in enumerate(dataloader): data = util.get_colorization_data(data, opt, p=opt.sample_p) if (data is None): continue input = torch.cat((data['A'], data['hint_B'], data['mask_B']), dim=1) input = input.to(device) outputclass, outputreg = net(input) realclass = util.encode_ab_ind(data['B'][:, :, ::4, ::4], opt).to(device) lossreg = L1oss(outputreg, data['B'].to(device)) print(outputclass.dtype, realclass.dtype) lossclass = CEloss( outputclass.type(torch.cuda.FloatTensor), realclass[:, 0, :, :].type(torch.cuda.LongTensor)) if record: if index % loss_fre == 0: #100 writer.add_scalars('train/loss:', { 'reg': lossreg.item() * 10, 'class': lossclass.item() }, epoch * lens + index * batch_size) if index % img_fre == 0: # 2000 image_fake = util.lab2rgb( torch.cat([