Пример #1
0
def run_gridW2(name,
               allLayers,
               states,
               deviceId,
               reports=[100, 200, 400, 800, 1600],
               N=2048):
    torch.cuda.set_device(deviceId)
    results = {}
    for L in allLayers:
        torch.cuda.empty_cache()
        offsets, scale, W, Qu = genq.optimal_spacing(states)
        Ws = torch.linspace(W * 0.5, W * 1.5, 16)
        for w in Ws:
            start = time.time()
            print("%d states, %d layers, W = %.02f" % (states, L, w))
            quant, W = opt_quantW(states, w)
            lr = pick_lr(L)
            res = main('--batch-size 32 --epochs 5 --lr %f -N %d -L %d' %
                       (lr, N, L),
                       act=quant,
                       sigmaW=w,
                       reports=reports)
            duration = time.time() - start
            print("Duration: %.02f" % duration)
            results[(L, w)] = res
            save_obj(results, name)
Пример #2
0
def run_gridW(name,
              allLayers,
              states,
              deviceId,
              reports=[100, 200, 400, 800, 1600],
              N=2048):
    torch.cuda.set_device(deviceId)
    results = {}
    for L in allLayers:
        torch.cuda.empty_cache()
        offsets, scale, W, Qu = genq.optimal_spacing(states)
        Ws = torch.linspace(W * 0.5, W * 1.5, 16)
        for w in Ws:
            slope = genq.optimize_ste(1.0, w, scale, samples=1000)
            chi = genq.Estimate(drnn.RnnArgs(SigmaW=w,
                                             SigmaU=0.0,
                                             SigmaB=0.001),
                                offsets,
                                scale,
                                steps=100)
            start = time.time()
            print("%d states, %d layers, chi: %.05f, W = %.02f" %
                  (states, L, chi, w))
            quant, W = opt_quant(states)
            lr = pick_lr(L)
            res = main('--batch-size 32 --epochs 5 --lr %f -N %d -L %d' %
                       (lr, N, L),
                       act=quant,
                       sigmaW=w,
                       reports=reports)
            duration = time.time() - start
            print("Duration: %.02f" % duration)
            results[(L, w)] = res
            save_obj(results, name)
Пример #3
0
def opt_quant(states):
    offsets, scale, W, Qu = genq.optimal_spacing(states)
    offsets = offsets.cuda()
    count = torch.Tensor([offsets.numel()]).cuda()
    scale = torch.Tensor([scale]).cuda()
    ##thresh = float(scale.cpu().numpy())
    thresh = 1.0
    ##print(offsets)
    slope = genq.optimize_ste(Qu, W, thresh, samples=1000)
    slope = torch.Tensor([1.0 / slope]).cuda()
    ##print("Thresh: %.05f, Slope: %.05f" % (thresh,slope))
    return Quant(offsets, count, scale, thresh, slope), W
Пример #4
0
def showContour(name="grid_results",
                steps=100,
                test=0,
                fig=plt.figure(),
                subplot=111,
                levels=7,
                scan=0):
    results, layers, states = examine(load_obj(name), steps, test)

    l = len(layers)
    s = len(states)
    S, L = np.meshgrid(states, layers)
    ax = fig.add_subplot(subplot)
    plt.gcf().set_size_inches(5, 5)

    ax.set_yscale("log")
    cmap = plt.get_cmap("hot")
    CS = ax.contourf(S,
                     L,
                     results,
                     levels,
                     colors=['black', 'darkred', 'firebrick', 'red'])
    ##CS.cmap.set_over('red')
    ##CS.cmap.set_under('black')
    plt.xlabel("# States", fontsize=18)
    if (subplot % 10 == 1):
        plt.ylabel("Layers", fontsize=18)
    ##ax.set_xticks( range(len(states)), minor=False )
    DS = np.asarray([genq.DepthScale(s) for s in states])
    ax.set_ylim((layers[0], layers[-1]))
    strlist = [l.__str__() for l in states]
    ##ax.set_xticklabels(strlist)
    ##ax.set_xticks( range(s), minor=False )

    for tick in ax.xaxis.get_major_ticks():
        tick.label.set_fontsize(14)
    for tick in ax.yaxis.get_major_ticks():
        tick.label.set_fontsize(14)

    if int(subplot / 10) > 11:
        plt.title("%d Training Steps" % steps)

    if (scan == 0):
        midstate = states[int(len(states) / 2)]
        idx = int(len(states) / 2)
        if subplot == 111:
            fig.colorbar(CS)
        ##if subplot%10==2:
        ##    fig.colorbar(CS)
        for a, c in [(4, 'springgreen'), (6, 'y'), (9, 'w')]:
            plt.semilogy(states, a * DS, lineStyle='--', color=c)
            point = (midstate, a * DS[idx])
            while (idx > 0 and a * DS[idx] > 0.8 * layers[-1]):
                idx = idx - 1
                midstate = states[idx]

            plt.annotate("%d$\\xi$" % a, (midstate, a * DS[idx]),
                         color=c,
                         fontSize=18,
                         textcoords="offset points",
                         xytext=(-7, 7),
                         ha='center')
            idx = idx - 1
            midstate = states[idx]

    else:

        results_misc, layers, sigmaWs = examine(load_obj(name),
                                                steps=steps,
                                                test=test)
        midstate = sigmaWs[int(len(sigmaWs) * 3 / 4)]
        idx = int(len(sigmaWs) * 3 / 4)
        colors = ['k', 'k', 'r', 'w']

        offsets, scale, W, Qu = genq.optimal_spacing(scan)

        Chis_for_DS = np.asarray([
            genq.Estimate(drnn.RnnArgs(SigmaW=w, SigmaU=0.0, SigmaB=0.001),
                          offsets,
                          scale,
                          steps=100) for w in sigmaWs
        ])

        DS = np.asarray([-1.0 / np.log(chi) for chi in Chis_for_DS])
        plt.xlabel("$\sigma_w$")
        ##plt.title("Mnist Training- %d States Actvation" % scan)

        perct = 0.4
        for a, c in [(4, 'c'), (6, 'y'), (9, 'w')]:
            plt.semilogy(sigmaWs, a * DS, lineStyle='--', color=c)
            point = (midstate, a * DS[idx])
            while (idx > 0 and a * DS[idx] > perct * layers[-1]):
                idx = idx - 1
                midstate = sigmaWs[idx]
            perct += 0.2
            plt.annotate("%d$\\xi$" % a, (midstate, a * DS[idx]),
                         color=c,
                         fontSize=16,
                         textcoords="offset points",
                         xytext=(0, 10),
                         ha='center')