示例#1
0
def distance(y0, y1):
    """
    Distance function is Kullback-Leibler Divergence for categorical distribution
    """
    return F.kl_multinomial(F.softmax(y0), F.softmax(y1))
def CrossDomainCorrespondence(_fake_var_s,
                              _fake_var_t,
                              _choice_num=4,
                              _layer_fix_switch=False):
    # input
    # _fake_var_s : type=nn.Variable(), fake image variable by source model
    # _fake_var_t : type=nn.Variable(), fake image variable by target model

    # output
    # CDC_loss : type=nn.Variable(), shape=()

    # [get feature keys]
    # =list, the len=12, one of components shape is (batch_size, 64, 256, 256)
    feature_list_s = get_feature_list(_fake_var_s)
    feature_list_t = get_feature_list(_fake_var_t)
    if not _layer_fix_switch:
        feature_gate_var = one_hot_combination(len(feature_list_s),
                                               _choice_num)  # .shape=(12,)
    else:
        feature_gate_var = nn.Variable.from_numpy_array(
            np.array([
                0,
            ] * (len(feature_list_s) - _choice_num) + [
                1,
            ] * _choice_num))

    # [Cosine Similarity & Integrate KL divergence]
    KL_var_list = []
    for i in range(len(feature_list_s)):
        # --- change shape ---
        i_vector_s = make_broadcast_matrix(feature_list_s[i])
        j_vector_s = make_symmetric_matrix(feature_list_s[i])
        i_vector_t = make_broadcast_matrix(feature_list_t[i])
        j_vector_t = make_symmetric_matrix(feature_list_t[i])
        # --- cosine similarity ---
        # .shape=(batch_size, batch_size - 1)
        CS_var_s = F.softmax(CosineSimilarity(i_vector_s, j_vector_s,
                                              _index=2),
                             axis=1)
        CS_var_t = F.softmax(CosineSimilarity(i_vector_t, j_vector_t,
                                              _index=2),
                             axis=1)
        KL_var = F.sum(F.kl_multinomial(CS_var_s, CS_var_t, base_axis=1))
        KL_var_list.append(F.reshape(KL_var, [
            1,
        ]))
        # --- name each variables for debug ---
        feature_list_s[i].name = 'Feature/source/{}'.format(i)
        feature_list_t[i].name = 'Feature/target/{}'.format(i)
        i_vector_s.name = 'Feature/source/{}/i_matrix'.format(i)
        j_vector_s.name = 'Feature/source/{}/j_matrix'.format(i)
        i_vector_t.name = 'Feature/target/{}/i_matrix'.format(i)
        j_vector_t.name = 'Feature/target/{}/j_matrix'.format(i)
        CS_var_s.name = 'CosineSimilarity/source/{}'.format(i)
        CS_var_t.name = 'CosineSimilarity/target/{}'.format(i)
        KL_var.name = 'Kullback-Leibler_Divergence/{}'.format(i)
    KL_var_all = F.concatenate(*KL_var_list, axis=0)

    # [Calculate final loss]
    CDC_loss = F.sum(KL_var_all * feature_gate_var)
    CDC_loss.name = 'CrossDomainCorrespondence_Output'

    return CDC_loss