コード例 #1
0
    def _compute_batch():
        trainer = Trainer(fast_dev_run=True, strategy="horovod", logger=False)

        assert isinstance(trainer.accelerator, CPUAccelerator)
        # TODO: test that we selected the correct training_type_plugin based on horovod flags

        metric = Accuracy(
            compute_on_step=True,
            dist_sync_on_step=True,
            dist_sync_fn=trainer.training_type_plugin.all_gather,
            threshold=threshold,
        )

        for i in range(hvd.rank(), num_batches, hvd.size()):
            batch_result = metric(preds[i], target[i])
            if hvd.rank() == 0:
                dist_preds = torch.stack(
                    [preds[i + r] for r in range(hvd.size())])
                dist_target = torch.stack(
                    [target[i + r] for r in range(hvd.size())])
                sk_batch_result = sk_metric(dist_preds, dist_target)
                assert np.allclose(batch_result.numpy(), sk_batch_result)

        # check on all batches on all ranks
        result = metric.compute()
        assert isinstance(result, torch.Tensor)

        total_preds = torch.stack([preds[i] for i in range(num_batches)])
        total_target = torch.stack([target[i] for i in range(num_batches)])
        sk_result = sk_metric(total_preds, total_target)

        assert np.allclose(result.numpy(), sk_result)
コード例 #2
0
    def __init__(self,
                 input_size: int = 784,
                 lin1_size: int = 256,
                 lin2_size: int = 256,
                 lin3_size: int = 256,
                 output_size: int = 10,
                 lr: float = 0.001,
                 weight_decay: float = 0.0005,
                 **kwargs):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters()

        self.model = SimpleDenseNet(hparams=self.hparams)

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
コード例 #3
0
    def __init__(self, cfg):
        super().__init__()
        self.save_hyperparameters(cfg)

        self.model = ConvClassifier(in_channels=cfg.in_channels,
                                    out_features=cfg.out_features)

        self.criterion = torch.nn.NLLLoss()

        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()
コード例 #4
0
 def __init__(self, model_cls: Type[LightningModule],
              checkpoint_paths: List[str]) -> None:
     super().__init__()
     # Create `num_folds` models with their associated fold weights
     self.models = torch.nn.ModuleList(
         [model_cls.load_from_checkpoint(p) for p in checkpoint_paths])
     self.test_acc = Accuracy()
コード例 #5
0
    def __init__(self,
                 architecture: str = "GCN",
                 num_node_features: int = 3,
                 activation: str = "prelu",
                 num_conv_layers: int = 3,
                 conv_size: int = 256,
                 pool_method: str = "add",
                 lin1_size: int = 128,
                 lin2_size: int = 64,
                 output_size: int = 10,
                 lr: float = 0.001,
                 weight_decay: float = 0,
                 **kwargs):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters(logger=False)

        # init network architecture
        if self.hparams.architecture == "GCN":
            self.model = gcn.GCN(hparams=self.hparams)
        elif self.hparams.architecture == "GAT":
            self.model = gat.GAT(hparams=self.hparams)
        elif self.hparams.architecture == "GraphSAGE":
            self.model = graph_sage.GraphSAGE(hparams=self.hparams)
        elif self.hparams.architecture == "GIN":
            self.model = gin.GIN(hparams=self.hparams)
        else:
            raise Exception("Incorrect architecture name!")

        # loss function
        self.criterion = torch.nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

        self.metric_hist = {
            "train/acc": [],
            "val/acc": [],
            "train/loss": [],
            "val/loss": [],
        }
コード例 #6
0
    def __init__(self, show_train=False, **kwarg):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        #self.save_hyperparameters()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_accuracy = Accuracy()
        self.val_accuracy = Accuracy()
        self.test_accuracy = Accuracy()

        self.train_iou = IoU(num_classes=2)
        self.val_iou = IoU(num_classes=2)
        self.test_iou = IoU(num_classes=2)
        self.show_train = show_train
