class SegmentionTrain(BaseTrain):
    def __init__(self, cfg_path, gpu_id, config_path=None):
        super().__init__(config_path)
        self.set_task_name(TaskName.Segment_Task)
        self.train_task_config = self.config_factory.get_config(
            self.task_name, self.config_path)

        self.train_logger = TrainLogger(self.train_task_config.log_name,
                                        self.train_task_config.root_save_dir)

        self.torchModelProcess = TorchModelProcess()
        self.freeze_normalization = TorchFreezeNormalization()
        self.torchOptimizer = TorchOptimizer(
            self.train_task_config.optimizer_config)
        self.model = self.torchModelProcess.initModel(cfg_path, gpu_id)
        self.device = self.torchModelProcess.getDevice()

        self.output_process = SegmentResultProcess()

        self.segment_test = SegmentionTest(cfg_path, gpu_id, config_path)

        self.total_images = 0
        self.optimizer = None
        self.start_epoch = 0
        self.bestmIoU = 0

    def load_pretrain_model(self, weights_path):
        self.torchModelProcess.loadPretainModel(weights_path, self.model)

    def load_latest_param(self, latest_weights_path):
        checkpoint = None
        if latest_weights_path and os.path.exists(latest_weights_path):
            checkpoint = self.torchModelProcess.loadLatestModelWeight(
                latest_weights_path, self.model)
            self.torchModelProcess.modelTrainInit(self.model)
        else:
            self.torchModelProcess.modelTrainInit(self.model)
        self.start_epoch, self.bestmIoU = self.torchModelProcess.getLatestModelValue(
            checkpoint)

        self.torchOptimizer.freeze_optimizer_layer(
            self.start_epoch, self.train_task_config.base_lr, self.model,
            self.train_task_config.freeze_layer_name,
            self.train_task_config.freeze_layer_type)
        self.torchOptimizer.print_freeze_layer(self.model)
        self.optimizer = self.torchOptimizer.getLatestModelOptimizer(
            checkpoint)

    def train(self, train_path, val_path):
        dataloader = get_segment_train_dataloader(
            train_path,
            self.train_task_config.image_size,
            self.train_task_config.train_batch_size,
            is_augment=self.train_task_config.train_data_augment)
        self.total_images = len(dataloader)
        self.load_latest_param(self.train_task_config.latest_weights_file)

        lr_factory = LrSchedulerFactory(self.train_task_config.base_lr,
                                        self.train_task_config.max_epochs,
                                        self.total_images)
        lr_scheduler = lr_factory.get_lr_scheduler(
            self.train_task_config.lr_scheduler_config)

        self.train_task_config.save_config()
        self.timer.tic()
        self.set_model_train()
        for epoch in range(self.start_epoch,
                           self.train_task_config.max_epochs):
            # self.optimizer = torchOptimizer.adjust_optimizer(epoch, lr)
            self.optimizer.zero_grad()
            for idx, (images, segments) in enumerate(dataloader):
                current_idx = epoch * self.total_images + idx
                lr = lr_scheduler.get_lr(epoch, current_idx)
                lr_scheduler.adjust_learning_rate(self.optimizer, lr)
                loss = self.compute_backward(images, segments, idx)
                self.update_logger(idx, self.total_images, epoch, loss)

            save_model_path = self.save_train_model(epoch)
            self.test(val_path, epoch, save_model_path)

    def compute_backward(self, input_datas, targets, setp_index):
        # Compute loss, compute gradient, update parameters
        output_list = self.model(input_datas.to(self.device))
        loss = self.compute_loss(output_list, targets)
        loss.backward()

        # accumulate gradient for x batches before optimizing
        if ((setp_index + 1) % self.train_task_config.accumulated_batches == 0) \
                or (setp_index == self.total_images - 1):
            self.optimizer.step()
            self.optimizer.zero_grad()
        return loss

    def compute_loss(self, output_list, targets):
        loss = 0
        loss_count = len(self.model.lossList)
        output_count = len(output_list)
        targets = targets.to(self.device)
        if loss_count == 1 and output_count == 1:
            output, target = self.output_process.output_feature_map_resize(
                output_list[0], targets)
            loss = self.model.lossList[0](output, target)
        elif loss_count == 1 and output_count > 1:
            loss = self.model.lossList[0](output_list, targets)
        elif loss_count > 1 and loss_count == output_count:
            for k in range(0, loss_count):
                output, target = self.output_process.output_feature_map_resize(
                    output_list[k], targets)
                loss += self.model.lossList[k](output, target)
        else:
            print("compute loss error")
        return loss

    def update_logger(self, index, total, epoch, loss):
        loss_value = loss.data.cpu().squeeze()
        step = epoch * total + index
        lr = self.optimizer.param_groups[0]['lr']
        self.train_logger.train_log(step, loss_value,
                                    self.train_task_config.display)
        self.train_logger.lr_log(step, lr, self.train_task_config.display)

        print('Epoch: {}[{}/{}]\t Loss: {}\t Rate: {} \t Time: {}\t'.format(
            epoch, index, total, '%.7f' % loss_value, '%.7f' % lr,
            '%.5f' % self.timer.toc(True)))

    def save_train_model(self, epoch):
        self.train_logger.epoch_train_log(epoch)
        if self.train_task_config.is_save_epoch_model:
            save_model_path = os.path.join(
                self.train_task_config.snapshot_path,
                "seg_model_epoch_%d.pt" % epoch)
        else:
            save_model_path = self.train_task_config.latest_weights_file
        self.torchModelProcess.saveLatestModel(save_model_path, self.model,
                                               self.optimizer, epoch,
                                               self.bestmIoU)
        return save_model_path

    def set_model_train(self):
        self.model.train()
        self.freeze_normalization.freeze_normalization_layer(
            self.model, self.train_task_config.freeze_bn_layer_name,
            self.train_task_config.freeze_bn_type)

    def test(self, val_path, epoch, save_model_path):
        if val_path is not None and os.path.exists(val_path):
            self.segment_test.load_weights(save_model_path)
            score, class_score, average_loss = self.segment_test.test(val_path)
            self.segment_test.save_test_value(epoch, score, class_score)

            self.train_logger.eval_log("val epoch loss", epoch, average_loss)
            print("Val epoch loss: {}".format(average_loss))
            # save best model
            self.bestmIoU = self.torchModelProcess.saveBestModel(
                score['Mean IoU : \t'], save_model_path,
                self.train_task_config.best_weights_file)
        else:
            print("no test!")
