Ejemplo n.º 1
0
class RobustClassifier(LightningModule):
    def __init__(self, args: argparse.Namespace = None):
        super().__init__()
        self.num_classes = args.num_classes
        self.latent_dim = args.latent_dim
        self.feature_dim = args.feature_dim
        self.batch_size = args.batch_size
        self.lr = args.lr

        self.class_noise_convertor = nn.ModuleDict({
            str(k): nn.Sequential(nn.Linear(self.feature_dim, self.latent_dim),
                                  nn.ReLU(),
                                  nn.Linear(self.latent_dim,
                                            self.latent_dim)).to(self.device)
            for k in range(self.num_classes)
        })

        self.generator = Generator(ngpu=1)
        self.generator.load_state_dict(torch.load(args.gen_weights))
        self.class_identifier = SiameseNet()
        self.class_identifier.load_state_dict(torch.load(args.siamese_weights))
        if args.generator_pre_train:
            self.generator.to(self.device).freeze()
        if args.siamese_pre_train:
            self.class_identifier.to(self.device).freeze()

        # self.class_similarity = nn.Linear(4096, 1)
        self.creterion = ContrastiveLoss(1)
        self.threshold = args.threshold

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    @staticmethod
    def add_to_argparse(parser):
        parser.add_argument("--num_classes",
                            type=int,
                            default=10,
                            help="Number of Classes")
        parser.add_argument("--lr", type=float, default=3e-4)
        parser.add_argument("--threshold", type=float, default=0.5)
        parser.add_argument("--batch_size", type=int, default=16)
        parser.add_argument("--latent_dim", type=int, default=100)
        parser.add_argument("--feature_dim", type=int, default=100)
        parser.add_argument("--gen_weights",
                            type=str,
                            default="weights/gen_weights.pth")
        parser.add_argument("--siamese_weights",
                            type=str,
                            default="weights/siamese_weights.pth")
        parser.add_argument("--generator_pre_train",
                            dest='generator_pre_train',
                            default=True,
                            action='store_true')
        parser.add_argument("--no_generator_pre_train",
                            dest='generator_pre_train',
                            default=True,
                            action='store_false')
        parser.set_defaults(generator_pre_train=True)
        parser.add_argument("--siamese_pre_train",
                            dest='siamese_pre_train',
                            default=True,
                            action='store_true')
        parser.add_argument("--no_siamese_pre_train",
                            dest='siamese_pre_train',
                            default=True,
                            action='store_false')
        parser.set_defaults(siamese_pre_train=True)
        return parser

    def forward(self, x):
        batch_size = x.size(0)
        embeddings1, embeddings2 = torch.Tensor([]), torch.Tensor([])
        scores = torch.ones(self.num_classes, batch_size)
        noise = torch.rand(batch_size, self.feature_dim, device=self.device)
        for class_idx, model in self.class_noise_convertor.items():
            class_noise = model(noise).view(batch_size, -1, 1, 1)
            gen_imgs = self.generator(class_noise)
            embed1, embed2 = self.class_identifier(gen_imgs, x)
            embeddings1 = torch.cat(
                (embeddings1.to(self.device), embed1.to(self.device)), dim=0)
            embeddings2 = torch.cat(
                (embeddings2.to(self.device), embed2.to(self.device)), dim=0)
            scores[int(class_idx)] = nn.functional.cosine_similarity(embed1,
                                                                     embed2,
                                                                     dim=1)
        self.register_buffer("embeddings_1", embeddings1.view(-1, 4096))
        self.register_buffer("embeddings_2", embeddings2.view(-1, 4096))
        scores = torch.softmax(scores[scores > 0].view(batch_size, -1), dim=1)
        pred = torch.argmax(scores, dim=1).to(
            self.device
        )  #torch.Tensor([torch.argmax(img_score) if (img_score.max()-img_score.min())>0.5 else -1 for img_score in scores]).to(self.device)
        return pred

    def calculate_loss(self, y_true):
        siamese_labels = torch.Tensor([])
        for class_idx in range(self.num_classes):
            temp = torch.from_numpy(
                np.where(y_true.cpu().numpy() == class_idx, 1, 0))
            siamese_labels = torch.cat(
                (siamese_labels.to(self.device), temp.to(self.device)), dim=0)
        result = self.creterion(self.embeddings_1, self.embeddings_2,
                                siamese_labels)
        return result

    def training_step(self, batch, batch_idx):
        img, y_true = batch
        img, y_true = img.to(self.device), y_true.to(self.device)
        y_pred = self(img)
        result = self.calculate_loss(y_true)
        self.train_acc(y_pred, y_true)
        self.log('train_acc', self.train_acc, on_epoch=True, on_step=False)
        self.log('train_loss', result, on_step=True)

        return result

    def validation_step(self, batch, batch_idx):
        img, y_true = batch
        img, y_true = img.to(self.device), y_true.to(self.device)
        y_pred = self(img)
        val_loss = self.calculate_loss(y_true)
        self.log('val_loss', val_loss, prog_bar=True)
        self.val_acc(y_pred, y_true)
        self.log('val_acc', self.val_acc, on_epoch=True, on_step=False)

    def test_step(self, batch, batch_idx):
        img, y_true = batch
        img, y_true = img.to(self.device), y_true.to(self.device)
        y_pred = self(img)
        self.test_acc(y_pred, y_true)
        self.log('test_acc', self.test_acc, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = optim.Adam(filter(lambda p: p.requires_grad_,
                                      self.parameters()),
                               lr=self.lr)
        return [optimizer], []
print(len(siamese_test_dataset))
print(len(siamese_train_dataset))
# a,b = siamese_test_dataset[0]
# print(b)
# print(type(a))

trainloader_l, trainloader_u, valloader, testloader = get_dataloaders(
    siamese_train_dataset, siamese_test_dataset, args.train_size,
    args.val_size, args.batch_size)

device = torch.device("cuda:0")

net = SiameseNet(10, False)

net = net.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr=args.lr)

model_name = 'reid_pl' + str(args.train_size)
alpha = 0.0
val_min_acc = 0.0
test_acc = 0.0
flag = False

for epoch in range(args.epochs):
    if (epoch > args.T1 and epoch <= args.T2):
        alpha = (epoch - args.T1) / (args.T2 - args.T1)

    net.train()