def run_test(model,
             device,
             test_loader,
             writer,
             epoch,
             loss_fn,
             log_suffix=''):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)

            output = model(data)
            output = output.squeeze(1)

            validate_target_outupt_shapes(output, target)

            test_loss += loss_fn(output, target).item()  # sum up batch loss
            correct += count_correct(output, target)
            total += data.shape[0]

    test_acc = correct / total
    logger.info(
        'Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            test_loss, correct, total, 100. * test_acc))

    writer.add_scalar(f'loss/test{log_suffix}', test_loss, epoch)
    writer.add_scalar(f'acc/test{log_suffix}', test_acc, epoch)
def train(model,
          device,
          train_loaders,
          optimizer,
          epoch,
          writer,
          scale_grad_inverse_sparsity,
          n_agreement_envs,
          loss_fn,
          l1_coef,
          method,
          agreement_threshold,
          scheduler,
          log_suffix=''):
    """n_agreement_envs is the number of envs used to compute agreements"""
    assert len(
        train_loaders
    ) % n_agreement_envs == 0  # Divisibility makes it more convenient
    model.train()

    losses = []
    correct = 0
    example_count = 0
    batch_idx = 0

    train_iterators = [iter(loader) for loader in train_loaders]
    it_groups = permutation_groups(train_iterators, n_agreement_envs)

    while 1:
        train_iterator_selection = next(it_groups)
        try:
            datas = [next(iterator) for iterator in train_iterator_selection]
        except StopIteration:
            break

        assert len(datas) == n_agreement_envs

        batch_size = datas[0][0].shape[0]
        assert all(d[0].shape[0] == batch_size for d in datas)

        inputs = [d[0].to(device) for d in datas]
        target = [d[1].to(device) for d in datas]

        inputs = torch.cat(inputs, dim=0)
        target = torch.cat(target, dim=0)

        optimizer.zero_grad()

        output = model(inputs)
        output = output.squeeze(1)
        validate_target_outupt_shapes(output, target)

        mean_loss, masks = get_grads(
            agreement_threshold,
            batch_size,
            loss_fn,
            n_agreement_envs,
            params=optimizer.param_groups[0]['params'],
            output=output,
            target=target,
            method=method,
            scale_grad_inverse_sparsity=scale_grad_inverse_sparsity,
        )
        model.step += 1

        if l1_coef > 0.0:
            add_l1_grads(l1_coef, optimizer.param_groups)

        optimizer.step()

        losses.append(mean_loss.item())
        correct += count_correct(output, target)
        example_count += output.shape[0]
        batch_idx += 1

    scheduler.step()

    # Logging
    train_loss = np.mean(losses)
    train_acc = correct / (example_count + 1e-10)
    writer.add_scalar(f'weight/norm', train_loss, epoch)
    writer.add_scalar(f'mean_loss/train{log_suffix}', train_loss, epoch)
    writer.add_scalar(f'acc/train{log_suffix}', train_acc, epoch)
    logger.info(
        f'Train Epoch: {epoch}\t Acc: {train_acc:.4} \tLoss: {train_loss:.6f}')