def find_epoch(dims, nb_classes, train_embs, train_labels, test_embs, test_labels, cuda): log = LogReg(dims, nb_classes) opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0) xent = nn.BCEWithLogitsLoss() sigmoid = nn.Sigmoid() if cuda: log.cuda() epoch_flag = 0 epoch_win = 0 best_f1 = torch.zeros(1) tmp_th = 0.5 if cuda: best_f1 = best_f1.cuda() for e in range(2000): log.train() opt.zero_grad() logits = log(train_embs) loss = xent(logits, train_labels) loss.backward() opt.step() if (e + 1) % 100 == 0: log.eval() logits = sigmoid(log(test_embs)) zero = torch.zeros_like(logits) one = torch.ones_like(logits) preds = torch.where(logits >= tmp_th, one, zero) if cuda: test_labels = test_labels.cpu() preds = preds.cpu() f1 = f1_score(test_labels.numpy(), preds.numpy(), average='micro') if f1 >= best_f1: epoch_flag = e + 1 best_f1 = f1 epoch_win = 0 else: epoch_win += 1 if epoch_win == 10: break return epoch_flag
def run_logreg(train_embs, train_labels, test_embs, test_labels, cuda=False): train_embs = torch.Tensor(train_embs) train_labels = torch.Tensor(train_labels) test_embs = torch.Tensor(test_embs) test_labels = torch.Tensor(test_labels) if cuda: train_embs = train_embs.cuda() test_embs = test_embs.cuda() train_labels = train_labels.cuda() test_labels = test_labels.cuda() tot = torch.zeros(1) if cuda: tot = tot.cuda() res = [] xent = nn.BCEWithLogitsLoss() sigmoid = nn.Sigmoid() dims = train_embs.shape[1] nb_classes = train_labels.shape[1] best_epoch = find_epoch(dims, nb_classes, train_embs, train_labels, test_embs, test_labels, cuda) print('best epoch', best_epoch) best_th = 0.0 repeats = 50 for i in range(repeats): log = LogReg(dims, nb_classes) opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0) if cuda: log.cuda() for _ in range(best_epoch): log.train() opt.zero_grad() logits = log(train_embs) loss = xent(logits, train_labels) loss.backward() opt.step() if i == 0: train_logits = sigmoid(log(train_embs)) if cuda: best_th = find_best_th(train_logits.cpu(), train_labels.cpu()) else: best_th = find_best_th(train_logits, train_labels) print('best threshold:', best_th) logits = sigmoid(log(test_embs)) zero = torch.zeros_like(logits) one = torch.ones_like(logits) preds = torch.where( logits >= best_th, one, zero) # ppi is a multi-label classification problem if cuda: test_labels = test_labels.cpu() preds = preds.cpu() f1 = f1_score(test_labels.numpy(), preds.numpy(), average='micro') res.append(f1) print(f1) tot += f1 print('Average f1:', tot / repeats) res = np.stack(res) print(np.mean(res)) print(np.std(res))