tb_hist.plot_weights_hist()
            tb_hist.plot_grads_hist()
            tb_hist.plot_bias_hist()
            tb_hist.plot_outputs_hist()

            # validation loss
            with torch.no_grad():
                running_val_loss = 0.0
                running_val_steps = 0.0
                for data in tqdm(valLoader, desc='val-step'):
                    input_ids = data['input_ids'].long().to(device)
                    segments = data['segments'].long().to(device)
                    targets = data['targets'].float().to(device)

                    outputs = model(input_ids=input_ids,
                                    token_type_ids=segments)
                    targets = targets.view(-1, 1)  # match output shape

                    val_loss = criterion(outputs, targets)
                    running_val_loss += val_loss.item()
                    running_val_steps += 1.0

                writer.add_sclar("val_loss",
                                 running_val_loss / running_val_steps,
                                 interval_count)

            running_loss = 0.0
            running_steps = 0.0
            interval_count += 1

print("Training complete.")