Esempio n. 1
0
    def train(self, data_loader: torch.utils.data.DataLoader, device: str, **kwargs):  # pylint: disable=unused-argument
        """Train function defined for a single epoch."""

        train_loss = 0.
        steps_per_epoch = len(data_loader)
        self._set_learning_phase(train=True)

        with tqdm.tqdm(**get_tqdm_config(total=steps_per_epoch, leave=False, color='green')) as pbar:
            for i, batch in enumerate(data_loader):

                x = batch['x'].to(device)  # 4d
                y = batch['y'].to(device)  # 3d

                self.optimizer.zero_grad()
                _, decoded = self.predict(x)  # `decoded` are logits (B, C, H, W)
                loss = self.loss_function(decoded, y)
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

                desc = f" Batch: [{i+1:>4}/{steps_per_epoch:>4}]"
                desc += f" Loss: {train_loss/(i+1):.4f} "
                pbar.set_description_str(desc)
                pbar.update(1)

        out = {'loss': train_loss / steps_per_epoch}
        if isinstance(self.metrics, dict):
            raise NotImplementedError

        return out
Esempio n. 2
0
    def train(self, data_loader: torch.utils.data.DataLoader, device: str, **kwargs):  # pylint: disable=unused-argument
        """Train function defined for a single epoch."""

        train_loss = 0.
        steps_per_epoch = len(data_loader)
        self._set_learning_phase(train=True)

        with tqdm.tqdm(**get_tqdm_config(steps_per_epoch, leave=False, color='red')) as pbar:
            for i, batch in enumerate(data_loader):

                self.optimizer.zero_grad()
                x1, x2 = batch['x1'].to(device), batch['x2'].to(device)
                z1, z2 = self.predict(x1), self.predict(x2)
                loss = self.loss_function(features=torch.stack([z1, z2], dim=1))
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                desc = f" Batch [{i+1:>4}/{steps_per_epoch:>4}]"
                desc += f" Loss: {train_loss/(i+1):.4f} "
                pbar.set_description_str(desc)
                pbar.update(1)

            out = {
                'loss': train_loss / steps_per_epoch,
            }
            if self.metrics is not None:
                raise NotImplementedError

            return out
Esempio n. 3
0
    def train(self, data_loader: torch.utils.data.DataLoader, device: str, **kwargs):  # pylint: disable=unused-argument
        """Train function defined for a single epoch."""

        preds = []
        train_loss = 0.
        steps_per_epoch = len(data_loader)
        self._set_learning_phase(train=True)

        with tqdm.tqdm(**get_tqdm_config(steps_per_epoch, leave=False, color='green')) as pbar:
            for i, batch in enumerate(data_loader):

                j  = batch['idx']
                x  = batch['x'].to(device)
                x_t = batch['x_t'].to(device)
                z = self.predict(x)
                z_t = self.predict(x_t)

                m = self.memory.get_representations(j).to(device)
                negatives = self.memory.get_negatives(self.num_negatives, exclude=j)

                # Calculate loss
                loss_z, _ = self.loss_function(
                    anchors=m,
                    positives=z,
                    negatives=negatives,
                )
                loss_z_t, logits = self.loss_function(
                    anchors=m,
                    positives=z_t,
                    negatives=negatives,
                )
                loss = (1 - self.loss_weight)  * loss_z + self.loss_weight * loss_z_t

                # Backpropagation & update
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.memory.update(j, values=z.detach())

                train_loss += loss.detach().item()
                preds += [logits.detach().cpu()]

                desc = f" Batch: [{i+1:>4}/{steps_per_epoch:>4}]"
                desc += f" Loss: {train_loss/(i+1):.4f} "
                pbar.set_description_str(desc)
                pbar.update(1)

        out = {'loss': train_loss / steps_per_epoch}
        if self.metrics is not None:
            assert isinstance(self.metrics, dict)
            with torch.no_grad():
                preds = torch.cat(preds, dim=0)                          # (N, 1+ num_negatives)
                trues = torch.zeros(preds.size(0), device=preds.device)  # (N, )
                for metric_name, metric_function in self.metrics.items():
                    out[metric_name] = metric_function(preds, trues).item()

        return out
Esempio n. 4
0
 def write_images(self, root: str, indices: list or tuple):
     """Write wafer images to .png files."""
     os.makedirs(root, exist_ok=True)
     with tqdm.tqdm(**get_tqdm_config(total=len(indices), leave=True, color='yellow')) as pbar:
         for i, row in self.data.loc[indices].iterrows():
             pngfile = os.path.join(root, row['labelString'], f'{i:06}.png')
             os.makedirs(os.path.dirname(pngfile), exist_ok=True)
             self.save_image(row['waferMap'], pngfile)
             pbar.set_description_str(f" {root} - {i:06} ")
             pbar.update(1)
