Esempio n. 1
0
class GraphNNMOTracker:

    def __init__(self, config, writer):
        self.writer = writer
        self.config = config
        self.device = torch.device('cuda' if config.cuda else 'cpu')

        self.net = Net().to(self.device)
        if self.config.train_cnn:
            self.re_id_net = ReID(out_feats=32, pretrained=True).to(self.device)

        log_dir = Path(
            self.writer.get_data_path(self.writer.name, self.writer.version))
        self.model_save_dir = log_dir / 'checkpoints'
        self.model_save_dir.mkdir(exist_ok=True)

        self.epoch = 0

    def train_dataloader(self):
        ds = PreprocessedDataset(Path(self.config.dataset_path),
                                 sequences=self.config.train_sequences,
                                 load_imgs=self.config.train_cnn)
        train = DataLoader(ds, batch_size=self.config.batch_size,
                           num_workers=self.config.workers, shuffle=True)
        return train

    def val_dataloader(self):
        ds = PreprocessedDataset(Path(self.config.dataset_path),
                                 sequences=self.config.val_sequences,
                                 load_imgs=self.config.train_cnn)
        train = DataLoader(ds, batch_size=self.config.batch_size,
                           num_workers=self.config.workers,
                           shuffle=True)
        return train

    def train(self):
        train_loader = self.train_dataloader()
        val_loader = self.val_dataloader()

        # setup optimizer
        opt = torch.optim.Adam(self.net.parameters(),
                               lr=3e-4,
                               weight_decay=1e-4,
                               betas=(0.9, 0.999))

        if self.config.train_cnn:
            opt_re_id = torch.optim.Adam(self.re_id_net.parameters(),
                                         lr=3e-6,
                                         weight_decay=1e-4,
                                         betas=(0.9, 0.999))

        if self.config.use_focal:
            criterion = FocalLoss()
        else:
            criterion = torch.nn.BCELoss()

        for epoch in range(self.config.epochs):
            self.epoch += 1
            self.net.train()
            if self.config.train_cnn:
                self.re_id_net.train()
            metrics = defaultdict(list)
            pbar = tqdm(train_loader)
            for i, data in enumerate(pbar):
                if self.config.train_cnn:
                    img_tensor = data.imgs
                    img_ds = TensorDataset(img_tensor)
                    img_dl = RegLoader(img_ds, batch_size=2)

                    x_feats = []
                    for imgs in img_dl:
                        if len(imgs) == 1:
                            # Batchnorm with size 1 batch fails
                            self.re_id_net.eval()
                        x_feats.append(self.re_id_net(imgs[0].to(self.device)))
                        self.re_id_net.train()

                    x_feats = torch.cat(x_feats)
                    data.x = x_feats
                    del data.imgs

                data = data.to(self.device)
                gt = data.y.float()
                initial_x = data.x.clone()
                out = self.net(data, initial_x).squeeze(1)
                loss = criterion(out, gt)

                with torch.no_grad():
                    acc = ((out > 0.5) == gt).float().mean().item()

                metrics['train/loss'].append(loss.item())
                metrics['train/acc'].append(acc)
                pbar.set_description(f"Loss: {loss.item():.4f}, Acc: {acc:.2f}")

                opt.zero_grad()
                loss.backward()
                opt.step()
                if self.config.train_cnn:
                    opt_re_id.step()
                    opt_re_id.zero_grad()

            with torch.no_grad():
                self.net.eval()
                if self.config.train_cnn:
                    self.re_id_net.eval()
                pbar = tqdm(val_loader)
                for i, data in enumerate(pbar):
                    data = data.to(self.device)
                    gt = data.y.float()

                    if self.config.train_cnn:
                        img_tensor = data.imgs
                        img_ds = TensorDataset(img_tensor)
                        img_dl = RegLoader(img_ds, batch_size=2)

                        x_feats = []
                        for imgs in img_dl:
                            x_feats.append(self.re_id_net(imgs[0].to(self.device)))

                        x_feats = torch.cat(x_feats)
                        data.x = x_feats
                        del data.imgs

                    initial_x = data.x.clone()
                    out = self.net(data, initial_x).squeeze(1)
                    loss = criterion(out, gt)

                    with torch.no_grad():
                        acc = ((out > 0.5) == gt).float().mean().item()

                    metrics['val/loss'].append(loss.item())
                    metrics['val/acc'].append(acc)
                    pbar.set_description(f"Validation epoch {self.epoch}: "
                                         f"Loss: {loss.item():.4f}, "
                                         f"Acc: {acc:.2f}")

            metrics = {k: np.mean(v) for k, v in metrics.items()}
            self.writer.log(metrics, epoch)
            if epoch % 10 == 0 and epoch > 5:
                self.save(self.model_save_dir / 'checkpoints_{}.pth'.format(epoch))

    def save(self, path: Path):
        torch.save(self.net.state_dict(), path)

    def load(self, path: Path):
        self.net.load_state_dict(torch.load(path))
