def val(net, dataloader_):
    global highest_iou
    net.eval()
    iou_meter_val = AverageValueMeter()
    loss_meter_val = AverageValueMeter()
    iou_meter_val.reset()
    for i, (img, mask, _) in tqdm(enumerate(dataloader_)):
        (img, mask) = (img.cuda(), mask.cuda()) if (torch.cuda.is_available()
                                                    and use_cuda) else (img,
                                                                        mask)
        pred_val = net(img)
        loss_val = criterion(pred_val, mask.squeeze(1))
        loss_meter_val.add(loss_val.item())
        iou_val = iou_loss(pred2segmentation(pred_val),
                           mask.squeeze(1).float(), class_number)[1]
        iou_meter_val.add(iou_val)
        if i % val_print_frequncy == 0:
            showImages(board_val_image, img, mask, pred2segmentation(pred_val))

    board_loss.plot('val_iou_per_epoch', iou_meter_val.value()[0])
    board_loss.plot('val_loss_per_epoch', loss_meter_val.value()[0])
    net.train()
    if highest_iou < iou_meter_val.value()[0]:
        highest_iou = iou_meter_val.value()[0]
        torch.save(
            net.state_dict(), 'checkpoint/modified_ENet_%.3f_%s.pth' %
            (iou_meter_val.value()[0], 'equal_' + str(Equalize)))
        print('The highest IOU is:%.3f' % iou_meter_val.value()[0],
              'Model saved.')
def train():
    net.train()
    iou_meter = AverageValueMeter()
    loss_meter = AverageValueMeter()
    for epoch in range(max_epoch):
        iou_meter.reset()
        loss_meter.reset()
        if epoch % 5 == 0:
            for param_group in optimiser.param_groups:
                param_group['lr'] = param_group['lr'] * (0.95**(epoch // 10))
                print('learning rate:', param_group['lr'])

        for i, (img, mask, _) in tqdm(enumerate(train_loader)):
            (img, mask) = (img.cuda(),
                           mask.cuda()) if (torch.cuda.is_available()
                                            and use_cuda) else (img, mask)
            optimiser.zero_grad()
            pred = net(img)
            loss = criterion(pred, mask.squeeze(1))
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(net.parameters(), 1e-3)
            optimiser.step()
            loss_meter.add(loss.item())
            iou = iou_loss(pred2segmentation(pred),
                           mask.squeeze(1).float(), class_number)[1]
            loss_meter.add(loss.item())
            iou_meter.add(iou)

            if i % train_print_frequncy == 0:
                showImages(board_train_image, img, mask,
                           pred2segmentation(pred))

        board_loss.plot('train_iou_per_epoch', iou_meter.value()[0])
        board_loss.plot('train_loss_per_epoch', loss_meter.value()[0])

        val(net, val_loader)