예제 #1
0
    def _set_metrics(self):
        num_classes = self.num_classes

        # Train
        self.train_acc = torchmetrics.Accuracy()
        self.train_precision = torchmetrics.Precision()
        self.train_recall = torchmetrics.Recall()
        self.train_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.train_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None

        # Validation
        self.validation_acc = torchmetrics.Accuracy()
        self.validation_precision = torchmetrics.Precision()
        self.validation_recall = torchmetrics.Recall()
        self.validation_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.validation_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None

        # Test
        self.test_acc = torchmetrics.Accuracy()
        self.test_precision = torchmetrics.Precision()
        self.test_recall = torchmetrics.Recall()
        self.test_f1 = torchmetrics.F1(
            num_classes=num_classes) if num_classes else None
        self.test_auc = torchmetrics.AUROC(
            num_classes=num_classes) if num_classes else None
예제 #2
0
    def __init__(self, batch_size, lr_scheduler_milestones, lr_gamma, nclass, nfeatures, length, lr=1e-2, L2_reg=1e-3, top_acc=1, loss=torch.nn.CrossEntropyLoss()):
        super().__init__()

        self.batch_size = batch_size
        self.nclass = nclass
        self.nfeatures = nfeatures
        self.length = length

        self.loss = loss
        self.lr = lr
        self.lr_scheduler_milestones = lr_scheduler_milestones
        self.lr_gamma = lr_gamma
        self.L2_reg = L2_reg

        # Log hyperparams (all arguments are logged by default)
        self.save_hyperparameters(
            'length',
            'nfeatures',
            'L2_reg',
            'lr',
            'lr_gamma',
            'lr_scheduler_milestones',
            'batch_size',
            'nclass'
        )

        # Metrics to log
        if not top_acc < nclass:
            raise ValueError('`top_acc` must be strictly smaller than `nclass`.')
        self.train_acc = torchmetrics.Accuracy(top_k=top_acc)
        self.val_acc = torchmetrics.Accuracy(top_k=top_acc)
        self.train_f1 = torchmetrics.F1(nclass, average='macro')
        self.val_f1 = torchmetrics.F1(nclass, average='macro')

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)),
            nn.BatchNorm2d(20),
            nn.ReLU(True),
            nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)),
            nn.BatchNorm2d(20),
            nn.ReLU(True),
            nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)),
            nn.BatchNorm2d(20),
            nn.ReLU(True),
            nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)),
            nn.BatchNorm2d(20),
            nn.ReLU(True),
            nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,3), stride=1, padding=(1,1)),
            nn.BatchNorm2d(20),
            nn.ReLU(True),
            nn.Conv2d(in_channels=20, out_channels=nfeatures, kernel_size=(3,3), stride=1, padding=(1,1)),
            nn.BatchNorm2d(nfeatures),
            nn.ReLU(True)
        )
        self.pool = nn.AvgPool2d(kernel_size=(2, self.length))
        self.classifier = nn.Sequential(
            nn.Linear(1*nfeatures, nclass),  # 1 because global pooling reduce length of features to 1
            #nn.Softmax(1)  # Already included in nn.CrossEntropy
        )
예제 #3
0
 def __init__(self, model, lr: float = 1e-4, augmentations: Optional[nn.Module] = None):
     super().__init__()
     self.model = model
     self.arch = self.model.arch
     self.num_classes = self.model.num_classes
     self.train_accuracy = torchmetrics.Accuracy()
     self.train_f1_score = torchmetrics.F1(self.num_classes, average='weighted')
     self.val_accuracy = torchmetrics.Accuracy()
     self.val_f1_score = torchmetrics.F1(self.num_classes, average='weighted')
     self.learn_rate = lr
     self.aug = augmentations
