コード例 #1
0
ファイル: junto.py プロジェクト: kimiyoung/ssl
    fout.close()

if __name__ == '__main__':
    if len(sys.argv) > 1 and sys.argv[1] == 'eval':
        if DATASET != 'diel':
            pos, neg = 0, 0
            for line in open(OUTPUT_FILE):
                if 'true' not in line: continue
                if float(line.strip().split()[-1]) == 1.0:
                    pos += 1
                else:
                    neg += 1
            print 1.0 * pos / (pos + neg) if pos + neg > 0 else 0.0
        quit()

    if DATASET == 'citeseer':
        x, y, tx, ty, graph = data.gen_dataset('../data/citeseer/citeseer.cites', '../data/citeseer/citeseer.content', 0)
        print_data(x, y, tx, ty, graph)
    if DATASET == 'cora':
        x, y, tx, ty, graph = data.gen_dataset('../data/cora/cora.cites', '../data/cora/cora.content', 0)
        print_data(x, y, tx, ty, graph)
    if DATASET == 'pubmed':
        x, y, tx, ty, graph = data.gen_pubmed_dataset('../data/pubmed/pubmed.cites', '../data/pubmed/pubmed.content', 0)
        print_data(x, y, tx, ty, graph)
    if DATASET == 'nell': # attention to the parameters in nell_main
        x, y, tx, ty, graph = nell_main.gen_dataset()
        print_data(x, y, tx, ty, graph)
    if DATASET == 'diel':
        pass

コード例 #2
0
ファイル: test_final_trans.py プロジェクト: kimiyoung/ssl
import nell_main as nell
from scipy import sparse as sp
from final.trans_model import trans_model as model
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--learning_rate', help = 'learning rate', type = float, default = 1.0)
parser.add_argument('--embedding_size', help = 'embedding dimensions', type = int, default = 50)
parser.add_argument('--window_size', help = 'window size in random walk sequences', type = int, default = 7)
parser.add_argument('--path_size', help = 'length of random walk sequences', type = int, default = 20)
parser.add_argument('--batch_size', help = 'the size of batch for training instances', type = int, default = 200)
parser.add_argument('--g_batch_size', help = 'the batch size for graph', type = int, default = 50)
parser.add_argument('--g_sample_size', help = 'the sample size from label information', type = int, default = 100)
parser.add_argument('--neg_samp', help = 'negative sampling rate', type = int, default = 10)
parser.add_argument('--g_learning_rate', help = 'learning rate for graph', type = float, default = 1e-2)
parser.add_argument('--embedding_file', help = 'filename for saving models', type = str, default = 'final/saved.model')
args = parser.parse_args()

x, y, tx, ty, graph = nell.gen_dataset()
m = model(args)
m.add_data(x, y, graph)
m.build()
m.train(init_iter_label = 1, init_iter_graph = 1, max_iter = 10, iter_graph = 1, iter_inst = 1, iter_label = 1)
m.predict(tx)
print 'test done.'