コード例 #7
0
    def __init__(self, hparams):
        super(SESEMI, self).__init__()
        self.save_hyperparameters(hparams)

        assert self.hparams.backbone in SUPPORTED_BACKBONES, f'--backbone must be one of {SUPPORTED_BACKBONES}'

        self.feature_extractor = PyTorchImageModels(self.hparams.backbone,
                                                    self.hparams.pretrained,
                                                    self.hparams.global_pool)

        if self.hparams.pretrained:
            logging.info(
                f'Initialized with pretrained {self.hparams.backbone} backbone'
            )

        if self.hparams.freeze_backbone:
            logging.info(f'Freezing {self.hparams.backbone} backbone')
            for m in self.feature_extractor.modules():
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

        self.in_features = self.feature_extractor.in_features
        self.dropout = nn.Dropout(self.hparams.dropout_rate)
        self.fc_labeled = nn.Linear(self.in_features,
                                    self.hparams.num_labeled_classes)
        self.fc_unlabeled = nn.Linear(self.in_features,
                                      self.hparams.num_unlabeled_classes)
        self.register_buffer(
            'current_learning_rate',
            torch.tensor(self.hparams.warmup_lr,
                         dtype=torch.float32,
                         device=self.device))
        self.register_buffer(
            'best_validation_top1_accuracy',
            torch.tensor(0., dtype=torch.float32, device=self.device))

        self.training_accuracy = Accuracy(top_k=1, dist_sync_on_step=True)
        self.validation_top1_accuracy = Accuracy(top_k=1)
        self.validation_average_loss = AverageMeter()
コード例 #8
0
 def __init__(self) -> None:
     super().__init__()
     self.val_acc = Accuracy()