Esempio n. 5
0
    def train(self, data_loader: torch.utils.data.DataLoader, device: str, **kwargs):  # pylint: disable=unused-argument
        """Train function defined for a single epoch."""

        train_loss = 0.
        steps_per_epoch = len(data_loader)
        self._set_learning_phase(train=True)

        with tqdm.tqdm(**get_tqdm_config(steps_per_epoch, leave=False, color='red')) as pbar:
            for i, batch in enumerate(data_loader):

                self.optimizer.zero_grad()

                x1, x2 = batch['x1'].to(device), batch['x2'].to(device)
                z, attn = self.predict(torch.cat([x1, x2], dim=0))
                z = z.view(x1.size(0), 2, -1)

                # Calculate attention-based contrastive loss
                loss, _ = self.loss_function(features=z, attention_scores=attn)
                loss.backward()

                # Clip the gradients (XXX: why is this necessary?)
                # nn.utils.clip_grad_norm_(self.backbone.parameters(), 1.)
                # nn.utils.clip_grad_norm_(self.projector.parameters(), 1.)

                # Update weights
                self.optimizer.step()

                train_loss += loss.item()
                desc = f" Batch [{i+1:>4}/{steps_per_epoch:>4}]"
                desc += f" Loss: {train_loss/(i+1):.4f} "
                pbar.set_description_str(desc)
                pbar.update(1)

            out = {
                'loss': train_loss / steps_per_epoch,
            }
            if self.metrics is not None:
                raise NotImplementedError

            return out
Esempio n. 6
0
    def run(self, train_set, valid_set, epochs: int, batch_size: int, num_workers: int = 0, device: str = 'cuda', **kwargs):  # pylint: disable=unused-argument

        assert isinstance(train_set, torch.utils.data.Dataset)
        assert isinstance(valid_set, torch.utils.data.Dataset)
        assert isinstance(epochs, int)
        assert isinstance(batch_size, int)
        assert isinstance(num_workers, int)
        assert device.startswith('cuda') or device == 'cpu'

        logger = kwargs.get('logger', None)

        self.backbone = self.backbone.to(device)
        self.projector = self.projector.to(device)

        train_loader = get_dataloader(train_set, batch_size, num_workers=num_workers)
        valid_loader = get_dataloader(valid_set, batch_size, num_workers=num_workers)

        with tqdm.tqdm(**get_tqdm_config(total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader, device=device)
                valid_history = self.evaluate(valid_loader, device=device)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss'),
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if self.metrics is not None:
                    raise NotImplementedError

                # 3. TensorBoard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(
                            main_tag=metric_name,
                            tag_scalar_dict=metric_dict,
                            global_step=epoch
                        )
                        if self.scheduler is not None:
                            self.writer.add_scalar(
                                tag='lr',
                                scalar_value=self.scheduler.get_last_lr()[0],
                                global_step=epoch
                            )

                # 4. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    self.save_checkpoint(self.best_ckpt, epoch=epoch, **epoch_history)
                    if kwargs.get('save_every', False):
                        new_ckpt = os.path.join(self.checkpoint_dir, f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt')
                        self.save_checkpoint(new_ckpt, epoch=epoch, **epoch_history)

                # 5. Update learning rate scheduler
                if self.scheduler is not None:
                    self.scheduler.step()

               # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if 'test_set' in kwargs.keys():
            test_loader = get_dataloader(kwargs.get('test_set'), batch_size=batch_size, num_workers=num_workers)
            self.test(test_loader, device=device, logger=logger)
Esempio n. 7
0
    def run(self,
            train_set,
            valid_set,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            **kwargs):
        """Train, evaluate and optionally test."""

        logger = kwargs.get('logger', None)

        self.backbone.to(self.local_rank)
        self.classifier.to(self.local_rank)

        train_loader = balanced_loader(train_set,
                                       batch_size,
                                       num_workers=num_workers,
                                       shuffle=False,
                                       pin_memory=False)
        valid_loader = DataLoader(valid_set,
                                  batch_size,
                                  num_workers=num_workers,
                                  shuffle=True,
                                  drop_last=False,
                                  pin_memory=False)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader)
                valid_history = self.evaluate(valid_loader)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if isinstance(self.metrics, dict):
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history[metric_name],
                            'valid': valid_history[metric_name],
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss <= best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    if self.local_rank == 0:
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)

                # 5. Update learning rate scheduler (optional)
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch,
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if self.local_rank == 0:
            if 'test_set' in kwargs.keys():
                test_set = kwargs['test_set']
                test_loader = DataLoader(test_set,
                                         batch_size,
                                         num_workers=num_workers,
                                         shuffle=True,
                                         drop_last=False,
                                         pin_memory=False)
                self.test(test_loader, logger=logger)
