def evalNetQuant(netWeights, batchSz=4, isPM=False):
    """Evaluate the network quantitatively over the entire validation set"""

    ## Load Model for training:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = centerEsti()
    model.load_state_dict(netWeights)
    model.to(device)

    valSet = getSetLoader(SetType.test, batchSz)

    totPsnr = 0
    totSsim = 0

    model.eval()
    valProg = tqdm(valSet, desc='Test', leave=False, ncols=100)
    with torch.no_grad():
        for y, x in valProg:
            if not isPM:
                y = x.mean(axis=1)  # TODO: Verify if correct!

            nFrames = x.shape[1]
            midFrame = int(np.floor(nFrames / 2))

            inputs, targets = y.to(device), x[:, midFrame, :, :].to(device)
            outputs = model(inputs)

            ## Loss Calculation:
            curPsnr, curSsim = evalImgQuant(targets, outputs)

            totPsnr += curPsnr
            totSsim += curSsim

    return totPsnr.to("cpu").numpy() / len(valSet), totSsim.to(
        "cpu").numpy() / len(valSet)
Beispiel #2
0
def evalNetQuant(netWeights, batchSz=4, isPM=False):
    """Evaluate the network quantitatively over the entire validation set"""

    ## Load Model for training:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = centerEsti()
    model.load_state_dict(netWeights)
    model.to(device)

    valSet = getSetLoader(SetType.test, batchSz)

    totPsnr = 0
    totSsim = 0
    percSimAlex = 0
    percSimVgg = 0

    lossAlex = lpips.LPIPS(net='alex').to(device)  # best forward scores
    lossVgg = lpips.LPIPS(net='vgg').to(device)  # closer to "traditional" perceptual loss, when used for optimization

    transRes = lambda x: x.to("cpu").numpy() / len(valSet)
    model.eval()
    valProg = tqdm(valSet, desc='Test', leave=False, ncols=100)
    with torch.no_grad():
        for y, x in valProg:
            if not isPM:
                y = x.mean(axis=1)  # TODO: Verify if correct!

            nFrames = x.shape[1]
            midFrame = int(np.floor(nFrames / 2))

            inputs, targets = y.to(device), x[:, midFrame, :, :].to(device)
            outputs = model(inputs)

            ## Loss Calculation:
            curPsnr, curSsim = evalImgQuant(targets, outputs)

            totPsnr += curPsnr
            totSsim += curSsim

            percSimAlex += percSim(targets[0], outputs[0], lossAlex)
            percSimVgg += percSim(targets[0], outputs[0], lossVgg)

    return transRes(totPsnr), transRes(totSsim), transRes(percSimAlex), transRes(percSimVgg)
Beispiel #3
0
def testLoss(netWeights, percWeight=3 / 4, batchSz=8):
    """return the loss of the provided network over the test-set"""

    ## Load Model for training:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = centerEsti()
    model.load_state_dict(netWeights)
    model.to(device)

    valSet = getSetLoader(SetType.val, batchSz)

    totLoss = 0
    totL2Loss = 0
    totPercLoss = 0

    model.eval()
    valProg = tqdm(valSet, desc='validation', leave=False, ncols=100)
    with torch.no_grad():
        for y, x in valProg:
            nFrames = x.shape[1]
            midFrame = int(np.floor(nFrames / 2))

            inputs, targets = y.to(device), x[:, midFrame, :, :].to(device)
            outputs = model(inputs)

            ## Loss Calculation:
            l2Loss = lossFuncs.l2Loss(outputs, targets)
            percLoss = lossFuncs.percLoss(outputs, targets)
            loss = l2Loss + percWeight * percLoss

            totL2Loss += l2Loss.item()
            totPercLoss += percLoss.item()
            totLoss += loss.item()
            valProg.set_description(f'validation loss {loss.item():.2}')

    return totL2Loss / len(valSet), totPercLoss / len(valSet), totLoss / len(
        valSet)
