コード例 #1
0
def test_saved_model_folder(dirname,feats,output,filt=False):
    """
    Test a saved model by loading it and applying to features.
    Output is the output file, used with print_write()
    RETURN
      avgerage dist
      nPatterns
      nIters
      totaltime
    """
    print_write('*** MODEL SAVED IN: '+dirname+' ***',output)
    if filt:
        print_write('model uses FILTERING',output)
    # load model
    model = ANALYZE.unpickle(os.path.join(dirname,'model.p'))
    print_write('model loaded',output)
    # find nIters (#tracks), nPatterns, totaltime
    nIters, nPatterns, totalTime = ANALYZE.traceback_stats(dirname)
    if filt:
        nPatterns = model._nPatternUsed
    print_write('nIters (=nTracks): '+str(nIters),output)
    print_write('nPatterns: '+str(nPatterns),output)
    print_write('total time ran: '+str(totalTime),output)
    # predict
    best_code_per_p, dists = model.predicts(feats)
    print_write('prediction done, avg. dist: '+str(np.average(dists)),output)
    # return
    return np.average(dists),nPatterns,nIters,totalTime
コード例 #2
0
def test_saved_model_folder(dirname, feats, output):
    """
    Test a saved model by loading it and applying to features.
    Output is the output file, used with print_write()
    RETURN
      avgerage dist
      nPatterns
      nIters
      totaltime
    """
    print_write('*** MODEL SAVED IN: ' + dirname + ' ***', output)
    # load model
    model = ANALYZE.unpickle(os.path.join(dirname, 'model.p'))
    print_write('model loaded', output)
    # find nIters (#tracks), nPatterns, totaltime
    nIters, nPatterns, totalTime = ANALYZE.traceback_stats(dirname)
    print_write('nIters (=nTracks): ' + str(nIters), output)
    print_write('nPatterns: ' + str(nPatterns), output)
    print_write('total time ran: ' + str(totalTime), output)
    # predict
    best_code_per_p, dists = model.predicts(feats)
    print_write('prediction done, avg. dist: ' + str(np.average(dists)),
                output)
    # return
    return np.average(dists), nPatterns, nIters, totalTime
