def train_epoch(model, device, train_loader, criterion, optimizer, k, warm_up, lr, writer, epoch): # training phase print("Training Progress:") metrics = Metrics(args.dataset, train=True) model.train() for batch_idx, (batch, labels) in enumerate(tqdm(train_loader)): iteration = epoch * len(train_loader) + batch_idx optimizer.zero_grad() batch = batch.type(torch.FloatTensor).to(device) labels = labels.to(device) outputs = model(batch) loss = criterion(outputs, labels) loss.backward() optimizer.step() # warm up if k <= warm_up: k = learning_rate_scheduler(optimizer, k, warm_up, lr) # Batch metrics metrics.update_metrics(outputs, labels, loss) if iteration % 10 == 0: metrics.write_to_tensorboard(writer, iteration) # Epoch metrics final_metrics = metrics.get_epoch_metrics() return (final_metrics, k)
def validate_epoch(model, device, validation_loader, criterion, scheduler, writer, epoch): with torch.no_grad(): # validation phase print("Validation Progress:") metrics = Metrics(args.dataset, train=False) model.eval() for batch_idx, (batch, labels) in enumerate(tqdm(validation_loader)): batch = batch.type(torch.FloatTensor).to(device) labels = labels.to(device) outputs = model(batch) loss = criterion(outputs, labels) # Batch metrics metrics.update_metrics(outputs, labels, loss) # Epoch metrics final_metrics = metrics.get_epoch_metrics() metrics.write_to_tensorboard(writer, epoch) scheduler.step(final_metrics["Loss"]) return final_metrics