def dataflow_test(): from DataFlow.sequence_folders import SequenceFolder import custom_transforms from torch.utils.data import DataLoader from DataFlow.validation_folders import ValidationSet normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize]) valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) datapath = 'G:/data/KITTI/KittiRaw_formatted' seed = 8964 train_set = SequenceFolder(datapath, transform=train_transform, seed=seed, train=True, sequence_length=3) train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True) val_set = ValidationSet(datapath, transform=valid_transform) print("length of train loader is %d" % len(train_loader)) val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=4, pin_memory=True) print("length of val loader is %d" % len(val_loader)) dataiter = iter(train_loader) imgs, intrinsics = next(dataiter) print(len(imgs)) print(intrinsics.shape) pass
def main(): global args, best_error, n_iter args = parser.parse_args() normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) print(len(train_loader))
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.FixedResize(size=self.args['crop_size']), tr.Normalize(mean=self.mean, std=self.std), tr.ToTensor() ]) return composed_transforms(sample)
def get_train_transforms(normalize): train_transforms = [] train_transforms.append(transforms.Scale(160)) train_transforms.append(transforms.RandomHorizontalFlip()) train_transforms.append(transforms.RandomColor(0.15)) train_transforms.append(transforms.RandomRotate(15)) train_transforms.append(transforms.RandomSizedCrop(128)) train_transforms.append(transforms.ToTensor()) train_transforms.append(normalize) train_transforms = transforms.Compose(train_transforms) return train_transforms
def transforms_train_esp(self, sample): composed_transforms = transforms.Compose([ tr.RandomVerticalFlip(), tr.RandomHorizontalFlip(), tr.RandomAffine(degrees=40, scale=(.9, 1.1), shear=30), tr.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), tr.FixedResize(size=self.input_size), tr.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]), tr.ToTensor() ]) return composed_transforms(sample)
def transform_pair_train(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0), tr.HorizontalFlip(), tr.GaussianBlur(), tr.Normalize(mean=self.source_dist['mean'], std=self.source_dist['std'], if_pair=True), tr.ToTensor(if_pair=True), ]) return composed_transforms(sample)
def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), #随机水平翻转 tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), #随机尺寸裁剪 tr.RandomGaussianBlur(), #随机高斯模糊 tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), #归一化 tr.ToTensor() ]) return composed_transforms(sample)
def transform(self, sample): if self.train: composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomVerticalFlip(), tr.RandomScaleCrop(), tr.ToTensor() ]) else: composed_transforms = transforms.Compose([ tr.ToTensor() ]) return composed_transforms(sample)
def transform_tr(self, sample): if not self.random_match: composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0), #tr.Remap(self.building_table, self.nonbuilding_table, self.channels) tr.RandomGaussianBlur(), #tr.ConvertFromInts(), #tr.PhotometricDistort(), tr.Normalize(mean=self.source_dist['mean'], std=self.source_dist['std']), tr.ToTensor(), ]) else: composed_transforms = transforms.Compose([ tr.HistogramMatching(), tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0), tr.RandomGaussianBlur(), tr.Normalize(mean=self.source_dist['mean'], std=self.source_dist['std']), tr.ToTensor(), ]) return composed_transforms(sample)
def main(args): # parse args best_acc1 = 0.0 if args.gpu >= 0: print("Use GPU: {}".format(args.gpu)) else: print('You are using CPU for computing!', 'Yet we assume you are using a GPU.', 'You will NOT be able to switch between CPU and GPU training!') # set up the model + loss if args.use_custom_conv: print("Using custom convolutions in the network") model = default_model(conv_op=CustomConv2d, num_classes=100) else: model = default_model(num_classes=100) model_arch = "simplenet" # model_arch = "simplenet_batchnorm2d" # model_arch = "resnet18" # model_arch = "resnet34" criterion = nn.CrossEntropyLoss() # put everthing to gpu if args.gpu >= 0: model = model.cuda(args.gpu) criterion = criterion.cuda(args.gpu) # setup the optimizer optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # resume from a checkpoint? if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] model.load_state_dict(checkpoint['state_dict']) if args.gpu < 0: model = model.cpu() else: model = model.cuda(args.gpu) # only load the optimizer if necessary if (not args.evaluate) and (not args.attack): optimizer.load_state_dict(checkpoint['optimizer']) print("=> loaded checkpoint '{}' (epoch {}, acc1 {})" .format(args.resume, checkpoint['epoch'], best_acc1)) else: print("=> no checkpoint found at '{}'".format(args.resume)) # set up transforms for data augmentation normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transforms = [] train_transforms.append(transforms.Scale(160)) train_transforms.append(transforms.RandomHorizontalFlip()) train_transforms.append(transforms.RandomColor(0.15)) train_transforms.append(transforms.RandomRotate(15)) train_transforms.append(transforms.RandomSizedCrop(128)) train_transforms.append(transforms.ToTensor()) train_transforms.append(normalize) train_transforms = transforms.Compose(train_transforms) # val transofrms val_transforms=[] val_transforms.append(transforms.Scale(160, interpolations=None)) val_transforms.append(transforms.ToTensor()) val_transforms.append(normalize) val_transforms = transforms.Compose(val_transforms) if (not args.evaluate) and (not args.attack): print("Training time data augmentations:") print(train_transforms) # setup dataset and dataloader train_dataset = MiniPlacesLoader(args.data_folder, split='train', transforms=train_transforms) val_dataset = MiniPlacesLoader(args.data_folder, split='val', transforms=val_transforms) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=100, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False) # testing only if (args.evaluate==args.attack) and args.evaluate: print("Cann't set evaluate and attack to True at the same time!") return # set up visualizer if args.vis: visualizer = default_attention(criterion) else: visualizer = None # evaluation if args.resume and args.evaluate: print("Testing the model ...") cudnn.deterministic = True validate(val_loader, model, -1, args, visualizer=visualizer) return # attack if args.resume and args.attack: print("Generating adversarial samples for the model ..") cudnn.deterministic = True validate(val_loader, model, -1, args, attacker=default_attack(criterion), visualizer=visualizer) return # enable cudnn benchmark cudnn.enabled = True cudnn.benchmark = True # warmup the training if (args.start_epoch == 0) and (args.warmup_epochs > 0): print("Warmup the training ...") for epoch in range(0, args.warmup_epochs): train(train_loader, model, criterion, optimizer, epoch, "warmup", args) # start the training print("Training the model ...") for epoch in range(args.start_epoch, args.epochs): # train for one epoch train(train_loader, model, criterion, optimizer, epoch, "train", args) # evaluate on validation set acc1 = validate(val_loader, model, epoch, args) # remember best acc@1 and save checkpoint is_best = acc1 > best_acc1 best_acc1 = max(acc1, best_acc1) save_checkpoint({ 'epoch': epoch + 1, 'model_arch': model_arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer' : optimizer.state_dict(), }, is_best)
_target = Image.open(self.categories[index]) return _img, _target def __str__(self): return 'ISPRS(split=' + str(self.split) + ')' if __name__ == '__main__': import custom_transforms as tr from utils import decode_segmap from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(), tr.RandomCrop(512), tr.RandomRotate(15), tr.ToTensor()] ) isprs_train = ISPRSSegmentation(split='train', transform=composed_transforms_tr) dataloader = DataLoader(isprs_train, batch_size=5, shuffle=True, num_workers=2) for ii, sample in enumerate(dataloader): for jj in range(sample["image"].size()[0]): img = sample['image'].numpy() gt = sample['label'].numpy() tmp = np.array(gt[jj]).astype(np.uint8) tmp = np.squeeze(tmp, axis=0)
plt.subplot(1, 2, 1) plt.imshow(img) plt.subplot(1, 2, 2) plt.imshow(np.array(gt) * 255) # show object ID plt.tight_layout() plt.show() if __name__ == "__main__": # delete() # torch_VOC() # 1.Show transforms's result before transfer to tensor # execute by PIL image data_transforms = transforms.Compose([ custom_transforms.RandomHorizontalFlip(1.0), custom_transforms.RandomRotation((-20, 20), scales=(0.75, 1.0)), custom_transforms.to_numpy() ]) # excute by numpy image handcraft_transforms = custom_transforms.Compose_dict([ custom_transforms.ObjectCenterCrop(), custom_transforms.Fix_size(size=512), custom_transforms.get_extreme_point_channel(), custom_transforms.concat_inputs() ]) voc = VOC_Instance_Segmentation(split_sets=['train'], transform=data_transforms, \ transform_handcraft=handcraft_transforms) for ii, dct in enumerate(voc):
start_epoch = 0 print('===> Start from scratch') if cuda: model.cuda() cudnn.benchmark = True #%% parser = argparse.ArgumentParser() args = parser.parse_args() args.base_size = 513 args.crop_size = 513 composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), #随机水平翻转 tr.RandomScaleCrop(base_size=args.base_size, crop_size=args.crop_size), #随机尺寸裁剪 tr.RandomGaussianBlur(), #随机高斯模糊 tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), #归一化 tr.ToTensor() ]) #%% import matplotlib.pyplot as plt import numpy as np # dataset.utils import decode_segmap tbar = tqdm(dataloader) num_img_tr = len(dataloader) for epoch in range(0, 10):
def main(): global best_error, n_iter, device args = parser.parse_args() timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") save_path = Path(args.name) args.save_path = 'checkpoints' / save_path / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) np.random.seed(args.seed) cudnn.deterministic = True cudnn.benchmark = True training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append( SummaryWriter(args.save_path / 'valid' / str(i))) # Data loading code normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data, transform=valid_transform) else: val_set = SequenceFolder(args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") disp_net = models.DispResNet(args.resnet_layers, args.with_pretrain).to(device) pose_net = models.PoseResNet(18, args.with_pretrain).to(device) # load parameters if args.pretrained_disp: print("=> using pre-trained weights for DispResNet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict'], strict=False) if args.pretrained_pose: print("=> using pre-trained weights for PoseResNet") weights = torch.load(args.pretrained_pose) pose_net.load_state_dict(weights['state_dict'], strict=False) disp_net = torch.nn.DataParallel(disp_net) pose_net = torch.nn.DataParallel(pose_net) print('=> setting adam solver') optim_params = [{ 'params': disp_net.parameters(), 'lr': args.lr }, { 'params': pose_net.parameters(), 'lr': args.lr }] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ 'train_loss', 'photo_loss', 'smooth_loss', 'geometry_consistency_loss' ]) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, logger, training_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) # evaluate on validation set logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[1] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) logger.epoch_bar.finish()
def main(): global args, best_photo_loss, n_iter args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format( args.epochs, ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '', args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight, args.mask_loss_weight, args.smooth_loss_weight)) timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = 'checkpoints' / save_path / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) train_writer = SummaryWriter(args.save_path / 'train') valid_writer = SummaryWriter(args.save_path / 'valid') output_writers = [] if args.log_output: for i in range(3): output_writers.append( SummaryWriter(args.save_path / 'valid' / str(i))) # Data loading code normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) input_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=input_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) val_set = SequenceFolder(args.data, transform=custom_transforms.Compose([ custom_transforms.ArrayToTensor(), normalize ]), seed=args.seed, train=False, sequence_length=args.sequence_length) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") disp_net = models.DispNetS().cuda() output_exp = args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") pose_exp_net = models.PoseExpNet( nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).cuda() if args.pretrained_exp_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_exp_pose) pose_exp_net.load_state_dict(weights['state_dict'], strict=False) else: pose_exp_net.init_weights() if args.pretrained_disp: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_exp_net = torch.nn.DataParallel(pose_exp_net) print('=> setting adam solver') parameters = chain(disp_net.parameters(), pose_exp_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, logger, train_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) # evaluate on validation set logger.reset_valid_bar() valid_photo_loss, valid_exp_loss, valid_total_loss = validate( val_loader, disp_net, pose_exp_net, epoch, logger, output_writers) logger.valid_writer.write( ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}' .format(valid_photo_loss, valid_exp_loss, valid_total_loss)) valid_writer.add_scalar( 'photometric_error', valid_photo_loss * 4, n_iter ) # Loss is multiplied by 4 because it's only one scale, instead of 4 during training valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4, n_iter) valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter) if best_photo_loss < 0: best_photo_loss = valid_photo_loss # remember lowest error and save checkpoint is_best = valid_photo_loss < best_photo_loss best_photo_loss = min(valid_photo_loss, best_photo_loss) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_exp_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, valid_total_loss]) logger.epoch_bar.finish()
def main(): global args, best_error, n_iter args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder save_path = Path(args.name) args.save_path = 'checkpoints'/save_path #/timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.alternating: args.alternating_flags = np.array([False,False,True]) training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append(SummaryWriter(args.save_path/'valid'/str(i))) # Data loading code flow_loader_h, flow_loader_w = 256, 832 if args.data_normalization =='global': normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif args.data_normalization =='local': normalize = custom_transforms.NormalizeLocally() train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) valid_flow_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder( args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length ) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping val_set = SequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) if args.with_flow_gt: from datasets.validation_flow import ValidationFlow val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=args.sequence_length, transform=valid_flow_transform) if args.DEBUG: train_set.__len__ = 32 train_set.samples = train_set.samples[:32] print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.with_flow_gt: val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1, # batch size is 1 since images in kitti have different sizes shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") if args.flownet=='SpyNet': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda() else: flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() # load pre-trained weights if args.pretrained_flow: print("=> using pre-trained weights for FlowNet") weights = torch.load(args.pretrained_flow) flow_net.load_state_dict(weights['state_dict']) # else: #flow_net.init_weights() if args.resume: print("=> resuming from checkpoint") flownet_weights = torch.load(args.save_path/'flownet_checkpoint.pth.tar') flow_net.load_state_dict(flownet_weights['state_dict']) # import ipdb; ipdb.set_trace() cudnn.benchmark = True flow_net = torch.nn.DataParallel(flow_net) print('=> setting adam solver') parameters = chain(flow_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) milestones = [300] scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1) if args.min: print("using min method") if args.resume and (args.save_path/'optimizer_checkpoint.pth.tar').exists(): print("=> loading optimizer from checkpoint") optimizer_weights = torch.load(args.save_path/'optimizer_checkpoint.pth.tar') optimizer.load_state_dict(optimizer_weights['state_dict']) with open(args.save_path/args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path/args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss']) if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() else: logger=None for epoch in range(args.epochs): scheduler.step() if args.fix_flownet: for fparams in flow_net.parameters(): fparams.requires_grad = False if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() # train for one epoch train_loss = train(train_loader, flow_net, optimizer, args.epoch_size, logger, training_writer) if args.log_terminal: logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) logger.reset_valid_bar() if args.with_flow_gt: flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, flow_net, epoch, logger, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(flow_error_names, flow_errors)) if args.log_terminal: logger.valid_writer.write(' * Avg {}'.format(error_string)) else: print('Epoch {} completed'.format(epoch)) for error, name in zip(flow_errors, flow_error_names): training_writer.add_scalar(name, error, epoch) decisive_error = flow_errors[0] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error <= best_error best_error = min(best_error, decisive_error) save_checkpoint( args.save_path, { 'epoch': epoch + 1, 'state_dict': flow_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': optimizer.state_dict() }, is_best) with open(args.save_path/args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) if args.log_terminal: logger.epoch_bar.finish()
def main(): global global_vars_dict args = global_vars_dict['args'] best_error = -1 #best model choosing #mkdir timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.alternating: args.alternating_flags = np.array([False, False, True]) #mk writers tb_writer = SummaryWriter(args.save_path) # Data loading code flow_loader_h, flow_loader_w = 256, 832 if args.data_normalization == 'global': normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif args.data_normalization == 'local': normalize = custom_transforms.NormalizeLocally() if args.fix_flownet: train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) else: train_transform = custom_transforms.Compose([ custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) valid_flow_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor(), normalize ]) print("=> fetching scenes in '{}'".format(args.data_dir)) #train set, loader only建立一个 if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder train_set = SequenceFolder( #mc data folder args.data_dir, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length, #5 target_transform=None) elif args.dataset_format == 'sequential_with_gt': # with all possible gt from datasets.sequence_mc import SequenceFolder train_set = SequenceFolder( # mc data folder args.data_dir, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length, # 5 target_transform=None) else: return if args.DEBUG: train_set.__len__ = 32 train_set.samples = train_set.samples[:32] train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) #val set,loader 挨个建立 # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.val_without_gt: from datasets.sequence_folders2 import SequenceFolder #就多了一级文件夹 val_set_without_gt = SequenceFolder( #只有图 args.data_dir, transform=valid_transform, seed=None, train=False, sequence_length=args.sequence_length, target_transform=None) val_loader = torch.utils.data.DataLoader(val_set_without_gt, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.val_with_depth_gt: from datasets.validation_folders2 import ValidationSet val_set_with_depth_gt = ValidationSet(args.data_dir, transform=valid_transform) val_loader_depth = torch.utils.data.DataLoader( val_set_with_depth_gt, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.val_with_flow_gt: #暂时没有 from datasets.validation_flow import ValidationFlow val_flow_set = ValidationFlow(root=args.kitti_dir, sequence_length=args.sequence_length, transform=valid_flow_transform) val_flow_loader = torch.utils.data.DataLoader( val_flow_set, batch_size=1, # batch size is 1 since images in kitti have different sizes shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) if args.val_without_gt: print('{} samples found in {} valid scenes'.format( len(val_set_without_gt), len(val_set_without_gt.scenes))) #1 create model print("=> creating model") #1.1 disp_net disp_net = getattr(models, args.dispnet)().cuda() output_exp = True #args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") #1.2 pose_net pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length - 1).cuda() #1.3.flow_net if args.flownet == 'SpyNet': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda() elif args.flownet == 'FlowNetC6': #flonwtc6 flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() elif args.flownet == 'FlowNetS': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() elif args.flownet == 'Back2Future': flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda() # 1.4 mask_net mask_net = getattr(models, args.masknet)(nb_ref_imgs=args.sequence_length - 1, output_exp=True).cuda() #2 载入参数 #2.1 pose if args.pretrained_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_pose) pose_net.load_state_dict(weights['state_dict']) else: pose_net.init_weights() if args.pretrained_mask: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_mask) mask_net.load_state_dict(weights['state_dict']) else: mask_net.init_weights() # import ipdb; ipdb.set_trace() if args.pretrained_disp: print("=> using pre-trained weights from {}".format( args.pretrained_disp)) weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() if args.pretrained_flow: print("=> using pre-trained weights for FlowNet") weights = torch.load(args.pretrained_flow) flow_net.load_state_dict(weights['state_dict']) else: flow_net.init_weights() if args.resume: print("=> resuming from checkpoint") dispnet_weights = torch.load(args.save_path / 'dispnet_checkpoint.pth.tar') posenet_weights = torch.load(args.save_path / 'posenet_checkpoint.pth.tar') masknet_weights = torch.load(args.save_path / 'masknet_checkpoint.pth.tar') flownet_weights = torch.load(args.save_path / 'flownet_checkpoint.pth.tar') disp_net.load_state_dict(dispnet_weights['state_dict']) pose_net.load_state_dict(posenet_weights['state_dict']) flow_net.load_state_dict(flownet_weights['state_dict']) mask_net.load_state_dict(masknet_weights['state_dict']) # import ipdb; ipdb.set_trace() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_net = torch.nn.DataParallel(pose_net) mask_net = torch.nn.DataParallel(mask_net) flow_net = torch.nn.DataParallel(flow_net) print('=> setting adam solver') parameters = chain(disp_net.parameters(), pose_net.parameters(), mask_net.parameters(), flow_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) if args.resume and (args.save_path / 'optimizer_checkpoint.pth.tar').exists(): print("=> loading optimizer from checkpoint") optimizer_weights = torch.load(args.save_path / 'optimizer_checkpoint.pth.tar') optimizer.load_state_dict(optimizer_weights['state_dict']) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([ 'train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss' ]) # if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader_depth)) logger.epoch_bar.start() else: logger = None #预先评估下 if args.pretrained_disp or args.evaluate: logger.reset_valid_bar() if args.val_without_gt: pass #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict) #val_loss =0 if args.val_with_depth_gt: pass depth_errors, depth_error_names = validate_depth_with_gt( val_loader_depth, disp_net, epoch=0, logger=logger, tb_writer=tb_writer, global_vars_dict=global_vars_dict) #3. main cycle for epoch in range(1, args.epochs): #epoch 0 在第没入循环之前已经测试了. #3.1 四个子网络,训练哪几个 if args.fix_flownet: for fparams in flow_net.parameters(): fparams.requires_grad = False if args.fix_masknet: for fparams in mask_net.parameters(): fparams.requires_grad = False if args.fix_posenet: for fparams in pose_net.parameters(): fparams.requires_grad = False if args.fix_dispnet: for fparams in disp_net.parameters(): fparams.requires_grad = False if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() #validation data flow_error_names = ['no'] flow_errors = [0] errors = [0] error_names = ['no error names depth'] print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs)) #3.2 train for one epoch--------- #train_loss=0 train_loss = train_gt(train_loader, disp_net, pose_net, mask_net, flow_net, optimizer, logger, tb_writer, global_vars_dict) #3.3 evaluate on validation set----- if args.val_without_gt: val_loss = validate_without_gt(val_loader, disp_net, pose_net, mask_net, flow_net, epoch=0, logger=logger, tb_writer=tb_writer, nb_writers=3, global_vars_dict=global_vars_dict) if args.val_with_depth_gt: depth_errors, depth_error_names = validate_depth_with_gt( val_loader_depth, disp_net, epoch=epoch, logger=logger, tb_writer=tb_writer, global_vars_dict=global_vars_dict) if args.val_with_flow_gt: pass #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer) #for error, name in zip(flow_errors, flow_error_names): # training_writer.add_scalar(name, error, epoch) #---------------------- #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) if not args.fix_posenet: decisive_error = 0 # flow_errors[-2] # epe_rigid_with_gt_mask elif not args.fix_dispnet: decisive_error = 0 # errors[0] #depth abs_diff elif not args.fix_flownet: decisive_error = 0 # flow_errors[-1] #epe_non_rigid_with_gt_mask elif not args.fix_masknet: decisive_error = 0 #flow_errors[3] # percent outliers #3.5 log if args.log_terminal: logger.train_writer.write( ' * Avg Loss : {:.3f}'.format(train_loss)) logger.reset_valid_bar() #eopch data log on tensorboard #train loss tb_writer.add_scalar('epoch/train_loss', train_loss, epoch) #val_without_gt loss if args.val_without_gt: tb_writer.add_scalar('epoch/val_loss', val_loss, epoch) if args.val_with_depth_gt: #val with depth gt for error, name in zip(depth_errors, depth_error_names): tb_writer.add_scalar('epoch/' + name, error, epoch) #3.6 save model and remember lowest error and save checkpoint if best_error < 0: best_error = train_loss is_best = train_loss <= best_error best_error = min(best_error, train_loss) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': mask_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': flow_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) if args.log_terminal: logger.epoch_bar.finish()
return img, gt def get_img_size(self): img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) return list(img.shape[:2]) if __name__ == '__main__': import custom_transforms as tr import torch from torchvision import transforms from matplotlib import pyplot as plt transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor() ]) dataset = DAVIS2016( db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016', train=True, transform=transforms) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) for i, data in enumerate(dataloader): plt.figure()
def train(self): global n_iter if not self.train_flow: self.pose_net.train() self.disp_net.train() normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) self.train_set = SequenceFolder( self.config['data'], transform=train_transform, split='train', seed=self.config['seed'], img_height=self.config['img_height'], img_width=self.config['img_width'], sequence_length=self.config['sequence_length']) self.val_set = SequenceFolder( self.config['data'], transform=valid_transform, split='val', seed=self.config['seed'], img_height=self.config['img_height'], img_width=self.config['img_width'], sequence_length=self.config['sequence_length']) self.train_loader = torch.utils.data.DataLoader( self.train_set, shuffle=True, drop_last=True, num_workers=self.config['data_workers'], batch_size=self.config['batch_size'], pin_memory=False) self.val_loader = torch.utils.data.DataLoader( self.val_set, shuffle=True, batch_size=self.config['batch_size'], drop_last=True, num_workers=self.config['data_workers'], pin_memory=False) optim_params = [{ 'params': v.parameters(), 'lr': self.config['learning_rate'] } for v in self.nets.values()] self.optimizer = torch.optim.Adam( optim_params, betas=(self.config['momentum'], self.config['beta']), weight_decay=self.config['weight_decay']) self.logger = TermLogger(n_epochs=self.config['epoch'], train_size=min(len(self.train_loader), self.config['epoch_size']), valid_size=len(self.val_loader)) self.logger.epoch_bar.start() for epoch in range(self.epochs): self.logger.epoch_bar.update(epoch) self.logger.reset_train_bar() epoch_train_loss = self.training_inside_epoch() self.logger.train_writer.write( ' training * Avg Loss : {:.3f}'.format(epoch_train_loss)) self.logger.reset_valid_bar() epoch_val_loss = self.validate_inside_epoch_without_gt() self.logger.valid_writer.write( ' validation * Avg Loss : {:.3f}'.format(epoch_val_loss))
def main(): global args, best_error, n_iter, device args = parser.parse_args() save_path = save_path_formatter(args, parser) args.save_path = 'checkpoints_shifted' / save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append( SummaryWriter(args.save_path / 'valid' / str(i))) # Data loading code normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = ShiftedSequenceFolder( args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length, target_displacement=args.target_displacement) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data, transform=valid_transform) else: val_set = SequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) adjust_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True ) # workers is set to 0 to avoid multiple instances to be modified at the same time val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) train.args = args # create model print("=> creating model") disp_net = models.DispNetS().cuda() output_exp = args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") pose_exp_net = models.PoseExpNet( nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device) if args.pretrained_exp_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_exp_pose) pose_exp_net.load_state_dict(weights['state_dict'], strict=False) else: pose_exp_net.init_weights() if args.pretrained_disp: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_exp_net = torch.nn.DataParallel(pose_exp_net) print('=> setting adam solver') parameters = chain(disp_net.parameters(), pose_exp_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, logger, training_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) if (epoch + 1) % 5 == 0: train_set.adjust = True logger.reset_train_bar(len(adjust_loader)) average_shifts = adjust_shifts(args, train_set, adjust_loader, pose_exp_net, epoch, logger, training_writer) shifts_string = ' '.join( ['{:.3f}'.format(s) for s in average_shifts]) logger.train_writer.write( ' * adjusted shifts, average shifts are now : {}'.format( shifts_string)) for i, shift in enumerate(average_shifts): training_writer.add_scalar('shifts{}'.format(i), shift, epoch) train_set.adjust = False # evaluate on validation set logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, logger, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[0] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_exp_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) logger.epoch_bar.finish()
def main(): global args, best_error, n_iter n_iter = 0 best_error = 0 args = parser.parse_args() ''' args = ['--name', 'deemo', '--FCCMnet', 'PatchWiseNetwork', '--dataset_dir', '/notebooks/FCCM/pre-process/pre-process', '--label_dir', '/notebooks/FCCM', '--batch-size', '4', '--epochs','100', '--lr', '1e-4' ] ''' save_path = Path(args.name) args.save_path = 'checkpoints'/save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() train_writer = SummaryWriter(args.save_path) torch.manual_seed(args.seed) train_transform = custom_transforms.Compose([ custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor() ]) train_set = Generate_train_set( root = args.dataset_dir, label_root = args.label_dir, transform=train_transform, seed=args.seed, train=True ) val_set = Generate_val_set( root = args.dataset_dir, label_root = args.label_dir, seed=args.seed ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} val scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) print("=> creating model") if args.FCCMnet == 'VGG': FCCM_net = models_inpytorch.vgg16(num_classes=19) FCCM_net.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1) if args.FCCMnet == 'ResNet': FCCM_net = models.resnet18(num_classes=19) if args.FCCMnet == 'CatNet': FCCM_net = models.catnet18(num_classes=19) if args.FCCMnet == 'CatNet_FCCM': FCCM_net = models.catnet1(num_classes=19) if args.FCCMnet == 'ResNet_FCCM': FCCM_net = models.resnet1(num_classes=19) if args.FCCMnet == 'ImageWise': FCCM_net = models.ImageWiseNetwork() if args.FCCMnet == 'PatchWise': FCCM_net = models.PatchWiseNetwork() if args.FCCMnet == 'Baseline': FCCM_net = models.Baseline() FCCM_net = FCCM_net.cuda() if args.pretrained_model: print("=> using pre-trained weights for net") weights = torch.load(args.pretrained_model) FCCM_net.load_state_dict(weights['state_dict']) cudnn.benchmark = True FCCM_net = torch.nn.DataParallel(FCCM_net) print('=> setting adam solver') parameters = chain(FCCM_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) is_best = False best_error = float("inf") FCCM_net.train() loss = 0 for epoch in tqdm(range(args.epochs)): is_best = loss <= best_error best_error = min(best_error, loss) save_checkpoint( args.save_path, { 'epoch': epoch + 1, 'state_dict': FCCM_net.state_dict() }, { 'epoch': epoch + 1, 'state_dict': optimizer.state_dict() }, is_best) validation(val_loader, FCCM_net, epoch, train_writer) loss = train(train_loader, FCCM_net, optimizer, args.epoch_size, train_writer)
gt_1 = np.array(label_1, dtype=np.float32) gt_1 = gt_1 / np.max([gt_1.max(), 1e-8]) gt_t = np.array(label_t, dtype=np.float32) gt_t = gt_t / np.max([gt_t.max(), 1e-8]) # name = self.img_list[idx].split('/')[-1].split('.')[0] return img_1, img_t, gt_1, gt_t if __name__ == '__main__': import custom_transforms as tr import torch from torchvision import transforms from matplotlib import pyplot as plt from models import net from torch import nn transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.ToTensor()]) dataset = DAVIS_OVER_FIT_TEST1(db_root_dir='../../../DAVIS-2016', train=True, transform=transforms) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1) for i, data in enumerate(dataloader): plt.figure() # b = tens2image(data['image_1']) # plt.imshow(b) # print(data['image_1'][:, :, 3:15, 3:15]) x = data['image_1'] xt = data['image_t'] lab = data['gt'] # x[:, 0, :, :] = x[:, 0, :, :]*lab # x[:, 1, :, :] = x[:, 1, :, :]*lab
def main(): global global_vars_dict args = global_vars_dict['args'] best_error = -1 #best model choosing #mkdir timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M") args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.alternating: args.alternating_flags = np.array([False, False, True]) #mk writers tb_writer = SummaryWriter(args.save_path) # Data loading code and transpose if args.data_normalization == 'global': normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) elif args.data_normalization == 'local': normalize = custom_transforms.NormalizeLocally() valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data_dir)) train_transform = custom_transforms.Compose([ #custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) #train set, loader only建立一个 from datasets.sequence_mc import SequenceFolder train_set = SequenceFolder( # mc data folder args.data_dir, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length, # 5 target_transform=None, depth_format='png') train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) #val set,loader 挨个建立 #if args.val_with_depth_gt: from datasets.validation_folders2 import ValidationSet val_set_with_depth_gt = ValidationSet(args.data_dir, transform=valid_transform, depth_format='png') val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) #1 create model print("=> creating model") #1.1 disp_net disp_net = getattr(models, args.dispnet)().cuda() output_exp = True #args.mask_loss_weight > 0 if args.pretrained_disp: print("=> using pre-trained weights from {}".format( args.pretrained_disp)) weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() if args.resume: print("=> resuming from checkpoint") dispnet_weights = torch.load(args.save_path / 'dispnet_checkpoint.pth.tar') disp_net.load_state_dict(dispnet_weights['state_dict']) cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) print('=> setting adam solver') parameters = chain(disp_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) if args.resume and (args.save_path / 'optimizer_checkpoint.pth.tar').exists(): print("=> loading optimizer from checkpoint") optimizer_weights = torch.load(args.save_path / 'optimizer_checkpoint.pth.tar') optimizer.load_state_dict(optimizer_weights['state_dict']) # if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader_depth)) logger.reset_epoch_bar() else: logger = None #预先评估下 criterion_train = MaskedL1Loss().to(device) # l1LOSS 容易优化 criterion_val = ComputeErrors().to(device) #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict) #logger.reset_epoch_bar() # logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors) epoch_time = AverageMeter() end = time.time() #3. main cycle for epoch in range(1, args.epochs): #epoch 0 在第没入循环之前已经测试了. logger.reset_train_bar() logger.reset_valid_bar() errors = [0] error_names = ['no error names depth'] #3.2 train for one epoch--------- loss_names, losses = train_depth_gt(train_loader=train_loader, disp_net=disp_net, optimizer=optimizer, criterion=criterion_train, logger=logger, train_writer=tb_writer, global_vars_dict=global_vars_dict) #3.3 evaluate on validation set----- depth_error_names, depth_errors = validate_depth_with_gt( val_loader=val_loader_depth, disp_net=disp_net, criterion=criterion_val, epoch=epoch, logger=logger, tb_writer=tb_writer, global_vars_dict=global_vars_dict) epoch_time.update(time.time() - end) end = time.time() #3.5 log_terminal #if args.log_terminal: if args.log_terminal: logger.epoch_logger_update(epoch=epoch, time=epoch_time, names=depth_error_names, values=depth_errors) # tensorboard scaler #train loss for loss_name, loss in zip(loss_names, losses.avg): tb_writer.add_scalar('train/' + loss_name, loss, epoch) #val_with_gt loss for name, error in zip(depth_error_names, depth_errors.avg): tb_writer.add_scalar('val/' + name, error, epoch) #3.6 save model and remember lowest error and save checkpoint total_loss = losses.avg[0] if best_error < 0: best_error = total_loss is_best = total_loss <= best_error best_error = min(best_error, total_loss) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': None }, { 'epoch': epoch + 1, 'state_dict': None }, { 'epoch': epoch + 1, 'state_dict': None }, is_best) if args.log_terminal: logger.epoch_bar.finish()
def main(): global opt, best_prec1 opt = parser.parse_args() print(opt) # Data loading train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) valid_transform = custom_transforms.Compose([ custom_transforms.ArrayToTensor(), custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) print('Loading scenes in', opt.data_dir) train_set = SequenceFolder(opt.data_dir, transform=train_transform, seed=opt.seed, train=True, sequence_length=opt.sequence_length) val_set = ValidationSet(opt.data_dir, transform=valid_transform) print(len(train_set), 'samples found') print(len(val_set), 'samples found') train_loader = torch.utils.data.DataLoader(train_set, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=True) # val_loader = torch.utils.data.DataLoader(val_set, batch_size=opt.batch_size, # shuffle=False, num_workers=opt.workers, # pin_memory=True) if opt.epoch == 0: opt.epoch_size = len(train_loader) # Done loading disp_model = dispnet.DispNet().cuda() pose_model = posenet.PoseNet().cuda() disp_model, pose_model, optimizer = init.setup(disp_model, pose_model, opt) print(disp_model, pose_model) trainer = train.Trainer(disp_model, pose_model, optimizer, opt) if opt.resume: if os.path.isfile(opt.resume): # disp_model, pose_model, optimizer, opt, best_prec1 = init.resumer(opt, disp_model, pose_model, optimizer) disp_model, pose_model, optimizer, opt = init.resumer( opt, disp_model, pose_model, optimizer) else: print("=> no checkpoint found at '{}'".format(opt.resume)) cudnn.benchmark = True for epoch in range(opt.start_epoch, opt.epochs): utils.adjust_learning_rate(opt, optimizer, epoch) print("Starting epoch number:", epoch + 1, "Learning rate:", optimizer.param_groups[0]["lr"]) if opt.testOnly == False: trainer.train(train_loader, epoch, opt) # init.save_checkpoint(opt, disp_model, pose_model, optimizer, best_prec1, epoch) init.save_checkpoint(opt, disp_model, pose_model, optimizer, epoch)
from torchvision import datasets, transforms import torch.utils.data as data import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from PIL import Image import os import custom_transforms as trans img_transform = transforms.Compose([ trans.RandomHorizontalFlip(), trans.RandomGaussianBlur(), trans.RandomScaleCrop(700, 512), trans.Normalize(), trans.ToTensor() ]) class TrainImageFolder(data.Dataset): def __init__(self, data_dir): self.f = open(os.path.join(data_dir, 'train_id.txt')) self.file_list = self.f.readlines() self.data_dir = data_dir def __getitem__(self, index): img = Image.open( os.path.join(self.data_dir, 'train_images', self.file_list[index][:-1] + '.jpg')).convert('RGB') parse = Image.open( os.path.join(self.data_dir, 'train_segmentations',
def main(): global best_error, n_iter, device args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder, StereoSequenceFolder save_path = save_path_formatter(args, parser) args.save_path = 'checkpoints'/save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.evaluate: args.epochs = 0 training_writer = SummaryWriter(args.save_path) output_writers = [] if args.log_output: for i in range(3): output_writers.append(SummaryWriter(args.save_path/'valid'/str(i))) # Data loading code normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = StereoSequenceFolder( args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length ) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet( args.data, transform=valid_transform ) else: val_set = StereoSequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # 没有epoch_size的时候(=0),每个epoch训练train_set中所有的samples # 有epoch_size的时候,每个epoch只训练一部分train_set if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model # 初始化网络结构 print("=> creating model") # disp_net = models.DispNetS().to(device) disp_net = models.DispResNet(3).to(device) output_exp = args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") # 如果有mask loss,PoseExpNet 要输出mask和pose estimation,因为两个输出共享encoder网络 # pose_exp_net = PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device) pose_exp_net = models.PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device) if args.pretrained_exp_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_exp_pose) pose_exp_net.load_state_dict(weights['state_dict'], strict=False) else: pose_exp_net.init_weights() if args.pretrained_disp: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True # 并行化 disp_net = torch.nn.DataParallel(disp_net) pose_exp_net = torch.nn.DataParallel(pose_exp_net) # 训练方式:Adam print('=> setting adam solver') # 两个网络一起 optim_params = [ {'params': disp_net.parameters(), 'lr': args.lr}, {'params': pose_exp_net.parameters(), 'lr': args.lr} ] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path/args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path/args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) # 对pretrained模型先做评估 if args.pretrained_disp or args.evaluate: if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, output_writers) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, 0) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9])) # 正式训练 for epoch in range(args.epochs): # train for one epoch 训练一个周期 print('\n') train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, training_writer, epoch) # evaluate on validation set print('\n') if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, output_writers) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) # 验证输出四个loss:总体final loss,warping loss以及mask正则化loss # 可自选以哪一种loss作为best model的标准 decisive_error = errors[0] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint # 保存validation最佳model is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint( args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_exp_net.module.state_dict() }, is_best) with open(args.save_path/args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error])
# seq_name = self.video_list[idx].split('/')[2] return img, gt def get_img_size(self): img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0])) return list(img.shape[:2]) if __name__ == '__main__': import custom_transforms as tr import torch from torchvision import transforms from matplotlib import pyplot as plt from dataloader.helpers import overlay_mask, im_normalize, tens2image transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()]) dataset = DAVIS2016(db_root_dir='/home/ty/data/davis', train=True, transform=None) dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1) for i, data in enumerate(dataloader): # plt.figure() # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt']))) # if i == 10: # break print(data['img'].size()) print(data['img_gt'].size()) # plt.show(block=True)
def main(): global args, best_error, n_iter args = parser.parse_args() save_path = Path(args.name) args.save_path = 'checkpoints' / save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomRotate(), custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) training_writer = SummaryWriter(args.save_path) intrinsics = np.array( [542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0, 1]).astype(np.float32).reshape((3, 3)) train_set = SequenceFolder(root=args.dataset_dir, intrinsics=intrinsics, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) print("=> creating model") mask_net = MaskNet6.MaskNet6().cuda() flow_net = back2future.Model(nlevels=args.nlevels).cuda() pose_net = PoseNetB6.PoseNetB6().cuda() if args.pretrained_mask: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_mask) mask_net.load_state_dict(weights['state_dict']) else: mask_net.init_weights() mask_net = torch.nn.DataParallel(mask_net) if args.pretrained_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_pose) pose_net.load_state_dict(weights['state_dict']) else: pose_net.init_weights() if args.pretrained_flow: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_flow) flow_net.load_state_dict(weights['state_dict']) else: flow_net.init_weights() print('=> setting adam solver') parameters = chain(mask_net.parameters(), pose_net.parameters(), flow_net.parameters()) optimizer = torch.optim.Adam(parameters, args.lr, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) # training best_error = 0 train_loss = 0 for epoch in tqdm(range(args.epochs)): if args.fix_flownet: for fparams in flow_net.parameters(): fparams.requires_grad = False if args.fix_masknet: for fparams in mask_net.parameters(): fparams.requires_grad = False if args.fix_posenet: for fparams in pose_net.parameters(): fparams.requires_grad = False is_best = train_loss < best_error best_error = min(best_error, train_loss) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': mask_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_net.state_dict() }, { 'epoch': epoch + 1, 'state_dict': flow_net.state_dict() }, { 'epoch': epoch + 1, 'state_dict': optimizer.state_dict() }, is_best) train_loss = train(train_loader, mask_net, pose_net, flow_net, optimizer, args.epoch_size, training_writer)
def main(): global args, best_error, n_iter args = parser.parse_args() save_path = Path(args.name) args.save_path = 'checkpoints' / save_path #/timestamp print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) training_writer = SummaryWriter(args.save_path) output_writer = SummaryWriter(args.save_path / 'valid') # Data loading code flow_loader_h, flow_loader_w = 384, 1280 train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(h=256, w=256), custom_transforms.ArrayToTensor(), ]) valid_transform = custom_transforms.Compose([ custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w), custom_transforms.ArrayToTensor() ]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=3) if args.valset == "kitti2015": from datasets.validation_flow import ValidationFlowKitti2015 val_set = ValidationFlowKitti2015(root=args.kitti_data, transform=valid_transform) elif args.valset == "kitti2012": from datasets.validation_flow import ValidationFlowKitti2012 val_set = ValidationFlowKitti2012(root=args.kitti_data, transform=valid_transform) if args.DEBUG: train_set.__len__ = 32 train_set.samples = train_set.samples[:32] print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in valid scenes'.format(len(val_set))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader( val_set, batch_size= 1, # batch size is 1 since images in kitti have different sizes shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") if args.flownet == 'SpyNet': flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True) elif args.flownet == 'Back2Future': flow_net = getattr( models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar') elif args.flownet == 'PWCNet': flow_net = models.pwc_dc_net( 'pretrained/pwc_net_chairs.pth.tar') # pwc_net.pth.tar') else: flow_net = getattr(models, args.flownet)() if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']: print("=> using pre-trained weights for " + args.flownet) elif args.flownet in ['FlowNetC']: print("=> using pre-trained weights for FlowNetC") weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar') flow_net.load_state_dict(weights['state_dict']) elif args.flownet in ['FlowNetS']: print("=> using pre-trained weights for FlowNetS") weights = torch.load('pretrained/flownets.pth.tar') flow_net.load_state_dict(weights['state_dict']) elif args.flownet in ['FlowNet2']: print("=> using pre-trained weights for FlowNet2") weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar') flow_net.load_state_dict(weights['state_dict']) else: flow_net.init_weights() pytorch_total_params = sum(p.numel() for p in flow_net.parameters()) print("Number of model paramters: " + str(pytorch_total_params)) flow_net = flow_net.cuda() cudnn.benchmark = True if args.patch_type == 'circle': patch, mask, patch_shape = init_patch_circle(args.image_size, args.patch_size) patch_init = patch.copy() elif args.patch_type == 'square': patch, patch_shape = init_patch_square(args.image_size, args.patch_size) patch_init = patch.copy() mask = np.ones(patch_shape) else: sys.exit("Please choose a square or circle patch") if args.patch_path: patch, mask, patch_shape = init_patch_from_image( args.patch_path, args.mask_path, args.image_size, args.patch_size) patch_init = patch.copy() if args.log_terminal: logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader), attack_size=args.max_count) logger.epoch_bar.start() else: logger = None for epoch in range(args.epochs): if args.log_terminal: logger.epoch_bar.update(epoch) logger.reset_train_bar() # train for one epoch patch, mask, patch_init, patch_shape = train(patch, mask, patch_init, patch_shape, train_loader, flow_net, epoch, logger, training_writer) # Validate errors, error_names = validate_flow_with_gt(patch, mask, patch_shape, val_loader, flow_net, epoch, logger, output_writer) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) # if args.log_terminal: logger.valid_writer.write(' * Avg {}'.format(error_string)) else: print('Epoch {} completed'.format(epoch)) for error, name in zip(errors, error_names): training_writer.add_scalar(name, error, epoch) torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch))) if args.log_terminal: logger.epoch_bar.finish()
def main(): global best_error, n_iter, device args = parser.parse_args() if args.dataset_format == 'stacked': from datasets.stacked_sequence_folders import SequenceFolder elif args.dataset_format == 'sequential': from datasets.sequence_folders import SequenceFolder save_path = save_path_formatter(args, parser) args.save_path = 'checkpoints' / save_path print('=> will save everything to {}'.format(args.save_path)) args.save_path.makedirs_p() torch.manual_seed(args.seed) if args.evaluate: args.epochs = 0 tb_writer = SummaryWriter(args.save_path) # Data loading code normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_transform = custom_transforms.Compose([ custom_transforms.RandomHorizontalFlip(), custom_transforms.RandomScaleCrop(), custom_transforms.ArrayToTensor(), normalize ]) valid_transform = custom_transforms.Compose( [custom_transforms.ArrayToTensor(), normalize]) print("=> fetching scenes in '{}'".format(args.data)) train_set = SequenceFolder(args.data, transform=train_transform, seed=args.seed, train=True, sequence_length=args.sequence_length) # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping if args.with_gt: from datasets.validation_folders import ValidationSet val_set = ValidationSet(args.data, transform=valid_transform) else: val_set = SequenceFolder( args.data, transform=valid_transform, seed=args.seed, train=False, sequence_length=args.sequence_length, ) print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes))) print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes))) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) if args.epoch_size == 0: args.epoch_size = len(train_loader) # create model print("=> creating model") disp_net = models.DispNetS().to(device) seg_net = DeepLab(num_classes=args.nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn).to(device) if args.pretrained_seg: print("=> using pre-trained weights for seg net") weights = torch.load(args.pretrained_seg) seg_net.load_state_dict(weights, strict=False) output_exp = args.mask_loss_weight > 0 if not output_exp: print("=> no mask loss, PoseExpnet will only output pose") pose_exp_net = models.PoseExpNet( nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device) if args.pretrained_exp_pose: print("=> using pre-trained weights for explainabilty and pose net") weights = torch.load(args.pretrained_exp_pose) pose_exp_net.load_state_dict(weights['state_dict'], strict=False) else: pose_exp_net.init_weights() if args.pretrained_disp: print("=> using pre-trained weights for Dispnet") weights = torch.load(args.pretrained_disp) disp_net.load_state_dict(weights['state_dict']) else: disp_net.init_weights() cudnn.benchmark = True disp_net = torch.nn.DataParallel(disp_net) pose_exp_net = torch.nn.DataParallel(pose_exp_net) seg_net = torch.nn.DataParallel(seg_net) print('=> setting adam solver') optim_params = [{ 'params': disp_net.parameters(), 'lr': args.lr }, { 'params': pose_exp_net.parameters(), 'lr': args.lr }] optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay) with open(args.save_path / args.log_summary, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow(['train_loss', 'validation_loss']) with open(args.save_path / args.log_full, 'w') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow( ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss']) logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader)) logger.epoch_bar.start() if args.pretrained_disp or args.evaluate: logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, logger, tb_writer) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, logger, tb_writer) for error, name in zip(errors, error_names): tb_writer.add_scalar(name, error, 0) error_string = ', '.join( '{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9])) logger.valid_writer.write(' * Avg {}'.format(error_string)) for epoch in range(args.epochs): logger.epoch_bar.update(epoch) # train for one epoch logger.reset_train_bar() train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net, optimizer, args.epoch_size, logger, tb_writer) logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss)) # evaluate on validation set logger.reset_valid_bar() if args.with_gt: errors, error_names = validate_with_gt(args, val_loader, disp_net, seg_net, epoch, logger, tb_writer) else: errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, logger, tb_writer) error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors)) logger.valid_writer.write(' * Avg {}'.format(error_string)) for error, name in zip(errors, error_names): tb_writer.add_scalar(name, error, epoch) # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3) decisive_error = errors[1] if best_error < 0: best_error = decisive_error # remember lowest error and save checkpoint is_best = decisive_error < best_error best_error = min(best_error, decisive_error) save_checkpoint(args.save_path, { 'epoch': epoch + 1, 'state_dict': disp_net.module.state_dict() }, { 'epoch': epoch + 1, 'state_dict': pose_exp_net.module.state_dict() }, is_best) with open(args.save_path / args.log_summary, 'a') as csvfile: writer = csv.writer(csvfile, delimiter='\t') writer.writerow([train_loss, decisive_error]) logger.epoch_bar.finish()