Exemple #1
0
def dataflow_test():
    from DataFlow.sequence_folders import SequenceFolder
    import custom_transforms
    from torch.utils.data import DataLoader
    from DataFlow.validation_folders import ValidationSet
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(),
                                                 custom_transforms.RandomScaleCrop(),
                                                 custom_transforms.ArrayToTensor(), normalize])
    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])
    datapath = 'G:/data/KITTI/KittiRaw_formatted'
    seed = 8964
    train_set = SequenceFolder(datapath, transform=train_transform, seed=seed, train=True,
                               sequence_length=3)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4,
                              pin_memory=True)

    val_set = ValidationSet(datapath, transform=valid_transform)
    print("length of train loader is %d" % len(train_loader))
    val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
    print("length of val loader is %d" % len(val_loader))

    dataiter = iter(train_loader)
    imgs, intrinsics = next(dataiter)
    print(len(imgs))
    print(intrinsics.shape)

    pass
Exemple #2
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=True)
    print(len(train_loader))
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.FixedResize(size=self.args['crop_size']),
            tr.Normalize(mean=self.mean, std=self.std),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
Exemple #4
0
def get_train_transforms(normalize):
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  return train_transforms
 def transforms_train_esp(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomVerticalFlip(),
         tr.RandomHorizontalFlip(),
         tr.RandomAffine(degrees=40, scale=(.9, 1.1), shear=30),
         tr.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
         tr.FixedResize(size=self.input_size),
         tr.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                      std=[x / 255.0 for x in [63.0, 62.1, 66.7]]),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
Exemple #6
0
 def transform_pair_train(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
         tr.HorizontalFlip(),
         tr.GaussianBlur(),
         tr.Normalize(mean=self.source_dist['mean'],
                      std=self.source_dist['std'],
                      if_pair=True),
         tr.ToTensor(if_pair=True),
     ])
     return composed_transforms(sample)
Exemple #7
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),  #随机水平翻转
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),  #随机尺寸裁剪
            tr.RandomGaussianBlur(),  #随机高斯模糊
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),  #归一化
            tr.ToTensor()
        ])

        return composed_transforms(sample)
 def transform(self, sample):
     if self.train:
         composed_transforms = transforms.Compose([
             tr.RandomHorizontalFlip(),
             tr.RandomVerticalFlip(),
             tr.RandomScaleCrop(),
             tr.ToTensor()
         ])
     else:
         composed_transforms = transforms.Compose([
             tr.ToTensor()
         ])
     return composed_transforms(sample)
