Exemple #1
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=True)
    print(len(train_loader))
Exemple #2
0
def dataflow_test():
    from DataFlow.sequence_folders import SequenceFolder
    import custom_transforms
    from torch.utils.data import DataLoader
    from DataFlow.validation_folders import ValidationSet
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(),
                                                 custom_transforms.RandomScaleCrop(),
                                                 custom_transforms.ArrayToTensor(), normalize])
    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])
    datapath = 'G:/data/KITTI/KittiRaw_formatted'
    seed = 8964
    train_set = SequenceFolder(datapath, transform=train_transform, seed=seed, train=True,
                               sequence_length=3)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4,
                              pin_memory=True)

    val_set = ValidationSet(datapath, transform=valid_transform)
    print("length of train loader is %d" % len(train_loader))
    val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
    print("length of val loader is %d" % len(val_loader))

    dataiter = iter(train_loader)
    imgs, intrinsics = next(dataiter)
    print(len(imgs))
    print(intrinsics.shape)

    pass
Exemple #3
0
 def transform_pair_train(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
         tr.HorizontalFlip(),
         tr.GaussianBlur(),
         tr.Normalize(mean=self.source_dist['mean'],
                      std=self.source_dist['std'],
                      if_pair=True),
         tr.ToTensor(if_pair=True),
     ])
     return composed_transforms(sample)
Exemple #4
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),  #随机水平翻转
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),  #随机尺寸裁剪
            tr.RandomGaussianBlur(),  #随机高斯模糊
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),  #归一化
            tr.ToTensor()
        ])

        return composed_transforms(sample)
 def transform(self, sample):
     if self.train:
         composed_transforms = transforms.Compose([
             tr.RandomHorizontalFlip(),
             tr.RandomVerticalFlip(),
             tr.RandomScaleCrop(),
             tr.ToTensor()
         ])
     else:
         composed_transforms = transforms.Compose([
             tr.ToTensor()
         ])
     return composed_transforms(sample)
Exemple #6
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            #tr.RandomHorizontalFlip(),

            # to make the image in the same batch has same shape
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),

            #tr.RandomGaussianBlur(),
            #tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])
        return composed_transforms(sample)
Exemple #7
0
 def transform_tr(self, sample):
     if not self.random_match:
         composed_transforms = transforms.Compose([
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             #tr.Remap(self.building_table, self.nonbuilding_table, self.channels)
             tr.RandomGaussianBlur(),
             #tr.ConvertFromInts(),
             #tr.PhotometricDistort(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     else:
         composed_transforms = transforms.Compose([
             tr.HistogramMatching(),
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             tr.RandomGaussianBlur(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     return composed_transforms(sample)
Exemple #8
0
                                                 num_workers=4)

        for item in dataloader:
            train = item['image']
            # train = np.array(train)      #?
            print(train.shape)
            print('sample {} images to calculate'.format(train.shape[0]))
            mean = np.mean(train.numpy(), axis=(0, 2, 3))
            std = np.std(train.numpy(), axis=(0, 2, 3))
        return mean, std


if __name__ == '__main__':
    trs = tf.Compose([
        tr.RandomHorizontalFlip(),
        tr.RandomScaleCrop(base_size=512, crop_size=512),
        tr.RandomGaussianBlur(),  #高斯模糊
        tr.Normalize(mean=LungDataset.mean, std=LungDataset.std),
        tr.ToTensor()
    ])
    dataset = LungDataset(root_dir=r'D:\code\U-net',
                          transforms=trs,
                          train=True)
    # dataset = LungDataset(root_dir=r'D:\code\U-net', transforms = False , train=True)
    # print(dataset.get_mean_std())
    for item in dataset:
        # print(item['label'].shape)
        # plt.imshow(image, cmap='gray')
        # print(item['image'])
        print(item['image'].min(), item['image'].max(), item['label'].min(),
              item['label'].max())
Exemple #9
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor()
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )

    val_set = KITTIDataset(
        root_dir,
        sequences,
        max_distance=args.max_distance,
        transform=None
    )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set)))

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    config_file = "./configs/e2e_mask_rcnn_R_50_FPN_1x.yaml"
    cfg.merge_from_file(config_file)
    cfg.freeze()
    pretrained_model_path = "./e2e_mask_rcnn_R_50_FPN_1x.pth"
    disvo = DISVO(cfg, pretrained_model_path).cuda()

    if args.pretrained_disvo:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disvo)
        disvo.load_state_dict(weights['state_dict'])
    else:
        disvo.init_weights()

    cudnn.benchmark = True

    print('=> setting adam solver')

    optim_params = [
        {'params': disvo.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disvo or args.evaluate:
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disvo, optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate(args, val_loader, disvo, 0, logger, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disvo.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemple #11
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #12
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()

    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    cudnn.deterministic = True
    cudnn.benchmark = True

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.45, 0.45, 0.45],
                                            std=[0.225, 0.225, 0.225])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(args.data,
                                 transform=valid_transform,
                                 seed=args.seed,
                                 train=False,
                                 sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet(args.resnet_layers,
                                 args.with_pretrain).to(device)
    pose_net = models.PoseResNet(18, args.with_pretrain).to(device)

    # load parameters
    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_pose:
        print("=> using pre-trained weights for PoseResNet")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)

    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)

    print('=> setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_loss', 'smooth_loss',
            'geometry_consistency_loss'
        ])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer,
                           args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemple #13
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        # custom_transforms.RandomRotate(),
        # custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array(
        [542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0,
         1]).astype(np.float32).reshape((3, 3))

    train_set = SequenceFolder(root=args.dataset_dir,
                               intrinsics=intrinsics,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    print("=> creating model")
    mask_net = MaskResNet6.MaskResNet6().cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()
    mask_net = torch.nn.DataParallel(mask_net)
    pose_net = torch.nn.DataParallel(pose_net)

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'], False)
    else:
        mask_net.init_weights()

    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'], False)
    else:
        pose_net.init_weights()

    print('=> setting adam solver')
    parameters = chain(mask_net.parameters(), pose_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # training
    best_error = 100001
    train_loss = 100000
    for epoch in tqdm(range(args.epochs)):

        train_loss = train(train_loader, mask_net, pose_net, optimizer,
                           args.epoch_size, training_writer)
        is_best = train_loss < best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': mask_net.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': optimizer.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.state_dict()
        }, is_best)