Esempio n. 2
0
class GraphNNMOTracker:
    def __init__(self, config, writer):
        self.writer = writer
        self.config = config
        self.device = torch.device("cuda" if config.cuda else "cpu")

        self.net = Net().to(self.device)
        if self.config.train_cnn:
            self.re_id_net = osnet_x0_5(pretrained=True)

        log_dir = Path(
            self.writer.get_data_path(self.writer.name, self.writer.version))
        self.model_save_dir = log_dir / "checkpoints"
        self.model_save_dir.mkdir(exist_ok=True)

        self.epoch = 0
        self.best = float("-inf")

    def train_dataloader(self):
        ds = PreprocessedDataset(
            Path(self.config.dataset_path),
            sequences=self.config.train_sequences,
            load_imgs=self.config.train_cnn,
        )
        train = DataLoader(
            ds,
            batch_size=self.config.batch_size,
            num_workers=self.config.workers,
            shuffle=True,
        )
        return train

    def val_dataloader(self):
        ds = PreprocessedDataset(
            Path(self.config.dataset_path),
            sequences=self.config.val_sequences,
            load_imgs=self.config.train_cnn,
        )
        train = DataLoader(
            ds,
            batch_size=self.config.batch_size,
            num_workers=self.config.workers,
            shuffle=True,
        )
        return train

    def train(self):
        train_loader = self.train_dataloader()
        val_loader = self.val_dataloader()

        # setup optimizer
        opt = torch.optim.Adam(
            self.net.parameters(),
            lr=self.config.base_lr,
            weight_decay=1e-4,
            betas=(0.9, 0.999),
        )

        if self.config.train_cnn:
            opt_re_id = torch.optim.Adam(
                self.re_id_net.parameters(),
                lr=3e-6,
                weight_decay=1e-4,
                betas=(0.9, 0.999),
            )

        if self.config.use_focal:
            criterion = FocalLoss()
        else:
            criterion = torch.nn.BCELoss()

        preds, gts = [], []
        for epoch in range(self.config.epochs):
            self.epoch += 1
            self.net.train()
            if self.config.train_cnn:
                self.re_id_net.train()
            metrics = defaultdict(list)
            pbar = tqdm(train_loader)
            for i, data in enumerate(pbar):
                if self.config.train_cnn:
                    img_tensor = data.imgs
                    img_ds = TensorDataset(img_tensor)
                    img_dl = RegLoader(img_ds, batch_size=2)

                    x_feats = []
                    for imgs in img_dl:
                        if len(imgs) == 1:
                            # Batchnorm with size 1 batch fails
                            self.re_id_net.eval()
                        x_feats.append(self.re_id_net(imgs[0].to(self.device)))
                        self.re_id_net.train()

                    x_feats = torch.cat(x_feats)
                    data.x = x_feats
                    del data.imgs

                data = data.to(self.device)
                gt = data.y.float()
                initial_x = data.x.clone()
                out = self.net(data, initial_x).squeeze(1)
                loss = criterion(out, gt)

                gts.append(gt.detach().cpu())
                preds.append(out.detach().cpu())

                metrics["train/Loss"].append(loss.item())
                pbar.set_description(f"Loss: {loss.item():.04f}")

                opt.zero_grad()
                loss.backward()
                opt.step()
                if self.config.train_cnn:
                    opt_re_id.step()
                    opt_re_id.zero_grad()

            train_metrics = _compute_metrics(torch.cat(preds),
                                             torch.cat(gts),
                                             prefix="train/")

            preds, gts = [], []
            with torch.no_grad():
                self.net.eval()
                if self.config.train_cnn:
                    self.re_id_net.eval()
                pbar = tqdm(val_loader)
                for _, data in enumerate(pbar):
                    data = data.to(self.device)
                    gt = data.y.float()

                    if self.config.train_cnn:
                        img_tensor = data.imgs
                        img_ds = TensorDataset(img_tensor)
                        img_dl = RegLoader(img_ds, batch_size=2)

                        x_feats = []
                        for imgs in img_dl:
                            x_feats.append(
                                self.re_id_net(imgs[0].to(self.device)))

                        x_feats = torch.cat(x_feats)
                        data.x = x_feats
                        del data.imgs

                    initial_x = data.x.clone()
                    out = self.net(data, initial_x).squeeze(1)
                    loss = criterion(out, gt)

                    gts.append(gt.detach().cpu())
                    preds.append(out.detach().cpu())

                    metrics["val/Loss"].append(loss.item())
                    pbar.set_description(f"Loss: {loss.item():.04f}")

            val_metrics = _compute_metrics(torch.cat(preds),
                                           torch.cat(gts),
                                           prefix="val/")
            metrics = {k: np.mean(v) for k, v in metrics.items()}
            metrics.update({**train_metrics, **val_metrics})

            self.writer.log(metrics, epoch)
            if epoch % 10 == 0 and epoch > 5:
                self.save(self.model_save_dir /
                          "checkpoints_{}.pth".format(epoch))

            if metrics["val/F1"] > self.best:
                self.best = metrics["val/F1"]
                self.save(self.model_save_dir / "checkpoints_best.pth")

    def save(self, path: Path):
        torch.save(self.net.state_dict(), path)

    def load(self, path: Path):
        self.net.load_state_dict(torch.load(path))