Exemple #9
0
 def transform_tr(self, sample):
     if not self.random_match:
         composed_transforms = transforms.Compose([
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             #tr.Remap(self.building_table, self.nonbuilding_table, self.channels)
             tr.RandomGaussianBlur(),
             #tr.ConvertFromInts(),
             #tr.PhotometricDistort(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     else:
         composed_transforms = transforms.Compose([
             tr.HistogramMatching(),
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             tr.RandomGaussianBlur(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     return composed_transforms(sample)
Exemple #10
0
def main(args):
  # parse args
  best_acc1 = 0.0

  if args.gpu >= 0:
    print("Use GPU: {}".format(args.gpu))
  else:
    print('You are using CPU for computing!',
          'Yet we assume you are using a GPU.',
          'You will NOT be able to switch between CPU and GPU training!')

  # set up the model + loss
  if args.use_custom_conv:
    print("Using custom convolutions in the network")
    model = default_model(conv_op=CustomConv2d, num_classes=100)
  else:
    model = default_model(num_classes=100)
  model_arch = "simplenet"
  # model_arch = "simplenet_batchnorm2d"
  # model_arch = "resnet18"
  # model_arch = "resnet34"
  criterion = nn.CrossEntropyLoss()
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion = criterion.cuda(args.gpu)

  # setup the optimizer
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      checkpoint = torch.load(args.resume)
      args.start_epoch = checkpoint['epoch']
      best_acc1 = checkpoint['best_acc1']
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      # only load the optimizer if necessary
      if (not args.evaluate) and (not args.attack):
        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))

  # set up transforms for data augmentation
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  train_transforms = []
  train_transforms.append(transforms.Scale(160))
  train_transforms.append(transforms.RandomHorizontalFlip())
  train_transforms.append(transforms.RandomColor(0.15))
  train_transforms.append(transforms.RandomRotate(15))
  train_transforms.append(transforms.RandomSizedCrop(128))
  train_transforms.append(transforms.ToTensor())
  train_transforms.append(normalize)
  train_transforms = transforms.Compose(train_transforms)
  # val transofrms
  val_transforms=[]
  val_transforms.append(transforms.Scale(160, interpolations=None))
  val_transforms.append(transforms.ToTensor())
  val_transforms.append(normalize)
  val_transforms = transforms.Compose(val_transforms)
  if (not args.evaluate) and (not args.attack):
    print("Training time data augmentations:")
    print(train_transforms)

  # setup dataset and dataloader
  train_dataset = MiniPlacesLoader(args.data_folder,
                  split='train', transforms=train_transforms)
  val_dataset = MiniPlacesLoader(args.data_folder,
                  split='val', transforms=val_transforms)

  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=100, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

  # testing only
  if (args.evaluate==args.attack) and args.evaluate:
    print("Cann't set evaluate and attack to True at the same time!")
    return

  # set up visualizer
  if args.vis:
    visualizer = default_attention(criterion)
  else:
    visualizer = None

  # evaluation
  if args.resume and args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args, visualizer=visualizer)
    return

  # attack
  if args.resume and args.attack:
    print("Generating adversarial samples for the model ..")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args,
             attacker=default_attack(criterion),
             visualizer=visualizer)
    return

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      train(train_loader, model, criterion, optimizer, epoch, "warmup", args)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, "train", args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, epoch, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
      'epoch': epoch + 1,
      'model_arch': model_arch,
      'state_dict': model.state_dict(),
      'best_acc1': best_acc1,
      'optimizer' : optimizer.state_dict(),
    }, is_best)
Exemple #11
0
        _target = Image.open(self.categories[index])

        return _img, _target

    def __str__(self):
        return 'ISPRS(split=' + str(self.split) + ')'


if __name__ == '__main__':
    import custom_transforms as tr
    from utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(),
                                                 tr.RandomCrop(512),
                                                 tr.RandomRotate(15),
                                                 tr.ToTensor()]
                                                )

    isprs_train = ISPRSSegmentation(split='train', transform=composed_transforms_tr)

    dataloader = DataLoader(isprs_train, batch_size=5, shuffle=True, num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            tmp = np.squeeze(tmp, axis=0)
    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.subplot(1, 2, 2)
    plt.imshow(np.array(gt) * 255)  # show object ID
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    # delete()
    # torch_VOC()

    # 1.Show transforms's result before transfer to tensor
    # execute by PIL image
    data_transforms = transforms.Compose([
        custom_transforms.RandomHorizontalFlip(1.0),
        custom_transforms.RandomRotation((-20, 20), scales=(0.75, 1.0)),
        custom_transforms.to_numpy()
    ])

    # excute by numpy image
    handcraft_transforms = custom_transforms.Compose_dict([
        custom_transforms.ObjectCenterCrop(),
        custom_transforms.Fix_size(size=512),
        custom_transforms.get_extreme_point_channel(),
        custom_transforms.concat_inputs()
    ])

    voc = VOC_Instance_Segmentation(split_sets=['train'], transform=data_transforms, \
                                    transform_handcraft=handcraft_transforms)
    for ii, dct in enumerate(voc):
Exemple #13
0
    start_epoch = 0
    print('===> Start from scratch')

if cuda:
    model.cuda()
    cudnn.benchmark = True

#%%
parser = argparse.ArgumentParser()
args = parser.parse_args()

args.base_size = 513
args.crop_size = 513

composed_transforms = transforms.Compose([
    tr.RandomHorizontalFlip(),  #随机水平翻转
    tr.RandomScaleCrop(base_size=args.base_size,
                       crop_size=args.crop_size),  #随机尺寸裁剪
    tr.RandomGaussianBlur(),  #随机高斯模糊
    tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  #归一化
    tr.ToTensor()
])

#%%
import matplotlib.pyplot as plt
import numpy as np
# dataset.utils import decode_segmap

tbar = tqdm(dataloader)
num_img_tr = len(dataloader)
for epoch in range(0, 10):
Exemple #14
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                            std=[0.225, 0.225, 0.225])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(args.data,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet(args.resnet_layers,
                                 args.with_pretrain).to(device)
    pose_net = models.PoseResNet(18, args.with_pretrain).to(device)

    # load parameters
    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_pose:
        print("=> using pre-trained weights for PoseResNet")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)

    print('=> setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_loss', 'smooth_loss',
            'geometry_consistency_loss'
        ])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemple #15