Exemple #14
0
def main():
    global args, best_error, n_iter
    n_iter = 0
    best_error = 0

    args = parser.parse_args()
    '''
    args = ['--name', 'deemo', '--FCCMnet', 'PatchWiseNetwork', 
                        '--dataset_dir', '/notebooks/FCCM/pre-process/pre-process', '--label_dir', '/notebooks/FCCM', 
                        '--batch-size', '4', '--epochs','100', '--lr', '1e-4' ]
    '''
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()

    train_writer = SummaryWriter(args.save_path)
    torch.manual_seed(args.seed)



    train_transform = custom_transforms.Compose([
          custom_transforms.RandomRotate(),
          custom_transforms.RandomHorizontalFlip(),
          custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor() ])

    train_set = Generate_train_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        transform=train_transform,
        seed=args.seed,
        train=True
    )

    val_set = Generate_val_set(
        root = args.dataset_dir,
        label_root = args.label_dir,
        seed=args.seed
    )

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} val scenes'.format(len(val_set), len(val_set.scenes)))
    
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    if args.FCCMnet == 'VGG':
        FCCM_net = models_inpytorch.vgg16(num_classes=19)
        FCCM_net.features[0]=nn.Conv2d(1, 64, kernel_size=3, padding=1)
    if args.FCCMnet == 'ResNet':
        FCCM_net = models.resnet18(num_classes=19)
    if args.FCCMnet == 'CatNet':
        FCCM_net = models.catnet18(num_classes=19)
    if args.FCCMnet == 'CatNet_FCCM':
        FCCM_net = models.catnet1(num_classes=19)
    if args.FCCMnet == 'ResNet_FCCM':
        FCCM_net = models.resnet1(num_classes=19)
    if args.FCCMnet == 'ImageWise':
        FCCM_net = models.ImageWiseNetwork()
    if args.FCCMnet == 'PatchWise':
        FCCM_net = models.PatchWiseNetwork()
    if args.FCCMnet == 'Baseline':
        FCCM_net = models.Baseline()



    FCCM_net = FCCM_net.cuda()

    if args.pretrained_model:
        print("=> using pre-trained weights for net")
        weights = torch.load(args.pretrained_model)
        FCCM_net.load_state_dict(weights['state_dict'])


    cudnn.benchmark = True
    FCCM_net = torch.nn.DataParallel(FCCM_net)

    print('=> setting adam solver')

    parameters = chain(FCCM_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)


    is_best = False
    best_error = float("inf") 
    FCCM_net.train()
    loss = 0
    for epoch in tqdm(range(args.epochs)):
        is_best = loss <= best_error
        best_error = min(best_error, loss)

        save_checkpoint(
                args.save_path, {
                    'epoch': epoch + 1,
                    'state_dict': FCCM_net.state_dict()
                },  {                
                    'epoch': epoch + 1,
                    'state_dict': optimizer.state_dict()
                }, is_best)
        validation(val_loader, FCCM_net, epoch, train_writer)
        loss = train(train_loader, FCCM_net, optimizer, args.epoch_size, train_writer)
Exemple #15
0
    print('===> Start from scratch')

if cuda:
    model.cuda()
    cudnn.benchmark = True

#%%
parser = argparse.ArgumentParser()
args = parser.parse_args()

args.base_size = 513
args.crop_size = 513

composed_transforms = transforms.Compose([
    tr.RandomHorizontalFlip(),  #随机水平翻转
    tr.RandomScaleCrop(base_size=args.base_size,
                       crop_size=args.crop_size),  #随机尺寸裁剪
    tr.RandomGaussianBlur(),  #随机高斯模糊
    tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),  #归一化
    tr.ToTensor()
])

#%%
import matplotlib.pyplot as plt
import numpy as np
# dataset.utils import decode_segmap

tbar = tqdm(dataloader)
num_img_tr = len(dataloader)
for epoch in range(0, 10):
    for i, sample in enumerate(tbar):
        image, target = sample['image'], sample['label']
Exemple #16
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()  #如果没有,则建立,有则啥都不干 in Path.py小工具
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0
#tensorboard SummaryWriter
    training_writer = SummaryWriter(args.save_path)  #for tensorboard

    output_writers = []  #list
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))
# Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    '''transform'''
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,  #processed_data_train_sets
        transform=train_transform,  #把几种变换函数输入进去
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length)
    # if no Groundtruth is avalaible, Validation set is
    # the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))  #训练集都是序列,不用左右
    print('{} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))  #测试集也是序列,不需要左右
    train_loader = torch.utils.data.DataLoader(  #data(list): [tensor(B,3,H,W),list(B),(B,H,W),(b,h,w)]
        dataset=train_set,  #sequenceFolder
        batch_size=args.batch_size,
        shuffle=True,  #打乱
        num_workers=args.workers,  #多线程读取数据
        pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        dataset=val_set,
        batch_size=args.batch_size,
        shuffle=False,  #不打乱
        num_workers=args.workers,
        pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

# create model
    print("=> creating model")
    #disp
    disp_net = models.DispNetS().to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #pose
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    #init posenet
    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    #init dispNet

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')
    #可以看到两个一起训练
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    #训练结果写入csv
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    n_epochs = args.epochs
    train_size = min(len(train_loader), args.epoch_size)
    valid_size = len(val_loader)
    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger,
                                                      output_writers)

        for error, name in zip(
                errors, error_names
        ):  #validation时,对['Total loss', 'Photo loss', 'Exp loss']三个 epoch-record 指标添加记录值
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))


