Exemplo n.º 1
0
def weights_sparsity_summary(model, opt=None):
    try:
        df = distiller.weights_sparsity_summary(
            model.module, return_total_sparsity=True
        )
    except AttributeError:
        df = distiller.weights_sparsity_summary(model, return_total_sparsity=True)
    return df[0]["NNZ (dense)"].sum() // 2
Exemplo n.º 2
0
def save_model_stats(model, save_path, name=None):
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    sparsity_df, _ = distiller.weights_sparsity_summary(model, True)
    performance_df = distiller.model_performance_summary(
        model, torch.rand([1, 3, 352, 352]), 1)
    if name:
        sparsity_df.to_csv('{}/{}_{}.sparsity.csv'.format(
            save_path, model.name, name))
        performance_df.to_csv('{}/{}_{}.performance.csv'.format(
            save_path, model.name, name))
    else:
        sparsity_df.to_csv('{}/{}.sparsity.csv'.format(save_path, model.name))
        performance_df.to_csv('{}/{}.performance.csv'.format(
            save_path, model.name))
Exemplo n.º 3
0
def sparsity_display(model, sparsity_file):
    """
        分析网络权重的稀疏性, 就是权重中有多少0

        Arguments:
            model (class Net):          已训练好的模型 \n
            sparsity_file (str):     存储稀疏性分析结果的文件名称和位置

        Examples:
            >>> from apputils.platform_summaries import *
            >>> sparsity_display(model, 'spars_file.xlsx')
    """
    df_sparsity = distiller.weights_sparsity_summary(model)
    # Remove these two columns which contains uninteresting values
    df_sparsity = df_sparsity.drop(['Cols (%)', 'Rows (%)'], axis=1)

    df_sparsity.to_csv(sparsity_file)