0
def main():
    global args, best_photo_loss, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format(
        args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight,
        args.mask_loss_weight, args.smooth_loss_weight))
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    train_writer = SummaryWriter(args.save_path / 'train')
    valid_writer = SummaryWriter(args.save_path / 'valid')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    input_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=input_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    val_set = SequenceFolder(args.data,
                             transform=custom_transforms.Compose([
                                 custom_transforms.ArrayToTensor(), normalize
                             ]),
                             seed=args.seed,
                             train=False,
                             sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).cuda()

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(train_loader, disp_net, pose_exp_net, optimizer,
                           args.epoch_size, logger, train_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        valid_photo_loss, valid_exp_loss, valid_total_loss = validate(
            val_loader, disp_net, pose_exp_net, epoch, logger, output_writers)
        logger.valid_writer.write(
            ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}'
            .format(valid_photo_loss, valid_exp_loss, valid_total_loss))
        valid_writer.add_scalar(
            'photometric_error', valid_photo_loss * 4, n_iter
        )  # Loss is multiplied by 4 because it's only one scale, instead of 4 during training
        valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4,
                                n_iter)
        valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter)

        if best_photo_loss < 0:
            best_photo_loss = valid_photo_loss

        # remember lowest error and save checkpoint
        is_best = valid_photo_loss < best_photo_loss
        best_photo_loss = min(valid_photo_loss, best_photo_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, valid_total_loss])
    logger.epoch_bar.finish()
