Esempio n. 1
0
def train_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = ep * args.num_batches
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        x_0 = model['ebd'](batch_0['X'])
        y_0 = batch_0['Y']
        x_1 = model['ebd'](batch_1['X'])
        y_1 = batch_1['Y']

        acc_0, loss_0, grad_0 = model['clf_all'](x_0,
                                                 y_0,
                                                 return_pred=False,
                                                 grad_penalty=True)

        acc_1, loss_1, grad_1 = model['clf_all'](x_1,
                                                 y_1,
                                                 return_pred=False,
                                                 grad_penalty=True)

        loss_ce = (loss_0 + loss_1) / 2.0
        regret = (grad_0 + grad_1) / 2.0

        acc = (acc_0 + acc_1) / 2.0

        weight = args.l_regret if step > args.anneal_iters else 1.0

        loss_total = loss_ce + weight * regret

        if weight > 1.0:
            loss_total /= weight

        opt.zero_grad()
        loss_total.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss_total.item())
        stats['loss_train'].append(loss_ce.item())
        stats['regret'].append(regret.item())
        step += 1

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats
Esempio n. 2
0
def train_loop(train_loader, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss']:
        stats[k] = []

    step = 0
    for batch in train_loader:
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']

        acc, loss = model['clf_all'](x,
                                     y,
                                     return_pred=False,
                                     grad_penalty=False)

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats
Esempio n. 3
0
def train_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = 0
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        # work on each batch
        # sample from the two env equally
        model['ebd'].train()
        model['clf_all'].train()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        x_0 = model['ebd'](batch_0['X'])
        y_0 = batch_0['Y']
        x_1 = model['ebd'](batch_1['X'])
        y_1 = batch_1['Y']

        acc_0, loss_0 = model['clf_all'](x_0, y_0, return_pred=False,
                                               grad_penalty=False)

        acc_1, loss_1 = model['clf_all'](x_1, y_1, return_pred=False,
                                               grad_penalty=False)

        loss = (loss_0 + loss_1) / 2.0
        acc = (acc_0 + acc_1) / 2.0

        opt.zero_grad()
        loss.backward()
        opt.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats
Esempio n. 4
0
def test_loop(test_loader, model, ep, args, att_idx_dict=None):
    loss_list = []
    true, pred = [], []

    if att_idx_dict is not None:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)

        if att_idx_dict is not None:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if att_idx_dict is not None:
        return utils.get_worst_acc(true, pred, idx, loss, att_idx_dict)

    return {
        'acc': acc,
        'loss': loss,
    }
Esempio n. 5
0
def test_loop(test_loader, model, ep, args, return_idx=False):
    loss_list = []
    true, pred, cor = [], [], []
    if return_idx:
        idx = []

    for batch in test_loader:
        # work on each batch
        model['ebd'].eval()
        model['clf_all'].eval()

        batch = utils.to_cuda(utils.squeeze_batch(batch))

        x = model['ebd'](batch['X'])
        y = batch['Y']
        c = batch['C']

        y_hat, loss = model['clf_all'](x, y, return_pred=True)

        true.append(y)
        pred.append(y_hat)
        cor.append(c)
        if return_idx:
            idx.append(batch['idx'])

        loss_list.append(loss.item())

    true = torch.cat(true)
    pred = torch.cat(pred)

    acc = torch.mean((true == pred).float()).item()
    loss = np.mean(np.array(loss_list))

    if not return_idx:
        return {
            'acc': acc,
            'loss': loss,
        }
    else:
        cor = torch.cat(cor).tolist()
        true = true.tolist()
        pred = pred.tolist()
        idx = torch.cat(idx).tolist()

        # split correct and wrong idx
        correct_idx, wrong_idx = [], []

        # compute correlation between cor and y for analysis
        correct_cor, wrong_cor = [], []
        correct_y, wrong_y = [], []

        for i, y, y_hat, c in zip(idx, true, pred, cor):
            if y == y_hat:
                correct_idx.append(i)
                # correct_cor += (1 if (int(c) == int(y)) else 0)
                correct_cor.append(c)
                correct_y.append(y)
            else:
                wrong_idx.append(i)
                # wrong_cor += (1 if (int(c) == int(y)) else 0)
                wrong_cor.append(c)
                wrong_y.append(y)

        return {
            'acc': acc,
            'loss': loss,
            'correct_idx': correct_idx,
            'correct_cor': correct_cor,
            'correct_y': correct_y,
            'wrong_idx': wrong_idx,
            'wrong_cor': wrong_cor,
            'wrong_y': wrong_y,
        }