예제 #4
0
파일: model.py 프로젝트: Dehde/mrnet
    def __init__(self):
        super().__init__()
        self.pretrained_model = models.alexnet(pretrained=True)
        self.pooling_layer = nn.AdaptiveAvgPool2d(1)
        self.classifer = nn.Linear(256, 3)
        self.sigmoid = torch.sigmoid
        #self.save_hyperparameters()

        self.train_f1 = torchmetrics.F1(num_classes=3)
        self.valid_f1 = torchmetrics.F1(num_classes=3)
        self.train_auc = torchmetrics.AUROC(num_classes=3,
                                            compute_on_step=False)
        self.valid_auc = torchmetrics.AUROC(num_classes=3,
                                            compute_on_step=False)
예제 #5
0
 def __init__(self, cfg):
     super(LitPlantModule2, self).__init__()
     self.num_classes = cfg.num_classes
     assert cfg.model_type in cfg.supported_model_type
     self.cfg = cfg
     self.model = eval(cfg.model_type)(cfg)
     self.acc = torchmetrics.Accuracy()
     self.f1_5 = torchmetrics.F1(num_classes=cfg.num_classes,
                                 average='weighted')
     self.f1_1 = torchmetrics.F1(num_classes=2)
     self.lr = cfg.lr
     self.loss1 = AsymmetricLoss()
     self.loss2 = AsymmetricLoss()
     self.smooth = cfg.smooth
예제 #6
0
    def __init__(self, model_name_or_path: str, num_labels: int,
                 learning_rate: float, adam_epsilon: float,
                 weight_decay: float, max_len: int, warmup_steps: int,
                 gpus: int, max_epochs: int, accumulate_grad_batches: int):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.num_labels = num_labels

        self.save_hyperparameters('learning_rate', 'adam_epsilon',
                                  'weight_decay', 'max_len', 'gpus',
                                  'accumulate_grad_batches', 'max_epochs',
                                  'warmup_steps')

        self.config = transformers.AutoConfig.from_pretrained(
            model_name_or_path, num_labels=self.num_labels)
        self.model = transformers.AutoModelForSequenceClassification.from_pretrained(
            model_name_or_path, config=self.config)
        # self.model = nn.Sequential(
        #     OrderedDict(
        #         [
        #          ('base',transformers.AutoModel.from_pretrained(model_name_or_path)),
        #          ('classifier',nn.Linear(in_features=768,out_features=self.num_labels)),
        #          ('softmax',nn.Softmax())
        #         ]
        #     )
        # )
        metrics = torchmetrics.MetricCollection([
            torchmetrics.Accuracy(),
            torchmetrics.F1(num_classes=3, average='macro')
        ])
        self.train_metrics = metrics.clone()
        self.val_metrics = metrics.clone()
def validation(epoch, model, dataloader, criterion, device):

    running_loss = 0.0
    num_inputs = 0

    model.eval()

    metric = torchmetrics.F1(num_classes=6, threshold=0.5, average='samples')
    metric = metric.to(device)

    pbar = tqdm(dataloader)
    for idx, (inputs, labels) in enumerate(pbar):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        score = metric(torch.sigmoid(outputs), labels)

        num_inputs += inputs.size(0)
        running_loss += loss.item() * inputs.size(0)

        pbar.set_description(
            "[{:02d} epoch][Valid] Loss: {:.6f} F1 Score: {:.5f}".format(
                epoch, running_loss / num_inputs, metric.compute()))

    epoch_loss = running_loss / num_inputs
    epoch_score = metric.compute().item()

    return epoch_loss, epoch_score
예제 #8
0
def get_metrics_collections_base(NUM_CLASS, prefix):

    metrics = MetricCollection(
        {
            "Accuracy": Accuracy(),
            "Top_3": Accuracy(top_k=3),
            "Top_5": Accuracy(top_k=5),
            "Precision_micro": Precision(num_classes=NUM_CLASS,
                                         average="micro"),
            "Precision_macro": Precision(num_classes=NUM_CLASS,
                                         average="macro"),
            "Recall_micro": Recall(num_classes=NUM_CLASS, average="micro"),
            "Recall_macro": Recall(num_classes=NUM_CLASS, average="macro"),
            "F1_micro": torchmetrics.F1(NUM_CLASS, average="micro"),
            "F1_macro": torchmetrics.F1(NUM_CLASS, average="micro"),
        },
        prefix=prefix)

    return metrics
