def train(loader, num_classes, device, net, optimizer, criterion):
    num_samples = 0
    running_loss = 0

    metrics = Metrics(range(num_classes))

    net.train()
    for images1,images2, masks  in tqdm.tqdm(loader):
        images1 = images1.to(device)
        images2 = images2.to(device)
        masks = masks.to(device)

        assert images1.size()[2:] == images2.size()[2:] == masks.size()[2:], "resolutions for images and masks are in sync"

        num_samples += int(images1.size(0))
        #print(num_samples)
        optimizer.zero_grad()
        outputs = net(images1,images2)
        #print(outputs.shape,masks.shape)
        #masks = masks.view(batch_size,masks.size()[2],masks.size()[3])
        #print(masks.shape)
        #masks = masks.squeeze()
        
        assert outputs.size()[2:] == masks.size()[2:], "resolutions for predictions and masks are in sync"
        assert outputs.size()[1] == num_classes, "classes for predictions and dataset are in sync"

        loss = criterion(outputs, masks.float()) ##BCELoss
        #loss = criterion(outputs, masks.long())
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        for mask, output in zip(masks, outputs):
            prediction = output.detach()
            metrics.add(mask, prediction)

    assert num_samples > 0, "dataset contains training images and labels"

    return {
        "loss": running_loss / num_samples,
        "precision": metrics.get_precision(),
        "recall": metrics.get_recall(),
        "f_score": metrics.get_f_score(),
        "oa":metrics.get_oa()
    }
def validate(loader, num_classes, device, net, criterion):
    num_samples = 0
    running_loss = 0

    metrics = Metrics(range(num_classes))

    net.eval()

    for images1, images2, masks,  in tqdm.tqdm(loader):
        images1 = images1.to(device)
        images2 = images2.to(device)
        masks = masks.to(device)

        assert images1.size()[2:] == images2.size()[2:] == masks.size()[2:], "resolutions for images and masks are in sync"

        num_samples += int(images1.size(0))

        outputs = net(images1,images2)

        assert outputs.size()[2:] == masks.size()[2:], "resolutions for predictions and masks are in sync"
        assert outputs.size()[1] == num_classes, "classes for predictions and dataset are in sync"

        loss = criterion(outputs, masks.float())  ##BCELoss
        #loss = criterion(outputs, masks.long())
        running_loss += loss.item()

        for mask, output in zip(masks, outputs):
            metrics.add(mask, output)

    assert num_samples > 0, "dataset contains validation images and labels"

    return {
        "loss": running_loss / num_samples,
        "precision": metrics.get_precision(),
        "recall": metrics.get_recall(),
        "f_score": metrics.get_f_score(),
        "oa":metrics.get_oa()
    }