Esempio n. 8
0
    def run(self,
            train_set,
            valid_set,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            device: str = 'cuda',
            **kwargs):
        """Train, evaluate and optionally test."""

        assert isinstance(train_set, torch.utils.data.Dataset)
        assert isinstance(valid_set, torch.utils.data.Dataset)
        assert isinstance(epochs, int)
        assert isinstance(batch_size, int)
        assert isinstance(num_workers, int)
        assert device.startswith('cuda') or device == 'cpu'

        logger = kwargs.get('logger', None)
        disable_mixup = kwargs.get('disable_mixup', False)

        self.backbone = self.backbone.to(device)
        self.classifier = self.classifier.to(device)

        balance = kwargs.get('balance', False)
        if logger is not None:
            logger.info(f"Class balance: {balance}")
        shuffle = not balance

        train_loader = get_dataloader(train_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=shuffle,
                                      balance=balance)
        valid_loader = get_dataloader(valid_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      balance=False)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            # Determine model selection metric. Defaults to 'loss'.
            eval_metric = kwargs.get('eval_metric', 'loss')
            if eval_metric == 'loss':
                best_metric_val = float('inf')
            elif eval_metric in [
                    'accuracy', 'precision', 'recall', 'f1', 'auroc', 'auprc'
            ]:
                best_metric_val = 0
            else:
                raise NotImplementedError

            best_epoch = 0
            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                if disable_mixup:
                    train_history = self.train(train_loader, device)
                else:
                    train_history = self.train_with_mixup(train_loader, device)
                valid_history = self.evaluate(valid_loader, device)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    }
                }

                # 2. Epoch history (other metrics if provided)
                if isinstance(self.metrics, dict):
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history[metric_name],
                            'valid': valid_history[metric_name],
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4. Save model if it is the current best
                metric_val = epoch_history[eval_metric]['valid']
                if eval_metric == 'loss':
                    if metric_val <= best_metric_val:
                        best_metric_val = metric_val
                        best_epoch = epoch
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                elif eval_metric in ['accuracy', 'f1', 'auroc', 'auprc']:
                    if metric_val >= best_metric_val:
                        best_metric_val = metric_val
                        best_epoch = epoch
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                else:
                    raise NotImplementedError

                # 5. Update learning rate scheduler (optional)
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(
                    history=epoch_history,
                    current=epoch,
                    total=epochs,
                    best=best_epoch,
                )
                pbar.set_description_str(desc)
                pbar.update(1)
                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)

        # 8. Test model (optional)
        if 'test_set' in kwargs.keys():
            test_loader = get_dataloader(kwargs.get('test_set'),
                                         batch_size,
                                         num_workers=num_workers)
            self.test(test_loader, device=device, logger=logger)
Esempio n. 9
0
    def train(self, data_loader: torch.utils.data.DataLoader, **kwargs):  # pylint: disable=unused-argument
        """Train function defined for a single epoch."""

        out = {'loss': 0.}
        steps_per_epoch = len(data_loader)
        self._set_learning_phase(train=True)

        with tqdm.tqdm(**get_tqdm_config(
                steps_per_epoch, leave=False, color='green')) as pbar:
            for i, batch in enumerate(data_loader):

                j = batch['idx']
                x = batch['x'].to(self.local_rank)
                x_t = batch['x_t'].to(self.local_rank)
                z_concat = self.predict(torch.cat([x, x_t], dim=0))
                z = z_concat[:x.size(0)]
                z_t = z_concat[x.size(0):]

                m = self.memory.get_representations(j).to(self.local_rank)
                negatives = self.memory.get_negatives(self.num_negatives,
                                                      exclude=j)

                # Calculate loss
                loss_z, _ = self.loss_function(
                    anchors=m,
                    positives=z,
                    negatives=negatives,
                )
                loss_z_t, logits = self.loss_function(
                    anchors=m,
                    positives=z_t,
                    negatives=negatives,
                )
                loss = (1 - self.loss_weight
                        ) * loss_z + self.loss_weight * loss_z_t

                # Backpropagation & update
                loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.memory.update(j, values=z.detach())

                # Accumulate loss & metrics
                out['loss'] += loss.item()
                if self.metrics is not None:
                    assert isinstance(self.metrics, dict)
                    for metric_name, metric_function in self.metrics.items():
                        if metric_name not in out.keys():
                            out[metric_name] = 0.
                        with torch.no_grad():
                            logits = logits.detach()
                            targets = torch.zeros(logits.size(0),
                                                  device=logits.device)
                            out[metric_name] += metric_function(
                                logits, targets).item()

                desc = f" Batch - [{i+1:>4}/{steps_per_epoch:>4}]: "
                desc += " | ".join(
                    [f"{k}: {v/(i+1):.4f}" for k, v in out.items()])
                pbar.set_description_str(desc)
                pbar.update(1)

        return {k: v / steps_per_epoch for k, v in out.items()}
