args.train_size, 0))

test_loader = torch.utils.data.DataLoader(test_set,
                                          batch_size=args.test_batch,
                                          shuffle=True)

# configure optimizer
optimizer = None
if args.optimizer == 'Adam':
    optimizer = optim.Adam(refiner.parameters(), lr=args.lr)
elif args.optimizer == 'RMSprop':
    optimizer = optim.RMSprop(refiner.parameters(), lr=args.lr)

# initialize monitor and logger
plotter = monitor.Plotter(args.name)
logger = monitor.Logger(args.log, args.name)

# train
classifier.eval()
refiner.train()
cnt = 0
for epoch in range(args.epochs):
    for data, target in train_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        refined_data = refiner(data)
        # loss = similarity distance + efficacy loss
        loss = (1 - args.lbd) * delta(classifier, refined_data,
                                      data) + args.lbd * eta(
Ejemplo n.º 2
0
def run(args):

    if not torch.cuda.is_available():
        raise RuntimeError("Not support cpu version currently...")

    torch.backends.cudnn.benchmark = True
    ## Visom Visualization
    logger = monitor.Logger(args.outDir)

    # Define Networks
    network = {
        'netFeatCoarse': model.FeatureExtractor(),
        'netCorr': model.CorrNeigh(args.kernelSize),
        'netFlowCoarse': model.NetFlowCoarse(args.kernelSize),
        'netMatch': model.NetMatchability(args.kernelSize),
    }

    for key in list(network.keys()):
        network[key].cuda()
        typeData = torch.cuda.FloatTensor

    # Network initialization
    if args.resumePth:
        param = torch.load(args.resumePth)
        msg = 'Loading pretrained model from {}'.format(args.resumePth)
        print(msg)

        for key in list(param.keys()):
            try:
                network[key].load_state_dict(param[key])
            except:
                print('{} and {} weight not compatible...'.format(key, key))

    # Optimizers & LR schedulers

    if args.trainMode == 'flow':
        optimizer = [
            torch.optim.Adam(itertools.chain(*[
                network['netFeatCoarse'].parameters(), network['netCorr'].
                parameters(), network['netFlowCoarse'].parameters()
            ]),
                             lr=args.lr,
                             betas=(0.5, 0.999))
        ]

        LossFunction = computeLossNoMatchability
        trainModule = ['netFeatCoarse', 'netCorr', 'netFlowCoarse']

    if args.trainMode == 'flow+match':
        optimizer = [
            torch.optim.Adam(itertools.chain(*[
                network['netFeatCoarse'].parameters(), network['netCorr'].
                parameters(), network['netFlowCoarse'].parameters()
            ]),
                             lr=args.lr,
                             betas=(0.5, 0.999)),
            torch.optim.Adam(
                itertools.chain(*[network['netMatch'].parameters()]),
                lr=args.lr,
                betas=(0.5, 0.999))
        ]

        LossFunction = computeLossMatchability
        trainModule = ['netFeatCoarse', 'netCorr', 'netFlowCoarse', 'netMatch']

    if args.trainMode == 'match':

        optimizer = [
            torch.optim.Adam(
                itertools.chain(*[network['netFlowCoarse'].parameters()]),
                lr=args.lr,
                betas=(0.5, 0.999))
        ]

        LossFunction = computeLossMatchability
        trainModule = ['netFlowCoarse']

    if args.trainMode == 'grad-match':

        optimizer = [
            torch.optim.Adam(itertools.chain(*[
                network['netFeatCoarse'].parameters(), network['netCorr'].
                parameters(), network['netFlowCoarse'].parameters()
            ]),
                             lr=args.lr,
                             betas=(0.5, 0.999))
        ]

        LossFunction = computeGradLossNoMatchability
        trainModule = ['netFeatCoarse', 'netCorr', 'netFlowCoarse']

    ## Size Bs * 1 * (imgSize - 2 * margin) * (imgSize - 2 * margin)
    maskMargin = torch.ones(args.batchSize * 2, 1,
                            args.imgSize - 2 * args.margin,
                            args.imgSize - 2 * args.margin).type(typeData)
    maskMargin = F.pad(maskMargin,
                       (args.margin, args.margin, args.margin, args.margin),
                       "constant", 0)

    ## Pixel shift loss of standard L1 loss, pixel shift loss can be used to handle the change of light condition
    if args.LrLoss == 'L1':
        ssim = None
        if args.trainPixelShift:
            LrLoss = model.L1PixelShift
        else:
            LrLoss = model.L1PixelWise
    else:
        ssim = ssimLoss.SSIM()
        if args.trainPixelShift:
            LrLoss = model.SSIMPixelShift
        else:
            LrLoss = model.SSIM

    if not os.path.exists(args.outDir):
        os.mkdir(args.outDir)
    outNet = os.path.join(args.outDir, 'BestModel.pth')

    # Train data loader

    trainT = dataloader.trainTransform

    trainLoader = dataloader.TrainDataLoader(args.trainImgDir, trainT,
                                             args.batchSize, args.imgSize)

    # Set up for real validation
    df = pd.read_csv(args.valCSV, dtype=str) if args.valCSV else None

    if args.inPklCoarse:
        with open(args.inPklCoarse, 'rb') as f:
            inPklCoarse = pickle.load(f)

    ## define the grid
    gridY = torch.linspace(-1, 1, steps=args.imgSize).view(1, -1, 1, 1).expand(
        1, args.imgSize, args.imgSize, 1)
    gridX = torch.linspace(-1, 1, steps=args.imgSize).view(1, 1, -1, 1).expand(
        1, args.imgSize, args.imgSize, 1)
    grid = torch.cat((gridX, gridY), dim=3).cuda()

    ## define loss and validation criteria
    bestPrec = 0
    LastUpdate = 0

    index = np.arange(args.batchSize * 2)
    indexRoll = np.roll(index, args.batchSize)

    index = torch.from_numpy(index).cuda()
    indexRoll = torch.from_numpy(indexRoll).cuda()

    ###### Standard Training ######
    for epoch in range(args.nEpochs):

        trainLossLr = 0
        trainLossCycle = 0
        trainLossMatch = 0
        trainLossMatchCycle = 0
        trainLossGrad = 0

        ## switch related module to train
        for key in list(network.keys()):
            network[key].eval()

        for key in trainModule:
            network[key].train()

        for i, batch in enumerate(trainLoader):
            # Set model input

            I = torch.cat((batch['I1'].cuda(), batch['I2'].cuda()), dim=0)
            # Forward
            for sub_optimizer in optimizer:
                sub_optimizer.zero_grad()

            # feature map B * 256 * W * H

            lossLr, lossCycle, lossMatch, lossGrad, loss = LossFunction(
                network, I, indexRoll, grid, maskMargin, args, ssim, LrLoss)
            loss.backward()

            for sub_optimizer in optimizer:
                sub_optimizer.step()

            # Save loss
            trainLossLr += lossLr
            trainLossCycle += lossCycle
            trainLossMatch += lossMatch
            trainLossGrad += lossGrad

            # Print information
            if i % 50 == 49:
                msg = '\n{}\tEpoch {:d}, Batch {:d}, Lr Loss: {:.9f}, Cycle Loss : {:.9f}, Matchability Loss {:.9f}, Gradient Loss {:.9f}'.format(
                    time.ctime(), epoch, i + 1, trainLossLr / (i + 1),
                    trainLossCycle / (i + 1), trainLossMatch / (i + 1),
                    trainLossGrad / (i + 1))
                print(msg)

        if df is not None:
            precFine = validation.validation(df, args.valImgDir, inPklCoarse,
                                             network, args.trainMode)
        else:
            precFine = np.zeros(8)

        # Save train loss for one epoch
        trainLossLr = trainLossLr / len(trainLoader)
        trainLossCycle = trainLossCycle / len(trainLoader)
        trainLossMatch = trainLossMatch / len(trainLoader)
        trainLossGrad = trainLossGrad / len(trainLoader)

        log_loss = {
            'epoch': epoch,
            'trainLossLr': trainLossLr,
            'trainLossCycle': trainLossCycle,
            'trainLossMatch': trainLossMatch,
            'trainLossGrad': trainLossGrad,
            'valPrec@8': precFine[4]
        }

        msg = '\n{} Last Update {:d}---> Epoch {:d}, Train Lr Loss : {:.9f}, Train Cycle Loss : {:.9f}, Train Match Loss : {:.9f}, Train Grad Loss : {:.9f}, valPrec@8 : {:.9f} (Best {:.9f})-----'.format(
            time.ctime(), LastUpdate, epoch, trainLossLr, trainLossCycle,
            trainLossMatch, trainLossGrad, precFine[4], bestPrec)
        print(msg)

        valPrecEpoch = precFine[3] if 'fine' in args.trainMode else precFine[4]

        if df is not None and valPrecEpoch > bestPrec:
            msg = '\n{}\t---> Epoch {:d}, VAL Prec@8 IMPROVED: {:.9f} -- > {:.9f}-----'.format(
                time.ctime(), epoch, bestPrec, valPrecEpoch)
            print(msg)
            bestPrec = valPrecEpoch
            pth = {}
            for key in list(network.keys()):
                pth[key] = network[key].state_dict()

            torch.save(pth, outNet)
            LastUpdate = epoch

        elif df is None and epoch % args.epochSaveModel == args.epochSaveModel - 1:
            outNet = os.path.join(
                args.outDir,
                'checkPoint_Epoch{:d}_Lr{:.3f}_Lf{:.5f}_Lm{:.5f}_Lg{:.5f}'.
                format(epoch, trainLossLr, trainLossCycle, trainLossMatch,
                       trainLossGrad))
            pth = {}
            for key in list(network.keys()):
                pth[key] = network[key].state_dict()

            torch.save(pth, outNet)

            print('Save model to {}'.format(outNet))

    if df is not None:
        finalOut = os.path.join(args.outDir,
                                'BestModel@8_{:.3f}.pth'.format(bestPrec))
        cmd = 'mv {} {}'.format(outNet, finalOut)
        os.system(cmd)