#main cycle
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        logger.reset_train_bar()
        #1. train for one epoch
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        #其他参数都好解释, logger: SelfDefined class,

        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        logger.reset_valid_bar()

        # 2. validate on validation set
        if args.with_gt:  #<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error,
                                       epoch)  #损失函数中记录epoch-record指标

        # Up to you to chose the most relevant error to measure
        # your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        # 3. remember lowest error and save checkpoint
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)

        #模型保存
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary,
                  'a') as csvfile:  #每个epoch留下结果
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss,
                             decisive_error])  #第二个就是validataion 中的epoch-record
            # loss<class 'list'>: ['Total loss', 'Photo loss', 'Exp loss']
    logger.epoch_bar.finish()
Exemple #17
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    dpsnet = PSNet(args.nlabel, args.mindepth).cuda()

    if args.pretrained_dps:
        print("=> using pre-trained weights for DPSNet")
        weights = torch.load(args.pretrained_dps)
        dpsnet.load_state_dict(weights['state_dict'])
    else:
        dpsnet.init_weights()

    cudnn.benchmark = True
    dpsnet = torch.nn.DataParallel(dpsnet)

    print('=> setting adam solver')

    parameters = chain(dpsnet.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train_loss = train(args, train_loader, dpsnet, optimizer,
                           args.epoch_size, training_writer)
        errors, error_names = validate_with_gt(args, val_loader, dpsnet, epoch,
                                               output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': dpsnet.module.state_dict()
        }, epoch)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
Exemple #18
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    tb_writer = SummaryWriter(args.save_path)
    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().to(device)
    seg_net = DeepLab(num_classes=args.nclass,
                      backbone=args.backbone,
                      output_stride=args.out_stride,
                      sync_bn=args.sync_bn,
                      freeze_bn=args.freeze_bn).to(device)
    if args.pretrained_seg:
        print("=> using pre-trained weights for seg net")
        weights = torch.load(args.pretrained_seg)
        seg_net.load_state_dict(weights, strict=False)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)
    seg_net = torch.nn.DataParallel(seg_net)

    print('=> setting adam solver')

    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   0, logger, tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      0, logger, tb_writer)
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, 0)
        error_string = ', '.join(
            '{} : {:.3f}'.format(name, error)
            for name, error in zip(error_names[2:9], errors[2:9]))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net, seg_net,
                           optimizer, args.epoch_size, logger, tb_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   seg_net, epoch, logger,
                                                   tb_writer)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
Exemple #19
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False,False,True])

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization =='global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization =='local':
        normalize = custom_transforms.NormalizeLocally()


    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])
 

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
                            custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    
    val_set = SequenceFolder(
        args.data,
        transform=valid_transform,
        seed=args.seed,
        train=False,
        sequence_length=args.sequence_length,
    )

    if args.with_flow_gt:
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                        sequence_length=args.sequence_length, transform=valid_flow_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.with_flow_gt:
        val_flow_loader = torch.utils.data.DataLoader(val_flow_set, batch_size=1,               # batch size is 1 since images in kitti have different sizes
                        shuffle=False, num_workers=args.workers, pin_memory=True, drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")
    
    if args.flownet=='SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels, pre_normalization=normalize).cuda()
    else:
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # load pre-trained weights

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    # else:
        #flow_net.init_weights()


    if args.resume:
        print("=> resuming from checkpoint")  
        flownet_weights = torch.load(args.save_path/'flownet_checkpoint.pth.tar')
        flow_net.load_state_dict(flownet_weights['state_dict'])


    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')
    parameters = chain(flow_net.parameters())
    optimizer = torch.optim.Adam(parameters, args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    milestones = [300]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)

    if args.min:
        print("using min method")

    if args.resume and (args.save_path/'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path/'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_cam_loss', 'photo_flow_loss', 'explainability_loss', 'smooth_loss'])

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs, train_size=min(len(train_loader), args.epoch_size), valid_size=len(val_loader))
        logger.epoch_bar.start()
    else:
        logger=None

    for epoch in range(args.epochs):
        scheduler.step()

        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        train_loss = train(train_loader, flow_net, optimizer, args.epoch_size, logger, training_writer)

        if args.log_terminal:
            logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()


        if args.with_flow_gt:
            flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, flow_net, epoch, logger, output_writers)

            error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(flow_error_names, flow_errors))

            if args.log_terminal:
                logger.valid_writer.write(' * Avg {}'.format(error_string))
            else:
                print('Epoch {} completed'.format(epoch))

            for error, name in zip(flow_errors, flow_error_names):
                training_writer.add_scalar(name, error, epoch)

        
        decisive_error = flow_errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error <= best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': flow_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': optimizer.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