コード例 #3
0
def train(savedmodel, expdir='', pSize=8, usebars=2, keyInv=True,
          songKeyInv=False, positive=True, do_resample=True, partialbar=0,
          lrate=1e-5, nThreads=4, oracle='EN', artistsdb='', matdir='',
          nIterations=1e7, useModel='VQ', autobar=False, randoffset=False):
    """
    Performs training
    Grab track data from oracle
    Pass them to model that updates itself

    INPUT
      savedmodel    - previously saved model directory, to restart it
                      or matlab file, to start from a codebook
      expdir        - experiment directory, where to save experiments
      pSize         - pattern size
      usebars       - how many bars per pattern
      keyInv        - perform 'key invariance' on patterns
      positive      - replace negative values by 0
      do_resample   - if false, crop
      partialbar    - actual size, divides pSize
      lrate         - learning rate
      nThreads      - number of threads for the oracle, default=4
      oracle        - EN (EchoNest) or MAT (matfiles)
      artistdb      - SQLlite database containing artist names
      matdir        - matfiles directory, for oracle MAT
      nIterations   - maximum number of iterations
      useModel      - which model to use: 'VQ', 'VQFILT'

    Saves everything when done.
    """

    # create the StatLog object
    statlog = StatLog()

    # count iteration
    #       main   - iteration since this instance started
    #     global   - iteration adding those from the savedmodel
    # last_printed - for verbose purposes
    main_iterations = 0
    global_iterations = 0
    last_printed_iter = 1

    # start from saved model
    if os.path.isdir(savedmodel):
        # load model
        assert os.path.exists(os.path.join(savedmodel,'model.p')),'loading saved model, model.p does not exist? %s' % (os.path.join(savedmodel,'model.p'),)
        f = open(os.path.join(savedmodel,'model.p'),'r')
        model_unp = pickle.Unpickler(f)
        model = model_unp.load()
        f.close()
        # load params, savedmodel will be modified
        assert os.path.exists(os.path.join(savedmodel,'params.p')),'loading saved model, params.p does not exist? %s' % (os.path.join(savedmodel,'params.p'),)
        f = open(os.path.join(savedmodel,'params.p'),'r')
        param_unp = pickle.Unpickler(f)
        oldparams = param_unp.load()
        f.close()
        for k in oldparams.keys():
            if k == 'savedmodel': # special case
                continue
            exec_str = k + ' = oldparams["'+k+'"]'
            exec( exec_str )
            print 'from saved model,',k,'=',eval(k)
        # get global_iterations
        global_iterations,tmp1,tmp2 = ANALYZE.traceback_stats(savedmodel)
        
    # initialized model from codebook
    elif os.path.isfile(savedmodel):
        codebook = load_codebook(savedmodel)
        assert codebook != None,'Could not load codebook in: %s.'%savedmodel
        if useModel == 'VQ':
            model = MODEL.Model(codebook)
        elif useModel == 'VQFILT':
            model = MODEL.ModelFilter(codebook)
        else:
            assert False, 'wrong model codename: %s.'%useModel
        statlog.startFromScratch()
    # problem
    else:
        assert False,'saved model does not exist: %s.'%savedmodel

    # creates a dictionary with all parameters
    params = {'savedmodel':savedmodel, 'expdir':expdir,
              'pSize':pSize, 'usebars':usebars,
              'keyInv':keyInv, 'songKeyInv':songKeyInv,
              'positive':positive, 'do_resample':do_resample,
              'partialbar':partialbar, 'lrate':lrate,
              'nThreads':nThreads, 'oracle':oracle,
              'artistsdb':artistsdb, 'matdir':matdir,
              'nIterations':nIterations, 'useModel':useModel,
              'autobar':autobar,'randoffset':randoffset}

    # creates the experiment folder
    if not os.path.isdir(expdir):
        print 'creating experiment directory:',expdir
        os.mkdir(expdir)

    # create oracle
    if autobar:
        assert oracle=='MAT','autobar implemented only for matfiles oracle'
    if oracle == 'EN':
        oracle = oracle_en.OracleEN(params,artistsdb)
    elif oracle == 'MAT':
        oracle = oracle_matfiles.OracleMatfiles(params,matdir)
    else:
        assert False, 'wrong oracle codename: %s.'%oracle

    # starttime and save time
    starttime = time.time()
    last_save = starttime

    # for estimate of distance
    dist_estimate = deque()
    dist_estimate_len = 2000
    
    # main algorithm
    try:
        while True:
            # increment iterations
            main_iterations += 1
            global_iterations += 1
            if main_iterations == int(np.ceil(last_printed_iter * 1.1)) and len(dist_estimate) > 0:
                print main_iterations,'/',global_iterations,'iterations (local/global), approx. avg dist:',np.average(dist_estimate)
                last_printed_iter = main_iterations
            statlog.iteration()
            if global_iterations > nIterations:
                raise StopIteration
            # get features from the oracle
            if not autobar:
                feats = oracle.next_track()
            else:
                feats = oracle.next_track(auto_bar=model)
            # check features, remove empty patterns            
            if feats == None:
                continue
            feats = feats[np.nonzero(np.sum(feats,axis=1))]
            if feats.shape[0] == 0:
                continue
            assert not np.isnan(feats).any(),'features have NaN???'
            # stats
            statlog.patternsSeen(feats.shape[0])
            # update model
            avg_dist = model.update(feats,lrate=lrate)
            assert not np.isnan(avg_dist)
            # add to dist_estimate
            dist_estimate.append(avg_dist)
            if len(dist_estimate) > dist_estimate_len:
                dist_estimate.popleft()
            # save
            if should_save(starttime,last_save):
                savedir = save_experiment(expdir,model,starttime,statlog,params)
                last_save = time.time()

    # error, save and quit
    except:
        print ''
        exc_type, exc_value, exc_traceback = sys.exc_info()
        # normal stop or not? normal = StopIteration
        exit_code = 1
        if str(exc_type) == "exceptions.StopIteration" or str(exc_type) == "<type 'exceptions.StopIteration'>":
            exit_code = 0 # normal exit
        print "ERROR:", exc_type
        if str(exc_type) != "<type 'exceptions.KeyboardInterrupt'>":
            print '********** DEBUGGING INFO *******************'
            formatted_lines = traceback.format_exc().splitlines()
            if len(formatted_lines) > 2:
                print formatted_lines[-3]
            if len(formatted_lines) > 1:
                print formatted_lines[-2]
            print formatted_lines[-1]
            print '*********************************************'
            print 'Stoping after', main_iterations - 1, 'iterations.'
        # EN oracle, try to stop/slow down threads
        try:
            oracle_en._en_queue_size = 0
        except NameError:
            pass # oracle_en not loaded on hog; I hate myself for doing that
        # save
        print 'saving...'
        savedir = save_experiment(expdir,model,starttime,statlog,params,
                                  crash=True)
        print 'saved to: ',savedir
        #quit
        return exit_code