Exemple #16
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False,False,True])

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization =='global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization =='local':
        normalize = custom_transforms.NormalizeLocally()


    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
 

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
                            custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    
    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )

    if args.with_flow_gt:
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                        sequence_length=args.sequence_length, transform=valid_flow_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.with_flow_gt:
        val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1,               # batch size is 1 since images in kitti have different sizes
                        shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    
    if args.flownet=='SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda()
    else:
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # load pre-trained weights

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    # else:
        #flow_net.init_weights()


    if args.resume:
        print("=> resuming from checkpoint")  
        flownet_weights = torch.load(args.save_path/'flownet_checkpoint.pth.tar')
        flow_net.load_state_dict(flownet_weights['state_dict'])


    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')
    parameters = chain(flow_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    milestones = [300]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)

    if args.min:
        print("using min method")

    if args.resume and (args.save_path/'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path/'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss'])

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
        logger.epoch_bar.start()
    else:
        logger=None

    for epoch in range(args.epochs):
        scheduler.step()

        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        train_loss = train(train_loader, flow_net, optimizer, args.epoch_size, logger, training_writer)

        if args.log_terminal:
            logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()


        if args.with_flow_gt:
            flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, flow_net, epoch, logger, output_writers)

            error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(flow_error_names, flow_errors))

            if args.log_terminal:
                logger.valid_writer.write(' * Avg {}'.format(error_string))
            else:
                print('Epoch {} completed'.format(epoch))

            for error, name in zip(flow_errors, flow_error_names):
                training_writer.add_scalar(name, error, epoch)

        
        decisive_error = flow_errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error <= best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': flow_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': optimizer.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #17
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    if args.fix_flownet:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomRotate(),
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    #train set, loader only建立一个
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
        train_set = SequenceFolder(  #mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  #5
            target_transform=None)
    elif args.dataset_format == 'sequential_with_gt':  # with all possible gt
        from datasets.sequence_mc import SequenceFolder
        train_set = SequenceFolder(  # mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  # 5
            target_transform=None)
    else:
        return

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.val_without_gt:
        from datasets.sequence_folders2 import SequenceFolder  #就多了一级文件夹
        val_set_without_gt = SequenceFolder(  #只有图
            args.data_dir,
            transform=valid_transform,
            seed=None,
            train=False,
            sequence_length=args.sequence_length,
            target_transform=None)
        val_loader = torch.utils.data.DataLoader(val_set_without_gt,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    if args.val_with_depth_gt:
        from datasets.validation_folders2 import ValidationSet

        val_set_with_depth_gt = ValidationSet(args.data_dir,
                                              transform=valid_transform)

        val_loader_depth = torch.utils.data.DataLoader(
            val_set_with_depth_gt,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    if args.val_with_flow_gt:  #暂时没有
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=args.sequence_length,
                                      transform=valid_flow_transform)
        val_flow_loader = torch.utils.data.DataLoader(
            val_flow_set,
            batch_size=1,
            # batch size is 1 since images in kitti have different sizes
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    if args.val_without_gt:
        print('{} samples found in {} valid scenes'.format(
            len(val_set_without_gt), len(val_set_without_gt.scenes)))

#1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length -
                                             1).cuda()

    #1.3.flow_net
    if args.flownet == 'SpyNet':
        flow_net = getattr(models,
                           args.flownet)(nlevels=args.nlevels,
                                         pre_normalization=normalize).cuda()
    elif args.flownet == 'FlowNetC6':  #flonwtc6
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'FlowNetS':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'Back2Future':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # 1.4 mask_net
    mask_net = getattr(models,
                       args.masknet)(nb_ref_imgs=args.sequence_length - 1,
                                     output_exp=True).cuda()

    #2 载入参数
    #2.1 pose
    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()

    # import ipdb; ipdb.set_trace()
    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        posenet_weights = torch.load(args.save_path /
                                     'posenet_checkpoint.pth.tar')
        masknet_weights = torch.load(args.save_path /
                                     'masknet_checkpoint.pth.tar')
        flownet_weights = torch.load(args.save_path /
                                     'flownet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])
        pose_net.load_state_dict(posenet_weights['state_dict'])
        flow_net.load_state_dict(flownet_weights['state_dict'])
        mask_net.load_state_dict(masknet_weights['state_dict'])

    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)
    mask_net = torch.nn.DataParallel(mask_net)
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_net.parameters(),
                       mask_net.parameters(), flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_cam_loss', 'photo_flow_loss',
            'explainability_loss', 'smooth_loss'
        ])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.epoch_bar.start()
    else:
        logger = None

#预先评估下

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.val_without_gt:
            pass
            #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict)
            #val_loss =0

        if args.val_with_depth_gt:
            pass
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=0,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)