예제 #9
0
def get_metrics_collections_base(NUM_CLASS,
                            # device="cuda" if torch.cuda.is_available() else "cpu",
                            
                            ):
    
    metrics = MetricCollection(
            {
                "Accuracy":Accuracy(),
                "Top_3":Accuracy(top_k=3),
                "Top_5" :Accuracy(top_k=5),
                "Precision_micro":Precision(num_classes=NUM_CLASS,average="micro"),
                "Precision_macro":Precision(num_classes=NUM_CLASS,average="macro"),
                "Recall_micro":Recall(num_classes=NUM_CLASS,average="micro"),
                "Recall_macro":Recall(num_classes=NUM_CLASS,average="macro"),
                "F1_micro":torchmetrics.F1(NUM_CLASS,average="micro"),
                "F1_macro":torchmetrics.F1(NUM_CLASS,average="micro"),
            }
            )
    
    
    return metrics
예제 #10
0
def valid_epoch(model, valid_loader, criterion, epoch):
    model.eval()

    total_loss = AverageMeter()
    
    manual_top1 = AverageMeter()
    manual_top5 = AverageMeter()
    torch_top1 = torchmetrics.Accuracy()
    torch_top5 = torchmetrics.Accuracy(top_k=5)
    torch_f1 = torchmetrics.F1(num_classes=312)

    with torch.no_grad():
        for batch in tqdm(valid_loader):
            images = batch["image"].to(device)
            elas = batch["ela"].to(device)
            target_labels = batch["label"].to(device)
            
            out_logits, _ = model(images, elas)

            loss = criterion(out_logits, target_labels)
            
            #---------------------Batch Loss Update-------------------------
            total_loss.update(loss.item(), valid_loader.batch_size)
                    
            # Metric
            with torch.no_grad():
                out_logits = out_logits.cpu().detach()
                target_labels = target_labels.cpu().detach()

                topk = topk_accuracy(out_logits, target_labels, topk=(1,5))
                manual_top1.update(topk[0].item(), valid_loader.batch_size)
                manual_top5.update(topk[1].item(), valid_loader.batch_size)

                torch_top1.update(torch.softmax(out_logits, dim=-1), target_labels)
                torch_top5.update(torch.softmax(out_logits, dim=-1), target_labels)
                torch_f1.update(torch.softmax(out_logits, dim=-1), target_labels)


    valid_metrics = {
        "valid_loss": total_loss.avg,
        "valid_acc1_manual": manual_top1.avg,
        "valid_acc5_manual": manual_top5.avg,
        "valid_acc1_torch": torch_top1.compute().item(),
        "valid_acc_5_torch": torch_top5.compute().item(),
        "valid_f1": torch_f1.compute().item(),
        "epoch": epoch
    }
    wandb.log(valid_metrics)

    return valid_metrics
예제 #11
0
 def __init__(
     self,
     model,
     lr=0.005,
     timestamps=True,
     num_classes=7,
     class_weights=None,
     log_to="training.csv",
     eps=1e-8,
 ):
     super().__init__()
     self.model = model
     self.lr = lr
     self.timestamps = timestamps
     self.num_classes = num_classes
     self.class_weights = class_weights
     self.accuracy = torchmetrics.Accuracy()
     self.f1 = torchmetrics.F1(num_classes, average="weighted")
