Example #1
0
def train():
    torch.manual_seed(args.seed)

    model = networks.__dict__[args.netName](channel=args.channels,
                            filter_size = args.filter_size ,
                            timestep=args.time_step,
                            training=True)
    if args.use_cuda:
        print("Turn the model into CUDA")
        model = model.cuda()

    if not args.SAVED_MODEL==None:
        args.SAVED_MODEL ='/content/DAIN/model_weights'+ args.SAVED_MODEL + "/best" + ".pth"
        # args.SAVED_MODEL ='./model_weights/best.pth'
        print("Fine tuning on " +  args.SAVED_MODEL)
        if not  args.use_cuda:
            pretrained_dict = torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage)
            # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage))
        else:
            pretrained_dict = torch.load(args.SAVED_MODEL)
            # model.load_state_dict(torch.load(args.SAVED_MODEL))
        #print([k for k,v in      pretrained_dict.items()])

        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)
        pretrained_dict = None

    if type(args.datasetName) == list:
        train_sets, test_sets = [],[]
        for ii, jj in zip(args.datasetName, args.datasetPath):
            tr_s, te_s = datasets.__dict__[ii](jj, split = args.dataset_split,single = args.single_output, task = args.task)
            train_sets.append(tr_s)
            test_sets.append(te_s)
        train_set = torch.utils.data.ConcatDataset(train_sets)
        test_set = torch.utils.data.ConcatDataset(test_sets)
    else:
        train_set, test_set = datasets.__dict__[args.datasetName](args.datasetPath)
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size = args.batch_size,
        sampler=balancedsampler.RandomBalancedSampler(train_set, int(len(train_set) / args.batch_size )),
        num_workers= args.workers, pin_memory=True if args.use_cuda else False)

    val_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size,
                                             num_workers=args.workers, pin_memory=True if args.use_cuda else False)
    print('{} samples found, {} train samples and {} test samples '.format(len(test_set)+len(train_set),
                                                                           len(train_set),
                                                                           len(test_set)))


    # if not args.lr == 0:
    print("train the interpolation net")
    optimizer = torch.optim.Adamax([
                {'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr},
                {'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr},
                {'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr},
                {'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr},
                {'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr},
                {'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr},
                {'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr}
            ],
                lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)


    scheduler = ReduceLROnPlateau(optimizer, 'min',factor=args.factor, patience=args.patience,verbose=True)

    print("*********Start Training********")
    print("LR is: "+ str(float(optimizer.param_groups[0]['lr'])))
    print("EPOCH is: "+ str(int(len(train_set) / args.batch_size )))
    print("Num of EPOCH is: "+ str(args.numEpoch))
    def count_network_parameters(model):

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        N = sum([numpy.prod(p.size()) for p in parameters])

        return N
    print("Num. of model parameters is :" + str(count_network_parameters(model)))
    if hasattr(model,'flownets'):
        print("Num. of flow model parameters is :" +
              str(count_network_parameters(model.flownets)))
    if hasattr(model,'initScaleNets_occlusion'):
        print("Num. of initScaleNets_occlusion model parameters is :" +
              str(count_network_parameters(model.initScaleNets_occlusion) +
                  count_network_parameters(model.initScaleNets_occlusion1) +
        count_network_parameters(model.initScaleNets_occlusion2)))
    if hasattr(model,'initScaleNets_filter'):
        print("Num. of initScaleNets_filter model parameters is :" +
              str(count_network_parameters(model.initScaleNets_filter) +
                  count_network_parameters(model.initScaleNets_filter1) +
        count_network_parameters(model.initScaleNets_filter2)))
    if hasattr(model, 'ctxNet'):
        print("Num. of ctxNet model parameters is :" +
              str(count_network_parameters(model.ctxNet)))
    if hasattr(model, 'depthNet'):
        print("Num. of depthNet model parameters is :" +
              str(count_network_parameters(model.depthNet)))
    if hasattr(model,'rectifyNet'):
        print("Num. of rectifyNet model parameters is :" +
              str(count_network_parameters(model.rectifyNet)))

    training_losses = AverageMeter()
    auxiliary_data = []
    saved_total_loss = 10e10
    saved_total_PSNR = -1
    ikk = 0
    for kk in optimizer.param_groups:
        if kk['lr'] > 0:
            ikk = kk
            break

    for t in range(args.numEpoch):
        print("The id of this in-training network is " + str(args.uid))
        print(args)
        #Turn into training mode
        model = model.train()

        for i, (X0_half,X1_half, y_half) in enumerate(train_loader):

            if i >= int(len(train_set) / args.batch_size ):
                #(0 if t == 0 else EPOCH):#
                break

            X0_half = X0_half.cuda() if args.use_cuda else X0_half
            X1_half = X1_half.cuda() if args.use_cuda else X1_half
            y_half = y_half.cuda() if args.use_cuda else y_half

            X0 = Variable(X0_half, requires_grad= False)
            X1 = Variable(X1_half, requires_grad= False)
            y  = Variable(y_half,requires_grad= False)

            diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0))

            pixel_loss, offset_loss, sym_loss = part_loss(diffs,offsets,occlusions, [X0,X1],epsilon=args.epsilon)

            total_loss = sum(x*y if x > 0 else 0 for x,y in zip(args.alpha, pixel_loss))

            training_losses.update(total_loss.item(), args.batch_size)
            if i % max(1, int(int(len(train_set) / args.batch_size )/500.0)) == 0:

                print("Ep [" + str(t) +"/" + str(i) +
                                    "]\tl.r.: " + str(round(float(ikk['lr']),7))+
                                    "\tPix: " + str([round(x.item(),5) for x in pixel_loss]) +
                                    "\tTV: " + str([round(x.item(),4)  for x in offset_loss]) +
                                    "\tSym: " + str([round(x.item(), 4) for x in sym_loss]) +
                                    "\tTotal: " + str([round(x.item(),5) for x in [total_loss]]) +
                                    "\tAvg. Loss: " + str([round(training_losses.avg, 5)]))

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        if t == 1:
            # delete the pre validation weights for cleaner workspace
            if os.path.exists(args.save_path + "/epoch" + str(0) +".pth" ):
                os.remove(args.save_path + "/epoch" + str(0) +".pth")

        if os.path.exists(args.save_path + "/epoch" + str(t-1) +".pth"):
            os.remove(args.save_path + "/epoch" + str(t-1) +".pth")
        torch.save(model.state_dict(), args.save_path + "/epoch" + str(t) +".pth")

        # print("\t\t**************Start Validation*****************")
        #Turn into evaluation mode

        val_total_losses = AverageMeter()
        val_total_pixel_loss = AverageMeter()
        val_total_PSNR_loss = AverageMeter()
        val_total_tv_loss = AverageMeter()
        val_total_pws_loss = AverageMeter()
        val_total_sym_loss = AverageMeter()

        for i, (X0,X1,y) in enumerate(val_loader):
            if i >=  int(len(test_set)/ args.batch_size):
                break

            with torch.no_grad():
                X0 = X0.cuda() if args.use_cuda else X0
                X1 = X1.cuda() if args.use_cuda else X1
                y = y.cuda() if args.use_cuda else y

                diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0))

                pixel_loss, offset_loss,sym_loss = part_loss(diffs, offsets, occlusions, [X0,X1],epsilon=args.epsilon)

                val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss))

                per_sample_pix_error = torch.mean(torch.mean(torch.mean(diffs[args.save_which] ** 2,
                                                                    dim=1),dim=1),dim=1)
                per_sample_pix_error = per_sample_pix_error.data # extract tensor
                psnr_loss = torch.mean(20 * torch.log(1.0/torch.sqrt(per_sample_pix_error)))/torch.log(torch.Tensor([10]))
                #

                val_total_losses.update(val_total_loss.item(),args.batch_size)
                val_total_pixel_loss.update(pixel_loss[args.save_which].item(), args.batch_size)
                val_total_tv_loss.update(offset_loss[0].item(), args.batch_size)
                val_total_sym_loss.update(sym_loss[0].item(), args.batch_size)
                val_total_PSNR_loss.update(psnr_loss[0],args.batch_size)
                print(".",end='',flush=True)

        print("\nEpoch " + str(int(t)) +
              "\tlearning rate: " + str(float(ikk['lr'])) +
              "\tAvg Training Loss: " + str(round(training_losses.avg,5)) +
              "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) +
              "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) +
              "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)]) +
              "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) +
              "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) +
              "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)])
              )

        auxiliary_data.append([t, float(ikk['lr']),
                                   training_losses.avg, val_total_losses.avg, val_total_pixel_loss.avg,
                                   val_total_tv_loss.avg,val_total_pws_loss.avg,val_total_sym_loss.avg])

        numpy.savetxt(args.log, numpy.array(auxiliary_data), fmt='%.8f', delimiter=',')
        training_losses.reset()


        print("\t\tFinished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
        if saved_total_loss >= val_total_losses.avg:
            saved_total_loss = val_total_losses.avg
            torch.save(model.state_dict(), args.save_path + "/best"+".pth")
            print("\t\tBest Weights updated for decreased validation loss\n")
            if os.path.exists("/content/model_weights")==True:
              shutil.rmtree("/content/model_weights")
            shutil.copytree("/content/DAIN/model_weights", "/content/model_weights")
            
        else:
            print("\t\tWeights Not updated for undecreased validation loss\n")

        #schdule the learning rate
        scheduler.step(val_total_losses.avg)


    print("*********Finish Training********")
Example #2
0
### /SAVED MODEL


if type(args.datasetName) == list:
    train_sets, test_sets = [],[]
    for ii, jj in zip(args.datasetName, args.datasetPath):
        tr_s, te_s = datasets.__dict__[ii](jj, split = args.dataset_split,single = args.single_output, task = args.task)
        train_sets.append(tr_s)
        test_sets.append(te_s)
    train_set = torch.utils.data.ConcatDataset(train_sets)
    test_set = torch.utils.data.ConcatDataset(test_sets)
else:
    train_set, test_set = datasets.__dict__[args.datasetName](args.datasetPath)
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size = args.batch_size,
    sampler=balancedsampler.RandomBalancedSampler(train_set, int(len(train_set) / args.batch_size )),
    num_workers= args.workers, pin_memory=True if use_cuda else False)

val_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size,
                                         num_workers=args.workers, pin_memory=True if use_cuda else False)
