class Demo: def __init__(self, generator_state_pth): self.model = IPCGANs() state_dict = torch.load(generator_state_pth) self.model.load_generator_state_dict(state_dict) def demo(self, image, target=0): img_size = 400 assert target < 5 and target >= 0, "label shoule be less than 5" transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor(), Img_to_zero_center() ]) label_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) image = transforms(image).unsqueeze(0) full_one = np.ones((img_size, img_size), dtype=np.float32) full_zero = np.zeros((img_size, img_size, 5), dtype=np.float32) full_zero[:, :, target] = full_one label = label_transforms(full_zero).unsqueeze(0) img = image.cuda() lbl = label.cuda() self.model.cuda() res = self.model.test_generate(img, lbl) res = Reverse_zero_center()(res) res_img = res.squeeze(0).cpu().numpy().transpose(1, 2, 0) return Image.fromarray((res_img * 255).astype(np.uint8))
def __init__(self, generator_state_pth): self.model = IPCGANs() state_dict = torch.load(generator_state_pth) self.model.load_generator_state_dict(state_dict)
class Demo: def __init__(self, generator_state_pth): self.model = IPCGANs() state_dict = torch.load(generator_state_pth) self.model.load_generator_state_dict(state_dict) def mtcnn_align(self, image): dst = [] src = np.array( [[30.2946, 51.6963], [65.5318, 51.5014], [48.0252, 71.7366], [33.5493, 92.3655], [62.7299, 92.2041]], dtype=np.float32) threshold = [0.6, 0.7, 0.9] factor = 0.85 minSize = 20 imgSize = [120, 100] detector = MTCNN(steps_threshold=threshold, scale_factor=factor, min_face_size=minSize) keypoint_list = [ 'left_eye', 'right_eye', 'nose', 'mouth_left', 'mouth_right' ] npimage = np.array(image) dictface_list = detector.detect_faces( npimage ) # if more than one face is detected, [0] means choose the first face if len(dictface_list) > 1: boxs = [] for dictface in dictface_list: boxs.append(dictface['box']) center = np.array(npimage.shape[:2]) / 2 boxs = np.array(boxs) face_center_y = boxs[:, 0] + boxs[:, 2] / 2 face_center_x = boxs[:, 1] + boxs[:, 3] / 2 face_center = np.column_stack( (np.array(face_center_x), np.array(face_center_y))) distance = np.sqrt(np.sum(np.square(face_center - center), axis=1)) min_id = np.argmin(distance) dictface = dictface_list[min_id] else: if len(dictface_list) == 0: return image else: dictface = dictface_list[0] face_keypoint = dictface['keypoints'] for keypoint in keypoint_list: dst.append(face_keypoint[keypoint]) dst = np.array(dst).astype(np.float32) tform = trans.SimilarityTransform() tform.estimate(dst, src) M = tform.params[0:2, :] warped = cv2.warpAffine(npimage, M, (imgSize[1], imgSize[0]), borderValue=0.0) warped = cv2.resize(warped, (400, 400)) return Image.fromarray(warped.astype(np.uint8)) def demo(self, image, target=0): image = self.mtcnn_align(image) assert target < 5 and target >= 0, "label shoule be less than 5" transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize((128, 128)), torchvision.transforms.ToTensor(), Img_to_zero_center() ]) label_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) image = transforms(image).unsqueeze(0) full_one = np.ones((128, 128), dtype=np.float32) full_zero = np.zeros((128, 128, 5), dtype=np.float32) full_zero[:, :, target] = full_one label = label_transforms(full_zero).unsqueeze(0) img = image.cuda() lbl = label.cuda() self.model.cuda() res = self.model.test_generate(img, lbl) res = Reverse_zero_center()(res) res_img = res.squeeze(0).cpu().numpy().transpose(1, 2, 0) return Image.fromarray((res_img * 255).astype(np.uint8))
def main(): logger.info("Start to train:\n arguments: %s" % str(args)) #step3: define transform transforms = torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), Img_to_zero_center()]) label_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) #step4: define train/test dataloader train_dataset = CACD("train", transforms, label_transforms) test_dataset = CACD("test", transforms, label_transforms) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False) #step5: define model,optim model = IPCGANs(lr=args.learning_rate, age_classifier_path=args.age_classifier_path, gan_loss_weight=args.gan_loss_weight, feature_loss_weight=args.feature_loss_weight, age_loss_weight=args.age_loss_weight) #,feature_extractor_path=args.feature_extractor_path) model.load_generator_state_dict_custom(model_path="model_last.pth") d_optim = model.d_optim g_optim = model.g_optim samples_tqdm = tqdm(enumerate(train_loader, 1), position=0, leave=True) for epoch in range(args.max_epoches): for idx, (source_img_227,source_img_128,true_label_img,\ true_label_128,true_label_64,fake_label_64, true_label) in samples_tqdm: running_d_loss = None running_g_loss = None n_iter = epoch * len(train_loader) + idx #mv to gpu source_img_227 = source_img_227.cuda() source_img_128 = source_img_128.cuda() true_label_img = true_label_img.cuda() true_label_128 = true_label_128.cuda() true_label_64 = true_label_64.cuda() fake_label_64 = fake_label_64.cuda() true_label = true_label.cuda() #train discriminator for d_iter in range(args.d_iter): #d_lr_scheduler.step() d_optim.zero_grad() model.train(source_img_227=source_img_227, source_img_128=source_img_128, true_label_img=true_label_img, true_label_128=true_label_128, true_label_64=true_label_64, fake_label_64=fake_label_64, age_label=true_label) d_loss = model.d_loss running_d_loss = d_loss d_loss.backward() d_optim.step() #visualize params for name, param in model.discriminator.named_parameters(): writer.add_histogram("discriminator:%s" % name, param.clone().cpu().detach().numpy(), n_iter) #train generator for g_iter in range(args.g_iter): #g_lr_scheduler.step() g_optim.zero_grad() model.train(source_img_227=source_img_227, source_img_128=source_img_128, true_label_img=true_label_img, true_label_128=true_label_128, true_label_64=true_label_64, fake_label_64=fake_label_64, age_label=true_label) g_loss = model.g_loss running_g_loss = g_loss g_loss.backward() g_optim.step() for name, param in model.generator.named_parameters(): writer.add_histogram("generator:%s" % name, param.clone().cpu().detach().numpy(), n_iter) format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f') samples_tqdm.set_description( format_str % (idx, len(train_loader), running_g_loss, running_d_loss)) writer.add_scalars('data/loss', { 'G_loss': running_g_loss, 'D_loss': running_d_loss }, n_iter) # save the parameters at the end of each save interval if idx % args.save_interval == 0: model.save_model(dir=args.saved_model_folder, filename='epoch_%d_iter_%d.pth' % (epoch, idx)) model.save_model(dir="", filename="model_last.pth") logger.info('checkpoint has been created!') #val step if idx % args.val_interval == 0: save_dir = os.path.join(args.saved_validation_folder, "epoch_%d" % epoch, "idx_%d" % idx) check_dir(save_dir) for val_idx, (source_img_128, true_label_128) in enumerate(tqdm(test_loader)): save_image(Reverse_zero_center()(source_img_128), fp=os.path.join( save_dir, "batch_%d_source.jpg" % (val_idx))) pic_list = [] pic_list.append(source_img_128) for age in range(args.age_groups): img = model.test_generate(source_img_128, true_label_128[age]) save_image(Reverse_zero_center()(img), fp=os.path.join( save_dir, "batch_%d_age_group_%d.jpg" % (val_idx, age))) logger.info('validation image has been created!')
def main(args, logger, writer): logger.info("Start to train:\n arguments: %s" % str(args)) content = "[202.*.*.150.] [INFO] Start to train - IPCGANs " payload = {"text": content} requests.post(webhook_url, data=json.dumps(payload), headers={'Content-Type': 'application/json'}) #step3: define transform transforms = torchvision.transforms.Compose( [torchvision.transforms.ToTensor(), Img_to_zero_center()]) label_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), ]) #step4: define train/test dataloader train_dataset = CACD("train", transforms, label_transforms) test_dataset = CACD("test", transforms, label_transforms) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False) #step5: define model,optim model = IPCGANs(lr=args.learning_rate, age_classifier_path=args.age_classifier_path, gan_loss_weight=args.gan_loss_weight, feature_loss_weight=args.feature_loss_weight, age_loss_weight=args.age_loss_weight, generator_path=args.g_checkpoint, discriminator_path=args.d_checkpoint) #,feature_extractor_path=args.feature_extractor_path) d_optim = model.d_optim g_optim = model.g_optim for epoch in range(args.max_epoches): avr_d_loss = 0 avr_g_loss = 0 count = 0 for idx, (source_img_227,source_img_128,true_label_img,\ true_label_128,true_label_64,fake_label_64, true_label) in enumerate(train_loader,1): running_d_loss = None running_g_loss = None n_iter = epoch * len(train_loader) + idx #mv to gpu source_img_227 = source_img_227.cuda() source_img_128 = source_img_128.cuda() true_label_img = true_label_img.cuda() true_label_128 = true_label_128.cuda() true_label_64 = true_label_64.cuda() fake_label_64 = fake_label_64.cuda() true_label = true_label.cuda() #train discriminator for d_iter in range(args.d_iter): #d_lr_scheduler.step() d_optim.zero_grad() model.train(source_img_227=source_img_227, source_img_128=source_img_128, true_label_img=true_label_img, true_label_128=true_label_128, true_label_64=true_label_64, fake_label_64=fake_label_64, age_label=true_label) d_loss = model.d_loss running_d_loss = d_loss d_loss.backward() d_optim.step() #visualize params for name, param in model.discriminator.named_parameters(): writer.add_histogram("discriminator:%s" % name, param.clone().cpu().detach().numpy(), n_iter) #train generator for g_iter in range(args.g_iter): #g_lr_scheduler.step() g_optim.zero_grad() model.train(source_img_227=source_img_227, source_img_128=source_img_128, true_label_img=true_label_img, true_label_128=true_label_128, true_label_64=true_label_64, fake_label_64=fake_label_64, age_label=true_label) g_loss = model.g_loss running_g_loss = g_loss g_loss.backward() g_optim.step() for name, param in model.generator.named_parameters(): writer.add_histogram("generator:%s" % name, param.clone().cpu().detach().numpy(), n_iter) format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f') logger.info( format_str % (idx, len(train_loader), running_g_loss, running_d_loss)) writer.add_scalars('data/loss', { 'G_loss': running_g_loss, 'D_loss': running_d_loss }, n_iter) avr_g_loss += running_g_loss avr_d_loss += running_d_loss count += 1 # save the parameters at the end of each save interval if idx % args.save_interval == 0: model.save_model(dir=args.saved_model_folder, filename='epoch_%d_iter_%d.pth' % (epoch, idx)) logger.info('checkpoint has been created!') #val step if idx % args.val_interval == 0: save_dir = os.path.join(args.saved_validation_folder, "epoch_%d" % epoch, "idx_%d" % idx) check_dir(save_dir) for val_idx, (source_img_128, true_label_128) in enumerate(tqdm(test_loader)): save_image(Reverse_zero_center()(source_img_128), filename=os.path.join( save_dir, "batch_%d_source.jpg" % (val_idx))) pic_list = [] pic_list.append(source_img_128) for age in range(args.age_groups): img = model.test_generate(source_img_128, true_label_128[age]) save_image(Reverse_zero_center()(img), filename=os.path.join( save_dir, "batch_%d_age_group_%d.jpg" % (val_idx, age))) # if epoch % 3 == 0 and idx % 1000 == 0: # print(post_image( # filename=os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)), # token='xoxp-66111612183-66103666016-826666478608-af2a1c301014db145d3cf92d02b9bdcf', # channels='CPX0UMK42')) logger.info('validation image has been created!') avr_d_loss = avr_d_loss / count avr_g_loss = avr_g_loss / count content = "[202.*.*.150.] [INFO] Epoch End : " + str( epoch) + ", d_loss : " + str(avr_d_loss) + ", g_loss : " + str( avr_g_loss) payload = {"text": content} requests.post(webhook_url, data=json.dumps(payload), headers={'Content-Type': 'application/json'})