예제 #12
0
    def __init__(self, hparams: Namespace) -> None:
        super(Classifier, self).__init__()

        self.hparams = hparams
        self.batch_size = hparams.batch_size

        # Build Data module
        self.data = self.DataModule(self)

        # build model
        self.__build_model()

        # Loss criterion initialization.
        self.__build_loss()

        if hparams.nr_frozen_epochs > 0:
            self.freeze_encoder()
        else:
            self._frozen = False
        self.nr_frozen_epochs = hparams.nr_frozen_epochs

        self.test_conf_matrices = []

        # Set up multi label binarizer:
        self.mlb = MultiLabelBinarizer()
        self.mlb.fit([self.hparams.top_codes])

        self.acc = torchmetrics.Accuracy()
        self.f1 = torchmetrics.F1(num_classes=self.hparams.n_labels,
                                  average='micro')
        self.auroc = torchmetrics.AUROC(num_classes=self.hparams.n_labels,
                                        average='weighted')
        # NOTE could try 'global' instead of samplewise for mdmc reduce
        self.prec = torchmetrics.Precision(num_classes=self.hparams.n_labels,
                                           is_multiclass=False)
        self.recall = torchmetrics.Recall(num_classes=self.hparams.n_labels,
                                          is_multiclass=False)
        self.confusion_matrix = torchmetrics.ConfusionMatrix(
            num_classes=self.hparams.n_labels)

        self.test_predictions = None
        self.test_labels = None
def train(epoch, model, dataloader, criterion, optimizer, device):

    running_loss = 0.0
    num_inputs = 0

    model.train()

    metric = torchmetrics.F1(num_classes=6, threshold=0.5, average='samples')
    metric = metric.to(device)

    pbar = tqdm(dataloader)
    for idx, (inputs, labels) in enumerate(pbar):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        score = metric(torch.sigmoid(outputs), labels)

        if APEX:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        optimizer.step()

        num_inputs += inputs.size(0)
        running_loss += loss.item() * inputs.size(0)

        pbar.set_description(
            "[{:02d} epoch][Train] Loss: {:.6f} F1 Score: {:.5f}".format(
                epoch, running_loss / num_inputs, metric.compute()))

    epoch_loss = running_loss / num_inputs
    epoch_score = metric.compute().item()

    return epoch_loss, epoch_score
예제 #14
0
    def __init__(
        self,
        *args,
        num_classes: Optional[int] = None,
        loss_fn: Optional[Callable] = None,
        metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None,
        multi_label: bool = False,
        serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None,
        **kwargs,
    ) -> None:
        if metrics is None:
            metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy()

        if loss_fn is None:
            loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy
        super().__init__(
            *args,
            loss_fn=loss_fn,
            metrics=metrics,
            serializer=serializer or Classes(multi_label=multi_label),
            **kwargs,
        )
예제 #15
0
    def __init__(self, settings, checkpoint_path, train_dataset, test_dataset):
        self.settings = settings
        self.checkpoint_path = checkpoint_path
        self.learning_rate = self.settings.learning_rate
        self.epochs = self.settings.epochs
        self.summary_writer = None
        if self.settings.write_statistics:
            self.summary_writer = SummaryWriter(log_dir=os.path.join(checkpoint_path, 'runs'))
        self.optimizer = None
        # self.scheduler = None
        self.f1 = 0
        self.global_train_index = 0
        self.last_image = None
        self.last_prob_map = None
        self.last_labels = None
        self.last_warped_image = None
        self.last_warped_prob_map = None
        self.last_warped_labels = None
        self.last_valid_mask = None
        print(f'Trainer is initialized with batch size = {self.settings.batch_size}')
        print(f'Gradient accumulation batch size divider = {self.settings.batch_size_divider}')
        print(f'Automatic Mixed Precision = {self.settings.use_amp}')

        batch_size = self.settings.batch_size // self.settings.batch_size_divider
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset

        self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True,
                                           num_workers=self.settings.data_loader_num_workers)
        self.test_dataloader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=True,
                                          num_workers=self.settings.data_loader_num_workers)

        self.scaler = torch.cuda.amp.GradScaler()
        self.model = None
        self.softmax = torch.nn.Softmax(dim=1)
        self.f1_metric = torchmetrics.F1(num_classes=65, mdmc_average='samplewise')
예제 #16
0
 def __init__(self, model, lr: float = 1e-4, augmentations: Optional[nn.Module] = None):
     super().__init__(model, lr, augmentations)
     self.val_f1_score = torchmetrics.F1(self.num_classes, multilabel=True, average='weighted')
     self.loss = nn.BCEWithLogitsLoss()