Beispiel #4
0
def train(nEpochs,
          optParams,
          batchSz=8,
          valEpochFact=1,
          percWeight=3 / 4,
          checkpoint=None,
          outpath=None):
    """The training routine for the video-from-image network

        Params:
         nEpochs (float): the number of desired epochs to run the training
         optParams (dict): a dictionary containing the required optimizer arguments
         batchSz (int): the size of the batch
         valEpochFact (int): the number of training iterations to run before every validation iteration
         percWeight (float): the weight of the perceptual loss from the total-loss
         checkpoint (dict): a checkpoint of the network for initialization.
    """

    ## Load Model for training:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = centerEsti()
    if checkpoint is not None:
        checkDict = torch.load(checkpoint)
        model.load_state_dict(checkDict['weights'])
        e0 = checkDict['epoch'] + 1
    else:
        e0 = 0

    torch.cuda.empty_cache()
    model.to(device)

    ## Set Training:
    cols = pd.MultiIndex.from_product([['Training', 'Validation'],
                                       ['L2', 'Perceptual', 'Total']])
    idx = np.arange(start=e0, stop=e0 + nEpochs, dtype=int)
    lossTable = pd.DataFrame(np.zeros((nEpochs, 6)), columns=cols, index=idx)

    optimizer = torch.optim.Adam(model.parameters(), **optParams)
    trainSet = getSetLoader(SetType.train, batchSz)
    valSet = getSetLoader(SetType.val, batchSz)

    valMinLoss = np.inf
    for epoch in tqdm(idx, desc='epochs', ncols=100):
        totLoss = 0
        totL2Loss = 0
        totPercLoss = 0
        model.train()
        trainProg = tqdm(trainSet, desc='training', leave=False, ncols=100)
        for y, x in trainProg:
            nFrames = x.shape[1]
            midFrame = int(np.floor(nFrames / 2))

            ## Forward Pass:
            optimizer.zero_grad()
            inputs, targets = y.to(device), x[:, midFrame, :, :].to(device)
            outputs = model(inputs)

            ## Loss Calculation:
            l2Loss = lossFuncs.l2Loss(outputs, targets)
            percLoss = lossFuncs.percLoss(outputs, targets)
            loss = l2Loss + percWeight * percLoss

            ## Back Propagation
            loss.backward()
            optimizer.step()

            ## Accumulate Loss:
            totL2Loss += l2Loss.item()
            totPercLoss += percLoss.item()
            totLoss += loss.item()
            trainProg.set_description(f'train loss {loss.item():.2}')

        lossTable.loc[epoch, ('Training', 'L2')] = totL2Loss / len(trainSet)
        lossTable.loc[epoch,
                      ('Training', 'Perceptual')] = totPercLoss / len(trainSet)
        lossTable.loc[epoch, ('Training', 'Total')] = totLoss / len(trainSet)

        if epoch % valEpochFact == 0:
            torch.cuda.empty_cache()

            totLoss = 0
            totL2Loss = 0
            totPercLoss = 0
            model.eval()
            valProg = tqdm(valSet, desc='validation', leave=False, ncols=100)
            with torch.no_grad():
                for y, x in valProg:

                    nFrames = x.shape[1]
                    midFrame = int(np.floor(nFrames / 2))

                    inputs, targets = y.to(device), x[:, midFrame, :, :].to(
                        device)
                    outputs = model(inputs)

                    ## Loss Calculation:
                    l2Loss = lossFuncs.l2Loss(outputs, targets)
                    percLoss = lossFuncs.percLoss(outputs, targets)
                    loss = l2Loss + percWeight * percLoss

                    totL2Loss += l2Loss.item()
                    totPercLoss += percLoss.item()
                    totLoss += loss.item()
                    valProg.set_description(
                        f'validation loss {loss.item():.2}')

            lossTable.loc[epoch,
                          ('Validation', 'L2')] = totL2Loss / len(valSet)
            lossTable.loc[epoch, ('Validation',
                                  'Perceptual')] = totPercLoss / len(valSet)
            lossTable.loc[epoch,
                          ('Validation', 'Total')] = totLoss / len(valSet)

            if totLoss < valMinLoss:  # Save Network's weights:
                checkpoint = {
                    "epoch": epoch,
                    "lr": optimizer.param_groups[0]['lr'],
                    "weights": model.state_dict()
                }
                if outpath:
                    torch.save(
                        checkpoint,
                        os.path.join(outpath, f'trainData_e{epoch}.pth'))
                else:
                    torch.save(checkpoint, f'trainData_e{epoch}.pth')
                valMinLoss = totLoss

    return lossTable, checkpoint