#3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.
        #3.1 四个子网络,训练哪几个
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        if args.fix_dispnet:
            for fparams in disp_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()
        #validation data
        flow_error_names = ['no']
        flow_errors = [0]
        errors = [0]
        error_names = ['no error names depth']
        print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs))
        #3.2 train for one epoch---------
        #train_loss=0
        train_loss = train_gt(train_loader, disp_net, pose_net, mask_net,
                              flow_net, optimizer, logger, tb_writer,
                              global_vars_dict)

        #3.3 evaluate on validation set-----

        if args.val_without_gt:
            val_loss = validate_without_gt(val_loader,
                                           disp_net,
                                           pose_net,
                                           mask_net,
                                           flow_net,
                                           epoch=0,
                                           logger=logger,
                                           tb_writer=tb_writer,
                                           nb_writers=3,
                                           global_vars_dict=global_vars_dict)

        if args.val_with_depth_gt:
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=epoch,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)

        if args.val_with_flow_gt:
            pass
            #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer)

            #for error, name in zip(flow_errors, flow_error_names):
            #    training_writer.add_scalar(name, error, epoch)

        #----------------------

        #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        if not args.fix_posenet:
            decisive_error = 0  # flow_errors[-2]    # epe_rigid_with_gt_mask
        elif not args.fix_dispnet:
            decisive_error = 0  # errors[0]      #depth abs_diff
        elif not args.fix_flownet:
            decisive_error = 0  # flow_errors[-1]    #epe_non_rigid_with_gt_mask
        elif not args.fix_masknet:
            decisive_error = 0  #flow_errors[3]     # percent outliers

        #3.5 log
        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()
        #eopch data log on tensorboard
        #train loss
        tb_writer.add_scalar('epoch/train_loss', train_loss, epoch)
        #val_without_gt loss
        if args.val_without_gt:
            tb_writer.add_scalar('epoch/val_loss', val_loss, epoch)

        if args.val_with_depth_gt:
            #val with depth gt
            for error, name in zip(depth_errors, depth_error_names):
                tb_writer.add_scalar('epoch/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint

        if best_error < 0:
            best_error = train_loss

        is_best = train_loss <= best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #18
0
        return img, gt

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))

        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt

    transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.Resize(scales=[0.5, 0.8, 1]),
        tr.ToTensor()
    ])

    dataset = DAVIS2016(
        db_root_dir='/media/eec/external/Databases/Segmentation/DAVIS-2016',
        train=True,
        transform=transforms)
    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             shuffle=True,
                                             num_workers=1)

    for i, data in enumerate(dataloader):
        plt.figure()