def main():
    global opt, best_prec1
    opt = parser.parse_args()
    print(opt)

    # Data loading
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    print('Loading scenes in', opt.data_dir)
    train_set = SequenceFolder(opt.data_dir,
                               transform=train_transform,
                               seed=opt.seed,
                               train=True,
                               sequence_length=opt.sequence_length)

    val_set = ValidationSet(opt.data_dir, transform=valid_transform)

    print(len(train_set), 'samples found')
    print(len(val_set), 'samples found')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=opt.workers,
                                               pin_memory=True)
    # val_loader = torch.utils.data.DataLoader(val_set, batch_size=opt.batch_size,
    #                                             shuffle=False, num_workers=opt.workers,
    #                                             pin_memory=True)
    if opt.epoch == 0:
        opt.epoch_size = len(train_loader)
    # Done loading

    disp_model = dispnet.DispNet().cuda()
    pose_model = posenet.PoseNet().cuda()
    disp_model, pose_model, optimizer = init.setup(disp_model, pose_model, opt)
    print(disp_model, pose_model)
    trainer = train.Trainer(disp_model, pose_model, optimizer, opt)
    if opt.resume:
        if os.path.isfile(opt.resume):
            # disp_model, pose_model, optimizer, opt, best_prec1 = init.resumer(opt, disp_model, pose_model, optimizer)
            disp_model, pose_model, optimizer, opt = init.resumer(
                opt, disp_model, pose_model, optimizer)
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    cudnn.benchmark = True
    for epoch in range(opt.start_epoch, opt.epochs):
        utils.adjust_learning_rate(opt, optimizer, epoch)
        print("Starting epoch number:", epoch + 1, "Learning rate:",
              optimizer.param_groups[0]["lr"])
        if opt.testOnly == False:
            trainer.train(train_loader, epoch, opt)
        # init.save_checkpoint(opt, disp_model, pose_model, optimizer, best_prec1, epoch)
        init.save_checkpoint(opt, disp_model, pose_model, optimizer, epoch)
Exemple #21
0
    def train(self):
        global n_iter

        if not self.train_flow:
            self.pose_net.train()
            self.disp_net.train()

        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

        valid_transform = custom_transforms.Compose(
            [custom_transforms.ArrayToTensor(), normalize])

        self.train_set = SequenceFolder(
            self.config['data'],
            transform=train_transform,
            split='train',
            seed=self.config['seed'],
            img_height=self.config['img_height'],
            img_width=self.config['img_width'],
            sequence_length=self.config['sequence_length'])

        self.val_set = SequenceFolder(
            self.config['data'],
            transform=valid_transform,
            split='val',
            seed=self.config['seed'],
            img_height=self.config['img_height'],
            img_width=self.config['img_width'],
            sequence_length=self.config['sequence_length'])

        self.train_loader = torch.utils.data.DataLoader(
            self.train_set,
            shuffle=True,
            drop_last=True,
            num_workers=self.config['data_workers'],
            batch_size=self.config['batch_size'],
            pin_memory=False)

        self.val_loader = torch.utils.data.DataLoader(
            self.val_set,
            shuffle=True,
            batch_size=self.config['batch_size'],
            drop_last=True,
            num_workers=self.config['data_workers'],
            pin_memory=False)

        optim_params = [{
            'params': v.parameters(),
            'lr': self.config['learning_rate']
        } for v in self.nets.values()]

        self.optimizer = torch.optim.Adam(
            optim_params,
            betas=(self.config['momentum'], self.config['beta']),
            weight_decay=self.config['weight_decay'])

        self.logger = TermLogger(n_epochs=self.config['epoch'],
                                 train_size=min(len(self.train_loader),
                                                self.config['epoch_size']),
                                 valid_size=len(self.val_loader))
        self.logger.epoch_bar.start()

        for epoch in range(self.epochs):
            self.logger.epoch_bar.update(epoch)
            self.logger.reset_train_bar()
            epoch_train_loss = self.training_inside_epoch()
            self.logger.train_writer.write(
                ' training * Avg Loss : {:.3f}'.format(epoch_train_loss))

            self.logger.reset_valid_bar()
            epoch_val_loss = self.validate_inside_epoch_without_gt()
            self.logger.valid_writer.write(
                ' validation * Avg Loss : {:.3f}'.format(epoch_val_loss))
Exemple #22
0
def main():
    global args, best_photo_loss, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format(
        args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight,
        args.mask_loss_weight, args.smooth_loss_weight))
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    train_writer = SummaryWriter(args.save_path / 'train')
    valid_writer = SummaryWriter(args.save_path / 'valid')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    input_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=input_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    val_set = SequenceFolder(args.data,
                             transform=custom_transforms.Compose([
                                 custom_transforms.ArrayToTensor(), normalize
                             ]),
                             seed=args.seed,
                             train=False,
                             sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).cuda()

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(train_loader, disp_net, pose_exp_net, optimizer,
                           args.epoch_size, logger, train_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        valid_photo_loss, valid_exp_loss, valid_total_loss = validate(
            val_loader, disp_net, pose_exp_net, epoch, logger, output_writers)
        logger.valid_writer.write(
            ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}'
            .format(valid_photo_loss, valid_exp_loss, valid_total_loss))
        valid_writer.add_scalar(
            'photometric_error', valid_photo_loss * 4, n_iter
        )  # Loss is multiplied by 4 because it's only one scale, instead of 4 during training
        valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4,
                                n_iter)
        valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter)

        if best_photo_loss < 0:
            best_photo_loss = valid_photo_loss

        # remember lowest error and save checkpoint
        is_best = valid_photo_loss < best_photo_loss
        best_photo_loss = min(valid_photo_loss, best_photo_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, valid_total_loss])
    logger.epoch_bar.finish()