class PointCloudClassifyTrain(BaseTrain):
    def __init__(self, cfg_path, gpu_id, config_path=None):
        super().__init__(config_path)
        self.set_task_name(TaskName.PC_Classify_Task)
        self.train_task_config = self.config_factory.get_config(
            self.task_name, self.config_path)

        self.train_logger = TrainLogger(self.train_task_config.log_name,
                                        self.train_task_config.root_save_dir)
        self.torchModelProcess = TorchModelProcess()
        self.freeze_normalization = TorchFreezeNormalization()
        self.torchOptimizer = TorchOptimizer(
            self.train_task_config.optimizer_config)

        self.model = self.torchModelProcess.initModel(cfg_path, gpu_id)
        self.device = self.torchModelProcess.getDevice()

        self.classify_test = PointCloudClassifyTest(cfg_path, gpu_id,
                                                    config_path)

        self.total_clouds = 0
        self.start_epoch = 0
        self.best_precision = 0
        self.optimizer = None

    def load_pretrain_model(self, weights_path):
        self.torchModelProcess.loadPretainModel(weights_path, self.model)

    def load_latest_param(self, latest_weights_path):
        checkpoint = None
        if latest_weights_path is not None and os.path.exists(
                latest_weights_path):
            checkpoint = self.torchModelProcess.loadLatestModelWeight(
                latest_weights_path, self.model)
            self.model = self.torchModelProcess.modelTrainInit(self.model)
        else:
            self.model = self.torchModelProcess.modelTrainInit(self.model)

        self.start_epoch, self.best_precision = self.torchModelProcess.getLatestModelValue(
            checkpoint)

        self.torchOptimizer.freeze_optimizer_layer(
            self.start_epoch, self.train_task_config.base_lr, self.model,
            self.train_task_config.freeze_layer_name,
            self.train_task_config.freeze_layer_type)
        self.optimizer = self.torchOptimizer.getLatestModelOptimizer(
            checkpoint)

    def train(self, train_path, val_path):

        dataloader = get_classify_train_dataloader(
            train_path, self.train_task_config.number_point_features,
            self.train_task_config.train_batch_size,
            self.train_task_config.train_data_augment)

        self.total_clouds = len(dataloader)

        lr_factory = LrSchedulerFactory(self.train_task_config.base_lr,
                                        self.train_task_config.max_epochs,
                                        self.total_clouds)
        lr_scheduler = lr_factory.get_lr_scheduler(
            self.train_task_config.lr_scheduler_config)

        self.load_latest_param(self.train_task_config.latest_weights_file)

        self.train_task_config.save_config()
        self.timer.tic()
        self.model.train()
        self.freeze_normalization.freeze_normalization_layer(
            self.model, self.train_task_config.freeze_bn_layer_name,
            self.train_task_config.freeze_bn_type)
        try:
            for epoch in range(self.start_epoch,
                               self.train_task_config.max_epochs):
                # self.optimizer = torchOptimizer.adjust_optimizer(epoch, lr)
                self.optimizer.zero_grad()
                for idx, (clouds, targets) in enumerate(dataloader):
                    current_iter = epoch * self.total_clouds + idx
                    lr = lr_scheduler.get_lr(epoch, current_iter)
                    lr_scheduler.adjust_learning_rate(self.optimizer, lr)
                    loss = self.compute_backward(clouds, targets, idx)
                    self.update_logger(idx, self.total_clouds, epoch, loss)

                save_model_path = self.save_train_model(epoch)
                self.test(val_path, epoch, save_model_path)
        except Exception as e:
            raise e
        finally:
            self.train_logger.close()

    def compute_backward(self, input_datas, targets, setp_index):
        # Compute loss, compute gradient, update parameters
        output_list = self.model(input_datas.to(self.device))
        loss = self.compute_loss(output_list, targets)
        loss.backward()
        # accumulate gradient for x batches before optimizing
        if ((setp_index + 1) % self.train_task_config.accumulated_batches == 0) or \
                (setp_index == self.total_clouds - 1):
            self.optimizer.step()
            self.optimizer.zero_grad()
        return loss

    def compute_loss(self, output_list, targets):
        loss = 0
        loss_count = len(self.model.lossList)
        output_count = len(output_list)
        targets = targets.to(self.device)
        if loss_count == 1 and output_count == 1:
            loss = self.model.lossList[0](output_list[0], targets)
        elif loss_count == 1 and output_count > 1:
            loss = self.model.lossList[0](output_list, targets)
        elif loss_count > 1 and loss_count == output_count:
            for k in range(0, loss_count):
                loss += self.model.lossList[k](output_list[k], targets)
        else:
            print("compute loss error")
        return loss

    def update_logger(self, index, total, epoch, loss):
        step = epoch * total + index
        lr = self.optimizer.param_groups[0]['lr']
        loss_value = loss.data.cpu().squeeze()
        self.train_logger.train_log(step, loss_value,
                                    self.train_task_config.display)
        self.train_logger.lr_log(step, lr, self.train_task_config.display)

        print('Epoch: {}[{}/{}]\t Loss: {}\t Rate: {} \t Time: {}\t'.format(
            epoch, index, total, '%.7f' % loss_value, '%.7f' % lr,
            '%.5f' % self.timer.toc(True)))

    def save_train_model(self, epoch):
        with DelayedKeyboardInterrupt():
            self.train_logger.epoch_train_log(epoch)
            if self.train_task_config.is_save_epoch_model:
                save_model_path = os.path.join(
                    self.train_task_config.snapshot_path,
                    "pc_cls_model_epoch_%d.pt" % epoch)
            else:
                save_model_path = self.train_task_config.latest_weights_file
            self.torchModelProcess.saveLatestModel(save_model_path, self.model,
                                                   self.optimizer, epoch,
                                                   self.best_precision)
        return save_model_path

    def test(self, val_path, epoch, save_model_path):
        if val_path is not None and os.path.exists(val_path):
            self.classify_test.load_weights(save_model_path)
            precision = self.classify_test.test(val_path)
            self.classify_test.save_test_value(epoch)

            # save best model
            self.best_precision = self.torchModelProcess.saveBestModel(
                precision, save_model_path,
                self.train_task_config.best_weights_file)
        else:
            print("no test!")