Exemple #19
0
    def train(self):
        global n_iter

        if not self.train_flow:
            self.pose_net.train()
            self.disp_net.train()

        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

        valid_transform = custom_transforms.Compose(
            [custom_transforms.ArrayToTensor(), normalize])

        self.train_set = SequenceFolder(
            self.config['data'],
            transform=train_transform,
            split='train',
            seed=self.config['seed'],
            img_height=self.config['img_height'],
            img_width=self.config['img_width'],
            sequence_length=self.config['sequence_length'])

        self.val_set = SequenceFolder(
            self.config['data'],
            transform=valid_transform,
            split='val',
            seed=self.config['seed'],
            img_height=self.config['img_height'],
            img_width=self.config['img_width'],
            sequence_length=self.config['sequence_length'])

        self.train_loader = torch.utils.data.DataLoader(
            self.train_set,
            shuffle=True,
            drop_last=True,
            num_workers=self.config['data_workers'],
            batch_size=self.config['batch_size'],
            pin_memory=False)

        self.val_loader = torch.utils.data.DataLoader(
            self.val_set,
            shuffle=True,
            batch_size=self.config['batch_size'],
            drop_last=True,
            num_workers=self.config['data_workers'],
            pin_memory=False)

        optim_params = [{
            'params': v.parameters(),
            'lr': self.config['learning_rate']
        } for v in self.nets.values()]

        self.optimizer = torch.optim.Adam(
            optim_params,
            betas=(self.config['momentum'], self.config['beta']),
            weight_decay=self.config['weight_decay'])

        self.logger = TermLogger(n_epochs=self.config['epoch'],
                                 train_size=min(len(self.train_loader),
                                                self.config['epoch_size']),
                                 valid_size=len(self.val_loader))
        self.logger.epoch_bar.start()

        for epoch in range(self.epochs):
            self.logger.epoch_bar.update(epoch)
            self.logger.reset_train_bar()
            epoch_train_loss = self.training_inside_epoch()
            self.logger.train_writer.write(
                ' training * Avg Loss : {:.3f}'.format(epoch_train_loss))

            self.logger.reset_valid_bar()
            epoch_val_loss = self.validate_inside_epoch_without_gt()
            self.logger.valid_writer.write(
                ' validation * Avg Loss : {:.3f}'.format(epoch_val_loss))
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemple #21
0
def main():
    global args, best_error, n_iter
    n_iter = 0
    best_error = 0

    args = parser.parse_args()
    '''
    args = ['--name', 'deemo', '--FCCMnet', 'PatchWiseNetwork', 
                        '--dataset_dir', '/notebooks/FCCM/pre-process/pre-process', '--label_dir', '/notebooks/FCCM', 
                        '--batch-size', '4', '--epochs','100', '--lr', '1e-4' ]
    '''
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    train_writer = SummaryWriter(args.save_path)
    torch.manual_seed(args.seed)



    train_transform = custom_transforms.Compose([
          custom_transforms.RandomRotate(),
          custom_transforms.RandomHorizontalFlip(),
          custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor() ])

    train_set = Generate_train_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        transform=train_transform,
        seed=args.seed,
        train=True
    )

    val_set = Generate_val_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        seed=args.seed
    )

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} val scenes'.format(len(val_set), len(val_set.scenes)))
    
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    if args.FCCMnet == 'VGG':
        FCCM_net = models_inpytorch.vgg16(num_classes=19)
        FCCM_net.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1)
    if args.FCCMnet == 'ResNet':
        FCCM_net = models.resnet18(num_classes=19)
    if args.FCCMnet == 'CatNet':
        FCCM_net = models.catnet18(num_classes=19)
    if args.FCCMnet == 'CatNet_FCCM':
        FCCM_net = models.catnet1(num_classes=19)
    if args.FCCMnet == 'ResNet_FCCM':
        FCCM_net = models.resnet1(num_classes=19)
    if args.FCCMnet == 'ImageWise':
        FCCM_net = models.ImageWiseNetwork()
    if args.FCCMnet == 'PatchWise':
        FCCM_net = models.PatchWiseNetwork()
    if args.FCCMnet == 'Baseline':
        FCCM_net = models.Baseline()



    FCCM_net = FCCM_net.cuda()

    if args.pretrained_model:
        print("=> using pre-trained weights for net")
        weights = torch.load(args.pretrained_model)
        FCCM_net.load_state_dict(weights['state_dict'])


    cudnn.benchmark = True
    FCCM_net = torch.nn.DataParallel(FCCM_net)

    print('=> setting adam solver')

    parameters = chain(FCCM_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)


    is_best = False
    best_error = float("inf") 
    FCCM_net.train()
    loss = 0
    for epoch in tqdm(range(args.epochs)):
        is_best = loss <= best_error
        best_error = min(best_error, loss)

        save_checkpoint(
                args.save_path, {
                    'epoch': epoch + 1,
                    'state_dict': FCCM_net.state_dict()
                },  {                
                    'epoch': epoch + 1,
                    'state_dict': optimizer.state_dict()
                }, is_best)
        validation(val_loader, FCCM_net, epoch, train_writer)
        loss = train(train_loader, FCCM_net, optimizer, args.epoch_size, train_writer)
        gt_1 = np.array(label_1, dtype=np.float32)
        gt_1 = gt_1 / np.max([gt_1.max(), 1e-8])
        gt_t = np.array(label_t, dtype=np.float32)
        gt_t = gt_t / np.max([gt_t.max(), 1e-8])
        # name = self.img_list[idx].split('/')[-1].split('.')[0]
        return img_1, img_t, gt_1, gt_t

if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt
    from models import net
    from torch import nn

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.ToTensor()])

    dataset = DAVIS_OVER_FIT_TEST1(db_root_dir='../../../DAVIS-2016',
                        train=True, transform=transforms)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)

    for i, data in enumerate(dataloader):
        plt.figure()
        # b = tens2image(data['image_1'])
        # plt.imshow(b)
        # print(data['image_1'][:, :, 3:15, 3:15])
        x = data['image_1']
        xt = data['image_t']
        lab = data['gt']
        # x[:, 0, :, :] = x[:, 0, :, :]*lab
        # x[:, 1, :, :] = x[:, 1, :, :]*lab