예제 #17
0

input_size = 784
hidden_size = 32
num_classes = 1
model = MnistModel(input_size, hidden_size, num_classes)

optim = torch.optim.Adam(model.parameters(), 0.0001)

callbacks = []

metrics = {
    "acc": tm.Accuracy(),
    'precision': tm.Precision(),
    'recall': tm.Recall(),
    'f1': tm.F1(),
    # 'ss': tm.StatScores(),
}

model.compile(loss=binary_cross_entropy_weighted_focal_loss,
              optimizer=optim,
              metrics=metrics)

trainer = pl.Trainer(logger=False, max_epochs=5, callbacks=callbacks)
trainer.fit(model, train_loader, val_loader)

print(model.get_history())

df = pd.DataFrame(model.get_history())
df.to_csv('pretrained.csv', index=False)
예제 #18
0
    def __init__(self,
                 model_type,
                 num_classes,
                 optimizer,
                 scheduler,
                 classes_weights,
                 learning_rate=0.0001):
        super().__init__()

        # log hyperparameters
        self.save_hyperparameters()

        self.learning_rate = learning_rate
        self.num_classes = num_classes
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.model_type = model_type

        self.optimizers = optimizer
        self.schedulers = scheduler

        # load network
        if self.model_type in [
                'densenet121',  # classifier
                'densenet161',
                'densenet169',
                'densenet201',
                'densenetblur121d',
                'dpn68',
                'dpn68b',
                'dpn92',
                'dpn98',
                'dpn107',
                'dpn131',
                'efficientnet_b0',
                'efficientnet_b1',
                'efficientnet_b1_pruned',
                'efficientnet_b2',
                'efficientnet_b2a',
                'efficientnet_b3',
                'efficientnet_b3_pruned',
                'efficientnet_b3a',
                'efficientnet_em',
                'efficientnet_es',
                'efficientnet_lite0',
                'fbnetc_100',
                'hrnet_w18',
                'hrnet_w18_small',
                'hrnet_w18_small_v2',
                'hrnet_w30',
                'hrnet_w32',
                'hrnet_w40',
                'hrnet_w44',
                'hrnet_w48',
                'hrnet_w64',
                'mixnet_l',
                'mixnet_m',
                'mixnet_s',
                'mixnet_xl',
                'mnasnet_100',
                'mobilenetv2_100',
                'mobilenetv2_110d',
                'mobilenetv2_120d',
                'mobilenetv2_140',
                'mobilenetv3_large_100',
                'mobilenetv3_rw',
                'semnasnet_100',
                'spnasnet_100',
                'tf_efficientnet_b0',
                'tf_efficientnet_b0_ap',
                'tf_efficientnet_b0_ns',
                'tf_efficientnet_b1',
                'tf_efficientnet_b1_ap',
                'tf_efficientnet_b1_ns',
                'tf_efficientnet_b2',
                'tf_efficientnet_b2_ap',
                'tf_efficientnet_b2_ns',
                'tf_efficientnet_b3',
                'tf_efficientnet_b3_ap',
                'tf_efficientnet_b3_ns',
                'tf_efficientnet_b4',
                'tf_efficientnet_b4_ap',
                'tf_efficientnet_b4_ns',
                'tf_efficientnet_b5',
                'tf_efficientnet_b5_ap',
                'tf_efficientnet_b5_ns',
                'tf_efficientnet_b6',
                'tf_efficientnet_b6_ap',
                'tf_efficientnet_b6_ns',
                'tf_efficientnet_b7',
                'tf_efficientnet_b7_ap',
                'tf_efficientnet_b7_ns',
                'tf_efficientnet_b8',
                'tf_efficientnet_b8_ap',
                'tf_efficientnet_cc_b0_4e',
                'tf_efficientnet_cc_b0_8e',
                'tf_efficientnet_cc_b1_8e',
                'tf_efficientnet_el',
                'tf_efficientnet_em',
                'tf_efficientnet_es',
                'tf_efficientnet_l2_ns',
                'tf_efficientnet_l2_ns_475',
                'tf_efficientnet_lite0',
                'tf_efficientnet_lite1',
                'tf_efficientnet_lite2',
                'tf_efficientnet_lite3',
                'tf_efficientnet_lite4',
                'tf_mixnet_l',
                'tf_mixnet_m',
                'tf_mixnet_s',
                'tf_mobilenetv3_large_075',
                'tf_mobilenetv3_large_100',
                'tf_mobilenetv3_large_minimal_100',
                'tf_mobilenetv3_small_075',
                'tf_mobilenetv3_small_100',
                'tf_mobilenetv3_small_minimal_100',
                'tv_densenet121',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.classifier.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model

        elif self.model_type in [
                'adv_inception_v3',  # fc
                'dla34',
                'dla46_c',
                'dla46x_c',
                'dla60',
                'dla60_res2net',
                'dla60_res2next',
                'dla60x',
                'dla60x_c',
                'dla102',
                'dla102x',
                'dla102x2',
                'dla169',
                'ecaresnet26t',
                'ecaresnet50d',
                'ecaresnet50d_pruned',
                'ecaresnet50t',
                'ecaresnet101d',
                'ecaresnet101d_pruned',
                'ecaresnet269d',
                'ecaresnetlight',
                'gluon_inception_v3',
                'gluon_resnet18_v1b',
                'gluon_resnet34_v1b',
                'gluon_resnet50_v1b',
                'gluon_resnet50_v1c',
                'gluon_resnet50_v1d',
                'gluon_resnet50_v1s',
                'gluon_resnet101_v1b',
                'gluon_resnet101_v1c',
                'gluon_resnet101_v1d',
                'gluon_resnet101_v1s',
                'gluon_resnet152_v1b',
                'gluon_resnet152_v1c',
                'gluon_resnet152_v1d',
                'gluon_resnet152_v1s',
                'gluon_resnext50_32x4d',
                'gluon_resnext101_32x4d',
                'gluon_resnext101_64x4d',
                'gluon_senet154',
                'gluon_seresnext50_32x4d',
                'gluon_seresnext101_32x4d',
                'gluon_seresnext101_64x4d',
                'gluon_xception65',
                'ig_resnext101_32x8d',
                'ig_resnext101_32x16d',
                'ig_resnext101_32x32d',
                'ig_resnext101_32x48d',
                'inception_v3',
                'res2net50_14w_8s',
                'res2net50_26w_4s',
                'res2net50_26w_6s',
                'res2net50_26w_8s',
                'res2net50_48w_2s',
                'res2net101_26w_4s',
                'res2next50',
                'resnest14d',
                'resnest26d',
                'resnest50d',
                'resnest50d_1s4x24d',
                'resnest50d_4s2x40d',
                'resnest101e',
                'resnest200e',
                'resnest269e',
                'resnet18',
                'resnet18d',
                'resnet26',
                'resnet26d',
                'resnet34',
                'resnet34d',
                'resnet50',
                'resnet50d',
                'resnet101d',
                'resnet152d',
                'resnet200d',
                'resnetblur50',
                'resnext50_32x4d',
                'resnext50d_32x4d',
                'resnext101_32x8d',
                'selecsls42b',
                'selecsls60',
                'selecsls60b',
                'seresnet50',
                'seresnet152d',
                'seresnext26d_32x4d',
                'seresnext26t_32x4d',
                'seresnext50_32x4d',
                'skresnet18',
                'skresnet34',
                'skresnext50_32x4d',
                'ssl_resnet18',
                'ssl_resnet50',
                'ssl_resnext50_32x4d',
                'ssl_resnext101_32x4d',
                'ssl_resnext101_32x8d',
                'ssl_resnext101_32x16d',
                'swsl_resnet18',
                'swsl_resnet50',
                'swsl_resnext50_32x4d',
                'swsl_resnext101_32x4d',
                'swsl_resnext101_32x8d',
                'swsl_resnext101_32x16d',
                'tf_inception_v3',
                'tv_resnet34',
                'tv_resnet50',
                'tv_resnet101',
                'tv_resnet152',
                'tv_resnext50_32x4d',
                'wide_resnet50_2',
                'wide_resnet101_2',
                'xception',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.fc.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model
        elif self.model_type in [
                'cspdarknet53',  # head.fc
                'cspresnet50',
                'cspresnext50',
                'dm_nfnet_f0',
                'dm_nfnet_f1',
                'dm_nfnet_f2',
                'dm_nfnet_f3',
                'dm_nfnet_f4',
                'dm_nfnet_f5',
                'dm_nfnet_f6',
                'ese_vovnet19b_dw',
                'ese_vovnet39b',
                'gernet_l',
                'gernet_m',
                'gernet_s',
                'nf_regnet_b1',
                'nf_resnet50',
                'nfnet_l0c',
                'regnetx_002',
                'regnetx_004',
                'regnetx_006',
                'regnetx_008',
                'regnetx_016',
                'regnetx_032',
                'regnetx_040',
                'regnetx_064',
                'regnetx_080',
                'regnetx_120',
                'regnetx_160',
                'regnetx_320',
                'regnety_002',
                'regnety_004',
                'regnety_006',
                'regnety_008',
                'regnety_016',
                'regnety_032',
                'regnety_040',
                'regnety_064',
                'regnety_080',
                'regnety_120',
                'regnety_160',
                'regnety_320',
                'repvgg_a2',
                'repvgg_b0',
                'repvgg_b1',
                'repvgg_b1g4',
                'repvgg_b2',
                'repvgg_b2g4',
                'repvgg_b3',
                'repvgg_b3g4',
                'resnetv2_50x1_bitm',
                'resnetv2_50x1_bitm_in21k',
                'resnetv2_50x3_bitm',
                'resnetv2_50x3_bitm_in21k',
                'resnetv2_101x1_bitm',
                'resnetv2_101x1_bitm_in21k',
                'resnetv2_101x3_bitm',
                'resnetv2_101x3_bitm_in21k',
                'resnetv2_152x2_bitm',
                'resnetv2_152x2_bitm_in21k',
                'resnetv2_152x4_bitm',
                'resnetv2_152x4_bitm_in21k',
                'rexnet_100',
                'rexnet_130',
                'rexnet_150',
                'rexnet_200',
                'tresnet_l',
                'tresnet_l_448',
                'tresnet_m',
                'tresnet_m_448',
                'tresnet_xl',
                'tresnet_xl_448',
                'vgg11',
                'vgg11_bn',
                'vgg13',
                'vgg13_bn',
                'vgg16',
                'vgg16_bn',
                'vgg19',
                'vgg19_bn',
                'xception41',
                'xception65',
                'xception71',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.head.fc.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model
        elif self.model_type in [
                'ens_adv_inception_resnet_v2',  # classif
                'inception_resnet_v2',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.classif.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model
        elif self.model_type in [
                'inception_v4',  # last_linear
                'legacy_senet154',
                'legacy_seresnet18',
                'legacy_seresnet34',
                'legacy_seresnet50',
                'legacy_seresnet101',
                'legacy_seresnet152',
                'legacy_seresnext26_32x4d',
                'legacy_seresnext50_32x4d',
                'legacy_seresnext101_32x4d',
                'nasnetalarge',
                'pnasnet5large',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.last_linear.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model
        elif self.model_type in [
                'vit_base_patch16_224',  # head
                'vit_base_patch16_224_in21k',
                'vit_base_patch16_384',
                'vit_base_patch32_224_in21k',
                'vit_base_patch32_384',
                'vit_base_resnet50_224_in21k',
                'vit_base_resnet50_384',
                'vit_deit_base_distilled_patch16_224',
                'vit_deit_base_distilled_patch16_384',
                'vit_deit_base_patch16_224',
                'vit_deit_base_patch16_384',
                'vit_deit_small_distilled_patch16_224',
                'vit_deit_small_patch16_224',
                'vit_deit_tiny_distilled_patch16_224',
                'vit_deit_tiny_patch16_224',
                'vit_large_patch16_224',
                'vit_large_patch16_224_in21k',
                'vit_large_patch16_384',
                'vit_large_patch32_224_in21k',
                'vit_large_patch32_384',
                'vit_small_patch16_224',
        ]:
            model = timm.create_model(model_type, pretrained=True)
            in_features = model.head.in_features
            model.classifier = nn.Linear(in_features, self.num_classes)
            self.model = model
        elif self.model_type in ['senet154']:
            model = pretrainedmodels.__dict__[model_type](
                num_classes=1000, pretrained='imagenet')
            model.eval()
            num_features = model.last_linear.in_features
            # Заменяем Fully-Connected слой на наш линейный классификатор
            model.last_linear = nn.Linear(num_features, self.num_classes)
            self.model = model
        else:
            assert (
                False
            ), f"model_type '{self.model_type}' not implemented. Please, choose from {MODELS}"

        self.classes_weights = classes_weights
        self.classes_weights = torch.FloatTensor(self.classes_weights).cuda()
        self.loss_func = nn.CrossEntropyLoss(weight=self.classes_weights)

        # self.loss_func = nn.CrossEntropyLoss(weight=self.classes_weigts)
        self.f1 = torchmetrics.F1(num_classes=self.num_classes)
예제 #19
0
파일: badgan.py 프로젝트: inigoval/badGAN
 def __init__(self):
     super().__init__()
     self.G = gen()
     self.D = disc()
     self.acc = torchmetrics.Accuracy()
     self.f1 = torchmetrics.F1(num_classes=2)
예제 #20
0
def train_epoch(model, train_loader, optimizer, criterion, epoch):
    model.train()

    total_loss = AverageMeter()
    manual_top1 = AverageMeter()
    manual_top5 = AverageMeter()
    torch_top1 = torchmetrics.Accuracy()
    torch_top5 = torchmetrics.Accuracy(top_k=5)
    torch_f1 = torchmetrics.F1(num_classes=312)

    for batch in tqdm(train_loader):
        images = batch["image"].to(device)
        elas = batch["ela"].to(device)
        target_labels = batch["label"].to(device)
        
        optimizer.zero_grad()
        
        out_logits, _ = model(images, elas)

        loss = criterion(out_logits, target_labels)
        
        loss.backward()
        optimizer.step()

        ############## SRM Step ###########
        bayer_mask = torch.zeros(3,3,5,5).cuda()
        bayer_mask[:,:,5//2, 5//2] = 1
        bayer_weight = model.module.bayer_conv.weight * (1-bayer_mask)
        bayer_weight = (bayer_weight / torch.sum(bayer_weight, dim=(2,3), keepdim=True)) + 1e-7
        bayer_weight -= bayer_mask
        model.module.bayer_conv.weight = nn.Parameter(bayer_weight)
        ###################################

        #---------------------Batch Loss Update-------------------------
        total_loss.update(loss.item(), train_loader.batch_size)
        
        # Metric
        with torch.no_grad():
            out_logits = out_logits.cpu().detach()
            target_labels = target_labels.cpu().detach()

            topk = topk_accuracy(out_logits, target_labels, topk=(1,5))
            manual_top1.update(topk[0].item(), train_loader.batch_size)
            manual_top5.update(topk[1].item(), train_loader.batch_size)

            torch_top1.update(torch.softmax(out_logits, dim=-1), target_labels)
            torch_top5.update(torch.softmax(out_logits, dim=-1), target_labels)
            torch_f1.update(torch.softmax(out_logits, dim=-1), target_labels)
    
        
    train_metrics = {
        "train_loss" : total_loss.avg,
        "train_acc1_manual": manual_top1.avg,
        "train_acc5_manual": manual_top5.avg,
        "train_acc1_torch": torch_top1.compute().item(),
        "train_acc_5_torch": torch_top5.compute().item(),
        "train_f1": torch_f1.compute().item(),
        "epoch" : epoch
    }
    wandb.log(train_metrics)

    return train_metrics