Esempio n. 10
0
    def run(self,
            train_set: torch.utils.data.Dataset,
            valid_set: torch.utils.data.Dataset,
            epochs: int,
            batch_size: int,
            num_workers: int = 0,
            **kwargs):

        logger = kwargs.get('logger', None)
        save_every = kwargs.get('save_every', epochs)

        self.backbone.to(self.local_rank)
        self.projector.to(self.local_rank)

        if self.distributed:
            raise NotImplementedError
        else:
            train_loader = DataLoader(train_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=True,
                                      pin_memory=False)
            valid_loader = DataLoader(valid_set,
                                      batch_size,
                                      num_workers=num_workers,
                                      shuffle=True,
                                      pin_memory=False)

        # Initialize memory representations for the training data
        if not self.memory.initialized:
            self.memory.initialize(self.backbone, self.projector, train_loader)

        with tqdm.tqdm(**get_tqdm_config(
                total=epochs, leave=True, color='blue')) as pbar:

            best_valid_loss = float('inf')
            best_epoch = 0

            for epoch in range(1, epochs + 1):

                # 0. Train & evaluate
                train_history = self.train(train_loader)
                valid_history = self.evaluate(valid_loader)

                # 1. Epoch history (loss)
                epoch_history = {
                    'loss': {
                        'train': train_history.get('loss'),
                        'valid': valid_history.get('loss')
                    },
                }

                # 2. Epoch history (other metrics if provided)
                if self.metrics is not None:
                    assert isinstance(self.metrics, dict)
                    for metric_name, _ in self.metrics.items():
                        epoch_history[metric_name] = {
                            'train': train_history.get(metric_name),
                            'valid': valid_history.get(metric_name),
                        }

                # 3. Tensorboard
                if self.writer is not None:
                    for metric_name, metric_dict in epoch_history.items():
                        self.writer.add_scalars(main_tag=metric_name,
                                                tag_scalar_dict=metric_dict,
                                                global_step=epoch)
                    if self.scheduler is not None:
                        self.writer.add_scalar(
                            tag='lr',
                            scalar_value=self.scheduler.get_last_lr()[0],
                            global_step=epoch)

                # 4-1. Save model if it is the current best
                valid_loss = epoch_history['loss']['valid']
                if valid_loss < best_valid_loss:
                    best_valid_loss = valid_loss
                    best_epoch = epoch
                    if self.local_rank == 0:
                        self.save_checkpoint(self.best_ckpt,
                                             epoch=epoch,
                                             **epoch_history)
                        self.memory.save(os.path.join(
                            os.path.dirname(self.best_ckpt), 'best_memory.pt'),
                                         epoch=epoch)

                # 4-2. Save intermediate models
                if epoch % save_every == 0:
                    if self.local_rank == 0:
                        new_ckpt = os.path.join(
                            self.checkpoint_dir,
                            f'epoch_{epoch:04d}.loss_{valid_loss:.4f}.pt')
                        self.save_checkpoint(new_ckpt,
                                             epoch=epoch,
                                             **epoch_history)

                # 5. Update learning rate scheduler
                if self.scheduler is not None:
                    self.scheduler.step()

                # 6. Logging
                desc = make_epoch_description(history=epoch_history,
                                              current=epoch,
                                              total=epochs,
                                              best=best_epoch)
                pbar.set_description_str(desc)
                pbar.update(1)

                if logger is not None:
                    logger.info(desc)

        # 7. Save last model
        if self.local_rank == 0:
            self.save_checkpoint(self.last_ckpt, epoch=epoch, **epoch_history)
            self.memory.save(os.path.join(os.path.dirname(self.last_ckpt),
                                          'last_memory.pt'),
                             epoch=epoch)