Exemple #23
0
def main():
    global best_error, n_iter, device
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder, StereoSequenceFolder
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints'/save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.evaluate:
        args.epochs = 0

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(SummaryWriter(args.save_path/'valid'/str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = StereoSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length
    )

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(
            args.data,
            transform=valid_transform
        )
    else:
        val_set = StereoSequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set), len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # 没有epoch_size的时候(=0),每个epoch训练train_set中所有的samples
    # 有epoch_size的时候,每个epoch只训练一部分train_set
    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    # 初始化网络结构
    print("=> creating model")

    # disp_net = models.DispNetS().to(device)
    disp_net = models.DispResNet(3).to(device)
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    # 如果有mask loss,PoseExpNet 要输出mask和pose estimation,因为两个输出共享encoder网络
    # pose_exp_net = PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)
    pose_exp_net = models.PoseExpNet(nb_ref_imgs=args.sequence_length - 1, output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    # 并行化
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # 训练方式:Adam
    print('=> setting adam solver')
    # 两个网络一起
    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr},
        {'params': pose_exp_net.parameters(), 'lr': args.lr}
    ]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path/args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path/args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    # 对pretrained模型先做评估
    if args.pretrained_disp or args.evaluate:
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, 0, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, 0, output_writers)
        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, 0)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names[2:9], errors[2:9]))

    # 正式训练
    for epoch in range(args.epochs):

        # train for one epoch 训练一个周期
        print('\n')
        train_loss = train(args, train_loader, disp_net, pose_exp_net, optimizer, args.epoch_size, training_writer, epoch)

        # evaluate on validation set
        print('\n')
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net, epoch, output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_exp_net, epoch, output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        # 验证输出四个loss:总体final loss,warping loss以及mask正则化loss
        # 可自选以哪一种loss作为best model的标准
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        # 保存validation最佳model
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(
            args.save_path, {
                'epoch': epoch + 1,
                'state_dict': disp_net.module.state_dict()
            }, {
                'epoch': epoch + 1,
                'state_dict': pose_exp_net.module.state_dict()
            },
            is_best)

        with open(args.save_path/args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
Exemple #24
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / (args.exp + '_' + save_path)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype,
                               dataset=args.dataset)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             dataset=args.dataset)

    train_set.samples = train_set.samples[:len(train_set) -
                                          len(train_set) % args.batch_size]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    mvdnet = MVDNet(args.nlabel, args.mindepth, no_pool=args.no_pool).cuda()
    mvdnet.init_weights()
    if args.pretrained_mvdn:
        print("=> using pre-trained weights for MVDNet")
        weights = torch.load(args.pretrained_mvdn)
        mvdnet.load_state_dict(weights['state_dict'])

    depth_cons = DepthCons().cuda()
    depth_cons.init_weights()

    if args.pretrained_cons:
        print("=> using pre-trained weights for ConsNet")
        weights = torch.load(args.pretrained_cons)
        depth_cons.load_state_dict(weights['state_dict'])

    cons_loss_ = ConsLoss().cuda()
    print('=> setting adam solver')

    if args.train_cons:
        optimizer = torch.optim.Adam(depth_cons.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
        mvdnet.eval()
    else:
        optimizer = torch.optim.Adam(mvdnet.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)

    cudnn.benchmark = True
    mvdnet = torch.nn.DataParallel(mvdnet)
    depth_cons = torch.nn.DataParallel(depth_cons)

    print(' ==> setting log files')
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'validation_abs_rel', 'validation_abs_diff',
            'validation_sq_rel', 'validation_a1', 'validation_a2',
            'validation_a3', 'mean_angle_error'
        ])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    print(' ==> main Loop')
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        if args.evaluate:
            train_loss = 0
        else:
            train_loss = train(args, train_loader, mvdnet, depth_cons,
                               cons_loss_, optimizer, args.epoch_size,
                               training_writer, epoch)
        if not args.evaluate and (args.skip_v):
            error_names = [
                'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'angle'
            ]
            errors = [0] * 7
        else:
            errors, error_names = validate_with_gt(args, val_loader, mvdnet,
                                                   depth_cons, epoch,
                                                   output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, decisive_error, errors[1], errors[2], errors[3],
                errors[4], errors[5], errors[6]
            ])
        if args.evaluate:
            break
        if args.train_cons:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': depth_cons.module.state_dict()
            },
                            epoch,
                            file_prefixes=['cons'])
        else:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': mvdnet.module.state_dict()
            },
                            epoch,
                            file_prefixes=['mvdnet'])
Exemple #25
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        ttype=args.ttype,
        add_geo=args.geo,
        depth_source=args.depth_init,
        pose_source='%s_poses.txt' %
        args.pose_init if args.pose_init else 'poses.txt',
        scale=False)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             add_geo=args.geo,
                             depth_source=args.depth_init,
                             pose_source='%s_poses.txt' %
                             args.pose_init if args.pose_init else 'poses.txt',
                             scale=False)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    depth_net = PSNet(args.nlabel,
                      args.mindepth,
                      add_geo_cost=args.geo,
                      depth_augment=False).cuda()

    if args.pretrained_dps:
        # for param in depth_net.feature_extraction.parameters():
        #     param.requires_grad = False

        print("=> using pre-trained weights for DPSNet")
        model_dict = depth_net.state_dict()
        weights = torch.load(args.pretrained_dps)['state_dict']
        pretrained_dict = {k: v for k, v in weights.items() if k in model_dict}

        model_dict.update(pretrained_dict)

        depth_net.load_state_dict(model_dict)

    else:
        depth_net.init_weights()

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)

    print('=> setting adam solver')

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        depth_net.parameters()),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        train_loss = train(args, train_loader, depth_net, optimizer,
                           args.epoch_size, training_writer)

        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': depth_net.module.state_dict()
        }, epoch)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss])
Exemple #26
0
from torchvision import datasets, transforms
import torch.utils.data as data
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import custom_transforms as trans

img_transform = transforms.Compose([
    trans.RandomHorizontalFlip(),
    trans.RandomGaussianBlur(),
    trans.RandomScaleCrop(700, 512),
    trans.Normalize(),
    trans.ToTensor()
])


