def train(model, criterion, optimizer, n_epochs, train_loader, test_loader=None, scheduler=None, noise_rate=0.0): train_noise_generator = Noise(train_loader, noise_rate=noise_rate) test_noise_generator = Noise(test_loader, noise_rate=noise_rate) if test_loader is not None else None train_loss_per_epoch = [] test_loss_per_epoch = [] correct_per_epoch = [] incorrect_per_epoch = [] memorized_per_epoch = [] for _ in tqdm(range(n_epochs)): # activate train mode model.train() train_loss = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): targets_with_noise = train_noise_generator.symmetric_noise(targets, batch_idx) # to(device) copies data from CPU to GPU inputs, targets = inputs.to(device), targets_with_noise.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() * targets.size(0) train_loss_per_epoch.append(train_loss / len(train_loader.dataset)) if test_loader is not None: model.eval() test_loss = 0 with torch.no_grad(): correct, incorrect, memorized, total = 0, 0, 0, 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): original_targets = targets.to(device) targets_with_noise = test_noise_generator.symmetric_noise(targets, batch_idx) inputs, targets = inputs.to(device), targets_with_noise.to(device) outputs = model(inputs) loss = criterion(outputs, targets) _, predicted = outputs.max(1) total += targets.size(0) correct_idx = predicted.eq(original_targets) memorized_idx = ((predicted != original_targets) & (predicted == targets)) incorrect_idx = ((predicted != original_targets) & (predicted != targets)) correct += correct_idx.sum().item() memorized += memorized_idx.sum().item() incorrect += incorrect_idx.sum().item() test_loss += loss.item() * targets.size(0) test_loss_per_epoch.append(test_loss / total) correct_per_epoch.append(correct / total) memorized_per_epoch.append(memorized / total) incorrect_per_epoch.append(incorrect / total) # anneal learning rate scheduler.step() return (train_loss_per_epoch, test_loss_per_epoch, correct_per_epoch, memorized_per_epoch, incorrect_per_epoch,)