Exemple #23
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
def main():
    global opt, best_prec1
    opt = parser.parse_args()
    print(opt)

    # Data loading
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    print('Loading scenes in', opt.data_dir)
    train_set = SequenceFolder(opt.data_dir,
                               transform=train_transform,
                               seed=opt.seed,
                               train=True,
                               sequence_length=opt.sequence_length)

    val_set = ValidationSet(opt.data_dir, transform=valid_transform)

    print(len(train_set), 'samples found')
    print(len(val_set), 'samples found')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.workers,
                                               pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(val_set, batch_size=opt.batch_size,
    #                                             shuffle=False, num_workers=opt.workers,
    #                                             pin_memory=True)
    if opt.epoch == 0:
        opt.epoch_size = len(train_loader)
    # Done loading

    disp_model = dispnet.DispNet().cuda()
    pose_model = posenet.PoseNet().cuda()
    disp_model, pose_model, optimizer = init.setup(disp_model, pose_model, opt)
    print(disp_model, pose_model)
    trainer = train.Trainer(disp_model, pose_model, optimizer, opt)
    if opt.resume:
        if os.path.isfile(opt.resume):
            # disp_model, pose_model, optimizer, opt, best_prec1 = init.resumer(opt, disp_model, pose_model, optimizer)
            disp_model, pose_model, optimizer, opt = init.resumer(
                opt, disp_model, pose_model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True
    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])
        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)
        # init.save_checkpoint(opt, disp_model, pose_model, optimizer, best_prec1, epoch)
        init.save_checkpoint(opt, disp_model, pose_model, optimizer, epoch)
Exemple #25
0
from torchvision import datasets, transforms
import torch.utils.data as data
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import custom_transforms as trans

img_transform = transforms.Compose([
    trans.RandomHorizontalFlip(),
    trans.RandomGaussianBlur(),
    trans.RandomScaleCrop(700, 512),
    trans.Normalize(),
    trans.ToTensor()
])


class TrainImageFolder(data.Dataset):
    def __init__(self, data_dir):
        self.f = open(os.path.join(data_dir, 'train_id.txt'))
        self.file_list = self.f.readlines()
        self.data_dir = data_dir

    def __getitem__(self, index):
        img = Image.open(
            os.path.join(self.data_dir, 'train_images',
                         self.file_list[index][:-1] + '.jpg')).convert('RGB')
        parse = Image.open(
            os.path.join(self.data_dir, 'train_segmentations',
Exemple #26
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder, StereoSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = StereoSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(
            args.data,
            transform=valid_transform
        )
    else:
        val_set = StereoSequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # 没有epoch_size的时候(=0),每个epoch训练train_set中所有的samples
    # 有epoch_size的时候,每个epoch只训练一部分train_set
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    # 初始化网络结构
    print("=> creating model")

    # disp_net = models.DispNetS().to(device)
    disp_net = models.DispResNet(3).to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    # 如果有mask loss,PoseExpNet 要输出mask和pose estimation,因为两个输出共享encoder网络
    # pose_exp_net = PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)
    pose_exp_net = models.PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    # 并行化
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # 训练方式:Adam
    print('=> setting adam solver')
    # 两个网络一起
    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr},
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    # 对pretrained模型先做评估
    if args.pretrained_disp or args.evaluate:
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))

    # 正式训练
    for epoch in range(args.epochs):

        # train for one epoch 训练一个周期
        print('\n')
        train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, training_writer, epoch)

        # evaluate on validation set
        print('\n')
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        # 验证输出四个loss:总体final loss,warping loss以及mask正则化loss
        # 可自选以哪一种loss作为best model的标准
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        # 保存validation最佳model
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disp_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': pose_exp_net.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
Exemple #27
0
        # seq_name = self.video_list[idx].split('/')[2]
        return img, gt

    def get_img_size(self):
        img = cv2.imread(os.path.join(self.db_root_dir, self.img_list[0]))

        return list(img.shape[:2])


