parser = argparse.ArgumentParser("main.py") parser.add_argument("dataset", help="the name of the dataset", type=str) args = parser.parse_args() # load dataset node_embedding, rel_embedding, kg, train, test = load_data( args.dataset, WORD_EMB_DIM, "ComplEX") # projection from word embedding to node node embedding word2node = nn.Linear(WORD_EMB_DIM, NODE_EMB_DIM, bias=False).to(device) # mutihead self-attention attention = Attention(4, NODE_EMB_DIM, H_DIM, math.sqrt(H_DIM)).to(device) # list contains all params that need to optimize model_param_list = list(word2node.parameters()) + list(attention.parameters()) # init agent state = State((train[0][1], train[0][2]), kg, node_embedding, WORD_EMB_DIM, word2node, attention, rel_embedding, T, device) # init here to calculate the input size input_dim = state.get_input_size() num_rel = len(kg.rel_vocab) num_entity = len(kg.en_vocab) num_subgraph = len(state.subgraphs) emb_dim = WORD_EMB_DIM + NODE_EMB_DIM baseline = ReactiveBaseline(l=0.02) agent = Agent(input_dim, 32, emb_dim, 0, 2, num_entity, num_rel, num_subgraph, GAMMA, 0.00005, model_param_list, baseline, device) # training loop