print('{} samples found, {} train samples and {} test samples '.format(len(test_set)+len(train_set),
                                                                       len(train_set),
                                                                       len(test_set)))


# if not args.lr == 0:
print("train the interpolation net")
optimizer = torch.optim.Adamax([
            {'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr},
            {'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr},
            {'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr},
Example #3
0
def main():
    global args, best_EPE, save_path
    args = parser.parse_args()
    save_path = '{},{},{}epochs{},b{},lr{}'.format(
        args.arch, args.solver, args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.batch_size, args.lr)
    if not args.no_date:
        timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M")
        save_path = os.path.join(timestamp, save_path)
    save_path = os.path.join(args.dataset, save_path)
    print('=> will save everything to {}'.format(save_path))
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    input_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), normalize
    ])
    target_transform = transforms.Compose([
        flow_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow])
    ])

    if 'KITTI' in args.dataset:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomCrop((320, 448)),
            #random flips are not supported yet for tensor conversion, but will be
            #flow_transforms.RandomVerticalFlip(),
            #flow_transforms.RandomHorizontalFlip()
        ])
    else:
        co_transform = flow_transforms.Compose([
            flow_transforms.RandomTranslate(10),
            flow_transforms.RandomRotate(10, 5),
            flow_transforms.RandomCrop((320, 448)),
            #random flips are not supported yet for tensor conversion, but will be
            #flow_transforms.RandomVerticalFlip(),
            #flow_transforms.RandomHorizontalFlip()
        ])

    print("=> fetching img pairs in '{}'".format(args.data))
    train_set, test_set = datasets.__dict__[args.dataset](
        args.data,
        transform=input_transform,
        target_transform=target_transform,
        co_transform=co_transform,
        split=args.split)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        sampler=balancedsampler.RandomBalancedSampler(train_set,
                                                      args.epoch_size),
        num_workers=args.workers,
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=args.batch_size,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
    else:
        print("=> creating model '{}'".format(args.arch))

    model = models.__dict__[args.arch](args.pretrained).cuda()

    model = torch.nn.DataParallel(model).cuda()
    criterion = multiscaleloss(sparse='KITTI' in args.dataset,
                               loss=args.loss).cuda()
    high_res_EPE = multiscaleloss(scales=1,
                                  downscale=4,
                                  weights=(1),
                                  loss='L1',
                                  sparse='KITTI' in args.dataset).cuda()
    cudnn.benchmark = True

    assert (args.solver in ['adam', 'sgd'])
    print('=> setting {} solver'.format(args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    if args.evaluate:
        best_EPE = validate(val_loader, model, criterion, high_res_EPE)
        return

    with open(os.path.join(save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_EPE', 'EPE'])

    with open(os.path.join(save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'train_EPE'])

    for epoch in range(args.start_epoch, args.epochs):
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        train_loss, train_EPE = train(train_loader, model, criterion,
                                      high_res_EPE, optimizer, epoch)

        # evaluate o validation set

        EPE = validate(val_loader, model, criterion, high_res_EPE)
        if best_EPE < 0:
            best_EPE = EPE

        # remember best prec@1 and save checkpoint
        is_best = EPE < best_EPE
        best_EPE = min(EPE, best_EPE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.module.state_dict(),
                'best_EPE': best_EPE,
            }, is_best)

        with open(os.path.join(save_path, args.log_summary), 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, train_EPE, EPE])
Example #4
0
def train():
    torch.manual_seed(args.seed)

    model = networks.__dict__[args.netName](channel=args.channels,
                                            filter_size=args.filter_size,
                                            timestep=args.time_step,
                                            training=True)
    original_model = networks.__dict__[args.netName](
        channel=args.channels,
        filter_size=args.filter_size,
        timestep=args.time_step,
        training=True)
    if args.use_cuda:
        print("Turn the model into CUDA")
        model = model.cuda()
        original_model = original_model.cuda()

    if not args.SAVED_MODEL == None:
        args.SAVED_MODEL = './model_weights/' + args.SAVED_MODEL + "/best" + ".pth"
        print("Fine tuning on " + args.SAVED_MODEL)
        if not args.use_cuda:
            pretrained_dict = torch.load(
                args.SAVED_MODEL, map_location=lambda storage, loc: storage)
            # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage))
        else:
            pretrained_dict = torch.load(args.SAVED_MODEL)
            # model.load_state_dict(torch.load(args.SAVED_MODEL))
        #print([k for k,v in      pretrained_dict.items()])

        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)

        # For comparison in meta training
        original_model.load_state_dict(model_dict)

        pretrained_dict = None

    if type(args.datasetName) == list:
        train_sets, test_sets = [], []
        for ii, jj in zip(args.datasetName, args.datasetPath):
            tr_s, te_s = datasets.__dict__[ii](jj,
                                               split=args.dataset_split,
                                               single=args.single_output,
                                               task=args.task)
            train_sets.append(tr_s)
            test_sets.append(te_s)
        train_set = torch.utils.data.ConcatDataset(train_sets)
        test_set = torch.utils.data.ConcatDataset(test_sets)
    else:
        train_set, test_set = datasets.__dict__[args.datasetName](
            args.datasetPath)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        sampler=balancedsampler.RandomBalancedSampler(
            train_set, int(len(train_set) / args.batch_size)),
        num_workers=args.workers,
        pin_memory=True if args.use_cuda else False)

    val_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True if args.use_cuda else False)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))

    # if not args.lr == 0:
    print("train the interpolation net")
    '''optimizer = torch.optim.Adamax([
                #{'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr},
                #{'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr},
                #{'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr},
                #{'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr},
                #{'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr},
                #{'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr},
                {'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr}
                ],
                #lr=args.lr, momentum=0, weight_decay=args.weight_decay)
                lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)'''
    optimizer = torch.optim.Adamax(model.rectifyNet.parameters(),
                                   lr=args.outer_lr,
                                   betas=(0.9, 0.999),
                                   eps=1e-8,
                                   weight_decay=args.weight_decay)

    # Fix weights for early layers
    for param in model.initScaleNets_filter.parameters():
        param.requires_grad = False
    for param in model.initScaleNets_filter1.parameters():
        param.requires_grad = False
    for param in model.initScaleNets_filter2.parameters():
        param.requires_grad = False
    for param in model.ctxNet.parameters():
        param.requires_grad = False
    for param in model.flownets.parameters():
        param.requires_grad = False
    for param in model.depthNet.parameters():
        param.requires_grad = False

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=args.factor,
                                  patience=args.patience,
                                  verbose=True)

    print("*********Start Training********")
    print("LR is: " + str(float(optimizer.param_groups[0]['lr'])))
    print("EPOCH is: " + str(int(len(train_set) / args.batch_size)))
    print("Num of EPOCH is: " + str(args.numEpoch))

    def count_network_parameters(model):

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        N = sum([numpy.prod(p.size()) for p in parameters])

        return N

    print("Num. of model parameters is :" +
          str(count_network_parameters(model)))
    if hasattr(model, 'flownets'):
        print("Num. of flow model parameters is :" +
              str(count_network_parameters(model.flownets)))
    if hasattr(model, 'initScaleNets_occlusion'):
        print("Num. of initScaleNets_occlusion model parameters is :" + str(
            count_network_parameters(model.initScaleNets_occlusion) +
            count_network_parameters(model.initScaleNets_occlusion1) +
            count_network_parameters(model.initScaleNets_occlusion2)))
    if hasattr(model, 'initScaleNets_filter'):
        print("Num. of initScaleNets_filter model parameters is :" + str(
            count_network_parameters(model.initScaleNets_filter) +
            count_network_parameters(model.initScaleNets_filter1) +
            count_network_parameters(model.initScaleNets_filter2)))
    if hasattr(model, 'ctxNet'):
        print("Num. of ctxNet model parameters is :" +
              str(count_network_parameters(model.ctxNet)))
    if hasattr(model, 'depthNet'):
        print("Num. of depthNet model parameters is :" +
              str(count_network_parameters(model.depthNet)))
    if hasattr(model, 'rectifyNet'):
        print("Num. of rectifyNet model parameters is :" +
              str(count_network_parameters(model.rectifyNet)))

    training_losses = AverageMeter()
    #original_training_losses = AverageMeter()
    batch_time = AverageMeter()
    auxiliary_data = []
    saved_total_loss = 10e10
    saved_total_PSNR = -1
    ikk = 0
    for kk in optimizer.param_groups:
        if kk['lr'] > 0:
            ikk = kk
            break

    for t in range(args.numEpoch):
        print("The id of this in-training network is " + str(args.uid))
        print(args)
        print("Learning rate for this epoch: %s" %
              str(round(float(ikk['lr']), 7)))

        #Turn into training mode
        model = model.train()

        #for i, (X0_half,X1_half, y_half) in enumerate(train_loader):
        _t = time.time()
        for i, images in enumerate(train_loader):

            if i >= min(TRAIN_ITER_CUT, int(len(train_set) / args.batch_size)):
                #(0 if t == 0 else EPOCH):#
                break

            if args.use_cuda:
                images = [im.cuda() for im in images]

            images = [Variable(im, requires_grad=False) for im in images]

            # For VimeoTriplet
            #X0, y, X1 = images[0], images[1], images[2]
            # For VimeoSepTuplet
            X0, y, X1 = images[2], images[3], images[4]

            outerstepsize = args.outer_lr
            k = args.num_inner_update  # inner loop update iteration

            inner_optimizer = torch.optim.Adamax(
                model.rectifyNet.parameters(),
                lr=args.inner_lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=args.weight_decay)

            if META_ALGORITHM == "Reptile":

                # Reptile setting
                weights_before = copy.deepcopy(model.state_dict())

                for _k in range(k):
                    indices = [[0, 2, 4], [2, 4, 6], [2, 3, 4], [0, 1, 2],
                               [4, 5, 6]]
                    total_loss = 0
                    for ind in indices:
                        meta_X0, meta_y, meta_X1 = images[ind[0]].clone(
                        ), images[ind[1]].clone(), images[ind[2]].clone()

                        diffs, offsets, filters, occlusions = model(
                            torch.stack((meta_X0, meta_y, meta_X1), dim=0))
                        pixel_loss, offset_loss, sym_loss = part_loss(
                            diffs,
                            offsets,
                            occlusions, [meta_X0, meta_X1],
                            epsilon=args.epsilon)
                        _total_loss = sum(
                            x * y if x > 0 else 0
                            for x, y in zip(args.alpha, pixel_loss))
                        total_loss = total_loss + _total_loss
                    # total *= 2 / len(indices)

                    inner_optimizer.zero_grad()
                    total_loss.backward()
                    inner_optimizer.step()

                # Reptile update
                weights_after = model.state_dict()
                model.load_state_dict({
                    name: weights_before[name] +
                    (weights_after[name] - weights_before[name]) *
                    outerstepsize
                    for name in weights_before
                })

                with torch.no_grad():
                    diffs, offsets, filters, occlusions = model(
                        torch.stack((X0, y, X1), dim=0))
                    pixel_loss, offset_loss, sym_loss = part_loss(
                        diffs,
                        offsets,
                        occlusions, [X0, X1],
                        epsilon=args.epsilon)
                    total_loss = sum(x * y if x > 0 else 0
                                     for x, y in zip(args.alpha, pixel_loss))
                training_losses.update(total_loss.item(), args.batch_size)

            elif META_ALGORITHM == "MAML":

                #weights_before = copy.deepcopy(model.state_dict())
                base_model = copy.deepcopy(model)
                #fast_weights = list(filter(lambda p: p.requires_grad, model.parameters()))

                for _k in range(k):

                    indices = [[0, 2, 4], [2, 4, 6]]
                    support_loss = 0
                    for ind in indices:
                        meta_X0, meta_y, meta_X1 = images[ind[0]].clone(
                        ), images[ind[1]].clone(), images[ind[2]].clone()

                        diffs, offsets, filters, occlusions = model(
                            torch.stack((meta_X0, meta_y, meta_X1), dim=0))
                        pixel_loss, offset_loss, sym_loss = part_loss(
                            diffs,
                            offsets,
                            occlusions, [meta_X0, meta_X1],
                            epsilon=args.epsilon)
                        _total_loss = sum(
                            x * y if x > 0 else 0
                            for x, y in zip(args.alpha, pixel_loss))
                        support_loss = support_loss + _total_loss

                    #grad = torch.autograd.grad(loss, fast_weights)
                    #fast_weights = list(map(lambda p: p[1] - args.lr * p[0], zip(grad, fast_weights)))
                    inner_optimizer.zero_grad()
                    support_loss.backward()  # create_graph=True
                    inner_optimizer.step()

                # Forward on query set
                diffs, offsets, filters, occlusions = model(
                    torch.stack((X0, y, X1), dim=0))
                pixel_loss, offset_loss, sym_loss = part_loss(
                    diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon)
                total_loss = sum(x * y if x > 0 else 0
                                 for x, y in zip(args.alpha, pixel_loss))
                training_losses.update(total_loss.item(), args.batch_size)

                # copy parameters to comnnect the computational graph
                for param, base_param in zip(
                        model.rectifyNet.parameters(),
                        base_model.rectifyNet.parameters()):
                    param.data = base_param.data

                filtered_params = filter(lambda p: p.requires_grad,
                                         model.parameters())
                optimizer.zero_grad()
                grads = torch.autograd.grad(total_loss, list(
                    filtered_params))  # backward on weights_before: FO-MAML
                j = 0
                #print('[before update]')
                #print(list(model.parameters())[45][-1])
                for _i, param in enumerate(model.parameters()):
                    if param.requires_grad:
                        #param = param - outerstepsize * grads[j]
                        param.grad = grads[j]
                        j += 1
                optimizer.step()
                #print('[after optim.step]')
                #print(list(model.parameters())[45][-1])

            batch_time.update(time.time() - _t)
            _t = time.time()

            if i % 100 == 0:  #max(1, int(int(len(train_set) / args.batch_size )/500.0)) == 0:

                print(
                    "Ep[%s][%05d/%d]  Time: %.2f  Pix: %s  TV: %s  Sym: %s  Total: %s  Avg. Loss: %s"
                    % (str(t), i, int(len(train_set)) // args.batch_size,
                       batch_time.avg,
                       str([round(x.item(), 5) for x in pixel_loss
                            ]), str([round(x.item(), 4) for x in offset_loss]),
                       str([round(x.item(), 4) for x in sym_loss
                            ]), str([round(x.item(), 5) for x in [total_loss]
                                     ]), str([round(training_losses.avg, 5)])))
                batch_time.reset()

        if t == 1:
            # delete the pre validation weights for cleaner workspace
            if os.path.exists(args.save_path + "/epoch" + str(0) + ".pth"):
                os.remove(args.save_path + "/epoch" + str(0) + ".pth")

        if os.path.exists(args.save_path + "/epoch" + str(t - 1) + ".pth"):
            os.remove(args.save_path + "/epoch" + str(t - 1) + ".pth")
        torch.save(model.state_dict(),
                   args.save_path + "/epoch" + str(t) + ".pth")

        # print("\t\t**************Start Validation*****************")
        #Turn into evaluation mode

        val_total_losses = AverageMeter()
        val_total_pixel_loss = AverageMeter()
        val_total_PSNR_loss = AverageMeter()
        val_total_tv_loss = AverageMeter()
        val_total_pws_loss = AverageMeter()
        val_total_sym_loss = AverageMeter()

        for i, (images, imgpaths) in enumerate(tqdm(val_loader)):
            #if i < 50: #i < 11 or (i > 14 and i < 50):
            #    continue
            if i >= min(VAL_ITER_CUT, int(len(test_set) / args.batch_size)):
                break

            if args.use_cuda:
                images = [im.cuda() for im in images]
            #X0, y, X1 = images[0], images[1], images[2]
            #X0, y, X1 = images[2], images[3], images[4]

            # define optimizer to update the inner loop
            inner_optimizer = torch.optim.Adamax(
                model.rectifyNet.parameters(),
                lr=args.inner_lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=args.weight_decay)

            # Reptile testing - save base model weights
            weights_base = copy.deepcopy(model.state_dict())

            k = args.num_inner_update  # 2
            model.train()
            for _k in range(k):
                indices = [[0, 2, 4], [2, 4, 6]]
                ind = indices[_k % 2]
                meta_X0, meta_y, meta_X1 = crop(images[ind[0]]), crop(
                    images[ind[1]]), crop(images[ind[2]])

                diffs, offsets, filters, occlusions, _ = model(
                    torch.stack((meta_X0, meta_y, meta_X1), dim=0))
                pixel_loss, offset_loss, sym_loss = part_loss(
                    diffs,
                    offsets,
                    occlusions, [meta_X0, meta_X1],
                    epsilon=args.epsilon)
                total_loss = sum(x * y if x > 0 else 0
                                 for x, y in zip(args.alpha, pixel_loss))

                inner_optimizer.zero_grad()
                total_loss.backward()
                inner_optimizer.step()

            # Actual target validation performance
            with torch.no_grad():
                if args.datasetName == 'Vimeo_90K_sep':
                    X0, y, X1 = images[2], images[3], images[4]
                    #diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0))
                    diffs, offsets, filters, occlusions, output = model(
                        torch.stack((X0, y, X1), dim=0))

                    pixel_loss, offset_loss, sym_loss = part_loss(
                        diffs,
                        offsets,
                        occlusions, [X0, X1],
                        epsilon=args.epsilon)

                    val_total_loss = sum(
                        x * y for x, y in zip(args.alpha, pixel_loss))

                    per_sample_pix_error = torch.mean(torch.mean(torch.mean(
                        diffs[args.save_which]**2, dim=1),
                                                                 dim=1),
                                                      dim=1)
                    per_sample_pix_error = per_sample_pix_error.data  # extract tensor
                    psnr_loss = torch.mean(20 * torch.log(
                        1.0 / torch.sqrt(per_sample_pix_error))) / torch.log(
                            torch.Tensor([10]))

                    val_total_losses.update(val_total_loss.item(),
                                            args.batch_size)
                    val_total_pixel_loss.update(
                        pixel_loss[args.save_which].item(), args.batch_size)
                    val_total_tv_loss.update(offset_loss[0].item(),
                                             args.batch_size)
                    val_total_sym_loss.update(sym_loss[0].item(),
                                              args.batch_size)
                    val_total_PSNR_loss.update(psnr_loss[0], args.batch_size)

                else:  # HD_dataset testing
                    for j in range(len(images) // 2):
                        mH, mW = 720, 1280
                        X0, y, X1 = crop(images[2 * j], maxH=mH,
                                         maxW=mW), crop(images[2 * j + 1],
                                                        maxH=mH,
                                                        maxW=mW), crop(
                                                            images[2 * j + 2],
                                                            maxH=mH,
                                                            maxW=mW)
                        diffs, offsets, filters, occlusions, output = model(
                            torch.stack((X0, y, X1), dim=0))

                        pixel_loss, offset_loss, sym_loss = part_loss(
                            diffs,
                            offsets,
                            occlusions, [X0, X1],
                            epsilon=args.epsilon)

                        val_total_loss = sum(
                            x * y for x, y in zip(args.alpha, pixel_loss))

                        per_sample_pix_error = torch.mean(torch.mean(
                            torch.mean(diffs[args.save_which]**2, dim=1),
                            dim=1),
                                                          dim=1)
                        per_sample_pix_error = per_sample_pix_error.data  # extract tensor
                        psnr_loss = torch.mean(
                            20 *
                            torch.log(1.0 / torch.sqrt(per_sample_pix_error))
                        ) / torch.log(torch.Tensor([10]))

                        val_total_losses.update(val_total_loss.item(),
                                                args.batch_size)
                        val_total_pixel_loss.update(
                            pixel_loss[args.save_which].item(),
                            args.batch_size)
                        val_total_tv_loss.update(offset_loss[0].item(),
                                                 args.batch_size)
                        val_total_sym_loss.update(sym_loss[0].item(),
                                                  args.batch_size)
                        val_total_PSNR_loss.update(psnr_loss[0],
                                                   args.batch_size)

            # Reset model to its base weights
            model.load_state_dict(weights_base)

            #del weights_base, inner_optimizer, meta_X0, meta_y, meta_X1, X0, y, X1, pixel_loss, offset_loss, sym_loss, total_loss, val_total_loss, diffs, offsets, filters, occlusions

            VIZ = False
            exp_name = 'meta_test'
            if VIZ:
                for b in range(images[0].size(0)):
                    imgpath = imgpaths[0][b]
                    savepath = os.path.join('checkpoint', exp_name,
                                            'vimeoSeptuplet',
                                            imgpath.split('/')[-3],
                                            imgpath.split('/')[-2])
                    if not os.path.exists(savepath):
                        os.makedirs(savepath)
                    img_pred = (output[b].data.permute(1, 2, 0).clamp_(
                        0, 1).cpu().numpy()[..., ::-1] * 255).astype(
                            numpy.uint8)
                    cv2.imwrite(os.path.join(savepath, 'im2_pred.png'),
                                img_pred)
            ''' # Original validation (not meta)
            with torch.no_grad():
                if args.use_cuda:
                    images = [im.cuda() for im in images]

                #X0, y, X1 = images[0], images[1], images[2]
                X0, y, X1 = images[2], images[3], images[4]

                    
                #diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0))

                pixel_loss, offset_loss,sym_loss = part_loss(diffs, offsets, occlusions, [X0,X1],epsilon=args.epsilon)

                val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss))

                per_sample_pix_error = torch.mean(torch.mean(torch.mean(diffs[args.save_which] ** 2,
                                                                    dim=1),dim=1),dim=1)
                per_sample_pix_error = per_sample_pix_error.data # extract tensor
                psnr_loss = torch.mean(20 * torch.log(1.0/torch.sqrt(per_sample_pix_error)))/torch.log(torch.Tensor([10]))
                #

                val_total_losses.update(val_total_loss.item(),args.batch_size)
                val_total_pixel_loss.update(pixel_loss[args.save_which].item(), args.batch_size)
                val_total_tv_loss.update(offset_loss[0].item(), args.batch_size)
                val_total_sym_loss.update(sym_loss[0].item(), args.batch_size)
                val_total_PSNR_loss.update(psnr_loss[0],args.batch_size)
                print(".",end='',flush=True)
            '''

        print("\nEpoch " + str(int(t)) + "\tlearning rate: " +
              str(float(ikk['lr'])) + "\tAvg Training Loss: " +
              str(round(training_losses.avg, 5)) + "\tValidate Loss: " +
              str([round(float(val_total_losses.avg), 5)]) +
              "\tValidate PSNR: " +
              str([round(float(val_total_PSNR_loss.avg), 5)]) +
              "\tPixel Loss: " +
              str([round(float(val_total_pixel_loss.avg), 5)]) +
              "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) +
              "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) +
              "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)]))

        auxiliary_data.append([
            t,
            float(ikk['lr']), training_losses.avg, val_total_losses.avg,
            val_total_pixel_loss.avg, val_total_tv_loss.avg,
            val_total_pws_loss.avg, val_total_sym_loss.avg
        ])

        numpy.savetxt(args.log,
                      numpy.array(auxiliary_data),
                      fmt='%.8f',
                      delimiter=',')
        training_losses.reset()
        #original_training_losses.reset()

        print("\t\tFinished an epoch, Check and Save the model weights")
        # we check the validation loss instead of training loss. OK~
        if saved_total_loss >= val_total_losses.avg:
            saved_total_loss = val_total_losses.avg
            torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
            print("\t\tBest Weights updated for decreased validation loss\n")

        else:
            print("\t\tWeights Not updated for undecreased validation loss\n")

        #schdule the learning rate
        scheduler.step(val_total_losses.avg)

    print("*********Finish Training********")
Example #5
0
def train():

    # ============================================================== #
    #                       Init Visdom                              #
    # ============================================================== #
    viz_env = args.vis_env
    viz = Visdom(env=viz_env)

    viz.line([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [0.],
             win='train_respective_loss',
             env=viz_env,
             opts=dict(title='train_respective_loss',
                       legend=[
                           'pixel_loss_0', 'pixel_loss_1', 'offset_loss',
                           'occlusion_loss', 'sym_loss', 'total_loss',
                           'total_loss_avg'
                       ]))

    viz.line([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [0.],
             win='val_respective_loss',
             env=viz_env,
             opts=dict(title='epoch_val_respective_loss',
                       legend=[
                           'training_losses', 'val_total_losses',
                           'val_total_PSNR_loss', 'val_total_pixel_loss',
                           'val_total_tv_loss', 'val_total_pws_loss',
                           'val_total_sym_loss'
                       ]))

    # viz.line([[0.0, 0.0]], [0.], win='validation_psnr', env=viz_env,
    #          opts=dict(title='val psnr', legend=['Resotred psnr', 'Blurry psnr']))
    # viz.line([[0.0, 0.0]], [0.], win='validation_ssim', env=viz_env,
    #          opts=dict(title='val ssim', legend=['Restored ssim', 'Blurry ssim']))

    torch.manual_seed(args.seed)

    model = networks.__dict__[args.netName](
        batch=args.batch_size,
        channel=args.channels,
        width=None,
        height=None,
        scale_num=1,
        scale_ratio=2,
        temporal=False,
        filter_size=args.filter_size,
        save_which=args.save_which,
        flowmethod=args.flowmethod,
        timestep=args.time_step,
        FlowProjection_threshhold=args.flowproj_threshhold,
        offset_scale=None,
        cuda_available=args.use_cuda,
        cuda_id=None,
        training=True)
    if args.use_cuda:
        print("Turn the model into CUDA")
        model = model.cuda()

    if not args.SAVED_MODEL == None:
        args.SAVED_MODEL = '../model_weights/' + args.SAVED_MODEL + "/best" + ".pth"
        print("Fine tuning on " + args.SAVED_MODEL)
        if not args.use_cuda:
            pretrained_dict = torch.load(
                args.SAVED_MODEL, map_location=lambda storage, loc: storage)
            # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage))
        else:
            pretrained_dict = torch.load(args.SAVED_MODEL)
            # model.load_state_dict(torch.load(args.SAVED_MODEL))
        #print([k for k,v in      pretrained_dict.items()])

        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }  # and not k[:10]== 'rectifyNet'}
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)
        pretrained_dict = None

        # torch.save(model.depthNet.state_dict(), "8402_best_depth" + ".pth")

    if type(args.datasetName) == list:
        train_sets, test_sets = [], []
        for ii, jj in zip(args.datasetName, args.datasetPath):
            tr_s, te_s = datasets.__dict__[ii](
                jj,
                split=args.dataset_split,
                single=args.single_output,
                task=args.task,
                middle=args.time_step == 0.5,
                high_fps=args.high_fps)  # if time_step = 0.5, only use middle
            train_sets.append(tr_s)
            test_sets.append(te_s)
        train_set = torch.utils.data.ConcatDataset(train_sets)
        test_set = torch.utils.data.ConcatDataset(test_sets)
    else:
        train_set, test_set = datasets.__dict__[args.datasetName](
            args.datasetPath,
            split=args.dataset_split,
            single=args.single_output,
            task=args.task,
            middle=args.time_step == 0.5,
            high_fps=args.high_fps)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        sampler=balancedsampler.RandomBalancedSampler(
            train_set, int(len(train_set) / args.batch_size)),
        # RandomBalancedSampler(train_set,args.epoch_size),
        num_workers=args.workers,
        pin_memory=True if args.use_cuda else False)

    val_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.batch_size,
        # sampler=balancedsampler.SequentialBalancedSampler(test_set,)
        num_workers=args.workers,
        pin_memory=True if args.use_cuda else False)
    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))

    # to skip the fixed parameters of vgg model, we need to filter them out...
    # for param in model.parameters():
    #     print(type(param.data), param.size())
    # for idx, m in enumerate(model.named_modules()):
    #     print(idx, '->', m)
    # optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-8)
    # optimizer = torch.optim.Adamax(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8,args.weight_decay=args.weight_decay)
    if not args.lr == 0:
        print("train the interpolation net")
        if args.netName == 'DAIN':
            optimizer = torch.optim.Adamax(
                [{
                    'params': model.initScaleNets_filter.parameters(),
                    'lr': args.filter_lr_coe * args.lr
                }, {
                    'params': model.initScaleNets_filter1.parameters(),
                    'lr': args.filter_lr_coe * args.lr
                }, {
                    'params': model.initScaleNets_filter2.parameters(),
                    'lr': args.filter_lr_coe * args.lr
                }, {
                    'params': model.ctxNet.parameters(),
                    'lr': args.ctx_lr_coe * args.lr
                }, {
                    'params': model.flownets.parameters(),
                    'lr': args.flow_lr_coe * args.lr
                }, {
                    'params': model.depthNet.parameters(),
                    'lr': args.depth_lr_coe * args.lr
                }, {
                    'params': model.rectifyNet.parameters(),
                    'lr': args.rectify_lr
                }],
                lr=args.lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=args.weight_decay)
    else:
        print("Only train the rectifyNet")
        optimizer = torch.optim.Adamax([{
            'params': model.rectifyNet.parameters(),
            'lr': args.rectify_lr
        }],
                                       lr=args.lr,
                                       betas=(0.9, 0.999),
                                       eps=1e-8,
                                       weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=args.factor,
                                  patience=args.patience,
                                  verbose=True)

    print("*********Start Training********")
    print("LR is: " + str(float(optimizer.param_groups[0]['lr'])))
    print("EPOCH is: " + str(int(len(train_set) / args.batch_size)))
    print("Num of EPOCH is: " + str(args.numEpoch))

    def count_network_parameters(model):

        parameters = filter(lambda p: p.requires_grad, model.parameters())
        N = sum([numpy.prod(p.size()) for p in parameters])

        return N

    print("Num. of model parameters is :" +
          str(count_network_parameters(model)))
    if hasattr(model, 'flownets'):
        print("Num. of flow model parameters is :" +
              str(count_network_parameters(model.flownets)))
    if hasattr(model, 'initScaleNets_occlusion'):
        print("Num. of initScaleNets_occlusion model parameters is :" + str(
            count_network_parameters(model.initScaleNets_occlusion) +
            count_network_parameters(model.initScaleNets_occlusion1) +
            count_network_parameters(model.initScaleNets_occlusion2)))
    if hasattr(model, 'initScaleNets_filter'):
        print("Num. of initScaleNets_filter model parameters is :" + str(
            count_network_parameters(model.initScaleNets_filter) +
            count_network_parameters(model.initScaleNets_filter1) +
            count_network_parameters(model.initScaleNets_filter2)))
    if hasattr(model, 'ctxNet'):
        print("Num. of ctxNet model parameters is :" +
              str(count_network_parameters(model.ctxNet)))
    if hasattr(model, 'depthNet'):
        print("Num. of depthNet model parameters is :" +
              str(count_network_parameters(model.depthNet)))

    if hasattr(model, 'rectifyNet'):
        print("Num. of rectifyNet model parameters is :" +
              str(count_network_parameters(model.rectifyNet)))

    if hasattr(model, 'fea_exat_net'):
        print("Num. of fea_exat_net model parameters is :" +
              str(count_network_parameters(model.fea_exat_net)))

    training_losses = AverageMeter()
    auxiliary_data = []
    saved_total_loss = 10e10
    saved_total_PSNR = -1
    saved_total_loss_MB = 10e10
    MB_avgLoss, MB_avgPSNR = 1e5, 0
    ikk = 0
    for kk in optimizer.param_groups:
        if kk['lr'] > 0:
            ikk = kk
            break

    for t in range(args.numEpoch):
        print("The id of this in-training network is " + str(args.uid))
        print(args)
        #Turn into training mode
        model = model.train()

        for i, (X0_half, X1_half, y_half,
                frame_index) in enumerate(train_loader):

            if i >= (args.N_iter * int(len(train_set) / args.batch_size)
                     ):  #(0 if t == 0 else EPOCH):#
                break

            X0_half = X0_half.cuda() if args.use_cuda else X0_half
            X1_half = X1_half.cuda() if args.use_cuda else X1_half  # middle
            y_half = y_half.cuda() if args.use_cuda else y_half

            X0 = Variable(X0_half, requires_grad=False)
            X1 = Variable(X1_half, requires_grad=False)
            y = Variable(y_half, requires_grad=False)

            if args.netName == 'MultiScaleStructure_filt_flo_ctxS2D_depth_Modeling3':
                diffs, offsets, filters, occlusions = model(
                    torch.stack((X0, y, X1), dim=0))
            else:
                diffs, offsets, filters, occlusions = model(
                    torch.stack((X0, y, X1), dim=0), frame_index)

            pixel_loss, offset_loss, occlusion_loss, sym_loss = part_loss(
                diffs,
                offsets,
                occlusions, [X0, X1],
                epsilon=args.epsilon,
                use_negPSNR=args.use_negPSNR)

            DF_loss = df_loss_func(offsets, occlusions)


            total_loss = sum(x*y if x > 0 else 0 for x,y in zip(args.alpha, pixel_loss)) + sum(x*y for x,y in zip(args.lambda1, offset_loss) )  + \
                         sum(x*y if x > 0 else 0 for x,y in zip(args.lambda2, occlusion_loss)) + \
                         sum(x*y if x > 0 else 0 for x,y  in zip(args.lambda3, sym_loss)) +\
                         sum(x*y if x>0 else 0 for x,y in zip(args.lambda4, [DF_loss]))

            training_losses.update(total_loss.item(),
                                   args.batch_size)  #.item(),
            if i % max(1, int(
                    int(len(train_set) / args.batch_size) / 500.0)) == 0:

                pstring = "Ep [" + str(t) +"/" + str(i) + \
                                    "]\tl.r.: " + str(round(float(ikk['lr']),7))+ \
                                    "\tPix: " + str([round(x.item(),5) for x in pixel_loss]) + \
                                    "\tTV: " + str([round(x.item(),4)  for x in offset_loss]) + \
                                    "\tPWS: " + str([round(x.item(), 4) for x in occlusion_loss]) + \
                                    "\tSym: " + str([round(x.item(), 4) for x in sym_loss]) + \
                                    "\tTotal: " + str([round(x.item(),5) for x in [total_loss]]) + \
                                    "\tAvg. Loss: " + str([round(training_losses.avg, 5)])

                print(pstring)
                print(pstring,
                      file=open(os.path.join(args.save_path, "all_log.txt"),
                                "a"))

            # visdom display
            itr = i + t * (args.N_iter * int(len(train_set) / args.batch_size))
            viz.line([[
                pixel_loss[0].item(), pixel_loss[1].item(),
                offset_loss[0].item(), occlusion_loss[0].item(),
                sym_loss[0].item(),
                total_loss.item(), training_losses.avg
            ]], [itr],
                     win='train_respective_loss',
                     env=viz_env,
                     opts=dict(title='train_respective_loss',
                               legend=[
                                   'pixel_loss_0', 'pixel_loss_1',
                                   'offset_loss', 'occlusion_loss', 'sym_loss',
                                   'total_loss', 'total_loss_avg'
                               ]),
                     update='append')

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        if t == 1:
            # delete the pre validation weights for cleaner workspace
            if os.path.exists(args.save_path + "/epoch" + str(0) + ".pth"):
                os.remove(args.save_path + "/epoch" + str(0) + ".pth")

        if os.path.exists(args.save_path + "/epoch" + str(t - 1) + ".pth"):
            os.remove(args.save_path + "/epoch" + str(t - 1) + ".pth")
        torch.save(model.state_dict(),
                   args.save_path + "/epoch" + str(t) + ".pth")

        # print("\t\t**************Start Validation*****************")
        #Turn into evaluation mode

        val_total_losses = AverageMeter()
        val_total_pixel_loss = AverageMeter()
        val_total_PSNR_loss = AverageMeter()
        val_total_tv_loss = AverageMeter()
        val_total_pws_loss = AverageMeter()
        val_total_sym_loss = AverageMeter()

        for i, (X0, X1, y, frame_index) in enumerate(val_loader):

            if i >= int(len(test_set) / args.batch_size):
                break

            with torch.no_grad():

                X0 = X0.cuda() if args.use_cuda else X0
                X1 = X1.cuda() if args.use_cuda else X1
                y = y.cuda() if args.use_cuda else y

                if args.netName == 'MultiScaleStructure_filt_flo_ctxS2D_depth_Modeling3':
                    diffs, offsets, filters, occlusions = model(
                        torch.stack((X0, y, X1), dim=0))

                else:
                    diffs, offsets, filters, occlusions = model(
                        torch.stack((X0, y, X1), dim=0), frame_index)

                pixel_loss, offset_loss, occlusion_loss, sym_loss = part_loss(
                    diffs,
                    offsets,
                    occlusions, [X0, X1],
                    epsilon=args.epsilon,
                    use_negPSNR=args.use_negPSNR)

                val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss)) + \
                            sum(x * y for x, y in zip(args.lambda1, offset_loss)) + \
                            sum(x * y for x, y in zip(args.lambda2, occlusion_loss)) + \
                             sum(x * y for x, y in zip(args.lambda3, sym_loss))

                per_sample_pix_error = torch.mean(torch.mean(torch.mean(
                    diffs[args.save_which]**2, dim=1),
                                                             dim=1),
                                                  dim=1)
                per_sample_pix_error = per_sample_pix_error.data  # extract tensor
                # print(per_sample_pix_error.size())
                # print(per_sample_pix_error.type())
                psnr_loss = torch.mean(20 * torch.log(
                    1.0 / torch.sqrt(per_sample_pix_error))) / torch.log(
                        torch.Tensor([10]))

                val_total_losses.update(val_total_loss.item(), args.batch_size)
                val_total_pixel_loss.update(pixel_loss[args.save_which].item(),
                                            args.batch_size)
                val_total_tv_loss.update(offset_loss[0].item(),
                                         args.batch_size)
                val_total_pws_loss.update(occlusion_loss[0].item(),
                                          args.batch_size)
                val_total_sym_loss.update(sym_loss[0].item(), args.batch_size)
                val_total_PSNR_loss.update(psnr_loss[0], args.batch_size)
                print(".", end='', flush=True)

        pstring = "\nEpoch " + str(int(t)) + \
                  "\tlearning rate: " + str(float(ikk['lr'])) + \
                  "\tAvg Training Loss: " + str(round(training_losses.avg, 5)) + \
                  "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) + \
                  "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) + \
                  "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)]) + \
                  "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) + \
                  "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) + \
                  "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)])

        print(pstring)
        print(pstring,
              file=open(os.path.join(args.save_path, "all_log.txt"), "a"))
        # visdom
        viz.line([[
            training_losses.avg, val_total_losses.avg, val_total_PSNR_loss.avg,
            val_total_pixel_loss.avg, val_total_tv_loss.avg,
            val_total_pws_loss.avg, val_total_sym_loss.avg
        ]], [int(t)],
                 win='val_respective_loss',
                 env=viz_env,
                 opts=dict(title='epoch_val_respective_loss',
                           legend=[
                               'training_losses', 'val_total_losses',
                               'val_total_PSNR_loss', 'val_total_pixel_loss',
                               'val_total_tv_loss', 'val_total_pws_loss',
                               'val_total_sym_loss'
                           ]),
                 update='append')

        # todo
        MB_avgLoss = 0
        MB_avgPSNR = 0

        auxiliary_data.append([
            t,
            float(ikk['lr']), training_losses.avg, val_total_losses.avg,
            val_total_pixel_loss.avg, val_total_tv_loss.avg,
            val_total_pws_loss.avg, val_total_sym_loss.avg, MB_avgLoss,
            MB_avgPSNR
        ])

        numpy.savetxt(args.log,
                      numpy.array(auxiliary_data),
                      fmt='%.8f',
                      delimiter=',')
        training_losses.reset()

        print("\t\tFinished an epoch, Check and Save the model weights")
        # we check the validation loss instead of training loss. OK~
        if saved_total_loss >= val_total_losses.avg:
            saved_total_loss = val_total_losses.avg
            torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
            print("\t\tBest Weights updated for decreased validation loss\n")

        else:
            print("\t\tWeights Not updated for undecreased validation loss\n")
        if saved_total_PSNR <= val_total_PSNR_loss.avg:
            saved_total_PSNR = val_total_PSNR_loss.avg
            # torch.save(model,MODEL_PATH)
            # model.save_state_dict(MODEL_PATH)
            torch.save(model.state_dict(),
                       args.save_path + "/bestPSNR" + ".pth")
            print(
                "\t\tBest Weights updated for increased validation PSNR \n\n")

        else:
            print(
                "\t\tWeights Not updated for unincreased validation PSNR\n\n")

        #schdule the learning rate
        scheduler.step(val_total_losses.avg)

    print("*********Finish Training********")