コード例 #9
0
class SESEMI(pl.LightningModule):
    def __init__(self, hparams):
        super(SESEMI, self).__init__()
        self.save_hyperparameters(hparams)

        assert self.hparams.backbone in SUPPORTED_BACKBONES, f'--backbone must be one of {SUPPORTED_BACKBONES}'

        self.feature_extractor = PyTorchImageModels(self.hparams.backbone,
                                                    self.hparams.pretrained,
                                                    self.hparams.global_pool)

        if self.hparams.pretrained:
            logging.info(
                f'Initialized with pretrained {self.hparams.backbone} backbone'
            )

        if self.hparams.freeze_backbone:
            logging.info(f'Freezing {self.hparams.backbone} backbone')
            for m in self.feature_extractor.modules():
                m.eval()
                for param in m.parameters():
                    param.requires_grad = False

        self.in_features = self.feature_extractor.in_features
        self.dropout = nn.Dropout(self.hparams.dropout_rate)
        self.fc_labeled = nn.Linear(self.in_features,
                                    self.hparams.num_labeled_classes)
        self.fc_unlabeled = nn.Linear(self.in_features,
                                      self.hparams.num_unlabeled_classes)
        self.register_buffer(
            'current_learning_rate',
            torch.tensor(self.hparams.warmup_lr,
                         dtype=torch.float32,
                         device=self.device))
        self.register_buffer(
            'best_validation_top1_accuracy',
            torch.tensor(0., dtype=torch.float32, device=self.device))

        self.training_accuracy = Accuracy(top_k=1, dist_sync_on_step=True)
        self.validation_top1_accuracy = Accuracy(top_k=1)
        self.validation_average_loss = AverageMeter()

    def forward(self, x):
        features = self.feature_extractor(x)
        logits = self.fc_labeled(features)
        return F.softmax(logits, dim=-1)

    def forward_train(self, x_labeled, x_unlabeled=None):
        # Compute output for labeled input
        x_labeled = self.feature_extractor(x_labeled)
        if self.hparams.dropout_rate > 0.0:
            x_labeled = self.dropout(x_labeled)
        output_labeled = self.fc_labeled(x_labeled)

        if x_unlabeled is not None:
            # Compute output for unlabeled input and return both outputs
            x_unlabeled = self.feature_extractor(x_unlabeled)
            output_unlabeled = self.fc_unlabeled(x_unlabeled)
            return output_labeled, output_unlabeled

        return output_labeled, None

    def optimizer_step(
        self,
        epoch,
        batch_idx,
        optimizer,
        optimizer_idx,
        optimizer_closure,
        **kwargs,
    ):
        optimizer.step(closure=optimizer_closure)
        self.current_learning_rate = torch.tensor(
            adjust_polynomial_lr(optimizer.optimizer,
                                 self.global_step,
                                 warmup_iters=self.hparams.warmup_iters,
                                 warmup_lr=self.hparams.warmup_lr,
                                 lr=self.hparams.lr,
                                 lr_pow=self.hparams.lr_pow,
                                 max_iters=self.hparams.max_iters),
            dtype=self.current_learning_rate.dtype,
            device=self.current_learning_rate.device)

    def configure_optimizers(self):
        if self.hparams.optimizer.lower() == 'sgd':
            optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                               self.parameters()),
                                        lr=self.hparams.lr,
                                        momentum=self.hparams.momentum,
                                        nesterov=True,
                                        weight_decay=self.hparams.weight_decay)
        elif self.hparams.optimizer.lower() == 'adam':
            optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                self.parameters()),
                                         lr=self.hparams.lr,
                                         betas=(self.hparams.momentum, 0.999),
                                         weight_decay=0.0)
        else:
            raise NotImplementedError()

        return optimizer

    def training_step(self, batch, batch_index):
        inputs_t, targets_t = batch['supervised']
        inputs_u, targets_u = batch.get('unsupervised_rotation', (None, None))

        # Forward pass
        outputs_t, outputs_u = self.forward_train(inputs_t, inputs_u)
        loss_t = F.cross_entropy(outputs_t, targets_t, reduction='mean')
        if outputs_u is not None:
            loss_u = F.cross_entropy(outputs_u, targets_u, reduction='mean')
        else:
            loss_u = 0.
        loss_weight = self.hparams.initial_loss_weight * sigmoid_rampup(
            self.global_step, self.hparams.stop_rampup)

        loss = loss_t + loss_u * loss_weight

        self.log('train/loss_labeled', loss_t)
        self.log('train/loss_unlabeled', loss_u)
        self.log('train/loss_unlabeled_weight', loss_weight)
        self.log('train/loss', loss)
        self.log('train/learning_rate', self.current_learning_rate)

        return {
            'loss': loss,
            'probs': F.softmax(outputs_t, dim=-1),
            'targets': targets_t
        }

    def training_step_end(self, outputs):
        self.training_accuracy(outputs['probs'], outputs['targets'])
        self.log('acc',
                 self.training_accuracy,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=True,
                 logger=False)
        self.log('lr',
                 self.current_learning_rate,
                 on_step=True,
                 on_epoch=False,
                 prog_bar=True,
                 logger=False)
        loss = outputs['loss'].mean()
        return loss

    def validation_step(self, batch, batch_index):
        inputs_t, targets_t = batch
        outputs_t = self.fc_labeled(self.feature_extractor(inputs_t))
        probs_t = F.softmax(outputs_t, dim=-1)
        loss_t = F.cross_entropy(outputs_t, targets_t, reduction='none')
        return probs_t, targets_t, loss_t

    def validation_step_end(self, outputs):
        outputs_t, targets_t, loss_t = outputs
        self.validation_top1_accuracy.update(outputs_t, targets_t)
        self.validation_average_loss.update(loss_t)

    def validation_epoch_end(self, outputs):
        top1 = self.validation_top1_accuracy.compute()
        loss = self.validation_average_loss.compute()
        self.validation_top1_accuracy.reset()
        self.validation_average_loss.reset()

        if self.trainer.state.stage != RunningStage.SANITY_CHECKING:
            if top1 > self.best_validation_top1_accuracy:
                self.best_validation_top1_accuracy = torch.tensor(
                    float(top1),
                    dtype=self.best_validation_top1_accuracy.dtype,
                    device=self.best_validation_top1_accuracy.device)

            self.log('val/top1', top1)
            self.log('val/loss', loss)

            if self.global_rank == 0:
                print()
                logging.info('Epoch {:03d} =====> '
                             'Valid Loss: {:.4f}  '
                             'Valid Acc: {:.4f}  [Best {:.4f}]'.format(
                                 self.trainer.current_epoch, loss, top1,
                                 self.best_validation_top1_accuracy))
コード例 #10
0
 def __init__(self, model, lr=1.0, gamma=0.7, batch_size=32):
     super().__init__()
     self.save_hyperparameters(ignore="model")
     self.model = model or Net()
     self.test_acc = Accuracy()
     self.val_acc = Accuracy()