コード例 #1
0
def countex(epochs=5000,nbits=15,ncases=500,lrate=0.5,showint=500,mbs=20,vfrac=0.1,tfrac=0.1,vint=200,sm=True,bestk=1):
    case_generator = (lambda: TFT.gen_vector_count_cases(ncases,nbits))
    cman = Caseman(cfunc=case_generator, vfrac=vfrac, tfrac=tfrac)
    ann = Gann(dims=[nbits, nbits*3, nbits+1], cman=cman, lrate=lrate, showint=showint, mbs=mbs, vint=vint, softmax=sm)
    ann.run(epochs,bestk=bestk)
    TFT.fireup_tensorboard('probeview')
    return ann
コード例 #2
0
def autoex(epochs=300,nbits=4,lrate=0.03,showint=100,mbs=None,vfrac=0.1,tfrac=0.1,vint=100,sm=False,bestk=None):
    size = 2**nbits
    mbs = mbs if mbs else size
    case_generator = (lambda : TFT.gen_all_one_hot_cases(2**nbits))
    cman = Caseman(cfunc=case_generator,vfrac=vfrac,tfrac=tfrac)
    ann = Gann(dims=[size,nbits,size],cman=cman,lrate=lrate,showint=showint,mbs=mbs,vint=vint,softmax=sm)
    ann.gen_probe(0,'wgt',('hist','avg'))  # Plot a histogram and avg of the incoming weights to module 0.
    ann.gen_probe(1,'out',('avg','max'))  # Plot average and max value of module 1's output vector
    ann.add_grabvar(0,'wgt') # Add a grabvar (to be displayed in its own matplotlib window).
    ann.run(epochs,bestk=bestk)
    ann.runmore(epochs*2,bestk=bestk)
    TFT.fireup_tensorboard('probeview')
    return ann
コード例 #3
0
ファイル: main.py プロジェクト: nummer1/General-FFnetwork
def main():
    parser = argument_parser.argument_parser()
    parser.parse()
    parser.organize()
    # (self, cases, vfrac, tfrac, casefrac, mapsep)
    caseman = gann_base.Caseman(parser.data_set_v, parser.vfrac_v,
                                parser.tfrac_v, parser.casefrac_v,
                                parser.mapbs_v)
    # (self, dims, cman, afunc, ofunc, cfunc, optimizer, lrate, wrange, vint, mbs, usevsi, showint=None):
    ann = gann_base.Gann(parser.dims_v,
                         caseman,
                         parser.afunc_v,
                         parser.ofunc_v,
                         parser.cfunc_v,
                         parser.optimizer_v,
                         parser.lrate_v,
                         parser.wrange_v,
                         parser.vint_v,
                         parser.mbs_v,
                         parser.usevsi_v,
                         showint=parser.steps_v - 1)

    for layer in parser.dispw_v:
        ann.add_grabvar(layer, type='wgt')
        ann.gen_probe(layer, 'wgt', 'hist')
    for layer in parser.dispb_v:
        ann.add_grabvar(layer, type='bias')
        ann.gen_probe(layer, 'bias', 'hist')

    # run, then map
    ann.run(steps=parser.steps_v,
            sess=None,
            continued=False,
            bestk=parser.best1_v)

    ann.remove_grabvars()
    for layer in parser.maplayers_v:
        if layer == 0:
            ann.add_grabvar(layer, type='in', add_figure=False)
        else:
            ann.add_grabvar(layer - 1, type='out', add_figure=False)
    res, labs = ann.do_mapping()
    results = []
    for i in range(len(res[0])):
        l = np.array([r[i] for r in res])
        l = l.reshape(l.shape[0], l.shape[2])
        TFT.hinton_plot(l,
                        title="mapping test output of layer " +
                        str(parser.maplayers_v[i]))
        results.append(l)

    for i, r in enumerate(results):
        # DENDOGRAM
        # if parser.maplayers_v[i] in parser.mapdend_v:
        if parser.best1_v:
            TFT.dendrogram(r,
                           list(map(TFT.one_hot_to_int, labs)),
                           title="Dendrogram " + str(parser.maplayers_v[i]))

    gann_base.PLT.show()
    TFT.fireup_tensorboard('probeview')