示例#1
0
def vit_b2_32x32(num_classes=10, input_channels=3):
    return VisionTransformer(
        num_classes=num_classes,
        input_channels=input_channels,
        input_size=32,
        patch_size=2,  # floor of 32 / (224 / 16) = 2.286
        hidden_size=768,
        num_layers=12,
        num_heads=12,
        mlp_dim=3072,
        drop_rate=0.1,
        attn_drop_rate=0.0,
    )
示例#2
0
def vit_b16_224x244(num_classes=10, input_channels=3):
    return VisionTransformer(
        num_classes=num_classes,
        input_channels=input_channels,
        input_size=224,
        patch_size=16,  # ceil of 224 / (224 / 16) = 16
        hidden_size=768,
        num_layers=12,
        num_heads=12,
        mlp_dim=3072,
        drop_rate=0.1,
        attn_drop_rate=0.0,
    )
示例#3
0
def vit_tiny_patchX_32x32(patch_size, num_classes=10, input_channels=3):
    return VisionTransformer(
        num_classes=num_classes,
        input_channels=input_channels,
        input_size=32,
        patch_size=patch_size,
        hidden_size=512,
        num_layers=4,
        num_heads=6,
        mlp_dim=1024,
        drop_rate=0.1,
        attn_drop_rate=0.0,
    )
    def __init__(
        self,
        steps,
        learning_rate=1e-4,
        weight_decay=0.0001,
        image_size=32,
        num_classes=10,
        patch_size=4,
        dim=256,
        layers=12,
        heads=8,
        dropout_p=0.0,
        linear_warmup_ratio=0.05,
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
        **_,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.steps = steps
        self.linear_warmup_ratio = linear_warmup_ratio

        self.model = VisionTransformer(
            image_size=image_size,
            num_classes=num_classes,
            patch_size=patch_size,
            dim=dim,
            layers=layers,
            heads=heads,
            dropout_p=dropout_p,
        )

        self.criterion = torch.nn.CrossEntropyLoss()
        self.val_accuracy = pl.metrics.Accuracy()
 def __init__(self, img_size=32, patch_size=2):
     super(NN, self).__init__()
     self.model = VisionTransformer(img_size=img_size,
                                    patch_size=patch_size,
                                    in_chans=3,
                                    num_classes=10,
                                    embed_dim=80,
                                    depth=20,
                                    num_heads=20,
                                    mlp_ratio=4.,
                                    qkv_bias=False,
                                    qk_scale=None,
                                    drop_rate=0.,
                                    attn_drop_rate=0.,
                                    drop_path_rate=0.,
                                    hybrid_backbone=None,
                                    norm_layer=nn.LayerNorm)
class LM(pl.LightningModule):
    def __init__(
        self,
        steps,
        learning_rate=1e-4,
        weight_decay=0.0001,
        image_size=32,
        num_classes=10,
        patch_size=4,
        dim=256,
        layers=12,
        heads=8,
        dropout_p=0.0,
        linear_warmup_ratio=0.05,
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
        **_,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.steps = steps
        self.linear_warmup_ratio = linear_warmup_ratio

        self.model = VisionTransformer(
            image_size=image_size,
            num_classes=num_classes,
            patch_size=patch_size,
            dim=dim,
            layers=layers,
            heads=heads,
            dropout_p=dropout_p,
        )

        self.criterion = torch.nn.CrossEntropyLoss()
        self.val_accuracy = pl.metrics.Accuracy()

    def configure_optimizers(self):

        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.learning_rate,
            momentum=0.9,
            weight_decay=self.weight_decay,
        )

        warmup_steps = int(self.linear_warmup_ratio * self.steps)
        scheduler = {
            "scheduler":
            torch.optim.lr_scheduler.LambdaLR(
                optimizer, linear_warmup_cosine_decay(warmup_steps,
                                                      self.steps)),
            "interval":
            "step",
        }
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self.criterion(y_hat, y)
        self.val_accuracy(y_hat, y)
        return loss

    def validation_epoch_end(self, losses):
        self.log("valid_loss", torch.stack(losses).mean(), prog_bar=True)
        self.log("valid_acc", self.val_accuracy.compute(), prog_bar=True)
    args = parser.parse_args()

    ds = tfds.load("cifar10", as_supervised=True)
    ds_train = (ds["train"].cache().shuffle(1024).batch(
        args.batch_size).prefetch(AUTOTUNE))
    ds_test = (ds["test"].cache().batch(args.batch_size).prefetch(AUTOTUNE))

    strategy = tf.distribute.MirroredStrategy()

    with strategy.scope():
        model = VisionTransformer(
            image_size=args.image_size,
            patch_size=args.patch_size,
            num_layers=args.num_layers,
            num_classes=10,
            d_model=args.d_model,
            num_heads=args.num_heads,
            mlp_dim=args.mlp_dim,
            channels=3,
            dropout=0.1,
        )
        model.compile(
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            optimizer=tfa.optimizers.AdamW(learning_rate=args.lr,
                                           weight_decay=args.weight_decay),
            metrics=["accuracy"],
        )

    early_stop = tf.keras.callbacks.EarlyStopping(patience=10),
    mcp = tf.keras.callbacks.ModelCheckpoint(filepath='weights/best.h5',
示例#8
0
def main(config):

    # For Reproducibility #
    random.seed(config.seed)
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.seed)

    # Weights and Plots Path #
    paths = [config.weights_path, config.plots_path]
    for path in paths:
        make_dirs(path)

    # Prepare Data Loader #
    if config.dataset == 'cifar':
        train_loader, val_loader, test_loader = cifar_loader(
            config.num_classes, config.batch_size)
        input_size = 32

    # Prepare Networks #
    if config.model == 'vit':
        model = VisionTransformer(in_channels=config.in_channels,
                                  embed_dim=config.embed_dim,
                                  patch_size=config.patch_size,
                                  num_layers=config.num_layers,
                                  num_heads=config.num_heads,
                                  mlp_dim=config.mlp_dim,
                                  dropout=config.drop_out,
                                  input_size=input_size,
                                  num_classes=config.num_classes).to(device)

    elif config.model == 'efficient':
        model = EfficientNet.from_name(
            'efficientnet-b0', num_classes=config.num_classes).to(device)

    elif config.model == 'resnet':
        model = resnet34(pretrained=False).to(device)
        model.fc = nn.Linear(config.mlp_dim, config.num_classes).to(device)

    else:
        raise NotImplementedError

    # Weight Initialization #
    if not config.model == 'efficient':
        if config.init == 'normal':
            model.apply(init_weights_normal)
        elif config.init == 'xavier':
            model.apply(init_weights_xavier)
        elif config.init == 'he':
            model.apply(init_weights_kaiming)
        else:
            raise NotImplementedError

    # Train #
    if config.phase == 'train':

        # Loss Function #
        criterion = nn.CrossEntropyLoss()

        # Optimizers #
        if config.num_classes == 10:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=config.lr,
                                         betas=(0.5, 0.999))
            optimizer_scheduler = get_lr_scheduler(config.lr_scheduler,
                                                   optimizer)
        elif config.num_classes == 100:
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=config.lr,
                                        momentum=0.9,
                                        weight_decay=5e-4)
            optimizer_scheduler = get_lr_scheduler('step', optimizer)

        # Constants #
        best_top1_acc = 0

        # Lists #
        train_losses, val_losses = list(), list()
        train_top1_accs, train_top5_accs = list(), list()
        val_top1_accs, val_top5_accs = list(), list()

        # Train and Validation #
        print("Training {} has started.".format(model.__class__.__name__))
        for epoch in range(config.num_epochs):

            # Train #
            train_loss, train_top1_acc, train_top5_acc = train(
                train_loader, model, optimizer, criterion, epoch, config)

            # Validation #
            val_loss, val_top1_acc, val_top5_acc = validate(
                val_loader, model, criterion, epoch, config)

            # Add items to Lists #
            train_losses.append(train_loss)
            val_losses.append(val_loss)

            train_top1_accs.append(train_top1_acc)
            train_top5_accs.append(train_top5_acc)

            val_top1_accs.append(val_top1_acc)
            val_top5_accs.append(val_top5_acc)

            # If Best Top 1 Accuracy #
            if val_top1_acc > best_top1_acc:
                best_top1_acc = max(val_top1_acc, best_top1_acc)

                # Save Models #
                print("The best model is saved!")
                torch.save(
                    model.state_dict(),
                    os.path.join(
                        config.weights_path,
                        'BEST_{}_{}_{}.pkl'.format(model.__class__.__name__,
                                                   str(config.dataset).upper(),
                                                   config.num_classes)))

            print("Best Top 1 Accuracy {:.2f}%\n".format(best_top1_acc))

            # Optimizer Scheduler #
            optimizer_scheduler.step()

        # Plot Losses and Accuracies #
        losses = (train_losses, val_losses)
        accs = (train_top1_accs, train_top5_accs, val_top1_accs, val_top5_accs)
        plot_metrics(losses, accs, config.plots_path, model, config.dataset,
                     config.num_classes)

        print("Training {} using {} {} finished.".format(
            model.__class__.__name__,
            str(config.dataset).upper(), config.num_classes))

    # Test #
    elif config.phase == 'test':

        test(test_loader, model, config)

    else:
        raise NotImplementedError