示例#3
0
class Detection2dTrain(BaseTrain):
    def __init__(self, cfg_path, gpu_id, config_path=None):
        super().__init__(config_path)
        self.set_task_name(TaskName.Detect2d_Task)
        self.train_task_config = self.config_factory.get_config(
            self.task_name, self.config_path)

        self.train_logger = TrainLogger(self.train_task_config.log_name,
                                        self.train_task_config.root_save_dir)

        self.torchModelProcess = TorchModelProcess()
        self.freeze_normalization = TorchFreezeNormalization()
        self.torchOptimizer = TorchOptimizer(
            self.train_task_config.optimizer_config)

        self.model = self.torchModelProcess.initModel(cfg_path, gpu_id)
        self.device = self.torchModelProcess.getDevice()

        self.detect_test = Detection2dTest(cfg_path, gpu_id, config_path)

        self.total_images = 0
        self.optimizer = None
        self.avg_loss = -1
        self.start_epoch = 0
        self.best_mAP = 0

    def load_pretrain_model(self, weights_path):
        self.torchModelProcess.loadPretainModel(weights_path, self.model)

    def load_latest_param(self, latest_weights_path):
        checkpoint = None
        if latest_weights_path and os.path.exists(latest_weights_path):
            checkpoint = self.torchModelProcess.loadLatestModelWeight(
                latest_weights_path, self.model)
            self.model = self.torchModelProcess.modelTrainInit(self.model)
        else:
            self.model = self.torchModelProcess.modelTrainInit(self.model)
        self.start_epoch, self.best_mAP = self.torchModelProcess.getLatestModelValue(
            checkpoint)

        self.torchOptimizer.freeze_optimizer_layer(
            self.start_epoch, self.train_task_config.base_lr, self.model,
            self.train_task_config.freeze_layer_name,
            self.train_task_config.freeze_layer_type)
        self.optimizer = self.torchOptimizer.getLatestModelOptimizer(
            checkpoint)

    def train(self, train_path, val_path):
        dataloader = DetectionTrainDataloader(
            train_path,
            self.train_task_config.class_name,
            self.train_task_config.train_batch_size,
            self.train_task_config.image_size,
            multi_scale=self.train_task_config.train_multi_scale,
            is_augment=self.train_task_config.train_data_augment,
            balanced_sample=self.train_task_config.balanced_sample)
        self.total_images = len(dataloader)

        lr_factory = LrSchedulerFactory(self.train_task_config.base_lr,
                                        self.train_task_config.max_epochs,
                                        self.total_images)
        lr_scheduler = lr_factory.get_lr_scheduler(
            self.train_task_config.lr_scheduler_config)

        self.load_latest_param(self.train_task_config.latest_weights_file)

        self.train_task_config.save_config()
        self.timer.tic()
        self.model.train()
        self.freeze_normalization.freeze_normalization_layer(
            self.model, self.train_task_config.freeze_bn_layer_name,
            self.train_task_config.freeze_bn_type)
        for epoch in range(self.start_epoch,
                           self.train_task_config.max_epochs):
            # self.optimizer = self.torchOptimizer.adjust_optimizer(epoch, lr)
            self.optimizer.zero_grad()
            for i, (images, targets) in enumerate(dataloader):
                current_iter = epoch * self.total_images + i
                lr = lr_scheduler.get_lr(epoch, current_iter)
                lr_scheduler.adjust_learning_rate(self.optimizer, lr)
                if sum([len(x)
                        for x in targets]) < 1:  # if no targets continue
                    continue
                loss = self.compute_backward(images, targets, i)
                self.update_logger(i, self.total_images, epoch, loss)

            save_model_path = self.save_train_model(epoch)
            self.test(val_path, epoch, save_model_path)

    def compute_backward(self, input_datas, targets, setp_index):
        # Compute loss, compute gradient, update parameters
        output_list = self.model(input_datas.to(self.device))
        loss = self.compute_loss(output_list, targets)
        loss.backward()

        # accumulate gradient for x batches before optimizing
        if ((setp_index + 1) % self.train_task_config.accumulated_batches == 0) \
                or (setp_index == self.total_images - 1):
            self.optimizer.step()
            self.optimizer.zero_grad()
        return loss

    def compute_loss(self, output_list, targets):
        loss = 0
        loss_count = len(self.model.lossList)
        output_count = len(output_list)
        if loss_count == 1 and output_count == 1:
            loss = self.model.lossList[0](output_list[0], targets)
        elif loss_count == 1 and output_count > 1:
            loss = self.model.lossList[0](output_list, targets)
        elif loss_count > 1 and loss_count == output_count:
            for k in range(0, loss_count):
                loss += self.model.lossList[k](output_list[k], targets)
        else:
            print("compute loss error")
        return loss

    def update_logger(self, index, total, epoch, loss):
        step = epoch * total + index
        lr = self.optimizer.param_groups[0]['lr']
        loss_value = loss.data.cpu().squeeze()

        if self.avg_loss < 0:
            self.avg_loss = (loss.cpu().detach().numpy() /
                             self.train_task_config.train_batch_size)
        self.avg_loss = 0.9 * (loss.cpu().detach().numpy() / self.train_task_config.train_batch_size) \
                        + 0.1 * self.avg_loss

        self.train_logger.train_log(step, loss_value,
                                    self.train_task_config.display)
        self.train_logger.lr_log(step, lr, self.train_task_config.display)
        print('Epoch: {}[{}/{}]\t Loss: {}\t Rate: {} \t Time: {}\t'.format(
            epoch, index, total, '%.7f' % self.avg_loss, '%.7f' % lr,
            '%.5f' % self.timer.toc(True)))

    def save_train_model(self, epoch):
        self.train_logger.epoch_train_log(epoch)
        if self.train_task_config.is_save_epoch_model:
            save_model_path = os.path.join(
                self.train_task_config.snapshot_path,
                "det2d_model_epoch_%d.pt" % epoch)
        else:
            save_model_path = self.train_task_config.latest_weights_file
        self.torchModelProcess.saveLatestModel(save_model_path, self.model,
                                               self.optimizer, epoch,
                                               self.best_mAP)
        return save_model_path

    def test(self, val_path, epoch, save_model_path):
        if val_path is not None and os.path.exists(val_path):
            self.detect_test.load_weights(save_model_path)
            mAP, aps = self.detect_test.test(val_path)
            self.detect_test.save_test_value(epoch, mAP, aps)
            # save best model
            self.best_mAP = self.torchModelProcess.saveBestModel(
                mAP, save_model_path, self.train_task_config.best_weights_file)
        else:
            print("no test!")
