Пример #1
0
    def __init__(self, args, cfg):
        super(Tester, self).__init__(args, cfg)

        args = self.args

        if self.batch_size != 1:
            self.logger.info(
                "batch size in the testing mode should be set to one.")
            self.logger.info("setting batch size (batch-size = 1).")
            self.batch_size = 1

        if self.seq_size != 1:
            self.logger.info("setting sequence size (s=1)")
            raise ValueError("Sequence size mus tbe equal 1 in test mode.")

        # create the folder for saving training checkpoints
        self.checkpoint_dir = self.out_dir
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        # preapre dataset and dataloaders
        transform = None

        self.model = nets.get_model(input_shape=(self.n_channels,
                                                 self.im_height_model,
                                                 self.im_width_model),
                                    cfg=self.cfg,
                                    device=self.device)
        self.criterion = get_loss_function(self.cfg, args.device)

        self.has_lidar = True if self.model.lidar_feat_net is not None else False
        self.has_imu = True if self.model.imu_feat_net is not None else False

        self.test_dataset = ds.Kitti(config=self.cfg,
                                     transform=transform,
                                     ds_type='test',
                                     has_imu=self.has_imu,
                                     has_lidar=self.has_lidar)

        self.test_dataloader = torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            worker_init_fn=worker_init_fn,
            collate_fn=ds.deeplio_collate)

        self.data_permuter = DataCombiCreater(combinations=self.combinations,
                                              device=self.device)

        self.tensor_writer = tensorboard.SummaryWriter(log_dir=self.runs_dir)

        # debugging and visualizing
        self.logger.print("System Training Configurations:")
        self.logger.print("args: {}".format(self.args))

        self.logger.print(yaml.dump(self.cfg))
        self.logger.print(self.test_dataset)
Пример #2
0
    def __init__(self, args, cfg):
        super(TestKittiGt, self).__init__(args, cfg)
        dataset = ds.Kitti(config=self.cfg, transform=None)
        self.train_dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            worker_init_fn=worker_init_fn,
            collate_fn=ds.deeplio_collate)

        self.data_permuter = DataCombiCreater(combinations=self.combinations,
                                              device=self.device)
Пример #3
0
    def __init__(self, args, cfg):
        super(TestTraj, self).__init__(args, cfg)
        args = self.args

        if self.seq_size > 1:
            self.logger.info("sequence size in the testing mode should be set to two.")
            #raise ValueError("sequence size in the testing mode should be set to two.")

        if self.batch_size != 1:
            self.logger.info("batch size in the testing mode should be set to one.")
            self.logger.info("setting batch size (batch-size = 1).")
            self.batch_size = 1

        # create the folder for saving training checkpoints
        self.checkpoint_dir = self.out_dir
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        # preapre dataset and dataloaders
        transform = None

        self.test_dataset = ds.Kitti(config=self.cfg, transform=transform, ds_type='test')
        self.test_dataloader = torch.utils.data.DataLoader(self.test_dataset, batch_size=self.batch_size,
                                                           num_workers=self.num_workers,
                                                           shuffle=False,
                                                           worker_init_fn = worker_init_fn,
                                                           collate_fn = ds.deeplio_collate)

        self.data_permuter = DataCombiCreater(combinations=self.combinations,
                                              device=self.device)

        # debugging and visualizing
        self.logger.print("System Training Configurations:")
        self.logger.print("args: {}".
                          format(self.args))

        self.logger.print(yaml.dump(self.cfg))
        self.logger.print(self.test_dataset)
Пример #4
0
    def __init__(self, args, cfg):
        super(Trainer, self).__init__(args, cfg)

        if self.args.resume and self.args.evaluate:
            raise ValueError("Error: can either resume training or evaluate, not both at the same time!")

        args = self.args

        self.start_epoch = args.start_epoch
        self.last_epoch = self.start_epoch
        self.epochs = args.epochs
        self.best_acc = float('inf')
        self.step_val = 0.
        self.max_glob_seq = 2

        # create the folder for saving training checkpoints
        self.checkpoint_dir = self.out_dir
        Path(self.checkpoint_dir).mkdir(parents=True, exist_ok=True)

        # preapre dataset and dataloaders
        transform = None
        self.data_last = None

        self.model = nets.get_model(input_shape=(self.n_channels, self.im_height_model, self.im_width_model),
                                    cfg=self.cfg, device=self.device)

        self.criterion = get_loss_function(self.cfg, args.device)
        self.optimizer = create_optimizer([{'params': self.model.parameters()},
                                           {'params': self.criterion.parameters()}]
                                          , self.cfg, args)

        self.has_lidar = True if self.model.lidar_feat_net is not None else False
        self.has_imu = True if self.model.imu_feat_net is not None else False

        self.train_dataset = ds.Kitti(config=self.cfg, transform=transform,
                                      has_imu=self.has_imu, has_lidar=self.has_lidar)
        self.train_dataloader = torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size,
                                                            num_workers=self.num_workers,
                                                            shuffle=True,
                                                            worker_init_fn=worker_init_fn,
                                                            collate_fn=ds.deeplio_collate,
                                                            drop_last=True)

        self.val_dataset = ds.Kitti(config=self.cfg, transform=transform, ds_type='validation',
                                    has_imu=self.has_imu, has_lidar=self.has_lidar)
        self.val_dataloader = torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size,
                                                          num_workers=self.num_workers,
                                                          shuffle=True,
                                                          worker_init_fn=worker_init_fn,
                                                          collate_fn = ds.deeplio_collate,
                                                          drop_last=True)

        self.data_permuter = DataCombiCreater(combinations=self.combinations,
                                              device=self.device)

        # debugging and visualizing
        self.logger.print("System Training Configurations:")
        self.logger.print("args: {}".
                          format(self.args))

        # optionally resume from a checkpoint
        last_epoch = -1
        if args.resume or args.evaluate:
            model_cfg = self.cfg['deeplio']
            pretrained = self.model.pretrained
            if not pretrained:
                self.logger.error("no model checkpoint loaded!")
                raise ValueError("no model checkpoint loaded!")

            ckp_path = model_cfg['model-path']
            if not os.path.isfile(ckp_path):
                self.logger.error("no checkpoint found at '{}'".format(ckp_path))
                raise ValueError("no checkpoint found at '{}'".format(ckp_path))

            self.logger.info("loading from checkpoint '{}'".format(ckp_path))
            checkpoint = torch.load(ckp_path, map_location=self.device)
            self.start_epoch = checkpoint['epoch'] + 1
            last_epoch = checkpoint['epoch']
            self.best_acc = checkpoint['best_acc']
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            if isinstance(self.criterion, HWSLoss):
                self.criterion.load_state_dict(checkpoint['criterion'])
            self.logger.info("loaded checkpoint '{}' (epoch {})".format(ckp_path, checkpoint['epoch']))

        self.lr_scheduler = PolynomialLRDecay(self.optimizer, max_decay_steps=self.epochs, end_learning_rate=0.000001,
                                              power=2.0, last_epoch=last_epoch)
        #self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3, gamma=0.7, last_epoch=-1)

        self.logger.print(yaml.dump(self.cfg))
        self.logger.print(self.train_dataset)
        self.logger.print(self.val_dataset)
        self.logger.info(self.criterion)
        self.post_init()