def distance(y0, y1): """ Distance function is Kullback-Leibler Divergence for categorical distribution """ return F.kl_multinomial(F.softmax(y0), F.softmax(y1))
def CrossDomainCorrespondence(_fake_var_s, _fake_var_t, _choice_num=4, _layer_fix_switch=False): # input # _fake_var_s : type=nn.Variable(), fake image variable by source model # _fake_var_t : type=nn.Variable(), fake image variable by target model # output # CDC_loss : type=nn.Variable(), shape=() # [get feature keys] # =list, the len=12, one of components shape is (batch_size, 64, 256, 256) feature_list_s = get_feature_list(_fake_var_s) feature_list_t = get_feature_list(_fake_var_t) if not _layer_fix_switch: feature_gate_var = one_hot_combination(len(feature_list_s), _choice_num) # .shape=(12,) else: feature_gate_var = nn.Variable.from_numpy_array( np.array([ 0, ] * (len(feature_list_s) - _choice_num) + [ 1, ] * _choice_num)) # [Cosine Similarity & Integrate KL divergence] KL_var_list = [] for i in range(len(feature_list_s)): # --- change shape --- i_vector_s = make_broadcast_matrix(feature_list_s[i]) j_vector_s = make_symmetric_matrix(feature_list_s[i]) i_vector_t = make_broadcast_matrix(feature_list_t[i]) j_vector_t = make_symmetric_matrix(feature_list_t[i]) # --- cosine similarity --- # .shape=(batch_size, batch_size - 1) CS_var_s = F.softmax(CosineSimilarity(i_vector_s, j_vector_s, _index=2), axis=1) CS_var_t = F.softmax(CosineSimilarity(i_vector_t, j_vector_t, _index=2), axis=1) KL_var = F.sum(F.kl_multinomial(CS_var_s, CS_var_t, base_axis=1)) KL_var_list.append(F.reshape(KL_var, [ 1, ])) # --- name each variables for debug --- feature_list_s[i].name = 'Feature/source/{}'.format(i) feature_list_t[i].name = 'Feature/target/{}'.format(i) i_vector_s.name = 'Feature/source/{}/i_matrix'.format(i) j_vector_s.name = 'Feature/source/{}/j_matrix'.format(i) i_vector_t.name = 'Feature/target/{}/i_matrix'.format(i) j_vector_t.name = 'Feature/target/{}/j_matrix'.format(i) CS_var_s.name = 'CosineSimilarity/source/{}'.format(i) CS_var_t.name = 'CosineSimilarity/target/{}'.format(i) KL_var.name = 'Kullback-Leibler_Divergence/{}'.format(i) KL_var_all = F.concatenate(*KL_var_list, axis=0) # [Calculate final loss] CDC_loss = F.sum(KL_var_all * feature_gate_var) CDC_loss.name = 'CrossDomainCorrespondence_Output' return CDC_loss