class TrainImageFolder(data.Dataset):
    def __init__(self, data_dir):
        self.f = open(os.path.join(data_dir, 'train_id.txt'))
        self.file_list = self.f.readlines()
        self.data_dir = data_dir

    def __getitem__(self, index):
        img = Image.open(
            os.path.join(self.data_dir, 'train_images',
                         self.file_list[index][:-1] + '.jpg')).convert('RGB')
        parse = Image.open(
            os.path.join(self.data_dir, 'train_segmentations',
Exemple #27
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code
    flow_loader_h, flow_loader_w = 256, 832

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    if args.fix_flownet:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomRotate(),
            custom_transforms.RandomHorizontalFlip(),
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    #train set, loader only建立一个
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
        train_set = SequenceFolder(  #mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  #5
            target_transform=None)
    elif args.dataset_format == 'sequential_with_gt':  # with all possible gt
        from datasets.sequence_mc import SequenceFolder
        train_set = SequenceFolder(  # mc data folder
            args.data_dir,
            transform=train_transform,
            seed=args.seed,
            train=True,
            sequence_length=args.sequence_length,  # 5
            target_transform=None)
    else:
        return

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立

# if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.val_without_gt:
        from datasets.sequence_folders2 import SequenceFolder  #就多了一级文件夹
        val_set_without_gt = SequenceFolder(  #只有图
            args.data_dir,
            transform=valid_transform,
            seed=None,
            train=False,
            sequence_length=args.sequence_length,
            target_transform=None)
        val_loader = torch.utils.data.DataLoader(val_set_without_gt,
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True,
                                                 drop_last=True)

    if args.val_with_depth_gt:
        from datasets.validation_folders2 import ValidationSet

        val_set_with_depth_gt = ValidationSet(args.data_dir,
                                              transform=valid_transform)

        val_loader_depth = torch.utils.data.DataLoader(
            val_set_with_depth_gt,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    if args.val_with_flow_gt:  #暂时没有
        from datasets.validation_flow import ValidationFlow
        val_flow_set = ValidationFlow(root=args.kitti_dir,
                                      sequence_length=args.sequence_length,
                                      transform=valid_flow_transform)
        val_flow_loader = torch.utils.data.DataLoader(
            val_flow_set,
            batch_size=1,
            # batch size is 1 since images in kitti have different sizes
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True,
            drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    if args.val_without_gt:
        print('{} samples found in {} valid scenes'.format(
            len(val_set_without_gt), len(val_set_without_gt.scenes)))

#1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    #1.2 pose_net
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=args.sequence_length -
                                             1).cuda()

    #1.3.flow_net
    if args.flownet == 'SpyNet':
        flow_net = getattr(models,
                           args.flownet)(nlevels=args.nlevels,
                                         pre_normalization=normalize).cuda()
    elif args.flownet == 'FlowNetC6':  #flonwtc6
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'FlowNetS':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()
    elif args.flownet == 'Back2Future':
        flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    # 1.4 mask_net
    mask_net = getattr(models,
                       args.masknet)(nb_ref_imgs=args.sequence_length - 1,
                                     output_exp=True).cuda()

    #2 载入参数
    #2.1 pose
    if args.pretrained_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_pose)
        pose_net.load_state_dict(weights['state_dict'])
    else:
        pose_net.init_weights()

    if args.pretrained_mask:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_mask)
        mask_net.load_state_dict(weights['state_dict'])
    else:
        mask_net.init_weights()

    # import ipdb; ipdb.set_trace()
    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.pretrained_flow:
        print("=> using pre-trained weights for FlowNet")
        weights = torch.load(args.pretrained_flow)
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        posenet_weights = torch.load(args.save_path /
                                     'posenet_checkpoint.pth.tar')
        masknet_weights = torch.load(args.save_path /
                                     'masknet_checkpoint.pth.tar')
        flownet_weights = torch.load(args.save_path /
                                     'flownet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])
        pose_net.load_state_dict(posenet_weights['state_dict'])
        flow_net.load_state_dict(flownet_weights['state_dict'])
        mask_net.load_state_dict(masknet_weights['state_dict'])

    # import ipdb; ipdb.set_trace()
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_net = torch.nn.DataParallel(pose_net)
    mask_net = torch.nn.DataParallel(mask_net)
    flow_net = torch.nn.DataParallel(flow_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_net.parameters(),
                       mask_net.parameters(), flow_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'photo_cam_loss', 'photo_flow_loss',
            'explainability_loss', 'smooth_loss'
        ])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.epoch_bar.start()
    else:
        logger = None

#预先评估下

    if args.pretrained_disp or args.evaluate:
        logger.reset_valid_bar()
        if args.val_without_gt:
            pass
            #val_loss = validate_without_gt(val_loader,disp_net,pose_net,mask_net,flow_net,epoch=0, logger=logger, tb_writer=tb_writer,nb_writers=3,global_vars_dict = global_vars_dict)
            #val_loss =0

        if args.val_with_depth_gt:
            pass
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=0,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)


