示例#1
0
def find_epoch(dims, nb_classes, train_embs, train_labels, test_embs,
               test_labels, cuda):
    log = LogReg(dims, nb_classes)
    opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
    xent = nn.BCEWithLogitsLoss()
    sigmoid = nn.Sigmoid()
    if cuda:
        log.cuda()

    epoch_flag = 0
    epoch_win = 0
    best_f1 = torch.zeros(1)
    tmp_th = 0.5
    if cuda:
        best_f1 = best_f1.cuda()

    for e in range(2000):
        log.train()
        opt.zero_grad()

        logits = log(train_embs)
        loss = xent(logits, train_labels)

        loss.backward()
        opt.step()

        if (e + 1) % 100 == 0:
            log.eval()
            logits = sigmoid(log(test_embs))
            zero = torch.zeros_like(logits)
            one = torch.ones_like(logits)
            preds = torch.where(logits >= tmp_th, one, zero)
            if cuda:
                test_labels = test_labels.cpu()
                preds = preds.cpu()
            f1 = f1_score(test_labels.numpy(), preds.numpy(), average='micro')

            if f1 >= best_f1:
                epoch_flag = e + 1
                best_f1 = f1
                epoch_win = 0
            else:
                epoch_win += 1
            if epoch_win == 10:
                break
    return epoch_flag
示例#2
0
def run_logreg(train_embs, train_labels, test_embs, test_labels, cuda=False):
    train_embs = torch.Tensor(train_embs)
    train_labels = torch.Tensor(train_labels)
    test_embs = torch.Tensor(test_embs)
    test_labels = torch.Tensor(test_labels)

    if cuda:
        train_embs = train_embs.cuda()
        test_embs = test_embs.cuda()
        train_labels = train_labels.cuda()
        test_labels = test_labels.cuda()

    tot = torch.zeros(1)
    if cuda:
        tot = tot.cuda()
    res = []
    xent = nn.BCEWithLogitsLoss()
    sigmoid = nn.Sigmoid()

    dims = train_embs.shape[1]
    nb_classes = train_labels.shape[1]

    best_epoch = find_epoch(dims, nb_classes, train_embs, train_labels,
                            test_embs, test_labels, cuda)
    print('best epoch', best_epoch)
    best_th = 0.0
    repeats = 50
    for i in range(repeats):
        log = LogReg(dims, nb_classes)
        opt = torch.optim.Adam(log.parameters(), lr=0.01, weight_decay=0.0)
        if cuda:
            log.cuda()

        for _ in range(best_epoch):
            log.train()
            opt.zero_grad()
            logits = log(train_embs)
            loss = xent(logits, train_labels)
            loss.backward()
            opt.step()

        if i == 0:
            train_logits = sigmoid(log(train_embs))
            if cuda:
                best_th = find_best_th(train_logits.cpu(), train_labels.cpu())
            else:
                best_th = find_best_th(train_logits, train_labels)
            print('best threshold:', best_th)

        logits = sigmoid(log(test_embs))
        zero = torch.zeros_like(logits)
        one = torch.ones_like(logits)
        preds = torch.where(
            logits >= best_th, one,
            zero)  # ppi is a multi-label classification problem
        if cuda:
            test_labels = test_labels.cpu()
            preds = preds.cpu()
        f1 = f1_score(test_labels.numpy(), preds.numpy(), average='micro')
        res.append(f1)
        print(f1)
        tot += f1

    print('Average f1:', tot / repeats)

    res = np.stack(res)
    print(np.mean(res))
    print(np.std(res))