def __init__(self): super(PhotoWCT, self).__init__() self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4)
def __init__(self, args): super(PhotoWCT, self).__init__() self.args = args if "16x" in self.args.mode: if "JointED" not in self.args.mode: ### 16x model trained for pwct e1 = '../KD/Experiments/Small16xEncoder_pwct/e1/weights/12-20181020-1610_1SE_E25S0-2.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2/weights/12-20181020-1602_2SE_E25S0-2.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3/weights/12-20181019-0420_3SE_E25S0-2.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4/weights/12-20181019-0349_4SE_E25S0-2.pth' d1 = '../KD/Experiments/Small16xDecoder_pwct/e1/weights/12-20181021-0913_1SD_E25S0-3.pth' d2 = '../KD/Experiments/Small16xDecoder_pwct/e2/weights/12-20181021-1418_2SD_E25S0-3.pth' d3 = '../KD/Experiments/Small16xDecoder_pwct/e3/weights/12-20181020-1638_3SD_E25S0-3.pth' d4 = '../KD/Experiments/Small16xDecoder_pwct/e4/weights/12-20181020-1637_4SD_E25S0-3.pth' else: ### 16x model trained for pwct, JointED e1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-2.pth' d1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-3.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-2.pth' d2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-3.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-2.pth' d3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-3.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-2.pth' d4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-3.pth' if self.args.mode == "" or self.args.mode == "original": #### original model self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) elif "16x" in self.args.mode: self.e1 = SmallEncoder_16x_plus(1, e1) self.d1 = SmallDecoder_16x(1, d1) self.e2 = SmallEncoder_16x_plus(2, e2) self.d2 = SmallDecoder_16x(2, d2) self.e3 = SmallEncoder_16x_plus(3, e3) self.d3 = SmallDecoder_16x(3, d3) self.e4 = SmallEncoder_16x_plus(4, e4) self.d4 = SmallDecoder_16x(4, d4) else: print("wrong mode") exit(1)
p_wct.d1.load_state_dict(torch.load('pth_models/feature_invertor_conv1.pth')) p_wct.e2.load_state_dict(torch.load('pth_models/vgg_normalised_conv2.pth')) p_wct.d2.load_state_dict(torch.load('pth_models/feature_invertor_conv2.pth')) p_wct.e3.load_state_dict(torch.load('pth_models/vgg_normalised_conv3.pth')) p_wct.d3.load_state_dict(torch.load('pth_models/feature_invertor_conv3.pth')) p_wct.e4.load_state_dict(torch.load('pth_models/vgg_normalised_conv4.pth')) p_wct.d4.load_state_dict(torch.load('pth_models/feature_invertor_conv4.pth')) if __name__ == '__main__': if not os.path.exists('pth_models'): os.mkdir('pth_models') ## VGGEncoder1 vgg1 = load_lua('models/vgg_normalised_conv1_1_mask.t7') e1 = VGGEncoder(1) weight_assign(vgg1, e1, { 'conv0': 0, 'conv1_1': 2, }) torch.save(e1.state_dict(), 'pth_models/vgg_normalised_conv1.pth') ## VGGDecoder1 inv1 = load_lua('models/feature_invertor_conv1_1_mask.t7') d1 = VGGDecoder(1) weight_assign(inv1, d1, { 'conv1_1': 1, }) torch.save(d1.state_dict(), 'pth_models/feature_invertor_conv1.pth') ## VGGEncoder2
class PhotoWCT(nn.Module): def __init__(self): super(PhotoWCT, self).__init__() self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) def transform(self, cont_img, styl_img, cont_seg, styl_seg): self.__compute_label_info(cont_seg, styl_seg) sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img) cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4( cont_img) sF4 = sF4.data.squeeze(0) cF4 = cF4.data.squeeze(0) # print(cont_seg) csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg) Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3) cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4) sF3 = sF3.data.squeeze(0) cF3 = cF3.data.squeeze(0) csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg) Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2) cF2, cpool_idx, cpool = self.e2(Im3) sF2 = sF2.data.squeeze(0) cF2 = cF2.data.squeeze(0) csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg) Im2 = self.d2(csF2, cpool_idx, cpool) cF1 = self.e1(Im2) sF1 = sF1.data.squeeze(0) cF1 = cF1.data.squeeze(0) csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg) Im1 = self.d1(csF1) return Im1 def __compute_label_info(self, cont_seg, styl_seg): if cont_seg.size == False or styl_seg.size == False: return max_label = np.max(cont_seg) + 1 self.label_set = np.unique(cont_seg) self.label_indicator = np.zeros(max_label) for l in self.label_set: # if l==0: # continue is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100 o_cont_mask = np.where( cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l) o_styl_mask = np.where( styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l) self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size( 1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size( 1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() styl_feat_view = styl_feat.view(styl_c, -1).clone() if cont_seg.size == False or styl_seg.size == False: target_feature = self.__wct_core(cont_feat_view, styl_feat_view) else: target_feature = cont_feat.view(cont_c, -1).clone() if len(cont_seg.shape) == 2: t_cont_seg = np.asarray( Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST)) else: t_cont_seg = np.asarray( Image.fromarray(cont_seg, mode='RGB').resize( (cont_w, cont_h), Image.NEAREST)) if len(styl_seg.shape) == 2: t_styl_seg = np.asarray( Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST)) else: t_styl_seg = np.asarray( Image.fromarray(styl_seg, mode='RGB').resize( (styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l] == 0: continue cont_mask = np.where( t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) styl_mask = np.where( t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: continue cont_indi = torch.LongTensor(cont_mask[0]) styl_indi = torch.LongTensor(styl_mask[0]) if self.is_cuda: cont_indi = cont_indi.cuda(0) styl_indi = styl_indi.cuda(0) cFFG = torch.index_select(cont_feat_view, 1, cont_indi) sFFG = torch.index_select(styl_feat_view, 1, styl_indi) # print(len(cont_indi)) # print(len(styl_indi)) tmp_target_feature = self.__wct_core(cFFG, sFFG) # print(tmp_target_feature.size()) if torch.__version__ >= "0.4.0": # This seems to be a bug in PyTorch 0.4.0 to me. new_target_feature = torch.transpose(target_feature, 1, 0) new_target_feature.index_copy_(0, cont_indi, \ torch.transpose(tmp_target_feature,1,0)) target_feature = torch.transpose(new_target_feature, 1, 0) else: target_feature.index_copy_(1, cont_indi, tmp_target_feature) target_feature = target_feature.view_as(cont_feat) ccsF = target_feature.float().unsqueeze(0) return ccsF def __wct_core(self, cont_feat, styl_feat): cFSize = cont_feat.size() c_mean = torch.mean(cont_feat, 1) # c x (h x w) c_mean = c_mean.unsqueeze(1).expand_as(cont_feat) cont_feat = cont_feat - c_mean iden = torch.eye(cFSize[0]) # .double() if self.is_cuda: iden = iden.cuda() contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) # c_e2, c_v = torch.eig(contentConv, True) # c_e = c_e2[:,0] k_c = cFSize[0] for i in range(cFSize[0] - 1, -1, -1): if c_e[i] >= 0.00001: k_c = i + 1 break sFSize = styl_feat.size() s_mean = torch.mean(styl_feat, 1) styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat) styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1) s_u, s_e, s_v = torch.svd(styleConv, some=False) k_s = sFSize[0] for i in range(sFSize[0] - 1, -1, -1): if s_e[i] >= 0.00001: k_s = i + 1 break c_d = (c_e[0:k_c]).pow(-0.5) step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d)) step2 = torch.mm(step1, (c_v[:, 0:k_c].t())) whiten_cF = torch.mm(step2, cont_feat) s_d = (s_e[0:k_s]).pow(0.5) targetFeature = torch.mm( torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF) targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as( targetFeature) return targetFeature @property def is_cuda(self): return next(self.parameters()).is_cuda def forward(self, *input): pass
class PhotoWCT(nn.Module): def __init__(self): super(PhotoWCT, self).__init__() self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) def transform(self, cont_img, styl_img, cont_seg, styl_seg, label_weight_list): self.__compute_label_info(cont_seg, styl_seg) sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img) cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4( cont_img) sF4 = sF4.data.squeeze( 0) #(1,channel,weight,height) -> (channel,weight,height) cF4 = cF4.data.squeeze(0) csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg, label_weight_list) Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3) cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4) sF3 = sF3.data.squeeze(0) cF3 = cF3.data.squeeze(0) csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg, label_weight_list) Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2) cF2, cpool_idx, cpool = self.e2(Im3) sF2 = sF2.data.squeeze(0) cF2 = cF2.data.squeeze(0) csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg, label_weight_list) Im2 = self.d2(csF2, cpool_idx, cpool) cF1 = self.e1(Im2) sF1 = sF1.data.squeeze(0) cF1 = cF1.data.squeeze(0) csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg, label_weight_list) Im1 = self.d1(csF1) return Im1 def __compute_label_info(self, cont_seg, styl_seg): if cont_seg.size == False or styl_seg.size == False: return max_label = np.max(cont_seg) + 1 #label 개수 self.label_set = np.unique(cont_seg) #[0,1,2,..] self.label_indicator = np.zeros(max_label) #[0,0,0,0,..] for l in self.label_set: # if l==0: # continue is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100 o_cont_mask = np.where( cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l) o_styl_mask = np.where( styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l) #레이블의 위치 저장 self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) #조건 부합하는지에 따라 1.0, 0.0 저장 def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg, label_weight_list): cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size( 1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size( 1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() # 3차원을 2차원으로 조정 styl_feat_view = styl_feat.view( styl_c, -1).clone() # ex. (512,x,y) -> (512,x*y) if cont_seg.size == False or styl_seg.size == False: target_feature = self.__wct_core(cont_feat_view, styl_feat_view) else: # print(cont_feat) target_feature = cont_feat.view(cont_c, -1).clone() # 512 * z # print(target_feature) if len(cont_seg.shape) == 2: # 2차원인 경우 t_cont_seg = np.asarray( Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST) ) #content segment image를 feature의 width, height크기에 맞춰서 조정 후 numpy로 전환해서 저장 else: #그 이상인 경우 t_cont_seg = np.asarray( Image.fromarray(cont_seg, mode='RGB').resize( (cont_w, cont_h), Image.NEAREST)) if len(styl_seg.shape) == 2: #style segment도 마찬가지로 크기조정 t_styl_seg = np.asarray( Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST)) else: t_styl_seg = np.asarray( Image.fromarray(styl_seg, mode='RGB').resize( (styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l] == 0: continue #indicator[l]이 false면 즉 변환할 필요가 없는 레이블이면 continue cont_mask = np.where( t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) #l과 label값이 같은 곳의 위치 저장하는 배열 styl_mask = np.where( t_styl_seg.reshape( t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) # z = x*y 중 label이 같은 일부 if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: continue cont_indi = torch.LongTensor( cont_mask[0]) #마스크를 long으로 자료형 변환 #(1,z<=x*y) styl_indi = torch.LongTensor(styl_mask[0]) if self.is_cuda: cont_indi = cont_indi.cuda(0) styl_indi = styl_indi.cuda(0) cFFG = torch.index_select( cont_feat_view, 1, cont_indi) #마스크와 인덱스가 같은 픽셀들 선택 c * z sFFG = torch.index_select(styl_feat_view, 1, styl_indi) # print(len(cont_indi)) # print(len(styl_indi)) tmp_target_feature = self.__wct_core(cFFG, sFFG) #실질적인 전환 if torch.__version__ >= "0.4.0": # This seems to be a bug in PyTorch 0.4.0 to me. if label_weight_list[l] == 0: tmp_target_feature = cFFG else: tmp_target_feature = label_weight_list[ l] * tmp_target_feature #######################################label_weight에 맞춰서 변환정도 조절####################################### new_target_feature = torch.transpose(target_feature, 1, 0) new_target_feature.index_copy_(0, cont_indi, \ torch.transpose(tmp_target_feature,1,0)) target_feature = torch.transpose(new_target_feature, 1, 0) else: target_feature.index_copy_(1, cont_indi, tmp_target_feature) target_feature = target_feature.view_as( cont_feat) # c x (hxw) -> c x h x w ccsF = target_feature.float().unsqueeze( 0) # c x h x w -> 1 x c x h x w return ccsF def __wct_core(self, cont_feat, styl_feat): cFSize = cont_feat.size() # c x z (ex. 512 * 7214) c_mean = torch.mean(cont_feat, 1) # 행의 평균을 구한 1차원 배열 [c, ] c_mean = c_mean.unsqueeze(1).expand_as( cont_feat) # 1차원 배열을 다시 2차원 c x z로 확장 # [1,2,3]이었으면 [[1,1,1],[2,2,2],[3,3,3]]으로, 즉 계산을 위한 차원 확장 cont_feat = cont_feat - c_mean # feature에 행(c)의 평균값 제거 iden = torch.eye(cFSize[0]) # .double() # c크기대로 대각선만 1.0 인 행렬 [[1,0,0],[0,1,0],[0,0,1]], 즉 단위행렬 if self.is_cuda: iden = iden.cuda() contentConv = torch.mm( cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden # 고유벡터와 고유값을 구하기 위한 과정 # (CFFG * CFFG.T) / (h*w -1) + iden # 크기는 c x c # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) # c_e2, c_v = torch.eig(contentConv, True) # c_e = c_e2[:,0] k_c = cFSize[0] for i in range(cFSize[0] - 1, -1, -1): if c_e[i] >= 0.00001: k_c = i + 1 break #style image도 content와 같은 과정 계산 sFSize = styl_feat.size() # c x (hxw) s_mean = torch.mean(styl_feat, 1) styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat) styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1) s_u, s_e, s_v = torch.svd(styleConv, some=False) k_s = sFSize[0] for i in range(sFSize[0] - 1, -1, -1): if s_e[i] >= 0.00001: k_s = i + 1 break c_d = (c_e[0:k_c]).pow(-0.5) epsilon = 0.001 # 0으로 나뉘는걸 방지하는 hyper parameter ep_mat = torch.ones(k_c, dtype=torch.float32, device='cuda') * epsilon c_d = c_d + ep_mat step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d)) step2 = torch.mm(step1, (c_v[:, 0:k_c].t())) #P_C = c_v * 1/root(c_d) * c_v.T whiten_cF = torch.mm(step2, cont_feat) #P_c * H_c s_d = (s_e[0:k_s]).pow(0.5) epsilon2 = 0.001 # 0으로 나뉘는걸 방지하는 hyper parameter ep_mat2 = torch.ones(k_s, dtype=torch.float32, device='cuda') * epsilon2 s_d = s_d + ep_mat2 targetFeature = torch.mm( torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF) # P_S = s_c * root(s_v) * s_c.T targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as( targetFeature) #featue = P_s * P_c * H_c return targetFeature @property def is_cuda(self): return next(self.parameters()).is_cuda def forward(self, *input): pass
import torch import torch.nn as nn from torch.utils.serialization import load_lua from models import VGGEncoder, VGGDecoder def weight_assign(lua, pth, maps): for k, v in maps.items(): getattr(pth, k).weight = nn.Parameter(lua.get(v).weight.float()) getattr(pth, k).bias = nn.Parameter(lua.get(v).bias.float()) if __name__ == '__main__': ## VGGEncoder4 vgg4 = load_lua('pretrained/encoder.t7', long_size=8) e4 = VGGEncoder() weight_assign(vgg4, e4, { 'conv0': 0, 'conv1_1': 2, 'conv1_2': 5, 'conv2_1': 9, 'conv2_2': 12, 'conv3_1': 16, 'conv3_2': 19, 'conv3_3': 22, 'conv3_4': 25, 'conv4_1': 29, }) torch.save(e4.state_dict(), 'pretrained/encoder_pretrained.pth') ## VGGDecoder4
torch.load('pth_models/feature_invertor_conv2.pth')) p_wct.e3.load_state_dict(torch.load('pth_models/vgg_normalised_conv3.pth')) p_wct.d3.load_state_dict( torch.load('pth_models/feature_invertor_conv3.pth')) p_wct.e4.load_state_dict(torch.load('pth_models/vgg_normalised_conv4.pth')) p_wct.d4.load_state_dict( torch.load('pth_models/feature_invertor_conv4.pth')) if __name__ == '__main__': if not os.path.exists('pth_models'): os.mkdir('pth_models') ## VGGEncoder1 vgg1 = load_lua('models/vgg_normalised_conv1_1_mask.t7') e1 = VGGEncoder(1) weight_assign(vgg1, e1, { 'conv0': 0, 'conv1_1': 2, }) torch.save(e1.state_dict(), 'pth_models/vgg_normalised_conv1.pth') ## VGGDecoder1 inv1 = load_lua('models/feature_invertor_conv1_1_mask.t7') d1 = VGGDecoder(1) weight_assign(inv1, d1, { 'conv1_1': 1, }) torch.save(d1.state_dict(), 'pth_models/feature_invertor_conv1.pth') ## VGGEncoder2
class PhotoWCT(nn.Module): def __init__(self, args): super(PhotoWCT, self).__init__() self.args = args if "16x" in self.args.mode: if "JointED" not in self.args.mode: ### 16x model trained for pwct e1 = '../KD/Experiments/Small16xEncoder_pwct/e1/weights/12-20181020-1610_1SE_E25S0-2.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2/weights/12-20181020-1602_2SE_E25S0-2.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3/weights/12-20181019-0420_3SE_E25S0-2.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4/weights/12-20181019-0349_4SE_E25S0-2.pth' d1 = '../KD/Experiments/Small16xDecoder_pwct/e1/weights/12-20181021-0913_1SD_E25S0-3.pth' d2 = '../KD/Experiments/Small16xDecoder_pwct/e2/weights/12-20181021-1418_2SD_E25S0-3.pth' d3 = '../KD/Experiments/Small16xDecoder_pwct/e3/weights/12-20181020-1638_3SD_E25S0-3.pth' d4 = '../KD/Experiments/Small16xDecoder_pwct/e4/weights/12-20181020-1637_4SD_E25S0-3.pth' else: ### 16x model trained for pwct, JointED e1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-2.pth' d1 = '../KD/Experiments/Small16xEncoder_pwct/e1_JointED/weights/12-20181026-0259_1SED_E25S0-3.pth' e2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-2.pth' d2 = '../KD/Experiments/Small16xEncoder_pwct/e2_JointED/weights/12-20181026-0256_2SED_E25S0-3.pth' e3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-2.pth' d3 = '../KD/Experiments/Small16xEncoder_pwct/e3_JointED/weights/12-20181026-0255_3SED_E25S0-3.pth' e4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-2.pth' d4 = '../KD/Experiments/Small16xEncoder_pwct/e4_JointED/weights/12-20181026-0255_4SED_E25S0-3.pth' if self.args.mode == "" or self.args.mode == "original": #### original model self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) elif "16x" in self.args.mode: self.e1 = SmallEncoder_16x_plus(1, e1) self.d1 = SmallDecoder_16x(1, d1) self.e2 = SmallEncoder_16x_plus(2, e2) self.d2 = SmallDecoder_16x(2, d2) self.e3 = SmallEncoder_16x_plus(3, e3) self.d3 = SmallDecoder_16x(3, d3) self.e4 = SmallEncoder_16x_plus(4, e4) self.d4 = SmallDecoder_16x(4, d4) else: print("wrong mode") exit(1) @torch.no_grad() def transform(self, cont_img, styl_img, cont_seg, styl_seg): self.__compute_label_info(cont_seg, styl_seg) if self.args.mode == "" or self.args.mode == "original": #### original model sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img) cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4( cont_img) sF4 = sF4.data.squeeze(0) cF4 = cF4.data.squeeze(0) csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg) Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3) cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4) sF3 = sF3.data.squeeze(0) cF3 = cF3.data.squeeze(0) csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg) Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2) cF2, cpool_idx, cpool = self.e2(Im3) sF2 = sF2.data.squeeze(0) cF2 = cF2.data.squeeze(0) csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg) Im2 = self.d2(csF2, cpool_idx, cpool) cF1 = self.e1(Im2) sF1 = sF1.data.squeeze(0) cF1 = cF1.data.squeeze(0) csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg) Im1 = self.d1(csF1) elif "16x" in self.args.mode: #### slimmed model sF4 = self.e4(styl_img) torch.cuda.empty_cache() cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4.forward_stem( cont_img)[3:] torch.cuda.empty_cache() sF4 = sF4.data.squeeze(0) cF4 = cF4.data.squeeze(0) csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg) Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3) torch.cuda.empty_cache() sF3 = self.e3(styl_img) torch.cuda.empty_cache() cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3.forward_stem( cont_img)[2:] torch.cuda.empty_cache() sF3 = sF3.data.squeeze(0) cF3 = cF3.data.squeeze(0) csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg) Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2) torch.cuda.empty_cache() sF2 = self.e2(styl_img) torch.cuda.empty_cache() cF2, cpool_idx, cpool = self.e2.forward_stem(Im3)[1:] torch.cuda.empty_cache() sF2 = sF2.data.squeeze(0) cF2 = cF2.data.squeeze(0) csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg) Im2 = self.d2(csF2, cpool_idx, cpool) torch.cuda.empty_cache() sF1 = self.e1(styl_img) torch.cuda.empty_cache() cF1 = self.e1(Im2) torch.cuda.empty_cache() sF1 = sF1.data.squeeze(0) cF1 = cF1.data.squeeze(0) csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg) Im1 = self.d1(csF1) torch.cuda.empty_cache() return Im1 def __compute_label_info(self, cont_seg, styl_seg): if cont_seg.size == False or styl_seg.size == False: return max_label = np.max(cont_seg) + 1 self.label_set = np.unique(cont_seg) self.label_indicator = np.zeros(max_label) for l in self.label_set: # if l==0: # continue is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100 o_cont_mask = np.where( cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l) o_styl_mask = np.where( styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l) self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size( 1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size( 1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() styl_feat_view = styl_feat.view(styl_c, -1).clone() if cont_seg.size == False or styl_seg.size == False: target_feature = self.__wct_core(cont_feat_view, styl_feat_view) else: target_feature = cont_feat.view(cont_c, -1).clone() if len(cont_seg.shape) == 2: t_cont_seg = np.asarray( Image.fromarray(cont_seg).resize((cont_w, cont_h), Image.NEAREST)) else: t_cont_seg = np.asarray( Image.fromarray(cont_seg, mode='RGB').resize( (cont_w, cont_h), Image.NEAREST)) if len(styl_seg.shape) == 2: t_styl_seg = np.asarray( Image.fromarray(styl_seg).resize((styl_w, styl_h), Image.NEAREST)) else: t_styl_seg = np.asarray( Image.fromarray(styl_seg, mode='RGB').resize( (styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l] == 0: continue cont_mask = np.where( t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) styl_mask = np.where( t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: continue cont_indi = torch.LongTensor(cont_mask[0]) styl_indi = torch.LongTensor(styl_mask[0]) if self.is_cuda: cont_indi = cont_indi.cuda(0) styl_indi = styl_indi.cuda(0) cFFG = torch.index_select(cont_feat_view, 1, cont_indi) sFFG = torch.index_select(styl_feat_view, 1, styl_indi) # print(len(cont_indi)) # print(len(styl_indi)) tmp_target_feature = self.__wct_core(cFFG, sFFG) # print(tmp_target_feature.size()) if torch.__version__ >= "0.4.0": # This seems to be a bug in PyTorch 0.4.0 to me. new_target_feature = torch.transpose(target_feature, 1, 0) new_target_feature.index_copy_(0, cont_indi, \ torch.transpose(tmp_target_feature,1,0)) target_feature = torch.transpose(new_target_feature, 1, 0) else: target_feature.index_copy_(1, cont_indi, tmp_target_feature) target_feature = target_feature.view_as(cont_feat) ccsF = target_feature.float().unsqueeze(0) return ccsF def __wct_core(self, cont_feat, styl_feat): cFSize = cont_feat.size() c_mean = torch.mean(cont_feat, 1) # c x (h x w) c_mean = c_mean.unsqueeze(1).expand_as(cont_feat) cont_feat = cont_feat - c_mean iden = torch.eye(cFSize[0]) # .double() if self.is_cuda: iden = iden.cuda() contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) # c_e2, c_v = torch.eig(contentConv, True) # c_e = c_e2[:,0] k_c = cFSize[0] for i in range(cFSize[0] - 1, -1, -1): if c_e[i] >= 0.00001: k_c = i + 1 break sFSize = styl_feat.size() s_mean = torch.mean(styl_feat, 1) styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat) styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1) s_u, s_e, s_v = torch.svd(styleConv, some=False) k_s = sFSize[0] for i in range(sFSize[0] - 1, -1, -1): if s_e[i] >= 0.00001: k_s = i + 1 break c_d = (c_e[0:k_c]).pow(-0.5) step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d)) step2 = torch.mm(step1, (c_v[:, 0:k_c].t())) whiten_cF = torch.mm(step2, cont_feat) s_d = (s_e[0:k_s]).pow(0.5) targetFeature = torch.mm( torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF) targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as( targetFeature) return targetFeature @property def is_cuda(self): return next(self.parameters()).is_cuda def forward(self, *input): pass
class PhotoWCT(nn.Module): def __init__(self): super(PhotoWCT, self).__init__() self.e1 = VGGEncoder(1) self.d1 = VGGDecoder(1) self.e2 = VGGEncoder(2) self.d2 = VGGDecoder(2) self.e3 = VGGEncoder(3) self.d3 = VGGDecoder(3) self.e4 = VGGEncoder(4) self.d4 = VGGDecoder(4) def transform(self, cont_img, styl_img, cont_seg, styl_seg): self.__compute_label_info(cont_seg, styl_seg) sF4, sF3, sF2, sF1 = self.e4.forward_multiple(styl_img) cF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3 = self.e4(cont_img) sF4 = sF4.data.squeeze(0) cF4 = cF4.data.squeeze(0) csF4 = self.__feature_wct(cF4, sF4, cont_seg, styl_seg) Im4 = self.d4(csF4, cpool_idx, cpool1, cpool_idx2, cpool2, cpool_idx3, cpool3) cF3, cpool_idx, cpool1, cpool_idx2, cpool2 = self.e3(Im4) sF3 = sF3.data.squeeze(0) cF3 = cF3.data.squeeze(0) csF3 = self.__feature_wct(cF3, sF3, cont_seg, styl_seg) Im3 = self.d3(csF3, cpool_idx, cpool1, cpool_idx2, cpool2) cF2, cpool_idx, cpool = self.e2(Im3) sF2 = sF2.data.squeeze(0) cF2 = cF2.data.squeeze(0) csF2 = self.__feature_wct(cF2, sF2, cont_seg, styl_seg) Im2 = self.d2(csF2, cpool_idx, cpool) cF1 = self.e1(Im2) sF1 = sF1.data.squeeze(0) cF1 = cF1.data.squeeze(0) csF1 = self.__feature_wct(cF1, sF1, cont_seg, styl_seg) Im1 = self.d1(csF1) return Im1 def __compute_label_info(self, cont_seg, styl_seg): if cont_seg.size == False or styl_seg.size == False: return max_label = np.max(cont_seg) + 1 self.label_set = np.unique(cont_seg) self.label_indicator = np.zeros(max_label) for l in self.label_set: # if l==0: # continue is_valid = lambda a, b: a > 10 and b > 10 and a / b < 100 and b / a < 100 o_cont_mask = np.where(cont_seg.reshape(cont_seg.shape[0] * cont_seg.shape[1]) == l) o_styl_mask = np.where(styl_seg.reshape(styl_seg.shape[0] * styl_seg.shape[1]) == l) self.label_indicator[l] = is_valid(o_cont_mask[0].size, o_styl_mask[0].size) def __feature_wct(self, cont_feat, styl_feat, cont_seg, styl_seg): cont_c, cont_h, cont_w = cont_feat.size(0), cont_feat.size(1), cont_feat.size(2) styl_c, styl_h, styl_w = styl_feat.size(0), styl_feat.size(1), styl_feat.size(2) cont_feat_view = cont_feat.view(cont_c, -1).clone() styl_feat_view = styl_feat.view(styl_c, -1).clone() if cont_seg.size == False or styl_seg.size == False: target_feature = self.__wct_core(cont_feat_view, styl_feat_view) else: target_feature = cont_feat.view(cont_c, -1).clone() t_cont_seg = np.asarray(Image.fromarray(cont_seg, mode='RGB').resize((cont_w, cont_h), Image.NEAREST)) t_styl_seg = np.asarray(Image.fromarray(styl_seg, mode='RGB').resize((styl_w, styl_h), Image.NEAREST)) for l in self.label_set: if self.label_indicator[l] == 0: continue cont_mask = np.where(t_cont_seg.reshape(t_cont_seg.shape[0] * t_cont_seg.shape[1]) == l) styl_mask = np.where(t_styl_seg.reshape(t_styl_seg.shape[0] * t_styl_seg.shape[1]) == l) if cont_mask[0].size <= 0 or styl_mask[0].size <= 0: continue cont_indi = torch.LongTensor(cont_mask[0]) styl_indi = torch.LongTensor(styl_mask[0]) if self.is_cuda: cont_indi = cont_indi.cuda(0) styl_indi = styl_indi.cuda(0) cFFG = torch.index_select(cont_feat_view, 1, cont_indi) sFFG = torch.index_select(styl_feat_view, 1, styl_indi) tmp_target_feature = self.__wct_core(cFFG, sFFG) target_feature.index_copy_(1, cont_indi, tmp_target_feature) target_feature = target_feature.view_as(cont_feat) ccsF = target_feature.float().unsqueeze(0) return ccsF def __wct_core(self, cont_feat, styl_feat): cFSize = cont_feat.size() c_mean = torch.mean(cont_feat, 1) # c x (h x w) c_mean = c_mean.unsqueeze(1).expand_as(cont_feat) cont_feat = cont_feat - c_mean iden = torch.eye(cFSize[0]) # .double() if self.is_cuda: iden = iden.cuda() contentConv = torch.mm(cont_feat, cont_feat.t()).div(cFSize[1] - 1) + iden # del iden c_u, c_e, c_v = torch.svd(contentConv, some=False) # c_e2, c_v = torch.eig(contentConv, True) # c_e = c_e2[:,0] k_c = cFSize[0] for i in range(cFSize[0] - 1, -1, -1): if c_e[i] >= 0.00001: k_c = i + 1 break sFSize = styl_feat.size() s_mean = torch.mean(styl_feat, 1) styl_feat = styl_feat - s_mean.unsqueeze(1).expand_as(styl_feat) styleConv = torch.mm(styl_feat, styl_feat.t()).div(sFSize[1] - 1) s_u, s_e, s_v = torch.svd(styleConv, some=False) k_s = sFSize[0] for i in range(sFSize[0] - 1, -1, -1): if s_e[i] >= 0.00001: k_s = i + 1 break c_d = (c_e[0:k_c]).pow(-0.5) step1 = torch.mm(c_v[:, 0:k_c], torch.diag(c_d)) step2 = torch.mm(step1, (c_v[:, 0:k_c].t())) whiten_cF = torch.mm(step2, cont_feat) s_d = (s_e[0:k_s]).pow(0.5) targetFeature = torch.mm(torch.mm(torch.mm(s_v[:, 0:k_s], torch.diag(s_d)), (s_v[:, 0:k_s].t())), whiten_cF) targetFeature = targetFeature + s_mean.unsqueeze(1).expand_as(targetFeature) return targetFeature @property def is_cuda(self): return next(self.parameters()).is_cuda