Example #6
0
def train():
    SAVED_MODEL_PATH = "./model_weights/pretrained.pth"
    DATA_PATH = "./pixel_triplets/"
    BATCH_SIZE = 1

    torch.manual_seed(1337)
    random.seed(1337)

    # -------------------------------------
    #  load pre-trained model
    # -------------------------------------
    model = networks.DAIN(channel=3,
                          filter_size=4,
                          timestep=0.5,
                          training=False,
                          pixel_model=True)
    model = model.cuda()

    print("Fine tuning on " + SAVED_MODEL_PATH)

    pretrained_dict = torch.load(SAVED_MODEL_PATH)

    model_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    pretrained_dict = []

    # -------------------------------------
    #  create discriminator
    # -------------------------------------
    discrim = Discriminator()
    discrim = discrim.cuda()

    # discriminator optimizer and loss
    optimizer_discrim = torch.optim.Adam(discrim.parameters(), lr=0.0005)

    # -------------------------------------
    #  create dataset loaders
    # -------------------------------------
    train_set, test_set = datasets.pixel_triplets(DATA_PATH)
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=BATCH_SIZE,
        sampler=balancedsampler.RandomBalancedSampler(
            train_set, int(len(train_set) / BATCH_SIZE)),
        num_workers=8,
        pin_memory=True)

    val_loader = torch.utils.data.DataLoader(test_set,
                                             batch_size=BATCH_SIZE,
                                             num_workers=8,
                                             pin_memory=True)

    print('{} samples found, {} train samples and {} test samples '.format(
        len(test_set) + len(train_set), len(train_set), len(test_set)))

    # -------------------------------------
    #  create optimizer / LR scheduler
    # -------------------------------------
    print("train the interpolation net")
    optimizer = torch.optim.Adamax(
        [{
            'params': model.initScaleNets_filter.parameters(),
            'lr': args.filter_lr_coe * args.lr
        }, {
            'params': model.initScaleNets_filter1.parameters(),
            'lr': args.filter_lr_coe * args.lr
        }, {
            'params': model.initScaleNets_filter2.parameters(),
            'lr': args.filter_lr_coe * args.lr
        }, {
            'params': model.ctxNet.parameters(),
            'lr': args.ctx_lr_coe * args.lr
        }, {
            'params': model.flownets.parameters(),
            'lr': args.flow_lr_coe * args.lr
        }, {
            'params': model.depthNet.parameters(),
            'lr': args.depth_lr_coe * args.lr
        }, {
            'params': model.rectifyNet.parameters(),
            'lr': args.rectify_lr
        }],
        lr=args.lr,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=args.weight_decay)

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=args.factor,
                                  patience=args.patience,
                                  verbose=True)

    # -------------------------------------
    #  print out some info before we start
    # -------------------------------------
    print("*********Start Training********")
    print("LR is: " + str(float(optimizer.param_groups[0]['lr'])))
    print("EPOCH is: " + str(int(len(train_set) / BATCH_SIZE)))
    print("Num of EPOCH is: " + str(args.numEpoch))

    def count_network_parameters(model):
        parameters = filter(lambda p: p.requires_grad, model.parameters())
        N = sum([numpy.prod(p.size()) for p in parameters])

        return N

    print("Num. of model parameters is:", count_network_parameters(model))
    if hasattr(model, 'flownets'):
        print("Num. of flow model parameters is:",
              count_network_parameters(model.flownets))
    if hasattr(model, 'initScaleNets_occlusion'):
        print(
            "Num. of initScaleNets_occlusion model parameters is:",
            count_network_parameters(model.initScaleNets_occlusion) +
            count_network_parameters(model.initScaleNets_occlusion1) +
            count_network_parameters(model.initScaleNets_occlusion2))
    if hasattr(model, 'initScaleNets_filter'):
        print(
            "Num. of initScaleNets_filter model parameters is:",
            count_network_parameters(model.initScaleNets_filter) +
            count_network_parameters(model.initScaleNets_filter1) +
            count_network_parameters(model.initScaleNets_filter2))
    if hasattr(model, 'ctxNet'):
        print("Num. of ctxNet model parameters is:",
              count_network_parameters(model.ctxNet))
    if hasattr(model, 'depthNet'):
        print("Num. of depthNet model parameters is:",
              count_network_parameters(model.depthNet))
    if hasattr(model, 'rectifyNet'):
        print("Num. of rectifyNet model parameters is:",
              count_network_parameters(model.rectifyNet))
    print("Num. of discriminator model parameters is:",
          count_network_parameters(discrim))

    # -------------------------------------
    #  and heeere we go
    # -------------------------------------

    # discriminator pretrains for a certain # of epochs
    PRETRAINING_EPOCHS = 0

    training_losses = AverageMeter()
    auxiliary_data = []
    saved_total_loss = 10e10
    saved_total_PSNR = -1
    ikk = 0
    for kk in optimizer.param_groups:
        if kk['lr'] > 0:
            ikk = kk
            break

    d_real_label = Variable(torch.ones(1, ), requires_grad=False).cuda() * 0.5
    d_fake_label = Variable(torch.ones(1, ), requires_grad=False).cuda() * -0.5

    g_label_target = Variable(torch.stack([
        d_real_label,
        d_real_label,
        d_real_label,
        d_real_label,
    ],
                                          dim=0),
                              requires_grad=False).cuda()

    d_label_target = Variable(torch.stack((
        d_real_label,
        d_real_label,
        d_real_label,
        d_real_label,
        d_fake_label,
        d_fake_label,
        d_fake_label,
        d_fake_label,
    ),
                                          dim=0),
                              requires_grad=False).cuda()

    for t in range(args.numEpoch):
        print("The id of this in-training network is " + unique_id)

        if (t < PRETRAINING_EPOCHS):
            print("-- Discriminator pre-training epoch --")
        elif (t == PRETRAINING_EPOCHS):
            print("-- End discriminator pre-training --")

        # turn into training mode
        model = model.train()
        discrim = discrim.train()

        for i, (X0_half, X1_half, y_half) in enumerate(train_loader):
            loss_function = charbonnier_loss

            #if i >= 100:#
            if i >= int(len(train_set) / BATCH_SIZE):
                break

            if (t < PRETRAINING_EPOCHS and i >= 100):
                break

            #before_mod = sum([torch.mean(p) for p in model.parameters()]).item()
            #before_dsc = sum([torch.mean(p) for p in discrim.parameters()]).item()

            X0_half = X0_half.cuda()
            X1_half = X1_half.cuda()
            y_half = y_half.cuda()

            X0 = Variable(X0_half, requires_grad=False)
            X1 = Variable(X1_half, requires_grad=False)
            y = Variable(y_half, requires_grad=False)

            # placeholder variables
            discrim_total_loss = Variable(torch.zeros(1, 1)).cuda()
            model_pixel_loss = torch.zeros(1, )
            model_dsc_loss = torch.zeros(1, )
            total_loss = torch.zeros(1, )

            # --------------------------------------------
            #  train the interpolation network
            #      (using the cycle consistency method from Reda, et al.)
            # --------------------------------------------
            optimizer.zero_grad()

            def create_model_input(before, after):
                return torch.cat((torch.stack((before, after), dim=0), ),
                                 dim=1)

            # predict the frame between X0 and y
            model_input = torch.stack((X0, X1), dim=0)
            model_output = model(model_input)

            y_est = model_output[0:1]

            # concatenate real and fake so we can do everything in two forward passes
            discrim_batch = torch.cat((y.detach(), y_est.detach()), dim=0)

            if (t >= PRETRAINING_EPOCHS):
                # pixel losses
                model_pixel_loss = charbonnier_loss(y_est, y)

                # discriminator losses
                C_y, C_y_est = discrim(discrim_batch)

                # RaLSGAN loss. what is it minimizing? uhhh i don't f*****g know man
                C_diff = ((C_y - C_y_est - 1.0)**2 +
                          (C_y_est - C_y + 1.0)**2) / 2.0
                model_dsc_loss = C_diff

                total_loss = model_pixel_loss + 0.01 * model_dsc_loss

                total_loss.backward()
                optimizer.step()

            # --------------------------------------------
            #  train the discriminator
            # --------------------------------------------
            optimizer_discrim.zero_grad()

            C_y, C_y_est = discrim(discrim_batch)

            # discriminator's RaLSGAN loss is the reverse of the generator's
            C_diff = ((C_y_est - C_y - 1.0)**2 +
                      (C_y - C_y_est + 1.0)**2) / 2.0
            discrim_loss = C_diff

            discrim_loss.backward()
            optimizer_discrim.step()

            # --------------------------------------------
            #  finally, output some stuff
            # --------------------------------------------
            training_losses.update(total_loss.item(), BATCH_SIZE)
            if i % max(1, int(int(len(train_set) / BATCH_SIZE) / 500.0)) == 0:

                print(
                    "Ep [" + str(t) + "/" + str(i) + "]\tl.r.: " +
                    str(round(float(ikk['lr']), 7)) + "\tPix: " +
                    str([round(model_pixel_loss.item(), 5)]) +
                    #"\tFool: " + str(100 - round(np.sqrt(model_dsc_loss.item()) * 100, 5)) + "%" +
                    "\tFool: " + str(round(model_dsc_loss.item(), 5)) +
                    "\tTotal: " +
                    str([round(x.item(), 5) for x in [total_loss]]) +
                    #"\tDiscrim: " + str(100 - round(np.sqrt(discrim_loss.item()) * 100, 5)) + "%" +
                    "\tDiscrim: " + str(round(discrim_loss.item(), 5)) +
                    "\tAvg. Loss: " + str([round(training_losses.avg, 5)]))

        if (t < PRETRAINING_EPOCHS):
            continue

        torch.save(model.state_dict(),
                   args.save_path + "/epoch" + str(t) + ".pth")

        # print("\t\t**************Start Validation*****************")

        #Turn into evaluation mode

        val_total_losses = AverageMeter()
        val_total_pixel_loss = AverageMeter()
        val_total_PSNR_loss = AverageMeter()

        for i, (X0, X1, y) in enumerate(val_loader):
            if i >= int(len(test_set) / BATCH_SIZE):
                break

            with torch.no_grad():
                X0 = X0.cuda()
                X1 = X1.cuda()
                y = y.cuda()

                y_est = model(torch.stack((X0, X1), dim=0))

                y_diff = y_est - y
                pixel_loss = torch.mean(
                    torch.sqrt(y_diff * y_diff + args.epsilon * args.epsilon))

                val_total_loss = pixel_loss

                per_sample_pix_error = torch.mean(torch.mean(torch.mean(
                    y_diff**2, dim=1),
                                                             dim=1),
                                                  dim=1)
                per_sample_pix_error = per_sample_pix_error.data  # extract tensor
                psnr_loss = torch.mean(20 * torch.log(
                    1.0 / torch.sqrt(per_sample_pix_error))) / torch.log(
                        torch.Tensor([10]))

                val_total_losses.update(val_total_loss.item(), BATCH_SIZE)
                val_total_pixel_loss.update(pixel_loss.item(), BATCH_SIZE)
                val_total_PSNR_loss.update(psnr_loss[0], BATCH_SIZE)
                print(".", end='', flush=True)

        print("\nEpoch " + str(int(t)) + "\tlearning rate: " +
              str(float(ikk['lr'])) + "\tAvg Training Loss: " +
              str(round(training_losses.avg, 5)) + "\tValidate Loss: " +
              str([round(float(val_total_losses.avg), 5)]) +
              "\tValidate PSNR: " +
              str([round(float(val_total_PSNR_loss.avg), 5)]) +
              "\tPixel Loss: " +
              str([round(float(val_total_pixel_loss.avg), 5)]))

        auxiliary_data.append([
            t,
            float(ikk['lr']), training_losses.avg, val_total_losses.avg,
            val_total_pixel_loss.avg
        ])

        numpy.savetxt(args.log,
                      numpy.array(auxiliary_data),
                      fmt='%.8f',
                      delimiter=',')
        training_losses.reset()

        print("\t\tFinished an epoch, Check and Save the model weights")
        # we check the validation loss instead of training loss. OK~
        if saved_total_loss >= val_total_losses.avg:
            saved_total_loss = val_total_losses.avg
            torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
            print("\t\tBest Weights updated for decreased validation loss\n")

        else:
            print("\t\tWeights Not updated for undecreased validation loss\n")

        #schdule the learning rate
        #scheduler.step(val_total_losses.avg)

    print("*********Finish Training********")