#3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.
        #3.1 四个子网络,训练哪几个
        if args.fix_flownet:
            for fparams in flow_net.parameters():
                fparams.requires_grad = False

        if args.fix_masknet:
            for fparams in mask_net.parameters():
                fparams.requires_grad = False

        if args.fix_posenet:
            for fparams in pose_net.parameters():
                fparams.requires_grad = False

        if args.fix_dispnet:
            for fparams in disp_net.parameters():
                fparams.requires_grad = False

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()
        #validation data
        flow_error_names = ['no']
        flow_errors = [0]
        errors = [0]
        error_names = ['no error names depth']
        print('\nepoch [{}/{}]\n'.format(epoch + 1, args.epochs))
        #3.2 train for one epoch---------
        #train_loss=0
        train_loss = train_gt(train_loader, disp_net, pose_net, mask_net,
                              flow_net, optimizer, logger, tb_writer,
                              global_vars_dict)

        #3.3 evaluate on validation set-----

        if args.val_without_gt:
            val_loss = validate_without_gt(val_loader,
                                           disp_net,
                                           pose_net,
                                           mask_net,
                                           flow_net,
                                           epoch=0,
                                           logger=logger,
                                           tb_writer=tb_writer,
                                           nb_writers=3,
                                           global_vars_dict=global_vars_dict)

        if args.val_with_depth_gt:
            depth_errors, depth_error_names = validate_depth_with_gt(
                val_loader_depth,
                disp_net,
                epoch=epoch,
                logger=logger,
                tb_writer=tb_writer,
                global_vars_dict=global_vars_dict)

        if args.val_with_flow_gt:
            pass
            #flow_errors, flow_error_names = validate_flow_with_gt(val_flow_loader, disp_net, pose_net, mask_net, flow_net, epoch, logger, tb_writer)

            #for error, name in zip(flow_errors, flow_error_names):
            #    training_writer.add_scalar(name, error, epoch)

        #----------------------

        #3.4 Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)

        if not args.fix_posenet:
            decisive_error = 0  # flow_errors[-2]    # epe_rigid_with_gt_mask
        elif not args.fix_dispnet:
            decisive_error = 0  # errors[0]      #depth abs_diff
        elif not args.fix_flownet:
            decisive_error = 0  # flow_errors[-1]    #epe_non_rigid_with_gt_mask
        elif not args.fix_masknet:
            decisive_error = 0  #flow_errors[3]     # percent outliers

        #3.5 log
        if args.log_terminal:
            logger.train_writer.write(
                ' * Avg Loss : {:.3f}'.format(train_loss))
            logger.reset_valid_bar()
        #eopch data log on tensorboard
        #train loss
        tb_writer.add_scalar('epoch/train_loss', train_loss, epoch)
        #val_without_gt loss
        if args.val_without_gt:
            tb_writer.add_scalar('epoch/val_loss', val_loss, epoch)

        if args.val_with_depth_gt:
            #val with depth gt
            for error, name in zip(depth_errors, depth_error_names):
                tb_writer.add_scalar('epoch/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint

        if best_error < 0:
            best_error = train_loss

        is_best = train_loss <= best_error
        best_error = min(best_error, train_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': mask_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': flow_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #28
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints' / save_path  #/timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writer = SummaryWriter(args.save_path / 'valid')

    # Data loading code
    flow_loader_h, flow_loader_w = 384, 1280

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(h=256, w=256),
        custom_transforms.ArrayToTensor(),
    ])

    valid_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor()
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=3)

    if args.valset == "kitti2015":
        from datasets.validation_flow import ValidationFlowKitti2015
        val_set = ValidationFlowKitti2015(root=args.kitti_data,
                                          transform=valid_transform)
    elif args.valset == "kitti2012":
        from datasets.validation_flow import ValidationFlowKitti2012
        val_set = ValidationFlowKitti2012(root=args.kitti_data,
                                          transform=valid_transform)

    if args.DEBUG:
        train_set.__len__ = 32
        train_set.samples = train_set.samples[:32]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in valid scenes'.format(len(val_set)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=
        1,  # batch size is 1 since images in kitti have different sizes
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.flownet == 'SpyNet':
        flow_net = getattr(models, args.flownet)(nlevels=6, pretrained=True)
    elif args.flownet == 'Back2Future':
        flow_net = getattr(
            models, args.flownet)(pretrained='pretrained/b2f_rm_hard.pth.tar')
    elif args.flownet == 'PWCNet':
        flow_net = models.pwc_dc_net(
            'pretrained/pwc_net_chairs.pth.tar')  # pwc_net.pth.tar')
    else:
        flow_net = getattr(models, args.flownet)()

    if args.flownet in ['SpyNet', 'Back2Future', 'PWCNet']:
        print("=> using pre-trained weights for " + args.flownet)
    elif args.flownet in ['FlowNetC']:
        print("=> using pre-trained weights for FlowNetC")
        weights = torch.load('pretrained/FlowNet2-C_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNetS']:
        print("=> using pre-trained weights for FlowNetS")
        weights = torch.load('pretrained/flownets.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    elif args.flownet in ['FlowNet2']:
        print("=> using pre-trained weights for FlowNet2")
        weights = torch.load('pretrained/FlowNet2_checkpoint.pth.tar')
        flow_net.load_state_dict(weights['state_dict'])
    else:
        flow_net.init_weights()

    pytorch_total_params = sum(p.numel() for p in flow_net.parameters())
    print("Number of model paramters: " + str(pytorch_total_params))

    flow_net = flow_net.cuda()

    cudnn.benchmark = True
    if args.patch_type == 'circle':
        patch, mask, patch_shape = init_patch_circle(args.image_size,
                                                     args.patch_size)
        patch_init = patch.copy()
    elif args.patch_type == 'square':
        patch, patch_shape = init_patch_square(args.image_size,
                                               args.patch_size)
        patch_init = patch.copy()
        mask = np.ones(patch_shape)
    else:
        sys.exit("Please choose a square or circle patch")

    if args.patch_path:
        patch, mask, patch_shape = init_patch_from_image(
            args.patch_path, args.mask_path, args.image_size, args.patch_size)
        patch_init = patch.copy()

    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader),
                            attack_size=args.max_count)
        logger.epoch_bar.start()
    else:
        logger = None

    for epoch in range(args.epochs):

        if args.log_terminal:
            logger.epoch_bar.update(epoch)
            logger.reset_train_bar()

        # train for one epoch
        patch, mask, patch_init, patch_shape = train(patch, mask, patch_init,
                                                     patch_shape, train_loader,
                                                     flow_net, epoch, logger,
                                                     training_writer)

        # Validate
        errors, error_names = validate_flow_with_gt(patch, mask, patch_shape,
                                                    val_loader, flow_net,
                                                    epoch, logger,
                                                    output_writer)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        #
        if args.log_terminal:
            logger.valid_writer.write(' * Avg {}'.format(error_string))
        else:
            print('Epoch {} completed'.format(epoch))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        torch.save(patch, args.save_path / 'patch_epoch_{}'.format(str(epoch)))

    if args.log_terminal:
        logger.epoch_bar.finish()
