def train(): torch.manual_seed(args.seed) model = networks.__dict__[args.netName](channel=args.channels, filter_size = args.filter_size , timestep=args.time_step, training=True) if args.use_cuda: print("Turn the model into CUDA") model = model.cuda() if not args.SAVED_MODEL==None: args.SAVED_MODEL ='/content/DAIN/model_weights'+ args.SAVED_MODEL + "/best" + ".pth" # args.SAVED_MODEL ='./model_weights/best.pth' print("Fine tuning on " + args.SAVED_MODEL) if not args.use_cuda: pretrained_dict = torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage) # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage)) else: pretrained_dict = torch.load(args.SAVED_MODEL) # model.load_state_dict(torch.load(args.SAVED_MODEL)) #print([k for k,v in pretrained_dict.items()]) model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) pretrained_dict = None if type(args.datasetName) == list: train_sets, test_sets = [],[] for ii, jj in zip(args.datasetName, args.datasetPath): tr_s, te_s = datasets.__dict__[ii](jj, split = args.dataset_split,single = args.single_output, task = args.task) train_sets.append(tr_s) test_sets.append(te_s) train_set = torch.utils.data.ConcatDataset(train_sets) test_set = torch.utils.data.ConcatDataset(test_sets) else: train_set, test_set = datasets.__dict__[args.datasetName](args.datasetPath) train_loader = torch.utils.data.DataLoader( train_set, batch_size = args.batch_size, sampler=balancedsampler.RandomBalancedSampler(train_set, int(len(train_set) / args.batch_size )), num_workers= args.workers, pin_memory=True if args.use_cuda else False) val_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True if args.use_cuda else False) print('{} samples found, {} train samples and {} test samples '.format(len(test_set)+len(train_set), len(train_set), len(test_set))) # if not args.lr == 0: print("train the interpolation net") optimizer = torch.optim.Adamax([ {'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr}, {'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr}, {'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr}, {'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr}, {'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr}, {'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr}, {'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr} ], lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, 'min',factor=args.factor, patience=args.patience,verbose=True) print("*********Start Training********") print("LR is: "+ str(float(optimizer.param_groups[0]['lr']))) print("EPOCH is: "+ str(int(len(train_set) / args.batch_size ))) print("Num of EPOCH is: "+ str(args.numEpoch)) def count_network_parameters(model): parameters = filter(lambda p: p.requires_grad, model.parameters()) N = sum([numpy.prod(p.size()) for p in parameters]) return N print("Num. of model parameters is :" + str(count_network_parameters(model))) if hasattr(model,'flownets'): print("Num. of flow model parameters is :" + str(count_network_parameters(model.flownets))) if hasattr(model,'initScaleNets_occlusion'): print("Num. of initScaleNets_occlusion model parameters is :" + str(count_network_parameters(model.initScaleNets_occlusion) + count_network_parameters(model.initScaleNets_occlusion1) + count_network_parameters(model.initScaleNets_occlusion2))) if hasattr(model,'initScaleNets_filter'): print("Num. of initScaleNets_filter model parameters is :" + str(count_network_parameters(model.initScaleNets_filter) + count_network_parameters(model.initScaleNets_filter1) + count_network_parameters(model.initScaleNets_filter2))) if hasattr(model, 'ctxNet'): print("Num. of ctxNet model parameters is :" + str(count_network_parameters(model.ctxNet))) if hasattr(model, 'depthNet'): print("Num. of depthNet model parameters is :" + str(count_network_parameters(model.depthNet))) if hasattr(model,'rectifyNet'): print("Num. of rectifyNet model parameters is :" + str(count_network_parameters(model.rectifyNet))) training_losses = AverageMeter() auxiliary_data = [] saved_total_loss = 10e10 saved_total_PSNR = -1 ikk = 0 for kk in optimizer.param_groups: if kk['lr'] > 0: ikk = kk break for t in range(args.numEpoch): print("The id of this in-training network is " + str(args.uid)) print(args) #Turn into training mode model = model.train() for i, (X0_half,X1_half, y_half) in enumerate(train_loader): if i >= int(len(train_set) / args.batch_size ): #(0 if t == 0 else EPOCH):# break X0_half = X0_half.cuda() if args.use_cuda else X0_half X1_half = X1_half.cuda() if args.use_cuda else X1_half y_half = y_half.cuda() if args.use_cuda else y_half X0 = Variable(X0_half, requires_grad= False) X1 = Variable(X1_half, requires_grad= False) y = Variable(y_half,requires_grad= False) diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0)) pixel_loss, offset_loss, sym_loss = part_loss(diffs,offsets,occlusions, [X0,X1],epsilon=args.epsilon) total_loss = sum(x*y if x > 0 else 0 for x,y in zip(args.alpha, pixel_loss)) training_losses.update(total_loss.item(), args.batch_size) if i % max(1, int(int(len(train_set) / args.batch_size )/500.0)) == 0: print("Ep [" + str(t) +"/" + str(i) + "]\tl.r.: " + str(round(float(ikk['lr']),7))+ "\tPix: " + str([round(x.item(),5) for x in pixel_loss]) + "\tTV: " + str([round(x.item(),4) for x in offset_loss]) + "\tSym: " + str([round(x.item(), 4) for x in sym_loss]) + "\tTotal: " + str([round(x.item(),5) for x in [total_loss]]) + "\tAvg. Loss: " + str([round(training_losses.avg, 5)])) optimizer.zero_grad() total_loss.backward() optimizer.step() if t == 1: # delete the pre validation weights for cleaner workspace if os.path.exists(args.save_path + "/epoch" + str(0) +".pth" ): os.remove(args.save_path + "/epoch" + str(0) +".pth") if os.path.exists(args.save_path + "/epoch" + str(t-1) +".pth"): os.remove(args.save_path + "/epoch" + str(t-1) +".pth") torch.save(model.state_dict(), args.save_path + "/epoch" + str(t) +".pth") # print("\t\t**************Start Validation*****************") #Turn into evaluation mode val_total_losses = AverageMeter() val_total_pixel_loss = AverageMeter() val_total_PSNR_loss = AverageMeter() val_total_tv_loss = AverageMeter() val_total_pws_loss = AverageMeter() val_total_sym_loss = AverageMeter() for i, (X0,X1,y) in enumerate(val_loader): if i >= int(len(test_set)/ args.batch_size): break with torch.no_grad(): X0 = X0.cuda() if args.use_cuda else X0 X1 = X1.cuda() if args.use_cuda else X1 y = y.cuda() if args.use_cuda else y diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0)) pixel_loss, offset_loss,sym_loss = part_loss(diffs, offsets, occlusions, [X0,X1],epsilon=args.epsilon) val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss)) per_sample_pix_error = torch.mean(torch.mean(torch.mean(diffs[args.save_which] ** 2, dim=1),dim=1),dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor psnr_loss = torch.mean(20 * torch.log(1.0/torch.sqrt(per_sample_pix_error)))/torch.log(torch.Tensor([10])) # val_total_losses.update(val_total_loss.item(),args.batch_size) val_total_pixel_loss.update(pixel_loss[args.save_which].item(), args.batch_size) val_total_tv_loss.update(offset_loss[0].item(), args.batch_size) val_total_sym_loss.update(sym_loss[0].item(), args.batch_size) val_total_PSNR_loss.update(psnr_loss[0],args.batch_size) print(".",end='',flush=True) print("\nEpoch " + str(int(t)) + "\tlearning rate: " + str(float(ikk['lr'])) + "\tAvg Training Loss: " + str(round(training_losses.avg,5)) + "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) + "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) + "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)]) + "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) + "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) + "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)]) ) auxiliary_data.append([t, float(ikk['lr']), training_losses.avg, val_total_losses.avg, val_total_pixel_loss.avg, val_total_tv_loss.avg,val_total_pws_loss.avg,val_total_sym_loss.avg]) numpy.savetxt(args.log, numpy.array(auxiliary_data), fmt='%.8f', delimiter=',') training_losses.reset() print("\t\tFinished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_total_losses.avg: saved_total_loss = val_total_losses.avg torch.save(model.state_dict(), args.save_path + "/best"+".pth") print("\t\tBest Weights updated for decreased validation loss\n") if os.path.exists("/content/model_weights")==True: shutil.rmtree("/content/model_weights") shutil.copytree("/content/DAIN/model_weights", "/content/model_weights") else: print("\t\tWeights Not updated for undecreased validation loss\n") #schdule the learning rate scheduler.step(val_total_losses.avg) print("*********Finish Training********")
### /SAVED MODEL if type(args.datasetName) == list: train_sets, test_sets = [],[] for ii, jj in zip(args.datasetName, args.datasetPath): tr_s, te_s = datasets.__dict__[ii](jj, split = args.dataset_split,single = args.single_output, task = args.task) train_sets.append(tr_s) test_sets.append(te_s) train_set = torch.utils.data.ConcatDataset(train_sets) test_set = torch.utils.data.ConcatDataset(test_sets) else: train_set, test_set = datasets.__dict__[args.datasetName](args.datasetPath) train_loader = torch.utils.data.DataLoader( train_set, batch_size = args.batch_size, sampler=balancedsampler.RandomBalancedSampler(train_set, int(len(train_set) / args.batch_size )), num_workers= args.workers, pin_memory=True if use_cuda else False) val_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True if use_cuda else False) print('{} samples found, {} train samples and {} test samples '.format(len(test_set)+len(train_set), len(train_set), len(test_set))) # if not args.lr == 0: print("train the interpolation net") optimizer = torch.optim.Adamax([ {'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr}, {'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr}, {'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr},
def main(): global args, best_EPE, save_path args = parser.parse_args() save_path = '{},{},{}epochs{},b{},lr{}'.format( args.arch, args.solver, args.epochs, ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '', args.batch_size, args.lr) if not args.no_date: timestamp = datetime.datetime.now().strftime("%a-%b-%d-%H:%M") save_path = os.path.join(timestamp, save_path) save_path = os.path.join(args.dataset, save_path) print('=> will save everything to {}'.format(save_path)) if not os.path.exists(save_path): os.makedirs(save_path) # Data loading code normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0, 0, 0], std=[255, 255, 255]), normalize ]) target_transform = transforms.Compose([ flow_transforms.ArrayToTensor(), transforms.Normalize(mean=[0, 0], std=[args.div_flow, args.div_flow]) ]) if 'KITTI' in args.dataset: co_transform = flow_transforms.Compose([ flow_transforms.RandomCrop((320, 448)), #random flips are not supported yet for tensor conversion, but will be #flow_transforms.RandomVerticalFlip(), #flow_transforms.RandomHorizontalFlip() ]) else: co_transform = flow_transforms.Compose([ flow_transforms.RandomTranslate(10), flow_transforms.RandomRotate(10, 5), flow_transforms.RandomCrop((320, 448)), #random flips are not supported yet for tensor conversion, but will be #flow_transforms.RandomVerticalFlip(), #flow_transforms.RandomHorizontalFlip() ]) print("=> fetching img pairs in '{}'".format(args.data)) train_set, test_set = datasets.__dict__[args.dataset]( args.data, transform=input_transform, target_transform=target_transform, co_transform=co_transform, split=args.split) print('{} samples found, {} train samples and {} test samples '.format( len(test_set) + len(train_set), len(train_set), len(test_set))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, sampler=balancedsampler.RandomBalancedSampler(train_set, args.epoch_size), num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True) # create model if args.pretrained: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) model = models.__dict__[args.arch](args.pretrained).cuda() model = torch.nn.DataParallel(model).cuda() criterion = multiscaleloss(sparse='KITTI' in args.dataset, loss=args.loss).cuda() high_res_EPE = multiscaleloss(scales=1, downscale=4, weights=(1), loss='L1', sparse='KITTI' in args.dataset).cuda() cudnn.benchmark = True assert (args.solver in ['adam', 'sgd']) print('=> setting {} solver'.format(args.solver)) if args.solver == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) elif args.solver == 'sgd': optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.evaluate: best_EPE = validate(val_loader, model, criterion, high_res_EPE) return with open(os.path.join(save_path, args.log_summary), 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'train_EPE', 'EPE']) with open(os.path.join(save_path, args.log_full), 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'train_EPE']) for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch) # train for one epoch train_loss, train_EPE = train(train_loader, model, criterion, high_res_EPE, optimizer, epoch) # evaluate o validation set EPE = validate(val_loader, model, criterion, high_res_EPE) if best_EPE < 0: best_EPE = EPE # remember best prec@1 and save checkpoint is_best = EPE < best_EPE best_EPE = min(EPE, best_EPE) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.module.state_dict(), 'best_EPE': best_EPE, }, is_best) with open(os.path.join(save_path, args.log_summary), 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, train_EPE, EPE])
def train(): torch.manual_seed(args.seed) model = networks.__dict__[args.netName](channel=args.channels, filter_size=args.filter_size, timestep=args.time_step, training=True) original_model = networks.__dict__[args.netName]( channel=args.channels, filter_size=args.filter_size, timestep=args.time_step, training=True) if args.use_cuda: print("Turn the model into CUDA") model = model.cuda() original_model = original_model.cuda() if not args.SAVED_MODEL == None: args.SAVED_MODEL = './model_weights/' + args.SAVED_MODEL + "/best" + ".pth" print("Fine tuning on " + args.SAVED_MODEL) if not args.use_cuda: pretrained_dict = torch.load( args.SAVED_MODEL, map_location=lambda storage, loc: storage) # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage)) else: pretrained_dict = torch.load(args.SAVED_MODEL) # model.load_state_dict(torch.load(args.SAVED_MODEL)) #print([k for k,v in pretrained_dict.items()]) model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) # For comparison in meta training original_model.load_state_dict(model_dict) pretrained_dict = None if type(args.datasetName) == list: train_sets, test_sets = [], [] for ii, jj in zip(args.datasetName, args.datasetPath): tr_s, te_s = datasets.__dict__[ii](jj, split=args.dataset_split, single=args.single_output, task=args.task) train_sets.append(tr_s) test_sets.append(te_s) train_set = torch.utils.data.ConcatDataset(train_sets) test_set = torch.utils.data.ConcatDataset(test_sets) else: train_set, test_set = datasets.__dict__[args.datasetName]( args.datasetPath) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, sampler=balancedsampler.RandomBalancedSampler( train_set, int(len(train_set) / args.batch_size)), num_workers=args.workers, pin_memory=True if args.use_cuda else False) val_loader = torch.utils.data.DataLoader( test_set, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True if args.use_cuda else False) print('{} samples found, {} train samples and {} test samples '.format( len(test_set) + len(train_set), len(train_set), len(test_set))) # if not args.lr == 0: print("train the interpolation net") '''optimizer = torch.optim.Adamax([ #{'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr}, #{'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr}, #{'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr}, #{'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr}, #{'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr}, #{'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr}, {'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr} ], #lr=args.lr, momentum=0, weight_decay=args.weight_decay) lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay)''' optimizer = torch.optim.Adamax(model.rectifyNet.parameters(), lr=args.outer_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) # Fix weights for early layers for param in model.initScaleNets_filter.parameters(): param.requires_grad = False for param in model.initScaleNets_filter1.parameters(): param.requires_grad = False for param in model.initScaleNets_filter2.parameters(): param.requires_grad = False for param in model.ctxNet.parameters(): param.requires_grad = False for param in model.flownets.parameters(): param.requires_grad = False for param in model.depthNet.parameters(): param.requires_grad = False scheduler = ReduceLROnPlateau(optimizer, 'min', factor=args.factor, patience=args.patience, verbose=True) print("*********Start Training********") print("LR is: " + str(float(optimizer.param_groups[0]['lr']))) print("EPOCH is: " + str(int(len(train_set) / args.batch_size))) print("Num of EPOCH is: " + str(args.numEpoch)) def count_network_parameters(model): parameters = filter(lambda p: p.requires_grad, model.parameters()) N = sum([numpy.prod(p.size()) for p in parameters]) return N print("Num. of model parameters is :" + str(count_network_parameters(model))) if hasattr(model, 'flownets'): print("Num. of flow model parameters is :" + str(count_network_parameters(model.flownets))) if hasattr(model, 'initScaleNets_occlusion'): print("Num. of initScaleNets_occlusion model parameters is :" + str( count_network_parameters(model.initScaleNets_occlusion) + count_network_parameters(model.initScaleNets_occlusion1) + count_network_parameters(model.initScaleNets_occlusion2))) if hasattr(model, 'initScaleNets_filter'): print("Num. of initScaleNets_filter model parameters is :" + str( count_network_parameters(model.initScaleNets_filter) + count_network_parameters(model.initScaleNets_filter1) + count_network_parameters(model.initScaleNets_filter2))) if hasattr(model, 'ctxNet'): print("Num. of ctxNet model parameters is :" + str(count_network_parameters(model.ctxNet))) if hasattr(model, 'depthNet'): print("Num. of depthNet model parameters is :" + str(count_network_parameters(model.depthNet))) if hasattr(model, 'rectifyNet'): print("Num. of rectifyNet model parameters is :" + str(count_network_parameters(model.rectifyNet))) training_losses = AverageMeter() #original_training_losses = AverageMeter() batch_time = AverageMeter() auxiliary_data = [] saved_total_loss = 10e10 saved_total_PSNR = -1 ikk = 0 for kk in optimizer.param_groups: if kk['lr'] > 0: ikk = kk break for t in range(args.numEpoch): print("The id of this in-training network is " + str(args.uid)) print(args) print("Learning rate for this epoch: %s" % str(round(float(ikk['lr']), 7))) #Turn into training mode model = model.train() #for i, (X0_half,X1_half, y_half) in enumerate(train_loader): _t = time.time() for i, images in enumerate(train_loader): if i >= min(TRAIN_ITER_CUT, int(len(train_set) / args.batch_size)): #(0 if t == 0 else EPOCH):# break if args.use_cuda: images = [im.cuda() for im in images] images = [Variable(im, requires_grad=False) for im in images] # For VimeoTriplet #X0, y, X1 = images[0], images[1], images[2] # For VimeoSepTuplet X0, y, X1 = images[2], images[3], images[4] outerstepsize = args.outer_lr k = args.num_inner_update # inner loop update iteration inner_optimizer = torch.optim.Adamax( model.rectifyNet.parameters(), lr=args.inner_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) if META_ALGORITHM == "Reptile": # Reptile setting weights_before = copy.deepcopy(model.state_dict()) for _k in range(k): indices = [[0, 2, 4], [2, 4, 6], [2, 3, 4], [0, 1, 2], [4, 5, 6]] total_loss = 0 for ind in indices: meta_X0, meta_y, meta_X1 = images[ind[0]].clone( ), images[ind[1]].clone(), images[ind[2]].clone() diffs, offsets, filters, occlusions = model( torch.stack((meta_X0, meta_y, meta_X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [meta_X0, meta_X1], epsilon=args.epsilon) _total_loss = sum( x * y if x > 0 else 0 for x, y in zip(args.alpha, pixel_loss)) total_loss = total_loss + _total_loss # total *= 2 / len(indices) inner_optimizer.zero_grad() total_loss.backward() inner_optimizer.step() # Reptile update weights_after = model.state_dict() model.load_state_dict({ name: weights_before[name] + (weights_after[name] - weights_before[name]) * outerstepsize for name in weights_before }) with torch.no_grad(): diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon) total_loss = sum(x * y if x > 0 else 0 for x, y in zip(args.alpha, pixel_loss)) training_losses.update(total_loss.item(), args.batch_size) elif META_ALGORITHM == "MAML": #weights_before = copy.deepcopy(model.state_dict()) base_model = copy.deepcopy(model) #fast_weights = list(filter(lambda p: p.requires_grad, model.parameters())) for _k in range(k): indices = [[0, 2, 4], [2, 4, 6]] support_loss = 0 for ind in indices: meta_X0, meta_y, meta_X1 = images[ind[0]].clone( ), images[ind[1]].clone(), images[ind[2]].clone() diffs, offsets, filters, occlusions = model( torch.stack((meta_X0, meta_y, meta_X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [meta_X0, meta_X1], epsilon=args.epsilon) _total_loss = sum( x * y if x > 0 else 0 for x, y in zip(args.alpha, pixel_loss)) support_loss = support_loss + _total_loss #grad = torch.autograd.grad(loss, fast_weights) #fast_weights = list(map(lambda p: p[1] - args.lr * p[0], zip(grad, fast_weights))) inner_optimizer.zero_grad() support_loss.backward() # create_graph=True inner_optimizer.step() # Forward on query set diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon) total_loss = sum(x * y if x > 0 else 0 for x, y in zip(args.alpha, pixel_loss)) training_losses.update(total_loss.item(), args.batch_size) # copy parameters to comnnect the computational graph for param, base_param in zip( model.rectifyNet.parameters(), base_model.rectifyNet.parameters()): param.data = base_param.data filtered_params = filter(lambda p: p.requires_grad, model.parameters()) optimizer.zero_grad() grads = torch.autograd.grad(total_loss, list( filtered_params)) # backward on weights_before: FO-MAML j = 0 #print('[before update]') #print(list(model.parameters())[45][-1]) for _i, param in enumerate(model.parameters()): if param.requires_grad: #param = param - outerstepsize * grads[j] param.grad = grads[j] j += 1 optimizer.step() #print('[after optim.step]') #print(list(model.parameters())[45][-1]) batch_time.update(time.time() - _t) _t = time.time() if i % 100 == 0: #max(1, int(int(len(train_set) / args.batch_size )/500.0)) == 0: print( "Ep[%s][%05d/%d] Time: %.2f Pix: %s TV: %s Sym: %s Total: %s Avg. Loss: %s" % (str(t), i, int(len(train_set)) // args.batch_size, batch_time.avg, str([round(x.item(), 5) for x in pixel_loss ]), str([round(x.item(), 4) for x in offset_loss]), str([round(x.item(), 4) for x in sym_loss ]), str([round(x.item(), 5) for x in [total_loss] ]), str([round(training_losses.avg, 5)]))) batch_time.reset() if t == 1: # delete the pre validation weights for cleaner workspace if os.path.exists(args.save_path + "/epoch" + str(0) + ".pth"): os.remove(args.save_path + "/epoch" + str(0) + ".pth") if os.path.exists(args.save_path + "/epoch" + str(t - 1) + ".pth"): os.remove(args.save_path + "/epoch" + str(t - 1) + ".pth") torch.save(model.state_dict(), args.save_path + "/epoch" + str(t) + ".pth") # print("\t\t**************Start Validation*****************") #Turn into evaluation mode val_total_losses = AverageMeter() val_total_pixel_loss = AverageMeter() val_total_PSNR_loss = AverageMeter() val_total_tv_loss = AverageMeter() val_total_pws_loss = AverageMeter() val_total_sym_loss = AverageMeter() for i, (images, imgpaths) in enumerate(tqdm(val_loader)): #if i < 50: #i < 11 or (i > 14 and i < 50): # continue if i >= min(VAL_ITER_CUT, int(len(test_set) / args.batch_size)): break if args.use_cuda: images = [im.cuda() for im in images] #X0, y, X1 = images[0], images[1], images[2] #X0, y, X1 = images[2], images[3], images[4] # define optimizer to update the inner loop inner_optimizer = torch.optim.Adamax( model.rectifyNet.parameters(), lr=args.inner_lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) # Reptile testing - save base model weights weights_base = copy.deepcopy(model.state_dict()) k = args.num_inner_update # 2 model.train() for _k in range(k): indices = [[0, 2, 4], [2, 4, 6]] ind = indices[_k % 2] meta_X0, meta_y, meta_X1 = crop(images[ind[0]]), crop( images[ind[1]]), crop(images[ind[2]]) diffs, offsets, filters, occlusions, _ = model( torch.stack((meta_X0, meta_y, meta_X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [meta_X0, meta_X1], epsilon=args.epsilon) total_loss = sum(x * y if x > 0 else 0 for x, y in zip(args.alpha, pixel_loss)) inner_optimizer.zero_grad() total_loss.backward() inner_optimizer.step() # Actual target validation performance with torch.no_grad(): if args.datasetName == 'Vimeo_90K_sep': X0, y, X1 = images[2], images[3], images[4] #diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0)) diffs, offsets, filters, occlusions, output = model( torch.stack((X0, y, X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon) val_total_loss = sum( x * y for x, y in zip(args.alpha, pixel_loss)) per_sample_pix_error = torch.mean(torch.mean(torch.mean( diffs[args.save_which]**2, dim=1), dim=1), dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor psnr_loss = torch.mean(20 * torch.log( 1.0 / torch.sqrt(per_sample_pix_error))) / torch.log( torch.Tensor([10])) val_total_losses.update(val_total_loss.item(), args.batch_size) val_total_pixel_loss.update( pixel_loss[args.save_which].item(), args.batch_size) val_total_tv_loss.update(offset_loss[0].item(), args.batch_size) val_total_sym_loss.update(sym_loss[0].item(), args.batch_size) val_total_PSNR_loss.update(psnr_loss[0], args.batch_size) else: # HD_dataset testing for j in range(len(images) // 2): mH, mW = 720, 1280 X0, y, X1 = crop(images[2 * j], maxH=mH, maxW=mW), crop(images[2 * j + 1], maxH=mH, maxW=mW), crop( images[2 * j + 2], maxH=mH, maxW=mW) diffs, offsets, filters, occlusions, output = model( torch.stack((X0, y, X1), dim=0)) pixel_loss, offset_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon) val_total_loss = sum( x * y for x, y in zip(args.alpha, pixel_loss)) per_sample_pix_error = torch.mean(torch.mean( torch.mean(diffs[args.save_which]**2, dim=1), dim=1), dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor psnr_loss = torch.mean( 20 * torch.log(1.0 / torch.sqrt(per_sample_pix_error)) ) / torch.log(torch.Tensor([10])) val_total_losses.update(val_total_loss.item(), args.batch_size) val_total_pixel_loss.update( pixel_loss[args.save_which].item(), args.batch_size) val_total_tv_loss.update(offset_loss[0].item(), args.batch_size) val_total_sym_loss.update(sym_loss[0].item(), args.batch_size) val_total_PSNR_loss.update(psnr_loss[0], args.batch_size) # Reset model to its base weights model.load_state_dict(weights_base) #del weights_base, inner_optimizer, meta_X0, meta_y, meta_X1, X0, y, X1, pixel_loss, offset_loss, sym_loss, total_loss, val_total_loss, diffs, offsets, filters, occlusions VIZ = False exp_name = 'meta_test' if VIZ: for b in range(images[0].size(0)): imgpath = imgpaths[0][b] savepath = os.path.join('checkpoint', exp_name, 'vimeoSeptuplet', imgpath.split('/')[-3], imgpath.split('/')[-2]) if not os.path.exists(savepath): os.makedirs(savepath) img_pred = (output[b].data.permute(1, 2, 0).clamp_( 0, 1).cpu().numpy()[..., ::-1] * 255).astype( numpy.uint8) cv2.imwrite(os.path.join(savepath, 'im2_pred.png'), img_pred) ''' # Original validation (not meta) with torch.no_grad(): if args.use_cuda: images = [im.cuda() for im in images] #X0, y, X1 = images[0], images[1], images[2] X0, y, X1 = images[2], images[3], images[4] #diffs, offsets,filters,occlusions = model(torch.stack((X0,y,X1),dim = 0)) pixel_loss, offset_loss,sym_loss = part_loss(diffs, offsets, occlusions, [X0,X1],epsilon=args.epsilon) val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss)) per_sample_pix_error = torch.mean(torch.mean(torch.mean(diffs[args.save_which] ** 2, dim=1),dim=1),dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor psnr_loss = torch.mean(20 * torch.log(1.0/torch.sqrt(per_sample_pix_error)))/torch.log(torch.Tensor([10])) # val_total_losses.update(val_total_loss.item(),args.batch_size) val_total_pixel_loss.update(pixel_loss[args.save_which].item(), args.batch_size) val_total_tv_loss.update(offset_loss[0].item(), args.batch_size) val_total_sym_loss.update(sym_loss[0].item(), args.batch_size) val_total_PSNR_loss.update(psnr_loss[0],args.batch_size) print(".",end='',flush=True) ''' print("\nEpoch " + str(int(t)) + "\tlearning rate: " + str(float(ikk['lr'])) + "\tAvg Training Loss: " + str(round(training_losses.avg, 5)) + "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) + "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) + "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)]) + "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) + "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) + "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)])) auxiliary_data.append([ t, float(ikk['lr']), training_losses.avg, val_total_losses.avg, val_total_pixel_loss.avg, val_total_tv_loss.avg, val_total_pws_loss.avg, val_total_sym_loss.avg ]) numpy.savetxt(args.log, numpy.array(auxiliary_data), fmt='%.8f', delimiter=',') training_losses.reset() #original_training_losses.reset() print("\t\tFinished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_total_losses.avg: saved_total_loss = val_total_losses.avg torch.save(model.state_dict(), args.save_path + "/best" + ".pth") print("\t\tBest Weights updated for decreased validation loss\n") else: print("\t\tWeights Not updated for undecreased validation loss\n") #schdule the learning rate scheduler.step(val_total_losses.avg) print("*********Finish Training********")
def train(): # ============================================================== # # Init Visdom # # ============================================================== # viz_env = args.vis_env viz = Visdom(env=viz_env) viz.line([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [0.], win='train_respective_loss', env=viz_env, opts=dict(title='train_respective_loss', legend=[ 'pixel_loss_0', 'pixel_loss_1', 'offset_loss', 'occlusion_loss', 'sym_loss', 'total_loss', 'total_loss_avg' ])) viz.line([[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], [0.], win='val_respective_loss', env=viz_env, opts=dict(title='epoch_val_respective_loss', legend=[ 'training_losses', 'val_total_losses', 'val_total_PSNR_loss', 'val_total_pixel_loss', 'val_total_tv_loss', 'val_total_pws_loss', 'val_total_sym_loss' ])) # viz.line([[0.0, 0.0]], [0.], win='validation_psnr', env=viz_env, # opts=dict(title='val psnr', legend=['Resotred psnr', 'Blurry psnr'])) # viz.line([[0.0, 0.0]], [0.], win='validation_ssim', env=viz_env, # opts=dict(title='val ssim', legend=['Restored ssim', 'Blurry ssim'])) torch.manual_seed(args.seed) model = networks.__dict__[args.netName]( batch=args.batch_size, channel=args.channels, width=None, height=None, scale_num=1, scale_ratio=2, temporal=False, filter_size=args.filter_size, save_which=args.save_which, flowmethod=args.flowmethod, timestep=args.time_step, FlowProjection_threshhold=args.flowproj_threshhold, offset_scale=None, cuda_available=args.use_cuda, cuda_id=None, training=True) if args.use_cuda: print("Turn the model into CUDA") model = model.cuda() if not args.SAVED_MODEL == None: args.SAVED_MODEL = '../model_weights/' + args.SAVED_MODEL + "/best" + ".pth" print("Fine tuning on " + args.SAVED_MODEL) if not args.use_cuda: pretrained_dict = torch.load( args.SAVED_MODEL, map_location=lambda storage, loc: storage) # model.load_state_dict(torch.load(args.SAVED_MODEL, map_location=lambda storage, loc: storage)) else: pretrained_dict = torch.load(args.SAVED_MODEL) # model.load_state_dict(torch.load(args.SAVED_MODEL)) #print([k for k,v in pretrained_dict.items()]) model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # and not k[:10]== 'rectifyNet'} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) pretrained_dict = None # torch.save(model.depthNet.state_dict(), "8402_best_depth" + ".pth") if type(args.datasetName) == list: train_sets, test_sets = [], [] for ii, jj in zip(args.datasetName, args.datasetPath): tr_s, te_s = datasets.__dict__[ii]( jj, split=args.dataset_split, single=args.single_output, task=args.task, middle=args.time_step == 0.5, high_fps=args.high_fps) # if time_step = 0.5, only use middle train_sets.append(tr_s) test_sets.append(te_s) train_set = torch.utils.data.ConcatDataset(train_sets) test_set = torch.utils.data.ConcatDataset(test_sets) else: train_set, test_set = datasets.__dict__[args.datasetName]( args.datasetPath, split=args.dataset_split, single=args.single_output, task=args.task, middle=args.time_step == 0.5, high_fps=args.high_fps) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, sampler=balancedsampler.RandomBalancedSampler( train_set, int(len(train_set) / args.batch_size)), # RandomBalancedSampler(train_set,args.epoch_size), num_workers=args.workers, pin_memory=True if args.use_cuda else False) val_loader = torch.utils.data.DataLoader( test_set, batch_size=args.batch_size, # sampler=balancedsampler.SequentialBalancedSampler(test_set,) num_workers=args.workers, pin_memory=True if args.use_cuda else False) print('{} samples found, {} train samples and {} test samples '.format( len(test_set) + len(train_set), len(train_set), len(test_set))) # to skip the fixed parameters of vgg model, we need to filter them out... # for param in model.parameters(): # print(type(param.data), param.size()) # for idx, m in enumerate(model.named_modules()): # print(idx, '->', m) # optimizer = torch.optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) # optimizer = torch.optim.Adamax(model.parameters(), lr=args.lr, betas=(0.9, 0.999), eps=1e-8,args.weight_decay=args.weight_decay) if not args.lr == 0: print("train the interpolation net") if args.netName == 'DAIN': optimizer = torch.optim.Adamax( [{ 'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr }, { 'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr }, { 'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr }, { 'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr }], lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) else: print("Only train the rectifyNet") optimizer = torch.optim.Adamax([{ 'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr }], lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=args.factor, patience=args.patience, verbose=True) print("*********Start Training********") print("LR is: " + str(float(optimizer.param_groups[0]['lr']))) print("EPOCH is: " + str(int(len(train_set) / args.batch_size))) print("Num of EPOCH is: " + str(args.numEpoch)) def count_network_parameters(model): parameters = filter(lambda p: p.requires_grad, model.parameters()) N = sum([numpy.prod(p.size()) for p in parameters]) return N print("Num. of model parameters is :" + str(count_network_parameters(model))) if hasattr(model, 'flownets'): print("Num. of flow model parameters is :" + str(count_network_parameters(model.flownets))) if hasattr(model, 'initScaleNets_occlusion'): print("Num. of initScaleNets_occlusion model parameters is :" + str( count_network_parameters(model.initScaleNets_occlusion) + count_network_parameters(model.initScaleNets_occlusion1) + count_network_parameters(model.initScaleNets_occlusion2))) if hasattr(model, 'initScaleNets_filter'): print("Num. of initScaleNets_filter model parameters is :" + str( count_network_parameters(model.initScaleNets_filter) + count_network_parameters(model.initScaleNets_filter1) + count_network_parameters(model.initScaleNets_filter2))) if hasattr(model, 'ctxNet'): print("Num. of ctxNet model parameters is :" + str(count_network_parameters(model.ctxNet))) if hasattr(model, 'depthNet'): print("Num. of depthNet model parameters is :" + str(count_network_parameters(model.depthNet))) if hasattr(model, 'rectifyNet'): print("Num. of rectifyNet model parameters is :" + str(count_network_parameters(model.rectifyNet))) if hasattr(model, 'fea_exat_net'): print("Num. of fea_exat_net model parameters is :" + str(count_network_parameters(model.fea_exat_net))) training_losses = AverageMeter() auxiliary_data = [] saved_total_loss = 10e10 saved_total_PSNR = -1 saved_total_loss_MB = 10e10 MB_avgLoss, MB_avgPSNR = 1e5, 0 ikk = 0 for kk in optimizer.param_groups: if kk['lr'] > 0: ikk = kk break for t in range(args.numEpoch): print("The id of this in-training network is " + str(args.uid)) print(args) #Turn into training mode model = model.train() for i, (X0_half, X1_half, y_half, frame_index) in enumerate(train_loader): if i >= (args.N_iter * int(len(train_set) / args.batch_size) ): #(0 if t == 0 else EPOCH):# break X0_half = X0_half.cuda() if args.use_cuda else X0_half X1_half = X1_half.cuda() if args.use_cuda else X1_half # middle y_half = y_half.cuda() if args.use_cuda else y_half X0 = Variable(X0_half, requires_grad=False) X1 = Variable(X1_half, requires_grad=False) y = Variable(y_half, requires_grad=False) if args.netName == 'MultiScaleStructure_filt_flo_ctxS2D_depth_Modeling3': diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0)) else: diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0), frame_index) pixel_loss, offset_loss, occlusion_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon, use_negPSNR=args.use_negPSNR) DF_loss = df_loss_func(offsets, occlusions) total_loss = sum(x*y if x > 0 else 0 for x,y in zip(args.alpha, pixel_loss)) + sum(x*y for x,y in zip(args.lambda1, offset_loss) ) + \ sum(x*y if x > 0 else 0 for x,y in zip(args.lambda2, occlusion_loss)) + \ sum(x*y if x > 0 else 0 for x,y in zip(args.lambda3, sym_loss)) +\ sum(x*y if x>0 else 0 for x,y in zip(args.lambda4, [DF_loss])) training_losses.update(total_loss.item(), args.batch_size) #.item(), if i % max(1, int( int(len(train_set) / args.batch_size) / 500.0)) == 0: pstring = "Ep [" + str(t) +"/" + str(i) + \ "]\tl.r.: " + str(round(float(ikk['lr']),7))+ \ "\tPix: " + str([round(x.item(),5) for x in pixel_loss]) + \ "\tTV: " + str([round(x.item(),4) for x in offset_loss]) + \ "\tPWS: " + str([round(x.item(), 4) for x in occlusion_loss]) + \ "\tSym: " + str([round(x.item(), 4) for x in sym_loss]) + \ "\tTotal: " + str([round(x.item(),5) for x in [total_loss]]) + \ "\tAvg. Loss: " + str([round(training_losses.avg, 5)]) print(pstring) print(pstring, file=open(os.path.join(args.save_path, "all_log.txt"), "a")) # visdom display itr = i + t * (args.N_iter * int(len(train_set) / args.batch_size)) viz.line([[ pixel_loss[0].item(), pixel_loss[1].item(), offset_loss[0].item(), occlusion_loss[0].item(), sym_loss[0].item(), total_loss.item(), training_losses.avg ]], [itr], win='train_respective_loss', env=viz_env, opts=dict(title='train_respective_loss', legend=[ 'pixel_loss_0', 'pixel_loss_1', 'offset_loss', 'occlusion_loss', 'sym_loss', 'total_loss', 'total_loss_avg' ]), update='append') optimizer.zero_grad() total_loss.backward() optimizer.step() if t == 1: # delete the pre validation weights for cleaner workspace if os.path.exists(args.save_path + "/epoch" + str(0) + ".pth"): os.remove(args.save_path + "/epoch" + str(0) + ".pth") if os.path.exists(args.save_path + "/epoch" + str(t - 1) + ".pth"): os.remove(args.save_path + "/epoch" + str(t - 1) + ".pth") torch.save(model.state_dict(), args.save_path + "/epoch" + str(t) + ".pth") # print("\t\t**************Start Validation*****************") #Turn into evaluation mode val_total_losses = AverageMeter() val_total_pixel_loss = AverageMeter() val_total_PSNR_loss = AverageMeter() val_total_tv_loss = AverageMeter() val_total_pws_loss = AverageMeter() val_total_sym_loss = AverageMeter() for i, (X0, X1, y, frame_index) in enumerate(val_loader): if i >= int(len(test_set) / args.batch_size): break with torch.no_grad(): X0 = X0.cuda() if args.use_cuda else X0 X1 = X1.cuda() if args.use_cuda else X1 y = y.cuda() if args.use_cuda else y if args.netName == 'MultiScaleStructure_filt_flo_ctxS2D_depth_Modeling3': diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0)) else: diffs, offsets, filters, occlusions = model( torch.stack((X0, y, X1), dim=0), frame_index) pixel_loss, offset_loss, occlusion_loss, sym_loss = part_loss( diffs, offsets, occlusions, [X0, X1], epsilon=args.epsilon, use_negPSNR=args.use_negPSNR) val_total_loss = sum(x * y for x, y in zip(args.alpha, pixel_loss)) + \ sum(x * y for x, y in zip(args.lambda1, offset_loss)) + \ sum(x * y for x, y in zip(args.lambda2, occlusion_loss)) + \ sum(x * y for x, y in zip(args.lambda3, sym_loss)) per_sample_pix_error = torch.mean(torch.mean(torch.mean( diffs[args.save_which]**2, dim=1), dim=1), dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor # print(per_sample_pix_error.size()) # print(per_sample_pix_error.type()) psnr_loss = torch.mean(20 * torch.log( 1.0 / torch.sqrt(per_sample_pix_error))) / torch.log( torch.Tensor([10])) val_total_losses.update(val_total_loss.item(), args.batch_size) val_total_pixel_loss.update(pixel_loss[args.save_which].item(), args.batch_size) val_total_tv_loss.update(offset_loss[0].item(), args.batch_size) val_total_pws_loss.update(occlusion_loss[0].item(), args.batch_size) val_total_sym_loss.update(sym_loss[0].item(), args.batch_size) val_total_PSNR_loss.update(psnr_loss[0], args.batch_size) print(".", end='', flush=True) pstring = "\nEpoch " + str(int(t)) + \ "\tlearning rate: " + str(float(ikk['lr'])) + \ "\tAvg Training Loss: " + str(round(training_losses.avg, 5)) + \ "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) + \ "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) + \ "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)]) + \ "\tTV Loss: " + str([round(float(val_total_tv_loss.avg), 4)]) + \ "\tPWS Loss: " + str([round(float(val_total_pws_loss.avg), 4)]) + \ "\tSym Loss: " + str([round(float(val_total_sym_loss.avg), 4)]) print(pstring) print(pstring, file=open(os.path.join(args.save_path, "all_log.txt"), "a")) # visdom viz.line([[ training_losses.avg, val_total_losses.avg, val_total_PSNR_loss.avg, val_total_pixel_loss.avg, val_total_tv_loss.avg, val_total_pws_loss.avg, val_total_sym_loss.avg ]], [int(t)], win='val_respective_loss', env=viz_env, opts=dict(title='epoch_val_respective_loss', legend=[ 'training_losses', 'val_total_losses', 'val_total_PSNR_loss', 'val_total_pixel_loss', 'val_total_tv_loss', 'val_total_pws_loss', 'val_total_sym_loss' ]), update='append') # todo MB_avgLoss = 0 MB_avgPSNR = 0 auxiliary_data.append([ t, float(ikk['lr']), training_losses.avg, val_total_losses.avg, val_total_pixel_loss.avg, val_total_tv_loss.avg, val_total_pws_loss.avg, val_total_sym_loss.avg, MB_avgLoss, MB_avgPSNR ]) numpy.savetxt(args.log, numpy.array(auxiliary_data), fmt='%.8f', delimiter=',') training_losses.reset() print("\t\tFinished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_total_losses.avg: saved_total_loss = val_total_losses.avg torch.save(model.state_dict(), args.save_path + "/best" + ".pth") print("\t\tBest Weights updated for decreased validation loss\n") else: print("\t\tWeights Not updated for undecreased validation loss\n") if saved_total_PSNR <= val_total_PSNR_loss.avg: saved_total_PSNR = val_total_PSNR_loss.avg # torch.save(model,MODEL_PATH) # model.save_state_dict(MODEL_PATH) torch.save(model.state_dict(), args.save_path + "/bestPSNR" + ".pth") print( "\t\tBest Weights updated for increased validation PSNR \n\n") else: print( "\t\tWeights Not updated for unincreased validation PSNR\n\n") #schdule the learning rate scheduler.step(val_total_losses.avg) print("*********Finish Training********")
def train(): SAVED_MODEL_PATH = "./model_weights/pretrained.pth" DATA_PATH = "./pixel_triplets/" BATCH_SIZE = 1 torch.manual_seed(1337) random.seed(1337) # ------------------------------------- # load pre-trained model # ------------------------------------- model = networks.DAIN(channel=3, filter_size=4, timestep=0.5, training=False, pixel_model=True) model = model.cuda() print("Fine tuning on " + SAVED_MODEL_PATH) pretrained_dict = torch.load(SAVED_MODEL_PATH) model_dict = model.state_dict() # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) pretrained_dict = [] # ------------------------------------- # create discriminator # ------------------------------------- discrim = Discriminator() discrim = discrim.cuda() # discriminator optimizer and loss optimizer_discrim = torch.optim.Adam(discrim.parameters(), lr=0.0005) # ------------------------------------- # create dataset loaders # ------------------------------------- train_set, test_set = datasets.pixel_triplets(DATA_PATH) train_loader = torch.utils.data.DataLoader( train_set, batch_size=BATCH_SIZE, sampler=balancedsampler.RandomBalancedSampler( train_set, int(len(train_set) / BATCH_SIZE)), num_workers=8, pin_memory=True) val_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, num_workers=8, pin_memory=True) print('{} samples found, {} train samples and {} test samples '.format( len(test_set) + len(train_set), len(train_set), len(test_set))) # ------------------------------------- # create optimizer / LR scheduler # ------------------------------------- print("train the interpolation net") optimizer = torch.optim.Adamax( [{ 'params': model.initScaleNets_filter.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.initScaleNets_filter1.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.initScaleNets_filter2.parameters(), 'lr': args.filter_lr_coe * args.lr }, { 'params': model.ctxNet.parameters(), 'lr': args.ctx_lr_coe * args.lr }, { 'params': model.flownets.parameters(), 'lr': args.flow_lr_coe * args.lr }, { 'params': model.depthNet.parameters(), 'lr': args.depth_lr_coe * args.lr }, { 'params': model.rectifyNet.parameters(), 'lr': args.rectify_lr }], lr=args.lr, betas=(0.9, 0.999), eps=1e-8, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, 'min', factor=args.factor, patience=args.patience, verbose=True) # ------------------------------------- # print out some info before we start # ------------------------------------- print("*********Start Training********") print("LR is: " + str(float(optimizer.param_groups[0]['lr']))) print("EPOCH is: " + str(int(len(train_set) / BATCH_SIZE))) print("Num of EPOCH is: " + str(args.numEpoch)) def count_network_parameters(model): parameters = filter(lambda p: p.requires_grad, model.parameters()) N = sum([numpy.prod(p.size()) for p in parameters]) return N print("Num. of model parameters is:", count_network_parameters(model)) if hasattr(model, 'flownets'): print("Num. of flow model parameters is:", count_network_parameters(model.flownets)) if hasattr(model, 'initScaleNets_occlusion'): print( "Num. of initScaleNets_occlusion model parameters is:", count_network_parameters(model.initScaleNets_occlusion) + count_network_parameters(model.initScaleNets_occlusion1) + count_network_parameters(model.initScaleNets_occlusion2)) if hasattr(model, 'initScaleNets_filter'): print( "Num. of initScaleNets_filter model parameters is:", count_network_parameters(model.initScaleNets_filter) + count_network_parameters(model.initScaleNets_filter1) + count_network_parameters(model.initScaleNets_filter2)) if hasattr(model, 'ctxNet'): print("Num. of ctxNet model parameters is:", count_network_parameters(model.ctxNet)) if hasattr(model, 'depthNet'): print("Num. of depthNet model parameters is:", count_network_parameters(model.depthNet)) if hasattr(model, 'rectifyNet'): print("Num. of rectifyNet model parameters is:", count_network_parameters(model.rectifyNet)) print("Num. of discriminator model parameters is:", count_network_parameters(discrim)) # ------------------------------------- # and heeere we go # ------------------------------------- # discriminator pretrains for a certain # of epochs PRETRAINING_EPOCHS = 0 training_losses = AverageMeter() auxiliary_data = [] saved_total_loss = 10e10 saved_total_PSNR = -1 ikk = 0 for kk in optimizer.param_groups: if kk['lr'] > 0: ikk = kk break d_real_label = Variable(torch.ones(1, ), requires_grad=False).cuda() * 0.5 d_fake_label = Variable(torch.ones(1, ), requires_grad=False).cuda() * -0.5 g_label_target = Variable(torch.stack([ d_real_label, d_real_label, d_real_label, d_real_label, ], dim=0), requires_grad=False).cuda() d_label_target = Variable(torch.stack(( d_real_label, d_real_label, d_real_label, d_real_label, d_fake_label, d_fake_label, d_fake_label, d_fake_label, ), dim=0), requires_grad=False).cuda() for t in range(args.numEpoch): print("The id of this in-training network is " + unique_id) if (t < PRETRAINING_EPOCHS): print("-- Discriminator pre-training epoch --") elif (t == PRETRAINING_EPOCHS): print("-- End discriminator pre-training --") # turn into training mode model = model.train() discrim = discrim.train() for i, (X0_half, X1_half, y_half) in enumerate(train_loader): loss_function = charbonnier_loss #if i >= 100:# if i >= int(len(train_set) / BATCH_SIZE): break if (t < PRETRAINING_EPOCHS and i >= 100): break #before_mod = sum([torch.mean(p) for p in model.parameters()]).item() #before_dsc = sum([torch.mean(p) for p in discrim.parameters()]).item() X0_half = X0_half.cuda() X1_half = X1_half.cuda() y_half = y_half.cuda() X0 = Variable(X0_half, requires_grad=False) X1 = Variable(X1_half, requires_grad=False) y = Variable(y_half, requires_grad=False) # placeholder variables discrim_total_loss = Variable(torch.zeros(1, 1)).cuda() model_pixel_loss = torch.zeros(1, ) model_dsc_loss = torch.zeros(1, ) total_loss = torch.zeros(1, ) # -------------------------------------------- # train the interpolation network # (using the cycle consistency method from Reda, et al.) # -------------------------------------------- optimizer.zero_grad() def create_model_input(before, after): return torch.cat((torch.stack((before, after), dim=0), ), dim=1) # predict the frame between X0 and y model_input = torch.stack((X0, X1), dim=0) model_output = model(model_input) y_est = model_output[0:1] # concatenate real and fake so we can do everything in two forward passes discrim_batch = torch.cat((y.detach(), y_est.detach()), dim=0) if (t >= PRETRAINING_EPOCHS): # pixel losses model_pixel_loss = charbonnier_loss(y_est, y) # discriminator losses C_y, C_y_est = discrim(discrim_batch) # RaLSGAN loss. what is it minimizing? uhhh i don't f*****g know man C_diff = ((C_y - C_y_est - 1.0)**2 + (C_y_est - C_y + 1.0)**2) / 2.0 model_dsc_loss = C_diff total_loss = model_pixel_loss + 0.01 * model_dsc_loss total_loss.backward() optimizer.step() # -------------------------------------------- # train the discriminator # -------------------------------------------- optimizer_discrim.zero_grad() C_y, C_y_est = discrim(discrim_batch) # discriminator's RaLSGAN loss is the reverse of the generator's C_diff = ((C_y_est - C_y - 1.0)**2 + (C_y - C_y_est + 1.0)**2) / 2.0 discrim_loss = C_diff discrim_loss.backward() optimizer_discrim.step() # -------------------------------------------- # finally, output some stuff # -------------------------------------------- training_losses.update(total_loss.item(), BATCH_SIZE) if i % max(1, int(int(len(train_set) / BATCH_SIZE) / 500.0)) == 0: print( "Ep [" + str(t) + "/" + str(i) + "]\tl.r.: " + str(round(float(ikk['lr']), 7)) + "\tPix: " + str([round(model_pixel_loss.item(), 5)]) + #"\tFool: " + str(100 - round(np.sqrt(model_dsc_loss.item()) * 100, 5)) + "%" + "\tFool: " + str(round(model_dsc_loss.item(), 5)) + "\tTotal: " + str([round(x.item(), 5) for x in [total_loss]]) + #"\tDiscrim: " + str(100 - round(np.sqrt(discrim_loss.item()) * 100, 5)) + "%" + "\tDiscrim: " + str(round(discrim_loss.item(), 5)) + "\tAvg. Loss: " + str([round(training_losses.avg, 5)])) if (t < PRETRAINING_EPOCHS): continue torch.save(model.state_dict(), args.save_path + "/epoch" + str(t) + ".pth") # print("\t\t**************Start Validation*****************") #Turn into evaluation mode val_total_losses = AverageMeter() val_total_pixel_loss = AverageMeter() val_total_PSNR_loss = AverageMeter() for i, (X0, X1, y) in enumerate(val_loader): if i >= int(len(test_set) / BATCH_SIZE): break with torch.no_grad(): X0 = X0.cuda() X1 = X1.cuda() y = y.cuda() y_est = model(torch.stack((X0, X1), dim=0)) y_diff = y_est - y pixel_loss = torch.mean( torch.sqrt(y_diff * y_diff + args.epsilon * args.epsilon)) val_total_loss = pixel_loss per_sample_pix_error = torch.mean(torch.mean(torch.mean( y_diff**2, dim=1), dim=1), dim=1) per_sample_pix_error = per_sample_pix_error.data # extract tensor psnr_loss = torch.mean(20 * torch.log( 1.0 / torch.sqrt(per_sample_pix_error))) / torch.log( torch.Tensor([10])) val_total_losses.update(val_total_loss.item(), BATCH_SIZE) val_total_pixel_loss.update(pixel_loss.item(), BATCH_SIZE) val_total_PSNR_loss.update(psnr_loss[0], BATCH_SIZE) print(".", end='', flush=True) print("\nEpoch " + str(int(t)) + "\tlearning rate: " + str(float(ikk['lr'])) + "\tAvg Training Loss: " + str(round(training_losses.avg, 5)) + "\tValidate Loss: " + str([round(float(val_total_losses.avg), 5)]) + "\tValidate PSNR: " + str([round(float(val_total_PSNR_loss.avg), 5)]) + "\tPixel Loss: " + str([round(float(val_total_pixel_loss.avg), 5)])) auxiliary_data.append([ t, float(ikk['lr']), training_losses.avg, val_total_losses.avg, val_total_pixel_loss.avg ]) numpy.savetxt(args.log, numpy.array(auxiliary_data), fmt='%.8f', delimiter=',') training_losses.reset() print("\t\tFinished an epoch, Check and Save the model weights") # we check the validation loss instead of training loss. OK~ if saved_total_loss >= val_total_losses.avg: saved_total_loss = val_total_losses.avg torch.save(model.state_dict(), args.save_path + "/best" + ".pth") print("\t\tBest Weights updated for decreased validation loss\n") else: print("\t\tWeights Not updated for undecreased validation loss\n") #schdule the learning rate #scheduler.step(val_total_losses.avg) print("*********Finish Training********")