def FirstStage_Training(syn_path=None, model_dir=None): learning_rate = 0.00125 weight_decay = 0.0001 if torch.cuda.is_available(): use_cuda = True nums_epoch = 20 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_path + 'train/', validation_split=2, training_syn=True) test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_path + 'test/', validation_split=0, training_syn=True) normal_loss = nn.L1Loss() albedo_loss = nn.L1Loss() lighting_loss = nn.MSELoss() recon_loss = nn.L1Loss() if use_cuda: normal_loss = normal_loss.cuda() albedo_loss = albedo_loss.cuda() lighting_loss = lighting_loss.cuda() recon_loss = recon_loss.cuda() lambda_normal = 0.5 lambda_albedo = 0.5 lambda_sh = 0.1 lambda_recon = 0.5 # wandb.init(tensorboard=True) for epoch in range(nums_epoch): syn_train_dl = DataLoader(train_dataset, batch_size=32, shuffle=True) syn_val_dl = DataLoader(val_dataset, batch_size=32, shuffle=True) syn_test_dl = DataLoader(test_dataset, batch_size=32, shuffle=True) print('Synthetic dataset: Train data: ', len(syn_train_dl), ' Val data: ', len(syn_val_dl), ' Test data: ', len(syn_test_dl)) len_syn_train = len(syn_train_dl) t_loss = 0 n_loss = 0 a_loss = 0 sh_loss = 0 r_loss = 0 sfsnet_model = SkipNet() if epoch > 0: sfsnet_model.load_state_dict(torch.load(model_dir + "Skip_First" + ".pkl")) sfsnet_model.to(device) parameters = sfsnet_model.parameters() optimizer = torch.optim.Adam(parameters, lr=learning_rate, weight_decay=weight_decay) for bix, data in enumerate(syn_train_dl): albedo, normal, mask, sh, face = data print(albedo.shape) print(normal.shape) print(mask.shape) print(face.shape) print(sh.shape) if use_cuda: albedo = albedo.cuda() normal = normal.cuda() mask = mask.cuda() sh = sh.cuda() face = face.cuda() print('True') predicted_normal, predicted_albedo, predicted_sh, produced_shading, produced_recon = sfsnet_model(face) current_normal_loss = normal_loss(predicted_normal, normal) current_albedo_loss = albedo_loss(predicted_albedo, albedo) current_sh_loss = lighting_loss(predicted_sh, sh) current_recon_loss = recon_loss(produced_recon, face) total_loss = lambda_normal * current_normal_loss + lambda_albedo * current_albedo_loss + \ lambda_sh * current_sh_loss + lambda_recon * current_recon_loss optimizer.zero_grad() total_loss.backward() optimizer.step() t_loss += total_loss.item() a_loss += current_albedo_loss.item() n_loss += current_normal_loss.item() sh_loss += current_sh_loss.item() r_loss += current_recon_loss.item() print('Epoch: {} - Total Loss : {}, Normal Loss: {}, Albedo Loss: {}, SH Loss:{}, Recon Loss:{}'.format( epoch, \ total_loss, current_albedo_loss, current_normal_loss, current_sh_loss, current_recon_loss)) print('This is {} / {} of training dataline'.format(bix, (len(syn_train_dl) - 1))) torch.save(sfsnet_model.state_dict(), model_dir + "Skip_First" + ".pkl")
def thirdStageTraining(syn_data, celeb_data, batch_size=16, num_epochs=20, log_path=None, use_cuda=True, lr=0.0025, weight_decay=0.005): train_dataset, val_dataset = get_sfsnet_dataset(syn_dir=syn_data + 'train/', read_from_csv=None, validation_split=10) test_dataset, _ = get_sfsnet_dataset(syn_dir=syn_data + 'test/', read_from_csv=None, validation_split=0) model_checkpoint_dir = log_path + 'checkpoints/' out_images_dir = log_path + 'out_images/' out_syn_images_dir = out_images_dir + 'syn/' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') prev_SkipNet_model = SkipNet() prev_SkipNet_model.load_state_dict( torch.load( '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl' )) prev_SkipNet_model.to(device) os.system("mkdir -p {}".format(model_checkpoint_dir)) os.system("mkdir -p {}".format(out_syn_images_dir + 'train/')) os.system("mkdir -p {}".format(out_syn_images_dir + 'val/')) os.system("mkdir -p {}".format(out_syn_images_dir + 'test/')) normal_loss = nn.L1Loss() albedo_loss = nn.L1Loss() sh_loss = nn.MSELoss() recon_loss = nn.L1Loss() c_recon_loss = nn.L1Loss() c_sh_loss = nn.MSELoss() c_albedo_loss = nn.L1Loss() c_normal_loss = nn.L1Loss() if use_cuda: normal_loss = normal_loss.cuda() albedo_loss = albedo_loss.cuda() sh_loss = sh_loss.cuda() recon_loss = recon_loss.cuda() c_recon_loss = c_recon_loss.cuda() c_sh_loss = c_sh_loss.cuda() c_albedo_loss = c_albedo_loss.cuda() c_normal_loss = c_normal_loss.cuda() lamda_recon = 0.5 lamda_normal = 0.5 lamda_albedo = 0.5 lamda_sh = 0.1 wandb.init(tensorboard=True) for epoch in range(1, num_epochs + 1): tloss = 0 nloss = 0 aloss = 0 shloss = 0 rloss = 0 c_tloss = 0 c_nloss = 0 c_aloss = 0 c_shloss = 0 c_reconloss = 0 predicted_normal = None predicted_albedo = None out_shading = None out_recon = None mask = None face = None normal = None albedo = None c_predicted_normal = None c_predicted_albedo = None c_out_shading = None c_out_recon = None c_face = None sfsnet_model = SfsNetPipeline() if epoch > 1: sfsnet_model.load_state_dict( torch.load( '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl' )) model_parameters = sfsnet_model.parameters() optimizer = torch.optim.Adam(model_parameters, lr=lr, weight_decay=weight_decay) sfsnet_model.to(device) syn_train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) syn_val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) syn_test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=True) syn_train_len = len(syn_train_dl) if epoch == 0: print("Synthetic dataset: Train data:", len(syn_train_dl), ' Val data: ', len(syn_val_dl), ' Test data: ', len(syn_test_dl)) #Initiate iterators syn_train_iter = iter(syn_train_dl) syn_count = 0 celeb_count = 0 #Until we process all synthetic and celebA data while True: #Get and train on synthetic data data = next(syn_train_iter, None) if data is not None: syn_count += 1 albedo, normal, mask, sh, face = data if use_cuda: albedo = albedo.cuda() normal = normal.cuda() mask = mask.cuda() sh = sh.cuda() face = face.cuda() face = apply_mask(face, mask) predicted_normal, predicted_albedo, predicted_sh, out_shading, out_recon = sfsnet_model( face) current_normal_loss = normal_loss(predicted_normal, normal) current_albedo_loss = albedo_loss(predicted_albedo, albedo) current_sh_loss = sh_loss(predicted_sh, sh) current_recon_loss = recon_loss(out_recon, de_norm(face)) total_loss = lamda_sh * current_sh_loss + lamda_normal * current_normal_loss + \ lamda_albedo * current_albedo_loss + lamda_recon * current_recon_loss optimizer.zero_grad() total_loss.backward() optimizer.step() tloss += total_loss.item() nloss += current_normal_loss.item() aloss += current_albedo_loss.item() shloss += current_sh_loss.item() rloss += current_recon_loss.item() print( "Epoch {}/20, synthetic data {}/{}. Synthetic total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}" .format(epoch, syn_count, syn_train_len, total_loss, current_normal_loss, current_albedo_loss, current_sh_loss, current_recon_loss)) elif data is None: break file_name = out_syn_images_dir + 'train/' + 'train_' + str(epoch) wandb_log_images(wandb, predicted_normal, mask, 'Train Predicted Normal', epoch, 'Train Predicted Normal', path=file_name + '_predicted_normal.png') wandb_log_images(wandb, predicted_albedo, mask, 'Train Predicted Albedo', epoch, 'Train Predicted Albedo', path=file_name + '_predicted_albedo.png') wandb_log_images(wandb, out_shading, mask, 'Train Predicted Shading', epoch, 'Train Predicted Shading', path=file_name + '_predicted_shading.png', denormalize=False) wandb_log_images(wandb, out_recon, mask, 'Train Recon', epoch, 'Train Recon', path=file_name + '_predicted_face.png', denormalize=False) wandb_log_images(wandb, face, mask, 'Train Ground Truth', epoch, 'Train Ground Truth', path=file_name + '_gt_face.png') wandb_log_images(wandb, normal, mask, 'Train Ground Truth Normal', epoch, 'Train Ground Truth Normal', path=file_name + '_gt_normal.png') wandb_log_images(wandb, albedo, mask, 'Train Ground Truth Albedo', epoch, 'Train Ground Truth Albedo', path=file_name + '_gt_albedo.png') torch.save( sfsnet_model.state_dict(), "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Syn.pkl" )
def thirdStageTraining(syn_data, celeb_data, batch_size=8, num_epochs=20, log_path=None, use_cuda=True, lr=0.0025, weight_decay=0.005): train_celeb_dataset, val_celeb_dataset = get_celeba_dataset( celeb_dir=celeb_data, validation_split=10) test_celeb_dataset, _ = get_celeba_dataset(celeb_dir=celeb_data, validation_split=0) model_checkpoint_dir = log_path + 'checkpoints/' out_images_dir = log_path + 'out_images/' out_celeb_images_dir = out_images_dir + 'celeb/' device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') prev_SkipNet_model = SkipNet() prev_SkipNet_model.load_state_dict( torch.load( '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/Skip_First.pkl' )) prev_SkipNet_model.to(device) os.system("mkdir -p {}".format(model_checkpoint_dir)) os.system("mkdir -p {}".format(out_celeb_images_dir + 'train/')) os.system("mkdir -p {}".format(out_celeb_images_dir + 'val/')) os.system("mkdir -p {}".format(out_celeb_images_dir + 'test/')) normal_loss = nn.L1Loss() albedo_loss = nn.L1Loss() sh_loss = nn.MSELoss() recon_loss = nn.L1Loss() c_recon_loss = nn.L1Loss() c_sh_loss = nn.MSELoss() c_albedo_loss = nn.L1Loss() c_normal_loss = nn.L1Loss() if use_cuda: normal_loss = normal_loss.cuda() albedo_loss = albedo_loss.cuda() sh_loss = sh_loss.cuda() recon_loss = recon_loss.cuda() c_recon_loss = c_recon_loss.cuda() c_sh_loss = c_sh_loss.cuda() c_albedo_loss = c_albedo_loss.cuda() c_normal_loss = c_normal_loss.cuda() lamda_recon = 0.5 lamda_normal = 0.5 lamda_albedo = 0.5 lamda_sh = 0.1 wandb.init(tensorboard=True) for epoch in range(1, num_epochs + 1): tloss = 0 nloss = 0 aloss = 0 shloss = 0 rloss = 0 c_tloss = 0 c_nloss = 0 c_aloss = 0 c_shloss = 0 c_reconloss = 0 predicted_normal = None predicted_albedo = None out_shading = None out_recon = None mask = None face = None normal = None albedo = None c_predicted_normal = None c_predicted_albedo = None c_out_shading = None c_out_recon = None c_face = None sfsnet_model = SfsNetPipeline() if epoch > 1: sfsnet_model.load_state_dict( torch.load( '/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl' )) model_parameters = sfsnet_model.parameters() optimizer = torch.optim.Adam(model_parameters, lr=lr, weight_decay=weight_decay) sfsnet_model.to(device) celeb_train_dl = DataLoader(train_celeb_dataset, batch_size=batch_size, shuffle=True) celeb_val_dl = DataLoader(val_celeb_dataset, batch_size=batch_size, shuffle=False) celeb_test_dl = DataLoader(test_celeb_dataset, batch_size=batch_size, shuffle=True) celeb_train_len = len(celeb_train_dl) if epoch == 0: print("Celeb dataset: Train data:", len(celeb_train_dl), ' Val data: ', len(celeb_val_dl), ' Test data: ', len(celeb_test_dl)) celeb_train_iter = iter(celeb_train_dl) celeb_count = 0 #Until we process all synthetic and celebA data while True: c_data = next(celeb_train_iter, None) if c_data is not None: celeb_count += 1 c_mask = None if use_cuda: c_data = c_data.cuda() c_face = c_data prevc_normal, prevc_albedo, prevc_sh, prevc_shading, prec_recon = prev_SkipNet_model( c_face) c_predicted_normal, c_predicted_albedo, c_predicted_sh, c_out_shading, c_out_recon = sfsnet_model( c_face) c_current_normal_loss = c_normal_loss(c_predicted_normal, prevc_normal) c_current_albedo_loss = c_albedo_loss(c_predicted_albedo, prevc_albedo) c_current_sh_loss = c_sh_loss(c_predicted_sh, prevc_sh) c_current_recon_loss = c_recon_loss(c_out_recon, de_norm(c_face)) c_total_loss = lamda_sh * c_current_sh_loss + lamda_normal * c_current_normal_loss + lamda_albedo * c_current_albedo_loss +\ lamda_recon * c_current_recon_loss optimizer.zero_grad() c_total_loss.backward() optimizer.step() c_tloss += c_total_loss.item() c_nloss += c_current_normal_loss.item() c_aloss += c_current_albedo_loss.item() c_shloss += c_current_sh_loss.item() c_reconloss += c_current_recon_loss.item() print( "Epoch {}/20, celeb data {}/{}. Celeb total loss: {}, normal_loss: {}, albedo_loss: {}, sh_loss: {}, recon_loss: {}" .format(epoch, celeb_count, celeb_train_len, c_total_loss, c_current_normal_loss, c_current_albedo_loss, c_current_sh_loss, c_current_recon_loss)) elif c_data is None: break # Log CelebA image file_name = out_celeb_images_dir + 'train/' + 'train_' + str(epoch) wandb_log_images(wandb, c_predicted_normal, None, 'Train CelebA Predicted Normal', epoch, 'Train CelebA Predicted Normal', path=file_name + '_c_predicted_normal.png') wandb_log_images(wandb, c_predicted_albedo, None, 'Train CelebA Predicted Albedo', epoch, 'Train CelebA Predicted Albedo', path=file_name + '_c_predicted_albedo.png') wandb_log_images(wandb, c_out_shading, None, 'Train CelebA Predicted Shading', epoch, 'Train CelebA Predicted Shading', path=file_name + '_c_predicted_shading.png', denormalize=False) wandb_log_images(wandb, c_out_recon, None, 'Train CelebA Recon', epoch, 'Train CelebA Recon', path=file_name + '_c_predicted_face.png', denormalize=False) wandb_log_images(wandb, c_face, None, 'Train CelebA Ground Truth', epoch, 'Train CelebA Ground Truth', path=file_name + '_c_gt_face.png') torch.save( sfsnet_model.state_dict(), "/home/hd8t/xiangyu.yin/results/metadata/checkpoints/SfsNet_Celeb.pkl" )