Esempio n. 6
0
def train_dro_loop(train_loaders, model, opt, ep, args):
    stats = {}
    for k in ['worst_loss', 'avg_loss', 'worst_acc', 'avg_acc']:
        stats[k] = []

    step = 0
    # for batches in tqdm(zip(*train_loaders), total=args.num_batches, ncols=80,
    #                     leave=False, desc=colored('Training on train',
    #                                               'yellow')):
    for batches in zip(*train_loaders):
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()

        x, y = [], []

        for batch in batches:
            batch = utils.to_cuda(utils.squeeze_batch(batch))
            x.append(batch['X'])
            y.append(batch['Y'])

        if args.dataset in ['beer_0', 'beer_1', 'beer_2']:
            # text models have varying length between batches
            pred = []
            for cur_x in x:
                pred.append(model['clf_all'](model['ebd'](cur_x)))
            pred = torch.cat(pred, dim=0)
        else:
            pred = model['clf_all'](model['ebd'](torch.cat(x, dim=0)))

        cur_idx = 0

        avg_loss = 0
        avg_acc = 0
        worst_loss = 0
        worst_acc = 0

        for cur_true in y:
            cur_pred = pred[cur_idx:cur_idx + len(cur_true)]
            cur_idx += len(cur_true)

            loss = F.cross_entropy(cur_pred, cur_true)
            acc = torch.mean((torch.argmax(cur_pred,
                                           dim=1) == cur_true).float()).item()

            avg_loss += loss.item()
            avg_acc += acc

            if loss.item() > worst_loss:
                worst_loss = loss
                worst_acc = acc

        opt.zero_grad()
        worst_loss.backward()
        opt.step()

        avg_loss /= len(y)
        avg_acc /= len(y)

        stats['avg_acc'].append(avg_acc)
        stats['avg_loss'].append(avg_loss)
        stats['worst_acc'].append(worst_acc)
        stats['worst_loss'].append(worst_loss.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats
Esempio n. 7
0
def train_loop(train_loaders, model, opt_all, opt_0, opt_1, ep, args):
    stats = {}
    for k in ['acc', 'loss', 'regret', 'loss_train']:
        stats[k] = []

    step = ep * args.num_batches
    for batch_0, batch_1 in zip(train_loaders[0], train_loaders[1]):
        # work on each batch
        model['ebd'].train()
        model['clf_all'].train()
        model['clf_0'].train()
        model['clf_1'].train()

        batch_0 = utils.to_cuda(utils.squeeze_batch(batch_0))
        batch_1 = utils.to_cuda(utils.squeeze_batch(batch_1))

        x_0 = model['ebd'](batch_0['X'])
        y_0 = batch_0['Y']
        x_1 = model['ebd'](batch_1['X'])
        y_1 = batch_1['Y']

        # train clf_0 on x_0
        _, loss_0_0 = model['clf_0'](x_0.detach(), y_0)
        opt_0.zero_grad()
        loss_0_0.backward()
        opt_0.step()

        # train clf_1 on x_1
        _, loss_1_1 = model['clf_1'](x_1.detach(), y_1)
        opt_1.zero_grad()
        loss_1_1.backward()
        opt_1.step()

        # train clf_all on both, backprop to representation
        x = torch.cat([x_0, x_1], dim=0)
        y = torch.cat([y_0, y_1], dim=0)
        acc, loss_ce = model['clf_all'](x, y)

        # randomly sample a group, evaluate the validation loss
        # do not detach feature representation at this time
        if random.random() > 0.5:
            # choose env 1
            # apply clf 1 on env 1
            _, loss_1_1 = model['clf_1'](x_1, y_1)

            # apply clf 0 on env 1
            _, loss_0_1 = model['clf_0'](x_1, y_1)

            regret = loss_0_1 - loss_1_1
        else:
            # chosse env 0
            # apply clf 1 on env 0
            _, loss_1_0 = model['clf_1'](x_0, y_0)

            # apply clf 0 on env 0
            _, loss_0_0 = model['clf_0'](x_0, y_0)

            regret = loss_1_0 - loss_0_0

        weight = args.l_regret if step > args.anneal_iters else 1.0
        loss = loss_ce + weight * regret
        step += 1

        if weight > 1.0:
            loss /= weight

        opt_all.zero_grad()
        opt_0.zero_grad()
        opt_1.zero_grad()
        loss.backward()
        opt_all.step()

        stats['acc'].append(acc)
        stats['loss'].append(loss.item())
        stats['loss_train'].append(loss_ce.item())
        stats['regret'].append(regret.item())

    for k, v in stats.items():
        stats[k] = float(np.mean(np.array(v)))

    return stats