예제 #1
0
파일: train.py 프로젝트: yhung119/PointNet3
def main():
    saver = utils.Saver(opt)

    # randomize seed
    opt.manualSeed = random.randint(1, 10000)  # fix seed
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)
    torch.cuda.manual_seed_all(opt.manualSeed)

    # load data
    root = "data/modelnet40_ply_hdf5_2048/"  #"data/modelnet40_normal_resampled"#
    use_cuda = torch.cuda.is_available()

    transforms_list = []
    random_permute = utils.Random_permute(opt.num_points, delta=opt.distance)
    # load transformations
    if opt.random_input:
        print("random_input")
        transforms_list.append(random_permute)

    # Load dataset / data loader
    train_dataset = data.ModelNetDataset(
        root,
        train=True,
        sort=opt.sort,
        transform=transforms.Compose(transforms_list),
        distance=opt.distance,
        normal=opt.normal)
    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batchSize,
                              shuffle=True,
                              num_workers=opt.workers)

    test_dataset = data.ModelNetDataset(root,
                                        train=False,
                                        sort=opt.sort,
                                        distance=opt.distance,
                                        normal=opt.normal)
    test_loader = DataLoader(test_dataset,
                             batch_size=opt.batchSize,
                             shuffle=False,
                             num_workers=opt.workers)

    # define model
    ndim = 6 if opt.distance or opt.normal else 3
    if opt.model == 'lstm':
        model = Baseline(input_dim=ndim, maxout=opt.elem_max)
    elif opt.model == 'lstm_mlp':
        model = LSTM_mlp(input_dim=ndim,
                         maxout=opt.elem_max,
                         mlp=[64, 128, 256, 512],
                         fc=[512, 256, 40])
    elif opt.model == 'test':
        model = Test(input_dim=ndim, maxout=opt.elem_max)

    # load speicified pre-trained model
    if opt.path != '':
        model.load_state_dict(torch.load(opt.path))

    # define optimizer and loss function
    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()

    # transfer model and criterion to cuda if exist
    if use_cuda:
        model = model.cuda(
        )  #nn.DataParallel(model).cuda()#model.cuda() #nn.DataParallel(model).cuda()
        criterion = criterion.cuda()

    best_model_wts = model.state_dict()

    early_stopping = utils.Early_stopping(opt.early_stopping, patience=15)

    saver.log_parameters(model.parameters())

    for epoch in range(opt.nepoch):
        adjust_learning_rate(optimizer, epoch, saver)

        train(model, optimizer, criterion, saver, train_loader, epoch)

        test_loss = test(model, criterion, saver, test_loader, epoch)

        early_stopping.update(test_loss)
        if early_stopping.stop():
            break

    saver.save_result()
