def _set_metrics(self): num_classes = self.num_classes # Train self.train_acc = torchmetrics.Accuracy() self.train_precision = torchmetrics.Precision() self.train_recall = torchmetrics.Recall() self.train_f1 = torchmetrics.F1( num_classes=num_classes) if num_classes else None self.train_auc = torchmetrics.AUROC( num_classes=num_classes) if num_classes else None # Validation self.validation_acc = torchmetrics.Accuracy() self.validation_precision = torchmetrics.Precision() self.validation_recall = torchmetrics.Recall() self.validation_f1 = torchmetrics.F1( num_classes=num_classes) if num_classes else None self.validation_auc = torchmetrics.AUROC( num_classes=num_classes) if num_classes else None # Test self.test_acc = torchmetrics.Accuracy() self.test_precision = torchmetrics.Precision() self.test_recall = torchmetrics.Recall() self.test_f1 = torchmetrics.F1( num_classes=num_classes) if num_classes else None self.test_auc = torchmetrics.AUROC( num_classes=num_classes) if num_classes else None
def __init__(self, batch_size, lr_scheduler_milestones, lr_gamma, nclass, nfeatures, length, lr=1e-2, L2_reg=1e-3, top_acc=1, loss=torch.nn.CrossEntropyLoss()): super().__init__() self.batch_size = batch_size self.nclass = nclass self.nfeatures = nfeatures self.length = length self.loss = loss self.lr = lr self.lr_scheduler_milestones = lr_scheduler_milestones self.lr_gamma = lr_gamma self.L2_reg = L2_reg # Log hyperparams (all arguments are logged by default) self.save_hyperparameters( 'length', 'nfeatures', 'L2_reg', 'lr', 'lr_gamma', 'lr_scheduler_milestones', 'batch_size', 'nclass' ) # Metrics to log if not top_acc < nclass: raise ValueError('`top_acc` must be strictly smaller than `nclass`.') self.train_acc = torchmetrics.Accuracy(top_k=top_acc) self.val_acc = torchmetrics.Accuracy(top_k=top_acc) self.train_f1 = torchmetrics.F1(nclass, average='macro') self.val_f1 = torchmetrics.F1(nclass, average='macro') self.features = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)), nn.BatchNorm2d(20), nn.ReLU(True), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)), nn.BatchNorm2d(20), nn.ReLU(True), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)), nn.BatchNorm2d(20), nn.ReLU(True), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,5), stride=1, padding=(1,2)), nn.BatchNorm2d(20), nn.ReLU(True), nn.Conv2d(in_channels=20, out_channels=20, kernel_size=(3,3), stride=1, padding=(1,1)), nn.BatchNorm2d(20), nn.ReLU(True), nn.Conv2d(in_channels=20, out_channels=nfeatures, kernel_size=(3,3), stride=1, padding=(1,1)), nn.BatchNorm2d(nfeatures), nn.ReLU(True) ) self.pool = nn.AvgPool2d(kernel_size=(2, self.length)) self.classifier = nn.Sequential( nn.Linear(1*nfeatures, nclass), # 1 because global pooling reduce length of features to 1 #nn.Softmax(1) # Already included in nn.CrossEntropy )
def __init__(self, model, lr: float = 1e-4, augmentations: Optional[nn.Module] = None): super().__init__() self.model = model self.arch = self.model.arch self.num_classes = self.model.num_classes self.train_accuracy = torchmetrics.Accuracy() self.train_f1_score = torchmetrics.F1(self.num_classes, average='weighted') self.val_accuracy = torchmetrics.Accuracy() self.val_f1_score = torchmetrics.F1(self.num_classes, average='weighted') self.learn_rate = lr self.aug = augmentations
def __init__(self): super().__init__() self.pretrained_model = models.alexnet(pretrained=True) self.pooling_layer = nn.AdaptiveAvgPool2d(1) self.classifer = nn.Linear(256, 3) self.sigmoid = torch.sigmoid #self.save_hyperparameters() self.train_f1 = torchmetrics.F1(num_classes=3) self.valid_f1 = torchmetrics.F1(num_classes=3) self.train_auc = torchmetrics.AUROC(num_classes=3, compute_on_step=False) self.valid_auc = torchmetrics.AUROC(num_classes=3, compute_on_step=False)
def __init__(self, cfg): super(LitPlantModule2, self).__init__() self.num_classes = cfg.num_classes assert cfg.model_type in cfg.supported_model_type self.cfg = cfg self.model = eval(cfg.model_type)(cfg) self.acc = torchmetrics.Accuracy() self.f1_5 = torchmetrics.F1(num_classes=cfg.num_classes, average='weighted') self.f1_1 = torchmetrics.F1(num_classes=2) self.lr = cfg.lr self.loss1 = AsymmetricLoss() self.loss2 = AsymmetricLoss() self.smooth = cfg.smooth
def __init__(self, model_name_or_path: str, num_labels: int, learning_rate: float, adam_epsilon: float, weight_decay: float, max_len: int, warmup_steps: int, gpus: int, max_epochs: int, accumulate_grad_batches: int): super().__init__() self.model_name_or_path = model_name_or_path self.num_labels = num_labels self.save_hyperparameters('learning_rate', 'adam_epsilon', 'weight_decay', 'max_len', 'gpus', 'accumulate_grad_batches', 'max_epochs', 'warmup_steps') self.config = transformers.AutoConfig.from_pretrained( model_name_or_path, num_labels=self.num_labels) self.model = transformers.AutoModelForSequenceClassification.from_pretrained( model_name_or_path, config=self.config) # self.model = nn.Sequential( # OrderedDict( # [ # ('base',transformers.AutoModel.from_pretrained(model_name_or_path)), # ('classifier',nn.Linear(in_features=768,out_features=self.num_labels)), # ('softmax',nn.Softmax()) # ] # ) # ) metrics = torchmetrics.MetricCollection([ torchmetrics.Accuracy(), torchmetrics.F1(num_classes=3, average='macro') ]) self.train_metrics = metrics.clone() self.val_metrics = metrics.clone()
def validation(epoch, model, dataloader, criterion, device): running_loss = 0.0 num_inputs = 0 model.eval() metric = torchmetrics.F1(num_classes=6, threshold=0.5, average='samples') metric = metric.to(device) pbar = tqdm(dataloader) for idx, (inputs, labels) in enumerate(pbar): inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) loss = criterion(outputs, labels) score = metric(torch.sigmoid(outputs), labels) num_inputs += inputs.size(0) running_loss += loss.item() * inputs.size(0) pbar.set_description( "[{:02d} epoch][Valid] Loss: {:.6f} F1 Score: {:.5f}".format( epoch, running_loss / num_inputs, metric.compute())) epoch_loss = running_loss / num_inputs epoch_score = metric.compute().item() return epoch_loss, epoch_score
def get_metrics_collections_base(NUM_CLASS, prefix): metrics = MetricCollection( { "Accuracy": Accuracy(), "Top_3": Accuracy(top_k=3), "Top_5": Accuracy(top_k=5), "Precision_micro": Precision(num_classes=NUM_CLASS, average="micro"), "Precision_macro": Precision(num_classes=NUM_CLASS, average="macro"), "Recall_micro": Recall(num_classes=NUM_CLASS, average="micro"), "Recall_macro": Recall(num_classes=NUM_CLASS, average="macro"), "F1_micro": torchmetrics.F1(NUM_CLASS, average="micro"), "F1_macro": torchmetrics.F1(NUM_CLASS, average="micro"), }, prefix=prefix) return metrics
def get_metrics_collections_base(NUM_CLASS, # device="cuda" if torch.cuda.is_available() else "cpu", ): metrics = MetricCollection( { "Accuracy":Accuracy(), "Top_3":Accuracy(top_k=3), "Top_5" :Accuracy(top_k=5), "Precision_micro":Precision(num_classes=NUM_CLASS,average="micro"), "Precision_macro":Precision(num_classes=NUM_CLASS,average="macro"), "Recall_micro":Recall(num_classes=NUM_CLASS,average="micro"), "Recall_macro":Recall(num_classes=NUM_CLASS,average="macro"), "F1_micro":torchmetrics.F1(NUM_CLASS,average="micro"), "F1_macro":torchmetrics.F1(NUM_CLASS,average="micro"), } ) return metrics
def valid_epoch(model, valid_loader, criterion, epoch): model.eval() total_loss = AverageMeter() manual_top1 = AverageMeter() manual_top5 = AverageMeter() torch_top1 = torchmetrics.Accuracy() torch_top5 = torchmetrics.Accuracy(top_k=5) torch_f1 = torchmetrics.F1(num_classes=312) with torch.no_grad(): for batch in tqdm(valid_loader): images = batch["image"].to(device) elas = batch["ela"].to(device) target_labels = batch["label"].to(device) out_logits, _ = model(images, elas) loss = criterion(out_logits, target_labels) #---------------------Batch Loss Update------------------------- total_loss.update(loss.item(), valid_loader.batch_size) # Metric with torch.no_grad(): out_logits = out_logits.cpu().detach() target_labels = target_labels.cpu().detach() topk = topk_accuracy(out_logits, target_labels, topk=(1,5)) manual_top1.update(topk[0].item(), valid_loader.batch_size) manual_top5.update(topk[1].item(), valid_loader.batch_size) torch_top1.update(torch.softmax(out_logits, dim=-1), target_labels) torch_top5.update(torch.softmax(out_logits, dim=-1), target_labels) torch_f1.update(torch.softmax(out_logits, dim=-1), target_labels) valid_metrics = { "valid_loss": total_loss.avg, "valid_acc1_manual": manual_top1.avg, "valid_acc5_manual": manual_top5.avg, "valid_acc1_torch": torch_top1.compute().item(), "valid_acc_5_torch": torch_top5.compute().item(), "valid_f1": torch_f1.compute().item(), "epoch": epoch } wandb.log(valid_metrics) return valid_metrics
def __init__( self, model, lr=0.005, timestamps=True, num_classes=7, class_weights=None, log_to="training.csv", eps=1e-8, ): super().__init__() self.model = model self.lr = lr self.timestamps = timestamps self.num_classes = num_classes self.class_weights = class_weights self.accuracy = torchmetrics.Accuracy() self.f1 = torchmetrics.F1(num_classes, average="weighted")
def __init__(self, hparams: Namespace) -> None: super(Classifier, self).__init__() self.hparams = hparams self.batch_size = hparams.batch_size # Build Data module self.data = self.DataModule(self) # build model self.__build_model() # Loss criterion initialization. self.__build_loss() if hparams.nr_frozen_epochs > 0: self.freeze_encoder() else: self._frozen = False self.nr_frozen_epochs = hparams.nr_frozen_epochs self.test_conf_matrices = [] # Set up multi label binarizer: self.mlb = MultiLabelBinarizer() self.mlb.fit([self.hparams.top_codes]) self.acc = torchmetrics.Accuracy() self.f1 = torchmetrics.F1(num_classes=self.hparams.n_labels, average='micro') self.auroc = torchmetrics.AUROC(num_classes=self.hparams.n_labels, average='weighted') # NOTE could try 'global' instead of samplewise for mdmc reduce self.prec = torchmetrics.Precision(num_classes=self.hparams.n_labels, is_multiclass=False) self.recall = torchmetrics.Recall(num_classes=self.hparams.n_labels, is_multiclass=False) self.confusion_matrix = torchmetrics.ConfusionMatrix( num_classes=self.hparams.n_labels) self.test_predictions = None self.test_labels = None
def train(epoch, model, dataloader, criterion, optimizer, device): running_loss = 0.0 num_inputs = 0 model.train() metric = torchmetrics.F1(num_classes=6, threshold=0.5, average='samples') metric = metric.to(device) pbar = tqdm(dataloader) for idx, (inputs, labels) in enumerate(pbar): inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) score = metric(torch.sigmoid(outputs), labels) if APEX: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() num_inputs += inputs.size(0) running_loss += loss.item() * inputs.size(0) pbar.set_description( "[{:02d} epoch][Train] Loss: {:.6f} F1 Score: {:.5f}".format( epoch, running_loss / num_inputs, metric.compute())) epoch_loss = running_loss / num_inputs epoch_score = metric.compute().item() return epoch_loss, epoch_score
def __init__( self, *args, num_classes: Optional[int] = None, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, **kwargs, ) -> None: if metrics is None: metrics = torchmetrics.F1(num_classes) if (multi_label and num_classes) else torchmetrics.Accuracy() if loss_fn is None: loss_fn = binary_cross_entropy_with_logits if multi_label else F.cross_entropy super().__init__( *args, loss_fn=loss_fn, metrics=metrics, serializer=serializer or Classes(multi_label=multi_label), **kwargs, )
def __init__(self, settings, checkpoint_path, train_dataset, test_dataset): self.settings = settings self.checkpoint_path = checkpoint_path self.learning_rate = self.settings.learning_rate self.epochs = self.settings.epochs self.summary_writer = None if self.settings.write_statistics: self.summary_writer = SummaryWriter(log_dir=os.path.join(checkpoint_path, 'runs')) self.optimizer = None # self.scheduler = None self.f1 = 0 self.global_train_index = 0 self.last_image = None self.last_prob_map = None self.last_labels = None self.last_warped_image = None self.last_warped_prob_map = None self.last_warped_labels = None self.last_valid_mask = None print(f'Trainer is initialized with batch size = {self.settings.batch_size}') print(f'Gradient accumulation batch size divider = {self.settings.batch_size_divider}') print(f'Automatic Mixed Precision = {self.settings.use_amp}') batch_size = self.settings.batch_size // self.settings.batch_size_divider self.train_dataset = train_dataset self.test_dataset = test_dataset self.train_dataloader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, num_workers=self.settings.data_loader_num_workers) self.test_dataloader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=True, num_workers=self.settings.data_loader_num_workers) self.scaler = torch.cuda.amp.GradScaler() self.model = None self.softmax = torch.nn.Softmax(dim=1) self.f1_metric = torchmetrics.F1(num_classes=65, mdmc_average='samplewise')
def __init__(self, model, lr: float = 1e-4, augmentations: Optional[nn.Module] = None): super().__init__(model, lr, augmentations) self.val_f1_score = torchmetrics.F1(self.num_classes, multilabel=True, average='weighted') self.loss = nn.BCEWithLogitsLoss()
input_size = 784 hidden_size = 32 num_classes = 1 model = MnistModel(input_size, hidden_size, num_classes) optim = torch.optim.Adam(model.parameters(), 0.0001) callbacks = [] metrics = { "acc": tm.Accuracy(), 'precision': tm.Precision(), 'recall': tm.Recall(), 'f1': tm.F1(), # 'ss': tm.StatScores(), } model.compile(loss=binary_cross_entropy_weighted_focal_loss, optimizer=optim, metrics=metrics) trainer = pl.Trainer(logger=False, max_epochs=5, callbacks=callbacks) trainer.fit(model, train_loader, val_loader) print(model.get_history()) df = pd.DataFrame(model.get_history()) df.to_csv('pretrained.csv', index=False)
def __init__(self, model_type, num_classes, optimizer, scheduler, classes_weights, learning_rate=0.0001): super().__init__() # log hyperparameters self.save_hyperparameters() self.learning_rate = learning_rate self.num_classes = num_classes self.optimizer = optimizer self.scheduler = scheduler self.model_type = model_type self.optimizers = optimizer self.schedulers = scheduler # load network if self.model_type in [ 'densenet121', # classifier 'densenet161', 'densenet169', 'densenet201', 'densenetblur121d', 'dpn68', 'dpn68b', 'dpn92', 'dpn98', 'dpn107', 'dpn131', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b1_pruned', 'efficientnet_b2', 'efficientnet_b2a', 'efficientnet_b3', 'efficientnet_b3_pruned', 'efficientnet_b3a', 'efficientnet_em', 'efficientnet_es', 'efficientnet_lite0', 'fbnetc_100', 'hrnet_w18', 'hrnet_w18_small', 'hrnet_w18_small_v2', 'hrnet_w30', 'hrnet_w32', 'hrnet_w40', 'hrnet_w44', 'hrnet_w48', 'hrnet_w64', 'mixnet_l', 'mixnet_m', 'mixnet_s', 'mixnet_xl', 'mnasnet_100', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_100', 'mobilenetv3_rw', 'semnasnet_100', 'spnasnet_100', 'tf_efficientnet_b0', 'tf_efficientnet_b0_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b2_ns', 'tf_efficientnet_b3', 'tf_efficientnet_b3_ap', 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4', 'tf_efficientnet_b4_ap', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b6_ns', 'tf_efficientnet_b7', 'tf_efficientnet_b7_ap', 'tf_efficientnet_b7_ns', 'tf_efficientnet_b8', 'tf_efficientnet_b8_ap', 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', 'tf_efficientnet_el', 'tf_efficientnet_em', 'tf_efficientnet_es', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', 'tf_efficientnet_lite4', 'tf_mixnet_l', 'tf_mixnet_m', 'tf_mixnet_s', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100', 'tv_densenet121', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.classifier.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in [ 'adv_inception_v3', # fc 'dla34', 'dla46_c', 'dla46x_c', 'dla60', 'dla60_res2net', 'dla60_res2next', 'dla60x', 'dla60x_c', 'dla102', 'dla102x', 'dla102x2', 'dla169', 'ecaresnet26t', 'ecaresnet50d', 'ecaresnet50d_pruned', 'ecaresnet50t', 'ecaresnet101d', 'ecaresnet101d_pruned', 'ecaresnet269d', 'ecaresnetlight', 'gluon_inception_v3', 'gluon_resnet18_v1b', 'gluon_resnet34_v1b', 'gluon_resnet50_v1b', 'gluon_resnet50_v1c', 'gluon_resnet50_v1d', 'gluon_resnet50_v1s', 'gluon_resnet101_v1b', 'gluon_resnet101_v1c', 'gluon_resnet101_v1d', 'gluon_resnet101_v1s', 'gluon_resnet152_v1b', 'gluon_resnet152_v1c', 'gluon_resnet152_v1d', 'gluon_resnet152_v1s', 'gluon_resnext50_32x4d', 'gluon_resnext101_32x4d', 'gluon_resnext101_64x4d', 'gluon_senet154', 'gluon_seresnext50_32x4d', 'gluon_seresnext101_32x4d', 'gluon_seresnext101_64x4d', 'gluon_xception65', 'ig_resnext101_32x8d', 'ig_resnext101_32x16d', 'ig_resnext101_32x32d', 'ig_resnext101_32x48d', 'inception_v3', 'res2net50_14w_8s', 'res2net50_26w_4s', 'res2net50_26w_6s', 'res2net50_26w_8s', 'res2net50_48w_2s', 'res2net101_26w_4s', 'res2next50', 'resnest14d', 'resnest26d', 'resnest50d', 'resnest50d_1s4x24d', 'resnest50d_4s2x40d', 'resnest101e', 'resnest200e', 'resnest269e', 'resnet18', 'resnet18d', 'resnet26', 'resnet26d', 'resnet34', 'resnet34d', 'resnet50', 'resnet50d', 'resnet101d', 'resnet152d', 'resnet200d', 'resnetblur50', 'resnext50_32x4d', 'resnext50d_32x4d', 'resnext101_32x8d', 'selecsls42b', 'selecsls60', 'selecsls60b', 'seresnet50', 'seresnet152d', 'seresnext26d_32x4d', 'seresnext26t_32x4d', 'seresnext50_32x4d', 'skresnet18', 'skresnet34', 'skresnext50_32x4d', 'ssl_resnet18', 'ssl_resnet50', 'ssl_resnext50_32x4d', 'ssl_resnext101_32x4d', 'ssl_resnext101_32x8d', 'ssl_resnext101_32x16d', 'swsl_resnet18', 'swsl_resnet50', 'swsl_resnext50_32x4d', 'swsl_resnext101_32x4d', 'swsl_resnext101_32x8d', 'swsl_resnext101_32x16d', 'tf_inception_v3', 'tv_resnet34', 'tv_resnet50', 'tv_resnet101', 'tv_resnet152', 'tv_resnext50_32x4d', 'wide_resnet50_2', 'wide_resnet101_2', 'xception', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.fc.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in [ 'cspdarknet53', # head.fc 'cspresnet50', 'cspresnext50', 'dm_nfnet_f0', 'dm_nfnet_f1', 'dm_nfnet_f2', 'dm_nfnet_f3', 'dm_nfnet_f4', 'dm_nfnet_f5', 'dm_nfnet_f6', 'ese_vovnet19b_dw', 'ese_vovnet39b', 'gernet_l', 'gernet_m', 'gernet_s', 'nf_regnet_b1', 'nf_resnet50', 'nfnet_l0c', 'regnetx_002', 'regnetx_004', 'regnetx_006', 'regnetx_008', 'regnetx_016', 'regnetx_032', 'regnetx_040', 'regnetx_064', 'regnetx_080', 'regnetx_120', 'regnetx_160', 'regnetx_320', 'regnety_002', 'regnety_004', 'regnety_006', 'regnety_008', 'regnety_016', 'regnety_032', 'regnety_040', 'regnety_064', 'regnety_080', 'regnety_120', 'regnety_160', 'regnety_320', 'repvgg_a2', 'repvgg_b0', 'repvgg_b1', 'repvgg_b1g4', 'repvgg_b2', 'repvgg_b2g4', 'repvgg_b3', 'repvgg_b3g4', 'resnetv2_50x1_bitm', 'resnetv2_50x1_bitm_in21k', 'resnetv2_50x3_bitm', 'resnetv2_50x3_bitm_in21k', 'resnetv2_101x1_bitm', 'resnetv2_101x1_bitm_in21k', 'resnetv2_101x3_bitm', 'resnetv2_101x3_bitm_in21k', 'resnetv2_152x2_bitm', 'resnetv2_152x2_bitm_in21k', 'resnetv2_152x4_bitm', 'resnetv2_152x4_bitm_in21k', 'rexnet_100', 'rexnet_130', 'rexnet_150', 'rexnet_200', 'tresnet_l', 'tresnet_l_448', 'tresnet_m', 'tresnet_m_448', 'tresnet_xl', 'tresnet_xl_448', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 'vgg19', 'vgg19_bn', 'xception41', 'xception65', 'xception71', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.head.fc.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in [ 'ens_adv_inception_resnet_v2', # classif 'inception_resnet_v2', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.classif.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in [ 'inception_v4', # last_linear 'legacy_senet154', 'legacy_seresnet18', 'legacy_seresnet34', 'legacy_seresnet50', 'legacy_seresnet101', 'legacy_seresnet152', 'legacy_seresnext26_32x4d', 'legacy_seresnext50_32x4d', 'legacy_seresnext101_32x4d', 'nasnetalarge', 'pnasnet5large', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.last_linear.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in [ 'vit_base_patch16_224', # head 'vit_base_patch16_224_in21k', 'vit_base_patch16_384', 'vit_base_patch32_224_in21k', 'vit_base_patch32_384', 'vit_base_resnet50_224_in21k', 'vit_base_resnet50_384', 'vit_deit_base_distilled_patch16_224', 'vit_deit_base_distilled_patch16_384', 'vit_deit_base_patch16_224', 'vit_deit_base_patch16_384', 'vit_deit_small_distilled_patch16_224', 'vit_deit_small_patch16_224', 'vit_deit_tiny_distilled_patch16_224', 'vit_deit_tiny_patch16_224', 'vit_large_patch16_224', 'vit_large_patch16_224_in21k', 'vit_large_patch16_384', 'vit_large_patch32_224_in21k', 'vit_large_patch32_384', 'vit_small_patch16_224', ]: model = timm.create_model(model_type, pretrained=True) in_features = model.head.in_features model.classifier = nn.Linear(in_features, self.num_classes) self.model = model elif self.model_type in ['senet154']: model = pretrainedmodels.__dict__[model_type]( num_classes=1000, pretrained='imagenet') model.eval() num_features = model.last_linear.in_features # Заменяем Fully-Connected слой на наш линейный классификатор model.last_linear = nn.Linear(num_features, self.num_classes) self.model = model else: assert ( False ), f"model_type '{self.model_type}' not implemented. Please, choose from {MODELS}" self.classes_weights = classes_weights self.classes_weights = torch.FloatTensor(self.classes_weights).cuda() self.loss_func = nn.CrossEntropyLoss(weight=self.classes_weights) # self.loss_func = nn.CrossEntropyLoss(weight=self.classes_weigts) self.f1 = torchmetrics.F1(num_classes=self.num_classes)
def __init__(self): super().__init__() self.G = gen() self.D = disc() self.acc = torchmetrics.Accuracy() self.f1 = torchmetrics.F1(num_classes=2)
def train_epoch(model, train_loader, optimizer, criterion, epoch): model.train() total_loss = AverageMeter() manual_top1 = AverageMeter() manual_top5 = AverageMeter() torch_top1 = torchmetrics.Accuracy() torch_top5 = torchmetrics.Accuracy(top_k=5) torch_f1 = torchmetrics.F1(num_classes=312) for batch in tqdm(train_loader): images = batch["image"].to(device) elas = batch["ela"].to(device) target_labels = batch["label"].to(device) optimizer.zero_grad() out_logits, _ = model(images, elas) loss = criterion(out_logits, target_labels) loss.backward() optimizer.step() ############## SRM Step ########### bayer_mask = torch.zeros(3,3,5,5).cuda() bayer_mask[:,:,5//2, 5//2] = 1 bayer_weight = model.module.bayer_conv.weight * (1-bayer_mask) bayer_weight = (bayer_weight / torch.sum(bayer_weight, dim=(2,3), keepdim=True)) + 1e-7 bayer_weight -= bayer_mask model.module.bayer_conv.weight = nn.Parameter(bayer_weight) ################################### #---------------------Batch Loss Update------------------------- total_loss.update(loss.item(), train_loader.batch_size) # Metric with torch.no_grad(): out_logits = out_logits.cpu().detach() target_labels = target_labels.cpu().detach() topk = topk_accuracy(out_logits, target_labels, topk=(1,5)) manual_top1.update(topk[0].item(), train_loader.batch_size) manual_top5.update(topk[1].item(), train_loader.batch_size) torch_top1.update(torch.softmax(out_logits, dim=-1), target_labels) torch_top5.update(torch.softmax(out_logits, dim=-1), target_labels) torch_f1.update(torch.softmax(out_logits, dim=-1), target_labels) train_metrics = { "train_loss" : total_loss.avg, "train_acc1_manual": manual_top1.avg, "train_acc5_manual": manual_top5.avg, "train_acc1_torch": torch_top1.compute().item(), "train_acc_5_torch": torch_top5.compute().item(), "train_f1": torch_f1.compute().item(), "epoch" : epoch } wandb.log(train_metrics) return train_metrics