Esempio n. 1
0
File: main.py Progetto: zfjsail/ASAP
        8971, 85688, 9467, 32830, 28689, 94845, 69840, 50883, 74177, 79585,
        1055, 75631, 6825, 93188, 95426, 54514, 31467, 70597, 71149, 81994
    ]
    seeds = [42]
    counter = 0
    args.log_db = args.name
    print("log_db:", args.log_db)
    avg_val = []
    avg_test = []
    for seed in seeds:
        # set seed
        args.seed = seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        set_gpu(args.gpu)

        args.name = '{}_run_{}'.format(args.log_db, counter)

        # start training the model
        model = Trainer(args)
        # val_acc, test_acc = model.run()
        model.run_new()
        # print('For seed {}\t Val Accuracy: {:.3f} \t Test Accuracy: {:.3f}\n'.format(seed, val_acc, test_acc))
        # avg_val.append(val_acc)
        # avg_test.append(test_acc)
        counter += 1

    # print('Val Accuracy: {:.3f} ± {:.3f} Test Accuracy: {:.3f} ± {:.3f}'.format(np.mean(avg_val), np.std(avg_val),
    #                                                                             np.mean(avg_test), np.std(avg_test)))
Esempio n. 2
0
        type=int,
        help='Max length of the sentences in data.txt (default: 40)')
    parser.add_argument(
        '-maxdeplen',
        dest="max_dep_len",
        default=800,
        type=int,
        help='Max length of the dependency relations in data.txt (default: 800)'
    )

    args = parser.parse_args()

    if not args.restore:
        args.name = args.name + '_' + time.strftime(
            "%d_%m_%Y") + '_' + time.strftime("%H:%M:%S")

    tf.set_random_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    hp.set_gpu(args.gpu)

    model = SynGCN(args)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        model.fit(sess)

    print('Model Trained Successfully!!')
Esempio n. 3
0
    parser.add_argument("-pred_mode", dest="mode", choices=["entity", "temporal"], default="entity")
    parser.add_argument(
        "-granularity",
        dest="granularity",
        choices=["day", "year"],
        default="year",
        help="Day or year level granularity. Note: day only works for a single year span!",
    )

    args = parser.parse_args()
    args.dataset = "data/" + args.data_type + "_" + args.version + "/train.txt"
    args.entity2id = "data/" + args.data_type + "_" + args.version + "/entity2id.txt"
    args.relation2id = "data/" + args.data_type + "_" + args.version + "/relation2id.txt"
    args.test_data = "data/" + args.data_type + "_" + args.version + "/test.txt"
    args.valid_data = "data/" + args.data_type + "_" + args.version + "/valid.txt"
    args.triple2id = "data/" + args.data_type + "_" + args.version + "/triple2id.txt"
    args.embed_dim = int(args.embed_init.split("_")[1])

    tf.set_random_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    Helper.set_gpu(args.gpu)
    model = HyTE(args)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        print("enter fitting")
        model.fit(sess)