Пример #1
0
def test_make_diverse_seqs():
    from retrieve_model_tools import retrieve_NGram, retrieve_SkipGramNN
    seq_model = retrieve_NGram()
    diversity_model = retrieve_SkipGramNN()
    print seq_model.syms[:10]
    seqs = make_diverse_seqs(seq_model, diversity_model, 2, 8)
    print seqs
Пример #2
0
    def initialize(self):
        # Called after socketio has initialized the namespace.
        self.history = []
        self.parsed_seqs_notes = {}
        self.unit_dur = 60/92.0

        self.ngram = retrieve_NGram()
        self.nn = retrieve_SkipGramNN()

        assert self.ngram.syms == self.nn.syms

        self.previous_sym = None
        self.previous_sym_ind = None
        self.n_suggestions = 5
        self.n_similar = 2

        self.suggestions = SuggestionList(self)
        self.suggestions_above = SuggestionList(self)

        self.config = get_configs()
        self.corpus = self.config['corpus']
        print '...corpus', self.corpus

        if self.config['use_letternames']:
            self.symbol_type = 'letter'
        else:
            self.symbol_type = 'roman'

        # need to correct some roman numerals
        print '# of syms: %d' % len(self.ngram.syms)
        self.syms = []
        for sym in self.ngram.syms:
            formatted_sym, valid = self.format_sym(sym)
            self.syms.append(formatted_sym)

        # print 'F#m in syms?', 'F#m' in self.syms

        # need to update the "spelling" of roman numerals in nn and ngram
        self.nn.syms = self.syms
        self.ngram.syms = self.syms

        self._rn2letter, self._letter2rn = self.load_rn2letter_dict()

        self.experiment_type = EXPERIMENT_TYPE

        self.logs = Logs(EXPERIMENT_TYPE, EXPERIMENT_TYPE_STRS)
Пример #3
0
def test_plot_one_gram_all():
    # configs, data = get_configs_data()
    
    ngram = retrieve_NGram()
    unigram = ngram.unigram_counts
    sorted_inds = np.argsort(-unigram)
    sorted_unigram = [unigram[ind] for ind in sorted_inds]
    PLOT_LOG = False
    if PLOT_LOG:
        # TODO: should be log-log scale to get straightline
        plt.plot(np.log(sorted_unigram))
        plt.ylabel('log counts')
    else:
        plt.plot(sorted_unigram)
    plt.title('Bach chorale (size of data: %d)' % np.sum(unigram))
    # plot_mat(, 'bach', ngram.syms)
    # plot_one_gram(ngram.syms, ngram.seqs)
    plt.savefig('bach_chorale-chord-dist.pdf')
Пример #4
0
def test_forward_backward():
    from retrieve_model_tools import retrieve_NGram
    ngram = retrieve_NGram()
    sorted_probs, sorted_syms = simple_foward_backward_gap_dist(ngram, 'I', 'I')
    for i in range(10):
        print sorted_probs[i], sorted_syms[i]