Пример #1
0
def train(model, criterion, optimizer, n_epochs, train_loader, test_loader=None, scheduler=None, noise_rate=0.0):
    train_noise_generator = Noise(train_loader, noise_rate=noise_rate)
    test_noise_generator = Noise(test_loader, noise_rate=noise_rate) if test_loader is not None else None

    train_loss_per_epoch = []
    test_loss_per_epoch = []
    correct_per_epoch = []
    incorrect_per_epoch = []
    memorized_per_epoch = []

    for _ in tqdm(range(n_epochs)):
        # activate train mode
        model.train()
        train_loss = 0
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            targets_with_noise = train_noise_generator.symmetric_noise(targets, batch_idx)
            # to(device) copies data from CPU to GPU
            inputs, targets = inputs.to(device), targets_with_noise.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * targets.size(0)
        train_loss_per_epoch.append(train_loss / len(train_loader.dataset))

        if test_loader is not None:
            model.eval()
            test_loss = 0
            with torch.no_grad():
                correct, incorrect, memorized, total = 0, 0, 0, 0
                with torch.no_grad():
                    for batch_idx, (inputs, targets) in enumerate(test_loader):
                        original_targets = targets.to(device)
                        targets_with_noise = test_noise_generator.symmetric_noise(targets, batch_idx)
                        inputs, targets = inputs.to(device), targets_with_noise.to(device)
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)

                        _, predicted = outputs.max(1)
                        total += targets.size(0)
                        correct_idx = predicted.eq(original_targets)
                        memorized_idx = ((predicted != original_targets) & (predicted == targets))
                        incorrect_idx = ((predicted != original_targets) & (predicted != targets))
                        correct += correct_idx.sum().item()
                        memorized += memorized_idx.sum().item()
                        incorrect += incorrect_idx.sum().item()
                        test_loss += loss.item() * targets.size(0)

                test_loss_per_epoch.append(test_loss / total)
                correct_per_epoch.append(correct / total)
                memorized_per_epoch.append(memorized / total)
                incorrect_per_epoch.append(incorrect / total)

        # anneal learning rate
        scheduler.step()

    return (train_loss_per_epoch, test_loss_per_epoch,
            correct_per_epoch, memorized_per_epoch, incorrect_per_epoch,)