Exemple #29
0
def main():
    global best_a3, n_iter, device
    args = parser.parse_args()
    torch.autograd.set_detect_anomaly(True)  # 启动梯度侦测,用于查找梯度终断
    """====== step 1 : 根据使用的数据类型加载相应的数据流水线  ======"""
    if args.dataset_format == 'stacked':
        from DataFlow.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from DataFlow.sequence_folders import SequenceFolder
    """====== step 2 : 准备存储目录 ======"""
    save_path = save_path_formatter(args, parser)
    if sys.platform is 'win32':
        args.save_path = '.\checkpoints' / save_path
    else:  # linux
        args.save_path = 'checkpoints' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    tb_writer = SummaryWriter(args.save_path)  # tensorboardx writer
    """====== step 3 : 指定随机数种子以便于实验复现 ======"""
    torch.manual_seed(args.seed)
    """========= step 4 : 数据准备 =========="""
    # 数据扩增
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    # 训练集
    print("=> fetching data from '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.SEQ_LENGTH)
    # 验证集
    val_set = ValidationSet(args.data, transform=valid_transform)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              pin_memory=True)
    val_loader = DataLoader(val_set,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers,
                            pin_memory=True)
    """========= step 5 : 加载模型 =========="""
    print("=> creating models")

    depth_net = models.DepthNet().to(device)
    motion_net = models.MotionNet(intrinsic_pred=args.intri_pred).to(device)

    if args.pretrained_depth:
        print("=> using pre-trained weights for DepthNet")
        weights = torch.load(args.pretrained_depth)
        depth_net.load_state_dict(weights['state_dict'], strict=False)

    if args.pretrained_motion:
        print("=> using pre-trained weights for MotionNet")
        weights = torch.load(args.pretrained_motion)
        motion_net.load_state_dict(weights['state_dict'])

    cudnn.benchmark = True
    depth_net = torch.nn.DataParallel(depth_net)
    motion_net = torch.nn.DataParallel(motion_net)
    """========= step 6 : 设置求解器 =========="""
    print('=> setting adam solver')

    optim_params = [{
        'params': depth_net.parameters(),
        'lr': args.lr
    }, {
        'params': motion_net.parameters(),
        'lr': args.lr
    }]

    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)
    """====== step 7 : 初始化损失函数计算器======="""
    total_loss_calculator = LossFactory(
        SEQ_LENGTH=args.SEQ_LENGTH,
        rgb_weight=args.rgb_weight,
        depth_smoothing_weight=args.depth_smoothing_weight,
        ssim_weight=args.ssim_weight,
        motion_smoothing_weight=args.motion_smoothing_weight,
        rotation_consistency_weight=args.rotation_consistency_weight,
        translation_consistency_weight=args.translation_consistency_weight,
        depth_consistency_loss_weight=args.depth_consistency_loss_weight)
    """========= step 8 : 训练循环 =========="""
    if args.epoch_size == 0:
        args.epoch_size = len(
            train_loader)  # 如果不指定epoch_size,那么每一个epoch就把全部的训练数据过一遍
    for epoch in range(args.epochs):
        tqdm.write("\n===========TRAIN EPOCH [{}/{}]===========".format(
            epoch + 1, args.epochs))
        """====== step 8.1 : 训练一个epoch ======"""
        train_loss = train(args, train_loader, depth_net, motion_net,
                           optimizer, args.epoch_size, total_loss_calculator,
                           tb_writer)
        tqdm.write('* Avg Loss : {:.3f}'.format(train_loss))
        """======= step 8.2 : 验证 ========"""
        # 验证时要输出 : 深度指标abs_diff, abs_rel, sq_rel, a1, a2, a3
        errors, error_names = validate_with_gt(args, val_loader, depth_net,
                                               motion_net, epoch, tb_writer)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        tqdm.write(error_string)
        # TODO:输出验证集上的轨迹指标

        # abs_rel, sq_rel, rms, log_rms, a1, a2, a3
        tb_writer.add_scalar("Relative Errors/abs_rel", errors[0], epoch)
        tb_writer.add_scalar("Relative Errors/sq_rel", errors[1], epoch)
        tb_writer.add_scalar("Root Mean Squared Error/rms", errors[2], epoch)
        tb_writer.add_scalar("Root Mean Squared Error/log_rms", errors[3],
                             epoch)
        tb_writer.add_scalar("Thresholding accuracy/a1", errors[4], epoch)
        tb_writer.add_scalar("Thresholding accuracy/a2", errors[5], epoch)
        tb_writer.add_scalar("Thresholding accuracy/a3", errors[6], epoch)
        """======= step 8.3 : 保存验证效果最佳的模型状态 =========="""
        decisive_a3 = errors[6]  # 选取a3为关键指标
        is_best = decisive_a3 > best_a3  # 如果当前的a3比之前记录的a3更大,那么模型的最佳状态就是现在的状态
        best_a3 = max(best_a3, decisive_a3)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': depth_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': motion_net.module.state_dict()
        }, is_best)
    pass  # end of main