示例#4
0
class SuperResolutionTrain():

    def __init__(self):
        if not os.path.exists(super_resolution_config.snapshotPath):
            os.makedirs(super_resolution_config.snapshotPath, exist_ok=True)

        self.torchModelProcess = TorchModelProcess()
        self.torchOptimizer = TorchOptimizer(super_resolution_config.optimizerConfig)
        self.multiLR = MultiStageLR(super_resolution_config.base_lr, [[50, 1], [70, 0.1], [100, 0.01]])
        self.device = self.torchModelProcess.getDevice()
        self.model = MSRResNet(super_resolution_config.in_nc, upscale_factor=super_resolution_config.upscale_factor).to(self.device)

        vision_process = TorchVisionProcess()
        self.input_transform = vision_process.input_transform(super_resolution_config.crop_size,
                                                              super_resolution_config.upscale_factor)
        self.target_transform = vision_process.target_transform(super_resolution_config.crop_size)

        self.start_epoch = 0
        self.psnr = 0

    def load_param(self, latest_weights_path):
        checkpoint = None
        if latest_weights_path and os.path.exists(latest_weights_path):
            checkpoint = self.torchModelProcess.loadLatestModelWeight(latest_weights_path, self.model)
            self.torchModelProcess.modelTrainInit(self.model)
        else:
            self.torchModelProcess.modelTrainInit(self.model)
        self.start_epoch, self.best_mAP = self.torchModelProcess.getLatestModelValue(checkpoint)

        self.torchOptimizer.createOptimizer(self.start_epoch, self.model, super_resolution_config.base_lr)
        self.optimizer = self.torchOptimizer.getLatestModelOptimizer(checkpoint)

    def update_logger(self, step, value):
        pass

    def compute_loss(self, output, targets):
        loss = 0
        criterion = nn.MSELoss()
        loss += criterion(output, targets)
        return loss

    def test(self, epoch):
        save_model_path = os.path.join(super_resolution_config.snapshotPath, "model_epoch_%d.pt" % epoch)
        self.torchModelProcess.saveLatestModel(save_model_path, self.model,
                                               self.optimizer, epoch, self.psnr)
        testing_data_loader = get_sr_dataloader(super_resolution_config.val_set,
                                                super_resolution_config.val_set,
                                                super_resolution_config.test_batch_size,
                                                num_workers=8,
                                                input_transform=self.input_transform,
                                                target_transform=self.target_transform)
        avg_psnr = 0
        with torch.no_grad():
            for batch in testing_data_loader:
                input, target = batch[0].to(self.device), batch[1].to(self.device)

                prediction = self.model(input)
                mse = self.compute_loss(prediction, target)
                psnr = 10 * log10(1 / mse.item())
                avg_psnr += psnr

        self.psnr = avg_psnr / len(testing_data_loader)

        self.torchModelProcess.saveBestModel(self.psnr,
                                             save_model_path,
                                             super_resolution_config.best_weights_file)

        print("===> Avg. PSNR: {:.4f} dB".format(self.psnr))

    def train(self):
        training_data_loader = get_sr_dataloader(super_resolution_config.train_set,
                                                 super_resolution_config.train_set,
                                                 super_resolution_config.train_batch_size,
                                                 num_workers=8,
                                                 shuffle=True,
                                                 input_transform=self.input_transform,
                                                 target_transform=self.target_transform)
        total_images = len(training_data_loader)
        self.load_param(super_resolution_config.latest_weights_file)

        t0 = time.time()
        for epoch in range(self.start_epoch, super_resolution_config.maxEpochs):
            self.optimizer.zero_grad()
            for i, batch in enumerate(training_data_loader, 1):
                imgs, targets = batch[0].to(self.device), batch[1].to(self.device)

                current_iter = epoch * total_images + i
                lr = self.multiLR.get_lr(epoch, current_iter)
                self.multiLR.adjust_learning_rate(self.optimizer, lr)

                # Compute loss, compute gradient, update parameters
                output = self.model(imgs.to(self.device))
                loss = self.compute_loss(output, targets)
                loss.backward()

                # accumulate gradient for x batches before optimizing
                if ((i + 1) % super_resolution_config.accumulated_batches == 0) or (i == len(training_data_loader) - 1):
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                print('Epoch: {}[{}/{}]\t Loss: {}\t Rate: {} \t Time: {}\t'.format(epoch, i, len(training_data_loader),
                                                                                    '%.3f' % loss,
                                                                                    '%.7f' %
                                                                                    self.optimizer.param_groups[0][
                                                                                        'lr'], time.time() - t0))
                self.update_logger(current_iter, loss.data)

                t0 = time.time()

            self.test(epoch)