def main(): global best_error, n_iter, device args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = 'checkpoints' / Path(args.name) / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) training_writer = SummaryWriter(args.save_path) # Data loading normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data, transform=valid_transform) else: val_set = SequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") disp_net = getattr(models, args.dispnet)().to(device) pose_net = models.PoseNet().to(device) if args.pretrained_pose: print("=> using pre-trained weights for PoseNet") weights = torch.load(args.pretrained_pose) pose_net.load_state_dict(weights['state_dict'], strict=False) else: pose_net.init_weights() if args.pretrained_disp: print("=> using pre-trained weights for DispNet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_net = torch.nn.DataParallel(pose_net) print('=> setting adam solver') optim_params = [{ 'params': disp_net.parameters(), 'lr': args.lr }, { 'params': pose_net.parameters(), 'lr': args.lr }] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( ['photo_loss', 'smooth_loss', 'geometry_loss', 'train_loss']) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() if args.pretrained_disp: logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, logger, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, 0, logger, output_writers) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, 0) error_string = ', '.join( '{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9])) logger.valid_writer.write(' * Avg {}'.format(error_string)) for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, logger, training_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) # evaluate on validation set logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, logger) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[1] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) logger.epoch_bar.finish()
def main(): global args, best_error, n_iter args = parser.parse_args() save_path = Path(args.name) args.data = Path(args.data) args.save_path = os.path.join(args.log_path, 'checkpoints'/save_path) #/timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) training_writer = SummaryWriter(args.save_path) output_writer= SummaryWriter(args.save_path/'valid') print("=> fetching dataset") mnist_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]) trainset_mnist = torchvision.datasets.MNIST(args.data/'mnist', train=True, transform=mnist_transform, target_transform=None, download=True) valset_mnist = torchvision.datasets.MNIST(args.data/'mnist', train=False, transform=mnist_transform, target_transform=None, download=True) svhn_transform = torchvision.transforms.Compose([torchvision.transforms.Resize(size=(28,28)), torchvision.transforms.Grayscale(), torchvision.transforms.ToTensor()]) trainset_svhn = torchvision.datasets.SVHN(args.data/'svhn', split='train', transform=svhn_transform, target_transform=None, download=True) valset_svhn = torchvision.datasets.SVHN(args.data/'svhn', split='test', transform=svhn_transform, target_transform=None, download=True) if args.dataset == 'mnist': print("Training only on MNIST") train_set, val_set = trainset_mnist, valset_mnist elif args.dataset == 'svhn': print("Training only on SVHN") train_set, val_set = trainset_svhn, valset_svhn else: print("Training on both MNIST and SVHN") train_set = torch.utils.data.ConcatDataset([trainset_mnist, trainset_svhn]) val_set = torch.utils.data.ConcatDataset([valset_mnist, valset_svhn]) print('{} Train samples and {} test samples found in MNIST'.format(len(trainset_mnist), len(valset_mnist))) print('{} Train samples and {} test samples found in SVHN'.format(len(trainset_svhn), len(valset_svhn))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=False) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") # define 10 autoencoder, 1 for each mnist class autoencoder_list = nn.ModuleList([AutoEncoder() for i in range(10)]) mod_net = LeNet() if args.pretrained_autoencoder: print("=> using pre-trained weights from {}".format(args.pretrained_autoencoder)) weights = torch.load(args.pretrained_autoencoder) autoencoder_list.load_state_dict(weights['state_dict'], strict=False) if args.pretrained_mod: print("=> using pre-trained weights from {}".format(args.pretrained_mod)) weights = torch.load(args.pretrained_mod) mod_net.load_state_dict(weights['state_dict'], strict=False) if args.resume: print("=> resuming from checkpoint") weights = torch.load(args.save_path/'autoencoder_checkpoint.pth.tar') autoencoder_list.load_state_dict(weights['state_dict']) mod_weights = torch.load(args.save_path/'mod_checkpoint.pth.tar') mod_net.load_state_dict(mod_weights['state_dict']) cudnn.benchmark = True mod_net = mod_net.cuda() autoencoder_list = autoencoder_list.cuda() print('=> setting adam solver') parameters = chain(chain.from_iterable([x.parameters() for x in autoencoder_list]), mod_net.parameters()) optimizer_compete = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) optimizer_collaborate = torch.optim.Adam(mod_net.parameters(), args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path/args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['val_loss_full', 'val_loss_alice', 'val_loss_bob']) with open(args.save_path/args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss_full', 'train_loss_alice', 'train_loss_bob']) if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() else: logger=None for epoch in range(args.epochs): mode = 'compete' if (epoch%2)==0 else 'collaborate' if args.fix_mod: mode = 'compete' for fparams in mod_net.parameters(): fparams.requires_grad = False if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() # train for one epoch if mode == 'compete': train_loss = train(train_loader, autoencoder_list, mod_net, optimizer_compete, args.epoch_size, logger, training_writer, mode=mode) elif mode == 'collaborate': train_loss = train(train_loader, autoencoder_list, mod_net, optimizer_collaborate, args.epoch_size, logger, training_writer, mode=mode) if args.log_terminal: logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) logger.reset_valid_bar() if epoch%args.validate_interval==0: # evaluate on validation set errors, error_names = validate(val_loader, autoencoder_list, mod_net, epoch, logger, output_writer) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) if args.log_terminal: logger.valid_writer.write(' * Avg {}'.format(error_string)) else: print('Epoch {} completed'.format(epoch)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[0] # epe_total if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error <= best_error best_error = min(best_error, decisive_error) # save_model_state(args.save_path, autoencoder_list, mod_net, is_best, epoch=epoch+1) save_model_state( args.save_path, { 'epoch': epoch + 1, 'state_dict': autoencoder_list.state_dict() }, { 'epoch': epoch + 1, 'state_dict': mod_net.state_dict() }, is_best) with open(args.save_path/args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) if args.log_terminal: logger.epoch_bar.finish()
def main(): global args, best_error, n_iter if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = Path('checkpoints') / args.name / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.alternating: args.alternating_flags = np.array([False, False, True]) training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append( SummaryWriter(args.save_path / 'valid' / str(i))) # Data loading code flow_loader_h, flow_loader_w = 256, 832 if args.data_normalization == 'global': normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif args.data_normalization == 'local': normalize = custom_transforms.NormalizeLocally() if args.fix_flownet: train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) else: train_transform = custom_transforms.Compose([ custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_depth_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data.replace('cityscapes', 'kitti'), transform=valid_transform) else: val_set = SequenceFolder( #只有图 args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) if args.with_flow_gt: from datasets.validation_flow import ValidationFlow val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=args.sequence_length, transform=valid_flow_transform) val_flow_loader = torch.utils.data.DataLoader( val_flow_set, batch_size=1, # batch size is 1 since images in kitti have different sizes shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.DEBUG: train_set.__len__ = 32 train_set.samples = train_set.samples[:32] print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) #1 create model print("=> creating model") #1.1 disp_net disp_net = getattr(models, args.dispnet)().cuda() output_exp = True #args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") #1.2 pose_net pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length - 1).cuda() #1.3.flow_net if args.flownet == 'SpyNet': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda() elif args.flownet == 'FlowNetC6': #flonwtc6 flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() elif args.flownet == 'FlowNetS': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() elif args.flownet == 'Back2Future': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() # 1.4 mask_net mask_net = getattr(models, args.masknet)(nb_ref_imgs=args.sequence_length - 1, output_exp=True).cuda() #2 载入参数 #2.1 pose if args.pretrained_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_pose) pose_net.load_state_dict(weights['state_dict']) else: pose_net.init_weights() if args.pretrained_mask: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_mask) mask_net.load_state_dict(weights['state_dict']) else: mask_net.init_weights() # import ipdb; ipdb.set_trace() if args.pretrained_disp: print("=> using pre-trained weights from {}".format( args.pretrained_disp)) weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() if args.pretrained_flow: print("=> using pre-trained weights for FlowNet") weights = torch.load(args.pretrained_flow) flow_net.load_state_dict(weights['state_dict']) else: flow_net.init_weights() if args.resume: print("=> resuming from checkpoint") dispnet_weights = torch.load(args.save_path / 'dispnet_checkpoint.pth.tar') posenet_weights = torch.load(args.save_path / 'posenet_checkpoint.pth.tar') masknet_weights = torch.load(args.save_path / 'masknet_checkpoint.pth.tar') flownet_weights = torch.load(args.save_path / 'flownet_checkpoint.pth.tar') disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) # import ipdb; ipdb.set_trace() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_net = torch.nn.DataParallel(pose_net) mask_net = torch.nn.DataParallel(mask_net) flow_net = torch.nn.DataParallel(flow_net) print('=> setting adam solver') parameters = chain(disp_net.parameters(), pose_net.parameters(), mask_net.parameters(), flow_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) if args.resume and (args.save_path / 'optimizer_checkpoint.pth.tar').exists(): print("=> loading optimizer from checkpoint") optimizer_weights = torch.load(args.save_path / 'optimizer_checkpoint.pth.tar') optimizer.load_state_dict(optimizer_weights['state_dict']) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ 'train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss' ]) if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() else: logger = None #3. main cycle for epoch in range(args.epochs): #3.1 四个子网络,训练哪几个 if args.fix_flownet: for fparams in flow_net.parameters(): fparams.requires_grad = False if args.fix_masknet: for fparams in mask_net.parameters(): fparams.requires_grad = False if args.fix_posenet: for fparams in pose_net.parameters(): fparams.requires_grad = False if args.fix_dispnet: for fparams in disp_net.parameters(): fparams.requires_grad = False if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() #validation data flow_error_names = ['no'] flow_errors = [0] errors = [0] error_names = ['no error names depth'] print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs)) #3.2 train for one epoch--------- train_loss = 0 train_loss = train(train_loader, disp_net, pose_net, mask_net, flow_net, optimizer, logger, training_writer) #3.3 evaluate on validation set----- #if args.without_gt: # validate_without_gt(val_loader,disp_net,pose_net,mask_net, epoch, logger, output_writers) if args.with_flow_gt: flow_errors, flow_error_names = validate_flow_with_gt( val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, output_writers) for error, name in zip(flow_errors, flow_error_names): training_writer.add_scalar(name, error, epoch) if args.with_depth_gt: depth_errors, depth_error_names = validate_depth_with_gt( val_loader, disp_net, epoch, logger, output_writers) error_string = ', '.join( '{} : {:.3f}'.format(name, error) for name, error in zip(depth_error_names, depth_errors)) if args.log_terminal: logger.valid_writer.write(' * Avg {}'.format(error_string)) else: print('Epoch {} completed'.format(epoch)) for error, name in zip(depth_errors, depth_error_names): training_writer.add_scalar(name, error, epoch) if args.without_gt: val_loss = 1 #---------------------- #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) if not args.fix_posenet: decisive_error = 0 # flow_errors[-2] # epe_rigid_with_gt_mask elif not args.fix_dispnet: decisive_error = 0 # errors[0] #depth abs_diff elif not args.fix_flownet: decisive_error = 0 # flow_errors[-1] #epe_non_rigid_with_gt_mask elif not args.fix_masknet: decisive_error = 0 #flow_errors[3] # percent outliers #3.5 log if args.log_terminal: logger.train_writer.write( ' * Avg Loss : {:.3f}'.format(train_loss)) logger.reset_valid_bar() #3.6 save model and remember lowest error and save checkpoint if best_error < 0: best_error = decisive_error is_best = decisive_error <= best_error best_error = min(best_error, decisive_error) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': mask_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': flow_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': optimizer.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) if args.log_terminal: logger.epoch_bar.finish()
def main(): global best_error, n_iter, device args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder save_path = save_path_formatter(args, parser) args.save_path = 'checkpoints' / save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() #如果没有,则建立,有则啥都不干 in Path.py小工具 torch.manual_seed(args.seed) if args.evaluate: args.epochs = 0 #tensorboard SummaryWriter training_writer = SummaryWriter(args.save_path) #for tensorboard output_writers = [] #list if args.log_output: for i in range(3): output_writers.append( SummaryWriter(args.save_path / 'valid' / str(i))) # Data loading code normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) '''transform''' valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder( args.data, #processed_data_train_sets transform=train_transform, #把几种变换函数输入进去 seed=args.seed, train=True, sequence_length=args.sequence_length) # if no Groundtruth is avalaible, Validation set is # the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data, transform=valid_transform) else: val_set = SequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) print('{} samples found in {} train scenes'.format( len(train_set), len(train_set.scenes))) #训练集都是序列,不用左右 print('{} samples found in {} valid scenes'.format( len(val_set), len(val_set.scenes))) #测试集也是序列,不需要左右 train_loader = torch.utils.data.DataLoader( #data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)] dataset=train_set, #sequenceFolder batch_size=args.batch_size, shuffle=True, #打乱 num_workers=args.workers, #多线程读取数据 pin_memory=True) val_loader = torch.utils.data.DataLoader( dataset=val_set, batch_size=args.batch_size, shuffle=False, #不打乱 num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") #disp disp_net = models.DispNetS().to(device) output_exp = args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") #pose pose_exp_net = models.PoseExpNet( nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device) #init posenet if args.pretrained_exp_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_exp_pose) pose_exp_net.load_state_dict(weights['state_dict'], strict=False) else: pose_exp_net.init_weights() #init dispNet if args.pretrained_disp: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_exp_net = torch.nn.DataParallel(pose_exp_net) print('=> setting adam solver') #可以看到两个一起训练 optim_params = [{ 'params': disp_net.parameters(), 'lr': args.lr }, { 'params': pose_exp_net.parameters(), 'lr': args.lr }] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) #训练结果写入csv with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) n_epochs = args.epochs train_size = min(len(train_loader), args.epoch_size) valid_size = len(val_loader) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() if args.pretrained_disp or args.evaluate: logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, logger, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, logger, output_writers) for error, name in zip( errors, error_names ): #validation时,对['Total loss', 'Photo loss', 'Exp loss']三个 epoch-record 指标添加记录值 training_writer.add_scalar(name, error, 0) error_string = ', '.join( '{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9])) logger.valid_writer.write(' * Avg {}'.format(error_string)) #main cycle for epoch in range(args.epochs): logger.epoch_bar.update(epoch) logger.reset_train_bar() #1. train for one epoch train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, logger, training_writer) #其他参数都好解释, logger: SelfDefined class, logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) logger.reset_valid_bar() # 2. validate on validation set if args.with_gt: #<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss'] errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) #损失函数中记录epoch-record指标 # Up to you to chose the most relevant error to measure # your model's performance, careful some measures are to maximize (such as a1,a2,a3) # 3. remember lowest error and save checkpoint decisive_error = errors[1] if best_error < 0: best_error = decisive_error is_best = decisive_error < best_error best_error = min(best_error, decisive_error) #模型保存 save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_exp_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: #每个epoch留下结果 writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) #第二个就是validataion 中的epoch-record # loss<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss'] logger.epoch_bar.finish()
def main(): global args, best_error, n_iter args = parser.parse_args() save_path = Path(args.name) args.save_path = 'checkpoints' / save_path #/timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) training_writer = SummaryWriter(args.save_path) output_writer = SummaryWriter(args.save_path / 'valid') # Data loading code flow_loader_h, flow_loader_w = 384, 1280 train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(h=256, w=256), custom_transforms.ArrayToTensor(), ]) valid_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor() ]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=3) if args.valset == "kitti2015": from datasets.validation_flow import ValidationFlowKitti2015 val_set = ValidationFlowKitti2015(root=args.kitti_data, transform=valid_transform) elif args.valset == "kitti2012": from datasets.validation_flow import ValidationFlowKitti2012 val_set = ValidationFlowKitti2012(root=args.kitti_data, transform=valid_transform) if args.DEBUG: train_set.__len__ = 32 train_set.samples = train_set.samples[:32] print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in valid scenes'.format(len(val_set))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size= 1, # batch size is 1 since images in kitti have different sizes shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") if args.flownet == 'SpyNet': flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True) elif args.flownet == 'Back2Future': flow_net = getattr( models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar') elif args.flownet == 'PWCNet': flow_net = models.pwc_dc_net( 'pretrained/pwc_net_chairs.pth.tar') # pwc_net.pth.tar') else: flow_net = getattr(models, args.flownet)() if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']: print("=> using pre-trained weights for " + args.flownet) elif args.flownet in ['FlowNetC']: print("=> using pre-trained weights for FlowNetC") weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar') flow_net.load_state_dict(weights['state_dict']) elif args.flownet in ['FlowNetS']: print("=> using pre-trained weights for FlowNetS") weights = torch.load('pretrained/flownets.pth.tar') flow_net.load_state_dict(weights['state_dict']) elif args.flownet in ['FlowNet2']: print("=> using pre-trained weights for FlowNet2") weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar') flow_net.load_state_dict(weights['state_dict']) else: flow_net.init_weights() pytorch_total_params = sum(p.numel() for p in flow_net.parameters()) print("Number of model paramters: " + str(pytorch_total_params)) flow_net = flow_net.cuda() cudnn.benchmark = True if args.patch_type == 'circle': patch, mask, patch_shape = init_patch_circle(args.image_size, args.patch_size) patch_init = patch.copy() elif args.patch_type == 'square': patch, patch_shape = init_patch_square(args.image_size, args.patch_size) patch_init = patch.copy() mask = np.ones(patch_shape) else: sys.exit("Please choose a square or circle patch") if args.patch_path: patch, mask, patch_shape = init_patch_from_image( args.patch_path, args.mask_path, args.image_size, args.patch_size) patch_init = patch.copy() if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader), attack_size=args.max_count) logger.epoch_bar.start() else: logger = None for epoch in range(args.epochs): if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() # train for one epoch patch, mask, patch_init, patch_shape = train(patch, mask, patch_init, patch_shape, train_loader, flow_net, epoch, logger, training_writer) # Validate errors, error_names = validate_flow_with_gt(patch, mask, patch_shape, val_loader, flow_net, epoch, logger, output_writer) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) # if args.log_terminal: logger.valid_writer.write(' * Avg {}'.format(error_string)) else: print('Epoch {} completed'.format(epoch)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch))) if args.log_terminal: logger.epoch_bar.finish()
def main(): global best_error, n_iter, device args = parser.parse_args() save_path = save_path_formatter(args, parser) args.save_path = 'checkpoints'/save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.evaluate: args.epochs = 0 training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append(SummaryWriter(args.save_path/'valid'/str(i))) # Data loading code train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor() ]) valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = KITTIDataset( root_dir, sequences, max_distance=args.max_distance, transform=None ) val_set = KITTIDataset( root_dir, sequences, max_distance=args.max_distance, transform=None ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") config_file = "./configs/e2e_mask_rcnn_R_50_FPN_1x.yaml" cfg.merge_from_file(config_file) cfg.freeze() pretrained_model_path = "./e2e_mask_rcnn_R_50_FPN_1x.pth" disvo = DISVO(cfg, pretrained_model_path).cuda() if args.pretrained_disvo: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disvo) disvo.load_state_dict(weights['state_dict']) else: disvo.init_weights() cudnn.benchmark = True print('=> setting adam solver') optim_params = [ {'params': disvo.parameters(), 'lr': args.lr} ] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path/args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path/args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss']) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() if args.pretrained_disvo or args.evaluate: logger.reset_valid_bar() errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, 0) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9])) logger.valid_writer.write(' * Avg {}'.format(error_string)) for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(args, train_loader, disvo, optimizer, args.epoch_size, logger, training_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) # evaluate on validation set logger.reset_valid_bar() errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[1] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint( args.save_path, { 'epoch': epoch + 1, 'state_dict': disvo.module.state_dict() }, is_best) with open(args.save_path/args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) logger.epoch_bar.finish()