예제 #2
0
파일: main.py 프로젝트: prismformore/expAT
class Session:
    def __init__(self):
        self.log_dir = settings.log_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.log_dir)
        ensure_dir(settings.model_dir)
        logger.info('set log dir as %s' % settings.log_dir)
        logger.info('set model dir as %s' % settings.model_dir)

        ##################################### Import models ###########################
        self.feature_generator = Baseline(
            last_stride=1, model_path=settings.pretrained_model_path)

        self.feature_embedder_rgb = FeatureEmbedder(2048)
        self.feature_embedder_ir = FeatureEmbedder(2048)
        self.id_classifier = IdClassifier()

        if torch.cuda.is_available():
            self.feature_generator.cuda()
            self.feature_embedder_rgb.cuda()
            self.feature_embedder_ir.cuda()
            self.id_classifier.cuda()

        self.feature_generator = nn.DataParallel(self.feature_generator,
                                                 device_ids=range(
                                                     settings.num_gpu))

        self.feature_embedder_rgb = nn.DataParallel(self.feature_embedder_rgb,
                                                    device_ids=range(
                                                        settings.num_gpu))
        self.feature_embedder_ir = nn.DataParallel(self.feature_embedder_ir,
                                                   device_ids=range(
                                                       settings.num_gpu))
        self.id_classifier = nn.DataParallel(self.id_classifier,
                                             device_ids=range(
                                                 settings.num_gpu))

        ############################# Get Losses & Optimizers #########################
        self.criterion_at = expATLoss()
        self.criterion_identity = CrossEntropyLabelSmoothLoss(
            settings.num_classes, epsilon=0.1)  #torch.nn.CrossEntropyLoss()

        opt_models = [
            self.feature_generator, self.feature_embedder_rgb,
            self.feature_embedder_ir, self.id_classifier
        ]

        def make_optimizer(opt_models):
            train_params = []

            for opt_model in opt_models:
                for key, value in opt_model.named_parameters():
                    if not value.requires_grad:
                        continue
                    lr = settings.BASE_LR
                    weight_decay = settings.WEIGHT_DECAY
                    if "bias" in key:
                        lr = settings.BASE_LR * settings.BIAS_LR_FACTOR
                        weight_decay = settings.WEIGHT_DECAY_BIAS
                    train_params += [{
                        "params": [value],
                        "lr": lr,
                        "weight_decay": weight_decay
                    }]

            optimizer = torch.optim.Adam(train_params)
            return optimizer

        self.optimizer_G = make_optimizer(opt_models)

        self.epoch_count = 0
        self.step = 0
        self.save_steps = settings.save_steps
        self.num_workers = settings.num_workers
        self.writers = {}
        self.dataloaders = {}

        self.sche_G = solver.WarmupMultiStepLR(self.optimizer_G,
                                               milestones=settings.iter_sche,
                                               gamma=0.1)  # default setting

    def tensorboard(self, name):
        self.writers[name] = SummaryWriter(
            os.path.join(self.log_dir, name + '.events'))
        return self.writers[name]

    def write(self, name, out):
        for k, v in out.items():
            self.writers[name].add_scalar(name + '/' + k, v, self.step)

        out['G_lr'] = self.optimizer_G.param_groups[0]['lr']
        out['step'] = self.step
        out['eooch_count'] = self.epoch_count
        outputs = ["{}:{:.4g}".format(k, v) for k, v in out.items()]
        logger.info(name + '--' + ' '.join(outputs))

    def save_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = {
            'feature_generator': self.feature_generator.state_dict(),
            'feature_embedder_rgb': self.feature_embedder_rgb.state_dict(),
            'feature_embedder_ir': self.feature_embedder_ir.state_dict(),
            'id_classifier': self.id_classifier.state_dict(),
            'clock': self.step,
            'epoch_count': self.epoch_count,
            'opt_G': self.optimizer_G.state_dict(),
        }
        torch.save(obj, ckp_path)

    def load_checkpoints(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            obj = torch.load(ckp_path)
            print('load checkpoint: %s' % ckp_path)
        except FileNotFoundError:
            return
        self.feature_generator.load_state_dict(obj['feature_generator'])
        self.feature_embedder_rgb.load_state_dict(obj['feature_embedder_rgb'])
        self.feature_embedder_ir.load_state_dict(obj['feature_embedder_ir'])
        self.id_classifier.load_state_dict(obj['id_classifier'])
        self.optimizer_G.load_state_dict(obj['opt_G'])
        self.step = obj['clock']
        self.epoch_count = obj['epoch_count']
        self.sche_G.last_epoch = self.step

    def load_checkpoints_delf_init(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        obj = torch.load(ckp_path)
        self.backbone.load_state_dict(obj['backbone'])

    def cal_fea(self, x, domain_mode):
        feat = self.feature_generator(x)
        if domain_mode == 'rgb':
            return self.feature_embedder_rgb(feat)
        elif domain_mode == 'ir':
            return self.feature_embedder_ir(feat)

    def inf_batch(self, batch):
        alpha = settings.alpha
        beta = settings.beta

        anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, \
        negative_ir, anchor_label, modality_rgb, modality_ir = batch

        if torch.cuda.is_available():
            anchor_rgb = anchor_rgb.cuda()
            positive_rgb = positive_rgb.cuda()
            negative_rgb = negative_rgb.cuda()
            anchor_ir = anchor_ir.cuda()
            positive_ir = positive_ir.cuda()
            negative_ir = negative_ir.cuda()
            anchor_label = anchor_label.cuda()

        anchor_rgb_features = self.cal_fea(anchor_rgb, 'rgb')
        positive_rgb_features = self.cal_fea(positive_rgb, 'rgb')
        negative_rgb_features = self.cal_fea(negative_rgb, 'rgb')

        anchor_ir_features = self.cal_fea(anchor_ir, 'ir')
        positive_ir_features = self.cal_fea(positive_ir, 'ir')
        negative_ir_features = self.cal_fea(negative_ir, 'ir')

        at_loss_rgb = self.criterion_at.forward(anchor_rgb_features,
                                                positive_ir_features,
                                                negative_ir_features)

        at_loss_ir = self.criterion_at.forward(anchor_ir_features,
                                               positive_rgb_features,
                                               negative_rgb_features)

        at_loss = at_loss_rgb + at_loss_ir

        predicted_id_rgb = self.id_classifier(anchor_rgb_features)
        predicted_id_ir = self.id_classifier(anchor_ir_features)

        identity_loss = self.criterion_identity(predicted_id_rgb, anchor_label) + \
                        self.criterion_identity(predicted_id_ir, anchor_label)

        loss_G = alpha * at_loss + beta * identity_loss

        self.optimizer_G.zero_grad()
        loss_G.backward()
        self.optimizer_G.step()

        self.write('train_stats', {
            'loss_G': loss_G,
            'at_loss': at_loss,
            'identity_loss': identity_loss
        })
예제 #3
0
class Trainer(BaseTrainer):
    def __init__(self, config):
        super(Trainer, self).__init__(config)
        self.datamanager = DataManger(config["data"])

        # model
        self.model = Baseline(
            num_classes=self.datamanager.datasource.get_num_classes("train")
        )

        # summary model
        summary(
            self.model,
            input_size=(3, 256, 128),
            batch_size=config["data"]["batch_size"],
            device="cpu",
        )

        # losses
        cfg_losses = config["losses"]
        self.criterion = Softmax_Triplet_loss(
            num_class=self.datamanager.datasource.get_num_classes("train"),
            margin=cfg_losses["margin"],
            epsilon=cfg_losses["epsilon"],
            use_gpu=self.use_gpu,
        )

        self.center_loss = CenterLoss(
            num_classes=self.datamanager.datasource.get_num_classes("train"),
            feature_dim=2048,
            use_gpu=self.use_gpu,
        )

        # optimizer
        cfg_optimizer = config["optimizer"]
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=cfg_optimizer["lr"],
            weight_decay=cfg_optimizer["weight_decay"],
        )

        self.optimizer_centerloss = torch.optim.SGD(
            self.center_loss.parameters(), lr=0.5
        )

        # learing rate scheduler
        cfg_lr_scheduler = config["lr_scheduler"]
        self.lr_scheduler = WarmupMultiStepLR(
            self.optimizer,
            milestones=cfg_lr_scheduler["steps"],
            gamma=cfg_lr_scheduler["gamma"],
            warmup_factor=cfg_lr_scheduler["factor"],
            warmup_iters=cfg_lr_scheduler["iters"],
            warmup_method=cfg_lr_scheduler["method"],
        )

        # track metric
        self.train_metrics = MetricTracker("loss", "accuracy")
        self.valid_metrics = MetricTracker("loss", "accuracy")

        # save best accuracy for function _save_checkpoint
        self.best_accuracy = None

        # send model to device
        self.model.to(self.device)

        self.scaler = GradScaler()

        # resume model from last checkpoint
        if config["resume"] != "":
            self._resume_checkpoint(config["resume"])

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            result = self._valid_epoch(epoch)

            # add scalars to tensorboard
            self.writer.add_scalars(
                "Loss",
                {
                    "Train": self.train_metrics.avg("loss"),
                    "Val": self.valid_metrics.avg("loss"),
                },
                global_step=epoch,
            )
            self.writer.add_scalars(
                "Accuracy",
                {
                    "Train": self.train_metrics.avg("accuracy"),
                    "Val": self.valid_metrics.avg("accuracy"),
                },
                global_step=epoch,
            )

            # logging result to console
            log = {"epoch": epoch}
            log.update(result)
            for key, value in log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # save model
            if (
                self.best_accuracy == None
                or self.best_accuracy < self.valid_metrics.avg("accuracy")
            ):
                self.best_accuracy = self.valid_metrics.avg("accuracy")
                self._save_checkpoint(epoch, save_best=True)
            else:
                self._save_checkpoint(epoch, save_best=False)

            # save logs
            self._save_logs(epoch)

    def _train_epoch(self, epoch):
        """Training step"""
        self.model.train()
        self.train_metrics.reset()
        with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar:
            epoch_pbar.set_description(f"Epoch {epoch}")
            for batch_idx, (data, labels, _) in enumerate(
                self.datamanager.get_dataloader("train")
            ):
                # push data to device
                data, labels = data.to(self.device), labels.to(self.device)

                # zero gradient
                self.optimizer.zero_grad()
                self.optimizer_centerloss.zero_grad()

                with autocast():
                    # forward batch
                    score, feat = self.model(data)

                    # calculate loss and accuracy
                    loss = (
                        self.criterion(score, feat, labels)
                        + self.center_loss(feat, labels) * self.config["losses"]["beta"]
                    )
                    _, preds = torch.max(score.data, dim=1)

                # backward parameters
                # loss.backward()
                self.scaler.scale(loss).backward()

                # backward parameters for center_loss
                for param in self.center_loss.parameters():
                    param.grad.data *= 1.0 / self.config["losses"]["beta"]

                # optimize
                # self.optimizer.step()
                self.scaler.step(self.optimizer)
                self.optimizer_centerloss.step()

                self.scaler.update()

                # update loss and accuracy in MetricTracker
                self.train_metrics.update("loss", loss.item())
                self.train_metrics.update(
                    "accuracy",
                    torch.sum(preds == labels.data).double().item() / data.size(0),
                )

                # update process bar
                epoch_pbar.set_postfix(
                    {
                        "train_loss": self.train_metrics.avg("loss"),
                        "train_acc": self.train_metrics.avg("accuracy"),
                    }
                )
                epoch_pbar.update(1)
        return self.train_metrics.result()

    def _valid_epoch(self, epoch):
        """Validation step"""
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar:
                epoch_pbar.set_description(f"Epoch {epoch}")
                for batch_idx, (data, labels, _) in enumerate(
                    self.datamanager.get_dataloader("val")
                ):
                    # push data to device
                    data, labels = data.to(self.device), labels.to(self.device)

                    with autocast():
                        # forward batch
                        score, feat = self.model(data)

                        # calculate loss and accuracy
                        loss = (
                            self.criterion(score, feat, labels)
                            + self.center_loss(feat, labels)
                            * self.config["losses"]["beta"]
                        )
                        _, preds = torch.max(score.data, dim=1)

                    # update loss and accuracy in MetricTracker
                    self.valid_metrics.update("loss", loss.item())
                    self.valid_metrics.update(
                        "accuracy",
                        torch.sum(preds == labels.data).double().item() / data.size(0),
                    )

                    # update process bar
                    epoch_pbar.set_postfix(
                        {
                            "val_loss": self.valid_metrics.avg("loss"),
                            "val_acc": self.valid_metrics.avg("accuracy"),
                        }
                    )
                    epoch_pbar.update(1)
        return self.valid_metrics.result()

    def _save_checkpoint(self, epoch, save_best=True):
        """save model to file"""
        state = {
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "center_loss": self.center_loss.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "optimizer_centerloss": self.optimizer_centerloss.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best_accuracy": self.best_accuracy,
        }
        filename = os.path.join(self.checkpoint_dir, "model_last.pth")
        self.logger.info("Saving last model: model_last.pth ...")
        torch.save(state, filename)
        if save_best:
            filename = os.path.join(self.checkpoint_dir, "model_best.pth")
            self.logger.info("Saving current best: model_best.pth ...")
            torch.save(state, filename)

    def _resume_checkpoint(self, resume_path):
        """Load model from checkpoint"""
        if not os.path.exists(resume_path):
            raise FileExistsError("Resume path not exist!")
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=self.map_location)
        self.start_epoch = checkpoint["epoch"] + 1
        self.model.load_state_dict(checkpoint["state_dict"])
        self.center_loss.load_state_dict(checkpoint["center_loss"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"])
        self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        self.best_accuracy = checkpoint["best_accuracy"]
        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
        )

    def _save_logs(self, epoch):
        """Save logs from google colab to google drive"""
        if os.path.isdir(self.logs_dir_saved):
            shutil.rmtree(self.logs_dir_saved)
        destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)