if __name__ == '__main__':
    import custom_transforms as tr
    import torch
    from torchvision import transforms
    from matplotlib import pyplot as plt
    from dataloader.helpers import overlay_mask, im_normalize, tens2image

    transforms = transforms.Compose([tr.RandomHorizontalFlip(), tr.Resize(scales=[0.5, 0.8, 1]), tr.ToTensor()])

    dataset = DAVIS2016(db_root_dir='/home/ty/data/davis',
                        train=True, transform=None)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=True, num_workers=1)

    for i, data in enumerate(dataloader):
        # plt.figure()
        # plt.imshow(overlay_mask(im_normalize(tens2image(data['image'])), tens2image(data['gt'])))
        # if i == 10:
        #     break
        print(data['img'].size())
        print(data['img_gt'].size())

    # plt.show(block=True)
Exemple #28
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array(
        [542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0,
         1]).astype(np.float32).reshape((3, 3))
    train_set = SequenceFolder(root=args.dataset_dir,
                               intrinsics=intrinsics,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print("=> creating model")
    mask_net = MaskNet6.MaskNet6().cuda()
    flow_net = back2future.Model(nlevels=args.nlevels).cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()
    mask_net = torch.nn.DataParallel(mask_net)

    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    print('=> setting adam solver')
    parameters = chain(mask_net.parameters(), pose_net.parameters(),
                       flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # training
    best_error = 0
    train_loss = 0
    for epoch in tqdm(range(args.epochs)):
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        is_best = train_loss < best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': optimizer.state_dict()
        }, is_best)
        train_loss = train(train_loader, mask_net, pose_net, flow_net,
                           optimizer, args.epoch_size, training_writer)
Exemple #29
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(h=256, w=256),
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=3)

    if args.valset == "kitti2015":
        from datasets.validation_flow import ValidationFlowKitti2015
        val_set = ValidationFlowKitti2015(root=args.kitti_data,
                                          transform=valid_transform)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        val_set = ValidationFlowKitti2012(root=args.kitti_data,
                                          transform=valid_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in valid scenes'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    pytorch_total_params = sum(p.numel() for p in flow_net.parameters())
    print("Number of model paramters: " + str(pytorch_total_params))

    flow_net = flow_net.cuda()

    cudnn.benchmark = True
    if args.patch_type == 'circle':
        patch, mask, patch_shape = init_patch_circle(args.image_size,
                                                     args.patch_size)
        patch_init = patch.copy()
    elif args.patch_type == 'square':
        patch, patch_shape = init_patch_square(args.image_size,
                                               args.patch_size)
        patch_init = patch.copy()
        mask = np.ones(patch_shape)
    else:
        sys.exit("Please choose a square or circle patch")

    if args.patch_path:
        patch, mask, patch_shape = init_patch_from_image(
            args.patch_path, args.mask_path, args.image_size, args.patch_size)
        patch_init = patch.copy()

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader),
                            attack_size=args.max_count)
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        patch, mask, patch_init, patch_shape = train(patch, mask, patch_init,
                                                     patch_shape, train_loader,
                                                     flow_net, epoch, logger,
                                                     training_writer)

        # Validate
        errors, error_names = validate_flow_with_gt(patch, mask, patch_shape,
                                                    val_loader, flow_net,
                                                    epoch, logger,
                                                    output_writer)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        #
        if args.log_terminal:
            logger.valid_writer.write(' * Avg {}'.format(error_string))
        else:
            print('Epoch {} completed'.format(epoch))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch)))

    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #30
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    tb_writer = SummaryWriter(args.save_path)
    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().to(device)
    seg_net = DeepLab(num_classes=args.nclass,
                      backbone=args.backbone,
                      output_stride=args.out_stride,
                      sync_bn=args.sync_bn,
                      freeze_bn=args.freeze_bn).to(device)
    if args.pretrained_seg:
        print("=> using pre-trained weights for seg net")
        weights = torch.load(args.pretrained_seg)
        seg_net.load_state_dict(weights, strict=False)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)
    seg_net = torch.nn.DataParallel(seg_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger, tb_writer)
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net,
                           optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   seg_net, epoch, logger,
                                                   tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()