Ejemplo n.º 1
0
    def process_function(unused_engine, batch):
        x, y = _prepare_batch(batch, device=device, non_blocking=True)
        model.train()
        optimizer.zero_grad()
        y_pred = model(x)

        if config['agreement_threshold'] > 0.0:
            # The "batch_size" in this function refers to the batch size per env
            # Since we treat every example as one env, we should set the parameter
            # n_agreement_envs equal to batch size
            mean_loss, masks = and_mask_utils.get_grads(
                agreement_threshold=config['agreement_threshold'],
                batch_size=1,
                loss_fn=criterion,
                n_agreement_envs=config['batch_size'],
                params=optimizer.param_groups[0]['params'],
                output=y_pred,
                target=y,
                method=args.method,
                scale_grad_inverse_sparsity=config[
                    'scale_grad_inverse_sparsity'],
            )
        else:
            mean_loss = criterion(y_pred, y)
            mean_loss.backward()

        optimizer.step()

        return {}
def train(model,
          device,
          train_loaders,
          optimizer,
          epoch,
          writer,
          scale_grad_inverse_sparsity,
          n_agreement_envs,
          loss_fn,
          l1_coef,
          method,
          agreement_threshold,
          scheduler,
          log_suffix=''):
    """n_agreement_envs is the number of envs used to compute agreements"""
    assert len(
        train_loaders
    ) % n_agreement_envs == 0  # Divisibility makes it more convenient
    model.train()

    losses = []
    correct = 0
    example_count = 0
    batch_idx = 0

    train_iterators = [iter(loader) for loader in train_loaders]
    it_groups = permutation_groups(train_iterators, n_agreement_envs)

    while 1:
        train_iterator_selection = next(it_groups)
        try:
            datas = [next(iterator) for iterator in train_iterator_selection]
        except StopIteration:
            break

        assert len(datas) == n_agreement_envs

        batch_size = datas[0][0].shape[0]
        assert all(d[0].shape[0] == batch_size for d in datas)

        inputs = [d[0].to(device) for d in datas]
        target = [d[1].to(device) for d in datas]

        inputs = torch.cat(inputs, dim=0)
        target = torch.cat(target, dim=0)

        optimizer.zero_grad()

        output = model(inputs)
        output = output.squeeze(1)
        validate_target_outupt_shapes(output, target)

        mean_loss, masks = get_grads(
            agreement_threshold,
            batch_size,
            loss_fn,
            n_agreement_envs,
            params=optimizer.param_groups[0]['params'],
            output=output,
            target=target,
            method=method,
            scale_grad_inverse_sparsity=scale_grad_inverse_sparsity,
        )
        model.step += 1

        if l1_coef > 0.0:
            add_l1_grads(l1_coef, optimizer.param_groups)

        optimizer.step()

        losses.append(mean_loss.item())
        correct += count_correct(output, target)
        example_count += output.shape[0]
        batch_idx += 1

    scheduler.step()

    # Logging
    train_loss = np.mean(losses)
    train_acc = correct / (example_count + 1e-10)
    writer.add_scalar(f'weight/norm', train_loss, epoch)
    writer.add_scalar(f'mean_loss/train{log_suffix}', train_loss, epoch)
    writer.add_scalar(f'acc/train{log_suffix}', train_acc, epoch)
    logger.info(
        f'Train Epoch: {epoch}\t Acc: {train_acc:.4} \tLoss: {train_loss:.6f}')
def train(model,
          args,
          device,
          train_loader,
          optimizer,
          epoch,
          writer,
          scale_grad_inverse_sparsity,
          loss_fn,
          method,
          agreement_threshold,
          scheduler,
          log_suffix=''):
    model.train()
    losses = []
    correct = 0
    example_count = 0
    batch_idx = 0

    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, args.n_dims)
        optimizer.zero_grad()
        y_pred = model(images)
        if agreement_threshold > 0.0:
            # The "batch_size" in this function refers to the batch size per env
            # Since we treat every example as one env, we should set the parameter
            # n_agreement_envs equal to batch size
            mean_loss, masks = and_mask_utils.get_grads(
                agreement_threshold=agreement_threshold,
                batch_size=1,
                loss_fn=loss_fn,
                n_agreement_envs=args.batch_size,
                params=optimizer.param_groups[0]['params'],
                output=y_pred,
                target=labels,
                method=args.method,
                scale_grad_inverse_sparsity=scale_grad_inverse_sparsity,
            )
        else:
            mean_loss = loss_fn(y_pred, labels)
            mean_loss.backward()

        mean_total_loss = 0

        if args.l1_coef > 0.0:
            add_l1_grads(args.l1_coef, optimizer.param_groups)
            mean_total_loss += add_l1(args.l1_coef, optimizer.param_groups)
        if args.l2_coef > 0.0:
            add_l2_grads(args.l2_coef, optimizer.param_groups)
            mean_total_loss += add_l2(args.l2_coef, optimizer.param_groups)

        mean_total_loss += mean_loss.item()

        optimizer.step()

        losses.append(mean_total_loss)
        correct += count_correct(y_pred, labels)
        example_count += y_pred.shape[0]
        batch_idx += 1

    scheduler.step()

    # Logging
    train_loss = np.mean(losses)
    train_acc = correct / (example_count + 1e-10)
    writer.add_scalar(f'weight/norm', train_loss, epoch)
    writer.add_scalar(f'mean_loss/train{log_suffix}', train_loss, epoch)
    writer.add_scalar(f'acc/train{log_suffix}', train_acc, epoch)
    logger.info(
        f'Train Epoch: {epoch}\t Acc: {train_acc:.4} \tLoss: {train_loss:.6f}')