예제 #1
0
# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
clstm = CNN_LSTM(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, N_FILTERS,
                 HIDDEN_DIM).to(device)
print(clstm)
policy_s = Policy_S(HIDDEN_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)
policy_n = Policy_N(HIDDEN_DIM, HIDDEN_DIM, MAX_K).to(device)
policy_c = Policy_C(HIDDEN_DIM, HIDDEN_DIM, LABEL_DIM).to(device)
value_net = ValueNetwork(HIDDEN_DIM, HIDDEN_DIM, OUTPUT_DIM).to(device)

# set up optimiser
params_pg = list(policy_s.parameters()) + list(policy_c.parameters()) + list(
    policy_n.parameters())
optim_loss = optim.Adam(clstm.parameters(), lr=learning_rate)
optim_policy = optim.Adam(params_pg, lr=learning_rate)
optim_value = optim.Adam(value_net.parameters(), lr=learning_rate)

# add pretrained embeddings
pretrained_embeddings = TEXT.vocab.vectors
clstm.embedding.weight.data.copy_(pretrained_embeddings)
clstm.embedding.weight.requires_grad = True  # update the initial weights

# set the default tensor type for GPU
#torch.set_default_tensor_type('torch.cuda.FloatTensor')


def finish_episode(policy_loss_sum, encoder_loss_sum, baseline_value_batch):
    '''
    Called when a data sample has been processed.
LABEL_DIM = 2
N_FILTERS = 128
learning_rate = 0.001

# the number of training epoches
num_of_epoch = 10

# set up the criterion
criterion = nn.CrossEntropyLoss().to(device)
# set up models
clstm = CNN_LSTM(INPUT_DIM, EMBEDDING_DIM, KER_SIZE, N_FILTERS,
                 HIDDEN_DIM).to(device)
policy_c = Policy_C(HIDDEN_DIM, HIDDEN_DIM, LABEL_DIM).to(device)

# set up optimiser
params = list(clstm.parameters()) + list(policy_c.parameters())
optimizer = optim.Adam(params, lr=learning_rate)

# add pretrained embeddings
pretrained_embeddings = TEXT.vocab.vectors
clstm.embedding.weight.data.copy_(pretrained_embeddings)
clstm.embedding.weight.requires_grad = True  # update the initial weights


def evaluate(iterator):
    clstm.eval()
    policy_c.eval()
    true_labels = []
    pred_labels = []
    eval_loss = 0
    for index, valid in enumerate(iterator):