def train_epoch(model, optimizer, epoch, train_loader, device, model_type): avg_loss = 0.0 model.train() for i, mini_batch in enumerate(train_loader): if model_type == "prediction": inputs, targets, mask = mini_batch else: inputs, targets, mask, text, text_mask = mini_batch text = text.to(device) text_mask = text_mask.to(device) inputs = inputs.to(device) targets = targets.to(device) mask = mask.to(device) batch_size = inputs.shape[0] optimizer.zero_grad() if model_type == "prediction": initial_hidden = model.init_hidden(batch_size, device) y_hat, state = model.forward(inputs, initial_hidden) else: initial_hidden, window_vector, kappa = model.init_hidden( batch_size, device) y_hat, state, window_vector, kappa = model.forward( inputs, text, text_mask, initial_hidden, window_vector, kappa) loss = compute_nll_loss(targets, y_hat, mask) # Output gradient clipping y_hat.register_hook(lambda grad: torch.clamp(grad, -100, 100)) loss.backward() # LSTM params gradient clipping if model_type == "prediction": nn.utils.clip_grad_value_(model.parameters(), 10) else: nn.utils.clip_grad_value_(model.lstm_1.parameters(), 10) nn.utils.clip_grad_value_(model.lstm_2.parameters(), 10) nn.utils.clip_grad_value_(model.lstm_3.parameters(), 10) nn.utils.clip_grad_value_(model.window_layer.parameters(), 10) optimizer.step() avg_loss += loss.item() # print every 10 mini-batches if i % 10 == 0: print("\t[MiniBatch: {:3d}] loss: {:.3f}".format( i + 1, loss / batch_size)) avg_loss /= len(train_loader.dataset) return avg_loss
def validation(model, valid_loader, device, epoch, model_type): avg_loss = 0.0 model.eval() with torch.no_grad(): for i, mini_batch in enumerate(valid_loader): if model_type == "prediction": inputs, targets, mask = mini_batch else: inputs, targets, mask, text, text_mask = mini_batch text = text.to(device) text_mask = text_mask.to(device) inputs = inputs.to(device) targets = targets.to(device) mask = mask.to(device) batch_size = inputs.shape[0] if model_type == "prediction": initial_hidden = model.init_hidden(batch_size, device) y_hat, state = model.forward(inputs, initial_hidden) else: initial_hidden, window_vector, kappa = model.init_hidden( batch_size, device) y_hat, state, window_vector, kappa = model.forward( inputs, text, text_mask, initial_hidden, window_vector, kappa) loss = compute_nll_loss(targets, y_hat, mask) avg_loss += loss.item() # print every 10 mini-batches if i % 10 == 0: print('[{:d}, {:5d}] loss: {:.3f}'.format( epoch + 1, i + 1, loss / batch_size)) avg_loss /= len(valid_loader.dataset) return avg_loss