Ejemplo n.º 1
0
class Trainer:
    def __init__(self, args, device):
        self.alpha_jigsaw_weight = 0.5
        self.alpha_odd_weight = 0.5
        self.alpha_rotation_weight = 0.5
        self.args = args
        self.device = device
        self.betaJigen = args.betaJigen

        model = model_factory.get_network(args.network)(classes=args.n_classes,
                                                        jigsaw_classes=31,
                                                        odd_classes=10,
                                                        rotation_classes=4)
        #if args.rotation== True:
        #    model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=4)
        #elif args.oddOneOut == True:
        #    model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=10)
        #else:
        #    model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=31)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args)
        self.target_loader = data_helper.get_val_dataloader(args)

        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes
        if args.oddOneOut == True and args.rotation == True:
            self.nTasks = 4
        elif args.oddOneOut == True or args.rotation == True:
            self.nTasks = 3
        else:
            self.nTasks = 2

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, (data, class_l, jigsaw_l,
                 self_sup_task) in enumerate(self.source_loader):
            #source_loader is only data for training
            data, class_l, jigsaw_l, self_sup_task = data.to(
                self.device), class_l.to(self.device), jigsaw_l.to(
                    self.device), self_sup_task.to(self.device)

            self.optimizer.zero_grad()

            class_logit, jigsaw_logit, odd_logit, rotation_logit = self.model(
                data)  #label from model
            #evaluate jigsaw mistake
            jigsaw_loss = criterion(
                jigsaw_logit[(self_sup_task == 0) | (self_sup_task == 3)],
                jigsaw_l[(self_sup_task == 0) | (self_sup_task == 3)])

            if self.args.oddOneOut == True:

                odd_loss = criterion(
                    odd_logit[(self_sup_task == 1) | (self_sup_task == 3)],
                    jigsaw_l[(self_sup_task == 1) | (self_sup_task == 3)])
            else:
                odd_loss = 0

            if self.args.rotation == True:
                rotation_loss = criterion(
                    rotation_logit[(self_sup_task == 2) |
                                   (self_sup_task == 3)],
                    jigsaw_l[(self_sup_task == 2) | (self_sup_task == 3)])
            else:
                rotation_loss = 0

            #for classification we evaluate the loss only for the not scrumbled images
            class_loss = criterion(class_logit[jigsaw_l == 0],
                                   class_l[jigsaw_l == 0])

            _, jigsaw_pred = jigsaw_logit[(self_sup_task == 0) |
                                          (self_sup_task == 3)].max(dim=1)

            if self.args.oddOneOut == True:
                _, odd_pred = odd_logit[(self_sup_task == 1) |
                                        (self_sup_task == 3)].max(dim=1)

            if self.args.rotation == True:
                _, rotation_pred = rotation_logit[(self_sup_task == 2) |
                                                  (self_sup_task == 3)].max(
                                                      dim=1)

            _, cls_pred = class_logit.max(dim=1)

            loss = class_loss + self.alpha_jigsaw_weight * jigsaw_loss + self.alpha_odd_weight * odd_loss + self.alpha_rotation_weight * rotation_loss

            loss.backward()

            self.optimizer.step()

            if self.args.oddOneOut == True and self.args.rotation == True:
                self.logger.log(
                    it, len(self.source_loader), {
                        "Class Loss ": class_loss.item(),
                        "Jigsaw Loss": jigsaw_loss.item(),
                        "Odd Loss": odd_loss.item(),
                        "Rotation Loss": rotation_loss.item()
                    }, {
                        "Class Accuracy ":
                        torch.sum(cls_pred == class_l.data).item(),
                        "Jigsaw Accuracy ":
                        torch.sum(jigsaw_pred == jigsaw_l[
                            (self_sup_task == 0)
                            | (self_sup_task == 3)].data).item(),
                        "Odd Accuracy ":
                        torch.sum(odd_pred == jigsaw_l[(self_sup_task == 1) | (
                            self_sup_task == 3)].data).item(),
                        "Rotation Accuracy ":
                        torch.sum(rotation_pred == jigsaw_l[
                            (self_sup_task == 2)
                            | (self_sup_task == 3)].data).item()
                    }, data.shape[0])
            elif self.args.oddOneOut == True and self.args.rotation == False:
                self.logger.log(
                    it, len(self.source_loader), {
                        "Class Loss ": class_loss.item(),
                        "Jigsaw Loss": jigsaw_loss.item(),
                        "Odd Loss": odd_loss.item()
                    }, {
                        "Class Accuracy ":
                        torch.sum(cls_pred == class_l.data).item(),
                        "Jigsaw Accuracy ":
                        torch.sum(jigsaw_pred == jigsaw_l[
                            (self_sup_task == 0)
                            | (self_sup_task == 3)].data).item(),
                        "Odd Accuracy ":
                        torch.sum(odd_pred == jigsaw_l[(self_sup_task == 1) | (
                            self_sup_task == 3)].data).item()
                    }, data.shape[0])
            elif self.args.oddOneOut == False and self.args.rotation == True:
                self.logger.log(
                    it, len(self.source_loader), {
                        "Class Loss ": class_loss.item(),
                        "Jigsaw Loss": jigsaw_loss.item(),
                        "Rotation Loss": rotation_loss.item()
                    }, {
                        "Class Accuracy ":
                        torch.sum(cls_pred == class_l.data).item(),
                        "Jigsaw Accuracy ":
                        torch.sum(jigsaw_pred == jigsaw_l[
                            (self_sup_task == 0)
                            | (self_sup_task == 3)].data).item(),
                        "Rotation Accuracy ":
                        torch.sum(rotation_pred == jigsaw_l[
                            (self_sup_task == 2)
                            | (self_sup_task == 3)].data).item()
                    }, data.shape[0])
            else:
                self.logger.log(
                    it, len(self.source_loader), {
                        "Class Loss ": class_loss.item(),
                        "Jigsaw Loss": jigsaw_loss.item()
                    }, {
                        "Class Accuracy ":
                        torch.sum(cls_pred == class_l.data).item(),
                        "Jigsaw Accuracy ":
                        torch.sum(jigsaw_pred == jigsaw_l[
                            (self_sup_task == 0)
                            | (self_sup_task == 3)].data).item()
                    }, data.shape[0])

            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit, odd_loss, rotation_loss, odd_logit, rotation_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                class_correct, jigsaw_correct, odd_correct, rotation_correct = self.do_test(
                    loader)
                class_acc = float(class_correct) / total
                jigsaw_acc = float(jigsaw_correct) / total
                odd_acc = float(odd_correct) / total
                rotation_acc = float(rotation_correct) / total
                acc = (class_acc + jigsaw_acc + odd_acc +
                       rotation_acc) / self.nTasks
                self.logger.log_test(phase, {"Classification Accuracy": acc})
                self.results[phase][self.current_epoch] = acc

    def do_test(self, loader):
        class_correct = 0
        jigsaw_correct = 0
        odd_correct = 0
        rotation_correct = 0
        for it, (data, class_l, jigsaw_l, self_sup_task) in enumerate(loader):
            data, class_l, jigsaw_l, self_sup_task = data.to(
                self.device), class_l.to(self.device), jigsaw_l.to(
                    self.device), self_sup_task.to(self.device)
            class_logit, jigsaw_logit, odd_logit, rotation_logit = self.model(
                data)
            _, jigsaw_pred = jigsaw_logit.max(dim=1)

            if self.args.oddOneOut == True:
                _, odd_pred = odd_logit.max(dim=1)
                odd_correct += torch.sum(odd_pred == jigsaw_l.data)
            if self.args.rotation == True:
                _, rotation_pred = rotation_logit.max(dim=1)
                rotation_correct += torch.sum(rotation_pred == jigsaw_l.data)

            _, cls_pred = class_logit.max(dim=1)

            jigsaw_correct += torch.sum(jigsaw_pred == jigsaw_l.data)

            class_correct += torch.sum(cls_pred == class_l.data)
        return class_correct, jigsaw_correct, odd_correct, rotation_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }

        for self.current_epoch in range(self.args.epochs):
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
            self.scheduler.step()

        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g" %
              (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 2
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device

        model = resnet18(pretrained=True, classes=args.n_classes)  # ------
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based())
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (
        len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all,
                                                                 nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
                self.device), d_idx.to(self.device)

            self.optimizer.zero_grad()

            class_logit = self.model(data, class_l, True)

            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss

            loss.backward()
            self.optimizer.step()

            self.logger.log(it, len(self.source_loader),
                            {"class": class_loss.item()},
                            {"class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
            del loss, class_loss, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)

                class_correct = self.do_test(loader)

                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"class": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        class_correct = 0
        for it, ((data, nouse, class_l), _) in enumerate(loader):
            data, nouse, class_l = data.to(self.device), nouse.to(self.device), class_l.to(self.device)

            class_logit = self.model(data, class_l, False)
            _, cls_pred = class_logit.max(dim=1)

            class_correct += torch.sum(cls_pred == class_l.data)

        return class_correct


    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g, best epoch: %g" % (
        val_res.max(), test_res[idx_best], test_res.max(), idx_best))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 3
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(
            jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)
            # absolute_iter_count = it + self.current_epoch * self.len_dataloader
            # p = float(absolute_iter_count) / self.args.epochs / self.len_dataloader
            # lambda_val = 2. / (1. + np.exp(-10 * p)) - 1
            # if domain_error > 2.0:
            #     lambda_val  = 0
            # print("Shutting down LAMBDA to prevent implosion")

            self.optimizer.zero_grad()

            jigsaw_logit, class_logit = self.model(
                data)  # , lambda_val=lambda_val)
            jigsaw_loss = criterion(jigsaw_logit, jig_l)
            # domain_loss = criterion(domain_logit, d_idx)
            # domain_error = domain_loss.item()
            if self.only_non_scrambled:
                if self.target_id is not None:
                    idx = (jig_l == 0) & (d_idx != self.target_id)
                    class_loss = criterion(class_logit[idx], class_l[idx])
                else:
                    class_loss = criterion(class_logit[jig_l == 0],
                                           class_l[jig_l == 0])

            elif self.target_id:
                class_loss = criterion(class_logit[d_idx != self.target_id],
                                       class_l[d_idx != self.target_id])
            else:
                class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            # _, domain_pred = domain_logit.max(dim=1)
            loss = class_loss + jigsaw_loss * self.jig_weight  # + 0.1 * domain_loss

            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it,
                len(self.source_loader),
                {
                    "jigsaw": jigsaw_loss.item(),
                    "class":
                    class_loss.item()  # , "domain": domain_loss.item()
                },
                # ,"lambda": lambda_val},
                {
                    "jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                    "class": torch.sum(cls_pred == class_l.data).item(),
                    # "domain": torch.sum(domain_pred == d_idx.data).item()
                },
                data.shape[0])
            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    jigsaw_correct, class_correct, single_acc = self.do_test_multi(
                        loader)
                    print("Single vs multi: %g %g" %
                          (float(single_acc) / total,
                           float(class_correct) / total))
                else:
                    jigsaw_correct, class_correct = self.do_test(loader)
                jigsaw_acc = float(jigsaw_correct) / total
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {
                    "jigsaw": jigsaw_acc,
                    "class": class_acc
                })
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_test_multi(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        single_correct = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0],
                                       self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[
                0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            jigsaw_logit, single_logit = self.model(data[:, 0])
            _, jig_pred = jigsaw_logit.max(dim=1)
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0])
        return jigsaw_correct, class_correct, single_correct

    def do_training(self):
        self.logger = Logger(self.args,
                             update_frequency=30)  # , "domain", "lambda"
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        #print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 4
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(
            jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        if args.target in args.source:
            print(
                "No need to include target in source, it is automatically done by this script"
            )
            k = args.source.index(args.target)
            args.source = args.source[:k] + args.source[k + 1:]
            print("Source: %s" % args.source)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_jig_loader = data_helper.get_target_jigsaw_loader(args)
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, target jig: %d, val %d, test %d" %
              (len(self.source_loader.dataset),
               len(self.target_jig_loader.dataset), len(
                   self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.target_weight = args.target_weight
        self.target_entropy = args.entropy_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, (source_batch, target_batch) in enumerate(
                zip(self.source_loader,
                    itertools.cycle(self.target_jig_loader))):
            (data, jig_l, class_l), d_idx = source_batch
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)
            tdata, tjig_l, _ = target_batch
            tdata, tjig_l = tdata.to(self.device), tjig_l.to(self.device)

            self.optimizer.zero_grad()

            jigsaw_logit, class_logit = self.model(data)
            jigsaw_loss = criterion(jigsaw_logit, jig_l)
            target_jigsaw_logit, target_class_logit = self.model(tdata)
            target_jigsaw_loss = criterion(target_jigsaw_logit, tjig_l)
            target_entropy_loss = entropy_loss(target_class_logit[tjig_l == 0])
            if self.only_non_scrambled:
                class_loss = criterion(class_logit[jig_l == 0],
                                       class_l[jig_l == 0])
            else:
                class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)

            loss = class_loss + jigsaw_loss * self.jig_weight + target_jigsaw_loss * self.target_weight + target_entropy_loss * self.target_entropy

            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it, len(self.source_loader), {
                    "jigsaw": jigsaw_loss.item(),
                    "class": class_loss.item(),
                    "t_jigsaw": target_jigsaw_loss.item(),
                    "entropy": target_entropy_loss.item()
                }, {
                    "jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit, target_jigsaw_logit, target_jigsaw_loss

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    jigsaw_correct, class_correct, single_acc = self.do_test_multi(
                        loader)
                    print("Single vs multi: %g %g" %
                          (float(single_acc) / total,
                           float(class_correct) / total))
                else:
                    jigsaw_correct, class_correct = self.do_test(loader)
                jigsaw_acc = float(jigsaw_correct) / total
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {
                    "jigsaw": jigsaw_acc,
                    "class": class_acc
                })
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_training(self):
        self.logger = Logger(self.args,
                             update_frequency=30)  # , "domain", "lambda"
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g" %
              (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 5
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        else:
            model = resnet18(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        self.model = model.to(device)

        if args.resume:
            if isfile(args.resume):
                print(f"=> loading checkpoint '{args.resume}'")
                checkpoint = torch.load(args.resume)
                self.args.start_epoch = checkpoint['epoch']
                self.model.load_state_dict(checkpoint['model'])
                print(
                    f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
                )
            else:
                raise ValueError(f"Failed to find checkpoint {args.resume}")

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        # self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_tgt_dataloader(
            self.args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
        self.topk = [0 for _ in range(3)]

    def _do_epoch(self, epoch=None):
        if self.args.loss == 'ce':
            criterion = nn.CrossEntropyLoss()
        elif self.args.loss == 'fl':
            criterion = FocalLoss(class_num=self.args.n_classes)
        self.model.train()
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)
            self.optimizer.zero_grad()

            data_flip = torch.flip(data, (3, )).detach().clone()
            data = torch.cat((data, data_flip))
            class_l = torch.clamp(class_l, 0, 9)
            class_l = torch.cat((class_l, class_l))

            class_logit = self.model(data, class_l, self.args.RSC_flag, epoch)
            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss

            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it, len(self.source_loader), {"loss": class_loss.item()}, {
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            del loss, class_loss, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                class_correct, auc_dict = self.do_test(loader)
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"class_acc": class_acc})
                self.logger.log_test(phase, {"auc": auc_dict['auc']})
                self.logger.log_test(phase, {"fpr_980": auc_dict['fpr_980']})
                self.logger.log_test(phase, {"fpr_991": auc_dict['fpr_991']})
                self.results[phase][self.current_epoch] = class_acc

                #save best&latest model params
                if phase == 'val':
                    self.save_model(epoch, auc_dict)
                del auc_dict

    def do_test(self, loader):
        class_correct = 0
        auc_meter = AUCMeter()
        for it, ((data, nouse, class_l), _) in enumerate(loader):
            data, nouse, class_l = data.to(self.device), nouse.to(
                self.device), class_l.to(self.device)

            class_logit = self.model(data, class_l, False)
            _, cls_pred = class_logit.max(dim=1)

            class_correct += torch.sum(cls_pred == class_l.data)

            cls_score = F.softmax(class_logit, dim=1)
            auc_meter.update(class_l.cpu(), cls_score.cpu())

        auc, fpr_980, fpr_991, fpr_993, fpr_995, fpr_997, fpr_999, fpr_1, thresholds = auc_meter.calculate(
        )
        auc_dict = {
            'auc': auc,
            'fpr_980': fpr_980,
            'fpr_991': fpr_991,
            'fpr_993': fpr_993,
            'fpr_995': fpr_995,
            'fpr_997': fpr_997,
            'fpr_999': fpr_999,
            'fpr_1': fpr_1,
            'thresholds': thresholds
        }
        return class_correct, auc_dict

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=50)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
        for self.current_epoch in range(self.args.start_epoch,
                                        self.args.epochs):
            self._do_epoch(self.current_epoch)
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_last_lr())
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print(
            "Best val %g, corresponding test %g - best test: %g, best epoch: %g"
            % (val_res.max(), test_res[idx_best], test_res.max(), idx_best))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model

    # def save_model(self,epoch, auc_dict):
    #     if not exists(_save_models_dir): os.mkdir(_save_models_dir)
    #     state_to_save = {'model':self.model.state_dict(), 'auc_dict':auc_dict, 'epoch':epoch}
    #     tmp_auc, tmp_fpr_980 = auc_dict['auc'], auc_dict['fpr_980']
    #     best1,best2,best3 = self.moving_record['best1'],self.moving_record['best2'],self.moving_record['best3']
    #     best1_path, best2_path, best3_path = (join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{_}.pth") for _ in [1,2,3])
    #     #resort top3
    #     update_pos = -1
    #     if tmp_auc>best1['auc']:
    #         best3['auc'], best3['fpr_980'] = best2['auc'], best2['fpr_980']
    #         best2['auc'], best2['fpr_980'] = best1['auc'], best1['fpr_980']
    #         best1['auc'], best1['fpr_980'] = tmp_auc, tmp_fpr_980
    #         if exists(best2_path) and exists(best3_path):
    #             os.rename(best2_path, best3_path)
    #         if exists(best1_path) and exists(best2_path):
    #             os.rename(best1_path, best2_path)
    #         update_pos = 1
    #     elif best2['auc']< tmp_auc < best1['auc']:
    #         best3['auc'], best3['fpr_980'] = best2['auc'], best2['fpr_980']
    #         best2['auc'], best2['fpr_980'] = tmp_auc, tmp_fpr_980
    #         if exists(best2_path) and exists(best3_path):
    #             os.rename(best2_path, best3_path)
    #         update_pos = 2
    #     elif best3['auc']< tmp_auc < best2['auc']:
    #         best3['auc'], best3['fpr_980'] = tmp_auc, tmp_fpr_980
    #         update_pos = 3

    #     if update_pos in [1,2,3]:
    #         model_saved_path = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{update_pos}.pth")
    #         torch.save(state_to_save, model_saved_path)
    #         print(f'=>Best{update_pos} model updated and saved in path {model_saved_path}')
    #     if epoch in range(self.args.epochs-3, self.args.epochs):
    #         model_saved_path = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_epochs{epoch}.pth")
    #         torch.save(state_to_save, model_saved_path)
    #         print(f'=>Last{self.args.epochs - epoch} model updated and saved in path {model_saved_path}')
    def save_model(self, epoch, auc_dict):
        if not exists(_save_models_dir): os.mkdir(_save_models_dir)
        tmp_auc, tmp_fpr_980 = auc_dict['auc'], auc_dict['fpr_980']
        for i, rec in enumerate(self.topk):
            if tmp_auc > rec:
                for j in range(len(self.topk) - 1, i, -1):
                    self.topk[j] = self.topk[j - 1]
                    _j, _jm1 = join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{j+1}.pth"),\
                    join(_save_models_dir, f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{j}.pth")
                    if exists(_jm1):
                        os.rename(_jm1, _j)
                self.topk[i] = tmp_auc
                model_saved_path = join(
                    _save_models_dir,
                    f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_best{i+1}.pth"
                )
                state_to_save = {
                    'model': self.model.state_dict(),
                    'auc_dict': auc_dict,
                    'epoch': epoch
                }
                torch.save(state_to_save, model_saved_path)
                print(
                    f'=>Best{i+1} model updated and saved in path {model_saved_path}'
                )
                break

        if epoch in range(self.args.epochs - 3, self.args.epochs):
            model_saved_path = join(
                _save_models_dir,
                f"tgt_{self.args.target}_src_{'-'.join(self.args.source)}_RSC_{self.args.RSC_flag}_epochs{epoch}.pth"
            )
            torch.save(state_to_save, model_saved_path)
            print(
                f'=>Last{self.args.epochs - epoch} model updated and saved in path {model_saved_path}'
            )
Ejemplo n.º 6
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(
            jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

        self.best_val_jigsaw = 0.0
        self.best_class_acc = 0.0

        _, logname = Logger.get_name_from_args(args)

        self.folder_name = "%s/%s_to_%s/%s" % (args.folder_name, "-".join(
            sorted(args.source)), args.target, logname)

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        epoch_loss = 0
        pbar = pkbar.Pbar(name='Epoch Progress',
                          target=len(self.source_loader))
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            pbar.update(it)
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)
            # absolute_iter_count = it + self.current_epoch * self.len_dataloader
            # p = float(absolute_iter_count) / self.args.epochs / self.len_dataloader
            # lambda_val = 2. / (1. + np.exp(-10 * p)) - 1
            # if domain_error > 2.0:
            #     lambda_val  = 0
            # print("Shutting down LAMBDA to prevent implosion")

            self.optimizer.zero_grad()

            jigsaw_logit, class_logit = self.model(
                data)  # , lambda_val=lambda_val)
            jigsaw_loss = criterion(jigsaw_logit, jig_l)
            # domain_loss = criterion(domain_logit, d_idx)
            # domain_error = domain_loss.item()
            if self.only_non_scrambled:
                if self.target_id is not None:
                    idx = (jig_l == 0) & (d_idx != self.target_id)
                    class_loss = criterion(class_logit[idx], class_l[idx])
                else:
                    class_loss = criterion(class_logit[jig_l == 0],
                                           class_l[jig_l == 0])

            elif self.target_id:
                class_loss = criterion(class_logit[d_idx != self.target_id],
                                       class_l[d_idx != self.target_id])
            else:
                class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)

            if self.args.deep_all:
                jigsaw_loss = torch.Tensor([0.0])
                loss = class_loss
            else:
                loss = class_loss + jigsaw_loss * self.jig_weight  # + 0.1 * domain_loss
                # _, domain_pred = domain_logit.max(dim=1)

            epoch_loss = epoch_loss + loss
            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it,
                len(self.source_loader),
                {
                    "jigsaw": jigsaw_loss.item(),
                    "class":
                    class_loss.item()  # , "domain": domain_loss.item()
                },
                # ,"lambda": lambda_val},
                {
                    "jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                    "class": torch.sum(cls_pred == class_l.data).item(),
                    # "domain": torch.sum(domain_pred == d_idx.data).item()
                },
                data.shape[0])
            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    jigsaw_correct, class_correct, single_acc = self.do_test_multi(
                        loader)
                    print("Single vs multi: %g %g" %
                          (float(single_acc) / total,
                           float(class_correct) / total))
                else:
                    jigsaw_correct, class_correct = self.do_test(loader)
                jigsaw_acc = float(jigsaw_correct) / total
                class_acc = float(class_correct) / total

                self.logger.log_test(phase, {
                    "jigsaw": jigsaw_acc,
                    "class": class_acc
                })
                self.results[phase][self.current_epoch] = class_acc

        if (self.results['val'][self.current_epoch] > self.best_class_acc):
            self.best_class_acc = self.results['val'][self.current_epoch]
            print("Saving new best at epoch: {}".format(self.current_epoch))
            self.save_model(
                os.path.join("logs", self.folder_name, 'best_model.pth'))

        print("Saving latest at epoch: {}".format(self.current_epoch))
        self.save_model(
            os.path.join("logs", self.folder_name, 'latest_model.pth'))

    def save_model(self, file_path):
        torch.save(
            {
                'epoch': self.current_epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'best_val_acc': self.results['val'][self.current_epoch],
                'test_acc': self.results['test'][self.current_epoch]
            }, file_path)

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_test_multi(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        single_correct = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0],
                                       self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[
                0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            jigsaw_logit, single_logit = self.model(data[:, 0])
            _, jig_pred = jigsaw_logit.max(dim=1)
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0])
        return jigsaw_correct, class_correct, single_correct

    def do_training(self):
        self.logger = Logger(self.args,
                             update_frequency=30)  # , "domain", "lambda"
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }

        for self.current_epoch in range(self.args.epochs):
            start_time = time.time()
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
            end_time = time.time()
            print(f"Runtime of the epoch is {end_time - start_time}")

        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g" %
              (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())

        # Save Arguments
        with open(osp.join('logs', self.folder_name, 'args.txt'), 'w') as f:
            json.dump(self.args.__dict__, f, indent=2)

        # Save results
        with open(osp.join('logs', self.folder_name, 'results.txt'), 'w') as f:
            f.write("Best val %g, corresponding test %g - best test: %g" %
                    (val_res.max(), test_res[idx_best], test_res.max()))

        return self.logger, self.model
Ejemplo n.º 7
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(
            jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_test_loaders = data_helper.get_jigsaw_test_dataloaders(
            args, patches=model.is_patch_based())
        # Evaluate on Validation & Test datasets
        self.evaluation_loaders = {
            "val": self.val_loader,
            "test": self.target_test_loaders
        }

        print("Dataset size: train %d, val %d" %
              (len(self.source_loader.dataset), len(self.val_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)

            self.optimizer.zero_grad()
            jigsaw_logit, class_logit = self.model(data)
            jigsaw_loss = criterion(jigsaw_logit, jig_l)

            if self.only_non_scrambled:  # 只对正常图片进行物种分类
                if self.target_id is not None:
                    # 图片没有被打乱 && 图片的 domain 不是 target domain
                    #(因为我们不训练target domain,target domain的图片只用来 predict)
                    idx = (jig_l == 0) & (d_idx != self.target_id)
                    class_loss = criterion(class_logit[idx], class_l[idx])
                else:
                    class_loss = criterion(class_logit[jig_l == 0],
                                           class_l[jig_l == 0])

            elif self.target_id:  # 对所有(包括打乱的)图片进行物种分类,target domain 只用于 predict
                class_loss = criterion(class_logit[d_idx != self.target_id],
                                       class_l[d_idx != self.target_id])
            else:  # 对所有(包括打乱的)图片进行物种分类,target domain 只用于 predict
                class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            loss = class_loss + jigsaw_loss * self.jig_weight  # + 0.1 * domain_loss

            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it, len(self.source_loader), {
                    "jigsaw": jigsaw_loss.item(),
                    "class": class_loss.item()
                }, {
                    "jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            # 解除变量引用与实际值的指向关系
            del loss, class_loss, jigsaw_loss, jigsaw_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.evaluation_loaders.items():
                if phase == 'test':
                    belonged_dataset = data_helper.get_belonged_dataset(
                        self.args.source[0])
                    target_domains = [
                        item for item in belonged_dataset
                        if item not in self.args.source
                    ]

                    acc_sum = 0.0
                    for didx in range(len(loader)):
                        dkey = phase + '-' + target_domains[didx]

                        test_loader = loader[didx]
                        test_total = len(test_loader.dataset)
                        jigsaw_correct, class_correct = self.do_test(
                            test_loader)

                        jigsaw_acc = float(jigsaw_correct) / total
                        class_acc = float(class_correct) / test_total

                        self.logger.log_test(dkey, {"class": class_acc})
                        if dkey not in self.results.keys():
                            self.results[dkey] = torch.zeros(self.args.epochs)
                        self.results[dkey][self.current_epoch] = class_acc
                        acc_sum += class_acc
                    self.logger.log_test(phase,
                                         {"class": acc_sum / len(loader)})
                    self.results[phase][
                        self.current_epoch] = acc_sum / len(loader)
                else:
                    total = len(loader.dataset)
                    if loader.dataset.isMulti():
                        jigsaw_correct, class_correct, single_acc = self.do_test_multi(
                            loader)
                        print("Single vs multi: %g %g" %
                              (float(single_acc) / total,
                               float(class_correct) / total))
                    else:
                        jigsaw_correct, class_correct = self.do_test(loader)

                    jigsaw_acc = float(jigsaw_correct) / total
                    class_acc = float(class_correct) / total
                    self.logger.log_test(phase, {
                        "jigsaw": jigsaw_acc,
                        "class": class_acc
                    })
                    self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_test_multi(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        single_correct = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0],
                                       self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[
                0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            jigsaw_logit, single_logit = self.model(data[:, 0])
            _, jig_pred = jigsaw_logit.max(dim=1)
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0])
        return jigsaw_correct, class_correct, single_correct

    def do_training(self):
        self.logger = Logger(self.args,
                             update_frequency=30)  # , "domain", "lambda"
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx__val_best = val_res.argmax()
        idx_test_best = test_res.argmax()
        print("Best test acc: %g in epoch: %d" %
              (test_res.max(), idx_test_best + 1))
        self.logger.save_best(test_res[idx_test_best].item(),
                              test_res.max().item())
        return self.logger, self.model
Ejemplo n.º 8
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=True, classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=True, classes=args.n_classes)
        else:
            model = resnet18(pretrained=True, classes=args.n_classes)
        self.model = model.to(device)
        self.D_model = IntraClsInfoMax(alpha=args.alpha,
                                       beta=args.beta,
                                       gamma=args.gamma).to(device)
        # print(self.model)
        # print(self.D_model)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            [self.model, self.D_model.global_d, self.D_model.local_d],
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.dis_optimizer, self.dis_scheduler = get_optim_and_scheduler(
            [self.D_model.prior_d],
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)  #args.learning_ratee*1e-3
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
        self.max_test_acc = 0.0
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }

    def _do_epoch(self, device='cuda'):

        criterion = nn.CrossEntropyLoss()
        self.model.train()
        self.D_model.train()
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)

            self.optimizer.zero_grad()

            data_flip = torch.flip(data, (3, )).detach().clone()
            data = torch.cat((data, data_flip))
            class_l = torch.cat((class_l, class_l))

            y, M = self.model(data, feature_flag=True)

            # Classification Loss
            class_logit = self.model.class_classifier(y)
            class_loss = criterion(class_logit, class_l)

            # G loss - DIM Loss - P_loss
            M_prime = torch.cat(
                (M[1:], M[0].unsqueeze(0)),
                dim=0)  # Move feature to front position one by one
            class_prime = torch.cat((class_l[1:], class_l[0].unsqueeze(0)),
                                    dim=0)
            class_ll = (class_l, class_prime)

            DIM_loss = self.D_model(y, M, M_prime, class_ll)
            P_loss = self.D_model.prior_loss(y)

            DIM_loss = DIM_loss - P_loss
            # DIM_loss=self.beta*(DIM_loss-P_loss)
            loss = class_loss + DIM_loss
            loss.backward()
            self.optimizer.step()

            self.dis_optimizer.zero_grad()
            P_loss = self.D_model.prior_loss(y.detach())
            P_loss.backward()
            self.dis_optimizer.step()

            # Prediction
            _, cls_pred = class_logit.max(dim=1)

            losses = {
                'class': class_loss.detach().item(),
                'DIM': DIM_loss.detach().item(),
                'P_loss': P_loss.detach().item()
            }
            self.logger.log(
                it, len(self.source_loader), losses, {
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            del loss, class_loss, class_logit, DIM_loss

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)

                class_correct = self.do_test(loader)

                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"class": class_acc})
                self.results[phase][self.current_epoch] = class_acc
                if phase == 'test' and class_acc > self.max_test_acc:
                    torch.save(
                        self.model.state_dict(),
                        os.path.join(self.logger.log_path,
                                     'best_{}.pth'.format(phase)))

    def do_test(self, loader):
        class_correct = 0
        for it, ((data, nouse, class_l), _) in enumerate(loader):
            data, nouse, class_l = data.to(self.device), nouse.to(
                self.device), class_l.to(self.device)

            class_logit = self.model(data, feature_flag=False)
            _, cls_pred = class_logit.max(dim=1)

            class_correct += torch.sum(cls_pred == class_l.data)

        return class_correct

    def do_training(self):
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.dis_scheduler.step()
            self.logger.new_epoch(
                [*self.scheduler.get_lr(), *self.dis_scheduler.get_lr()])
            self._do_epoch()  # use self.current_epoch
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print(
            "Best val %g, corresponding test %g - best test: %g, best epoch: %g"
            % (val_res.max(), test_res[idx_best], test_res.max(), idx_best))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 9
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device

        model = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args)
        self.target_loader = data_helper.get_val_dataloader(args)

        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, (data, class_l) in enumerate(self.source_loader):

            data, class_l = data.to(self.device), class_l.to(self.device)

            self.optimizer.zero_grad()

            class_logit = self.model(data)
            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)

            loss = class_loss

            loss.backward()

            self.optimizer.step()

            self.logger.log(it, len(self.source_loader), {
                "Class Loss ": class_loss.item()
            }, {"Class Accuracy ":
                torch.sum(cls_pred == class_l.data).item()}, data.shape[0])
            del loss, class_loss, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                class_correct = self.do_test(loader)
                class_acc = float(class_correct) / total
                self.logger.log_test(phase,
                                     {"Classification Accuracy": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        class_correct = 0
        for it, (data, class_l) in enumerate(loader):
            data, class_l = data.to(self.device), class_l.to(self.device)
            class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
        return class_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }

        for self.current_epoch in range(self.args.epochs):
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
            self.scheduler.step()

        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g" %
              (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 10
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device

        model = model_factory.get_network(args.network)(classes=args.n_classes, jigsaw_classes=31, rotation_classes=4, odd_classes=9)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args)
        self.target_loader = data_helper.get_val_dataloader(args)

        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes

        self.nTasks = 2
        if args.rotation == True:
            self.nTasks += 1
        if args.odd_one_out == True:
            self.nTasks += 1

        print("N of tasks: " + str(self.nTasks))

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, (data, class_l, jigsaw_label, task_type) in enumerate(self.source_loader):
            
            rotation_loss = 0
            odd_one_out_loss = 0
            rotation_pred = 0
            odd_pred = 0

            data, class_l, jigsaw_label, task_type = data.to(self.device), class_l.to(self.device), jigsaw_label.to(self.device), task_type.to(self.device)

            self.optimizer.zero_grad()

            class_logit, jigsaw_logit, rotation_logit, odd_logit = self.model(data)
            class_loss = criterion(class_logit[task_type==0], class_l[task_type==0])
            jigsaw_loss = criterion(jigsaw_logit[(task_type==0) | (task_type==1)], jigsaw_label[(task_type==0) | (task_type==1)])

            _, cls_pred = class_logit.max(dim=1)
            _, jigsaw_pred = jigsaw_logit.max(dim=1)

            if self.args.rotation == True:
                #Rotation loss if the task is classification of "rotation"
                rotation_loss = criterion(rotation_logit[(task_type==0) | (task_type==2)], jigsaw_label[(task_type==0) | (task_type==2)])

                _, rotation_pred = rotation_logit.max(dim=1)

            if self.args.odd_one_out == True:
                #Odd one out loss if the task is classification of "rotation"
                odd_one_out_loss = criterion(odd_logit[(task_type==0) | (task_type==3)], jigsaw_label[(task_type==0) | (task_type==3)])

                _, odd_pred = odd_logit.max(dim=1)

            jig_loss = jigsaw_loss * self.args.jigsaw_alpha
            rot_loss = rotation_loss * self.args.beta_rotated
            odd_loss = odd_one_out_loss * self.args.beta_odd
            loss = class_loss + jig_loss + rot_loss + odd_loss + odd_loss

            loss.backward()

            self.optimizer.step()

            self.logger.log(it, len(self.source_loader),
                            {"Class Loss ": class_loss.item()},
                            {"Class Accuracy ": torch.sum(cls_pred == class_l.data).item()},
                            data.shape[0])

            self.logger.log(it, len(self.source_loader),
                            {"Jigsaw Loss ": jigsaw_loss.item()},
                            {"Jigsaw Accuracy ": torch.sum(jigsaw_pred == jigsaw_label.data).item()},
                            data.shape[0])

            if self.args.rotation == True:
                self.logger.log(it, len(self.source_loader),
                                {"Rotation Loss ": rotation_loss.item()},
                                {"Rotation Accuracy ": torch.sum(rotation_pred == jigsaw_label.data).item()},
                                data.shape[0])

            if self.args.odd_one_out == True:
                self.logger.log(it, len(self.source_loader),
                                {"Odd one out Loss ": odd_loss.item()},
                                {"Odd one out Accuracy ": torch.sum(odd_pred == jigsaw_label.data).item()},
                                data.shape[0])

            del loss, class_loss, class_logit, jigsaw_loss, jigsaw_logit
            del rotation_loss, odd_one_out_loss, jig_loss, rot_loss, odd_loss, rotation_pred, odd_pred

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)

                class_correct, jigsaw_correct, rotation_correct, odd_correct = self.do_test(loader)
                
                class_acc = float(class_correct) / total
                jigsaw_acc = float(jigsaw_correct) / total
                rotation_acc = 0
                odd_acc = 0
               
                if self.args.rotation == True:
                    rotation_acc = float(rotation_correct) / total

                if self.args.odd_one_out == True:
                    odd_acc = float(odd_correct) / total
 

                self.logger.log_test(phase, {"Classification Accuracy": class_acc, "Jigsaw Accuracy": jigsaw_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        class_correct = 0
        jigsaw_correct = 0
        rotation_correct = 0
        odd_correct = 0

        for it, (data, class_l, jigsaw_label, task_type) in enumerate(loader):
            data, class_l, jigsaw_label, task_type = data.to(self.device), class_l.to(self.device), jigsaw_label.to(self.device), task_type.to(self.device)
            
            class_logit, jigsaw_logit, rotation_logit, odd_logit = self.model(data)

            _, cls_pred = class_logit.max(dim=1)
            _, jigsaw_pred = jigsaw_logit.max(dim=1)

            if self.args.rotation == True:
                _, rotation_pred = rotation_logit.max(dim=1)
                rotation_correct += torch.sum(rotation_pred == jigsaw_label.data)
    
            if self.args.odd_one_out == True:
                _, odd_pred = odd_logit.max(dim=1)
                odd_correct += torch.sum(odd_pred == jigsaw_label.data)

            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jigsaw_pred == jigsaw_label.data)

        return class_correct, jigsaw_correct, rotation_correct, odd_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}

        for self.current_epoch in range(self.args.epochs):
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
            self.scheduler.step()

        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 11
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based())
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov, adam=args.adam)
        self.jig_weight = args.jig_weight
        self.rex_weight_class = args.rex_weight_class
        self.irm_weight_class = args.irm_weight_class
        self.rex_weight_jigsaw = args.rex_weight_jigsaw
        self.irm_weight_jigsaw = args.irm_weight_jigsaw
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(self.device), d_idx.to(self.device)
            self.optimizer.zero_grad()

            jigsaw_logit, class_logit = self.model(data)  # , lambda_val=lambda_val)
            j1 = criterion(jigsaw_logit[d_idx == 0], jig_l[d_idx == 0])
            j1_irm = compute_irm_penalty(jigsaw_logit[d_idx == 0], jig_l[d_idx == 0], criterion)
            j2 = criterion(jigsaw_logit[d_idx == 1], jig_l[d_idx == 1])
            j2_irm = compute_irm_penalty(jigsaw_logit[d_idx == 1], jig_l[d_idx == 1], criterion)
            j3 = criterion(jigsaw_logit[d_idx == 2], jig_l[d_idx == 2])
            j3_irm = compute_irm_penalty(jigsaw_logit[d_idx == 2], jig_l[d_idx == 2], criterion)
            rex_jigsaw = compute_rex_penalty(j1,j2,j3)
            jigsaw_loss = j1+j2+j3
            irm_jigsaw = (j1_irm+j2_irm+j3_irm)/3
            if self.only_non_scrambled:
                if self.target_id is not None:
                    idx = (jig_l == 0) & (d_idx != self.target_id)
                    class_loss = criterion(class_logit[idx], class_l[idx])
                    rex_class = torch.Tensor([0.]).cuda()
                    irm_class = torch.Tensor([0.]).cuda()
                else:
                    class_loss_1 = criterion(class_logit[(jig_l == 0) & (d_idx == 0)], class_l[(jig_l == 0) & (d_idx == 0)])
                    class_irm_1 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 0)], class_l[(jig_l == 0) & (d_idx == 0)], criterion)
                    class_loss_2 = criterion(class_logit[(jig_l == 0) & (d_idx == 1)], class_l[(jig_l == 0) & (d_idx == 1)])
                    class_irm_2 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 1)], class_l[(jig_l == 0) & (d_idx == 1)], criterion)
                    class_loss_3 = criterion(class_logit[(jig_l == 0) & (d_idx == 2)], class_l[(jig_l == 0) & (d_idx == 2)])
                    class_irm_3 = compute_irm_penalty(class_logit[(jig_l == 0) & (d_idx == 2)], class_l[(jig_l == 0) & (d_idx == 2)], criterion)
                    class_loss = class_loss_1 + class_loss_2 + class_loss_3
                    irm_class = (class_irm_1 + class_irm_2 + class_irm_3)/3
                    rex_class = compute_rex_penalty(class_loss_1, class_loss_2, class_loss_2) 

            elif self.target_id:
                class_loss = criterion(class_logit[d_idx != self.target_id], class_l[d_idx != self.target_id])
                rex_class = torch.Tensor([0.]).cuda()
                irm_class = torch.Tensor([0.]).cuda()
            else:
                class_loss_1 = criterion(class_logit[(d_idx == 0)], class_l[(d_idx == 0)])
                class_irm_1 = compute_irm_penalty(class_logit[(d_idx == 0)], class_l[(d_idx == 0)], criterion)
                class_loss_2 = criterion(class_logit[(d_idx == 1)], class_l[(d_idx == 1)])
                class_irm_2 = compute_irm_penalty(class_logit[(d_idx == 1)], class_l[(d_idx == 1)], criterion)
                class_loss_3 = criterion(class_logit[(d_idx == 2)], class_l[(d_idx == 2)])
                class_irm_3 = compute_irm_penalty(class_logit[(d_idx == 2)], class_l[(d_idx == 2)], criterion)
                class_loss = class_loss_1 + class_loss_2 + class_loss_3
                irm_class = (class_irm_1 + class_irm_2 + class_irm_3)/3
                rex_class = compute_rex_penalty(class_loss_1, class_loss_2, class_loss_2) 
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            # _, domain_pred = domain_logit.max(dim=1)


            rex_loss = self.rex_weight_class * rex_class + self.rex_weight_jigsaw * self.jig_weight * rex_jigsaw
            irm_loss = self.irm_weight_class * irm_class + self.irm_weight_jigsaw * self.jig_weight * irm_jigsaw
            if self.rex_weight_class == 0. and self.rex_weight_jigsaw == 0. and self.irm_weight_jigsaw == 0. and self.irm_weight_class == 0.:
                loss = class_loss + jigsaw_loss * self.jig_weight
            elif self.irm_weight_jigsaw == 0. and self.irm_weight_class == 0.:
                loss = class_loss + jigsaw_loss * self.jig_weight + rex_loss 
            elif self.rex_weight_class == 0. and self.rex_weight_jigsaw == 0.:
                loss = class_loss + jigsaw_loss * self.jig_weight + irm_loss 

            loss.backward()
            self.optimizer.step()

            self.logger.log(it, len(self.source_loader),
                            {"jigsaw": jigsaw_loss.item(), "class": class_loss.item(), "rex loss class": rex_class.item(), "rex loss jigsaw": rex_jigsaw.item(), "rext total": rex_loss.item(), "irm loss class": irm_class.item(), "irm loss jigsaw": irm_jigsaw.item(), "irm total": irm_loss.item()},
                            # ,"lambda": lambda_val},
                            {"jigsaw": torch.sum(jig_pred == jig_l.data).item(),
                             "class": torch.sum(cls_pred == class_l.data).item(),
                             # "domain": torch.sum(domain_pred == d_idx.data).item()
                             },
                            data.shape[0])
            del loss, class_loss, jigsaw_loss, rex_loss, jigsaw_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                if loader.dataset.isMulti():
                    jigsaw_correct, class_correct, single_acc = self.do_test_multi(loader)
                    print("Single vs multi: %g %g" % (float(single_acc) / total, float(class_correct) / total))
                else:
                    jigsaw_correct, class_correct = self.do_test(loader)
                jigsaw_acc = float(jigsaw_correct) / total
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"jigsaw": jigsaw_acc, "class": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        domain_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device)
            jigsaw_logit, class_logit = self.model(data)
            _, cls_pred = class_logit.max(dim=1)
            _, jig_pred = jigsaw_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data)
        return jigsaw_correct, class_correct

    def do_test_multi(self, loader):
        jigsaw_correct = 0
        class_correct = 0
        single_correct = 0
        for it, ((data, jig_l, class_l), d_idx) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device)
            n_permutations = data.shape[1]
            class_logits = torch.zeros(n_permutations, data.shape[0], self.n_classes).to(self.device)
            for k in range(n_permutations):
                class_logits[k] = F.softmax(self.model(data[:, k])[1], dim=1)
            class_logits[0] *= 4 * n_permutations  # bias more the original image
            class_logit = class_logits.mean(0)
            _, cls_pred = class_logit.max(dim=1)
            jigsaw_logit, single_logit = self.model(data[:, 0])
            _, jig_pred = jigsaw_logit.max(dim=1)
            _, single_logit = single_logit.max(dim=1)
            single_correct += torch.sum(single_logit == class_l.data)
            class_correct += torch.sum(cls_pred == class_l.data)
            jigsaw_correct += torch.sum(jig_pred == jig_l.data[:, 0])
        return jigsaw_correct, class_correct, single_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)  # , "domain", "lambda"
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        #print("Best val %g, corresponding test %g - best test: %g" % (val_res.max(), test_res[idx_best], test_res.max()))
        name = self.args.prefix+"_"+str(self.args.source[0])+str(self.args.source[1])+str(self.args.source[2])+"_"+str(self.args.target)+"_eps%d_bs%d_lr%g_class%d_jigClass%d_rexWeightClass%g_rexWeightJig%g_irmWeightClass%g_irmWeightJig%g_jigWeight%g" % (self.args.epochs, self.args.batch_size, self.args.learning_rate, self.args.n_classes, self.args.jigsaw_n_classes, self.args.rex_weight_class, self.args.rex_weight_jigsaw, self.args.irm_weight_class, self.args.irm_weight_jigsaw, self.args.jig_weight)
        with open('./result_summary_txt/'+name+'.txt', 'a+') as f:
            f.write('best validation accuracy: '+str(val_res.max())+' test acc at best val acc: '+str(test_res[idx_best])+' max test: '+str(test_res.max()))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 12
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based())
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all, nesterov=args.nesterov)
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            NUM_DOMAINS = 3
            oh_dids = torch.tensor(one_hot(d_idx, NUM_DOMAINS), dtype=torch.float, device='cuda')
            
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(self.device), d_idx.to(self.device)

            self.optimizer.zero_grad()

            specific_logit, class_logit = self.model(data, oh_dids) 
            specific_loss = criterion(specific_logit, class_l)

            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)

            sms = self.model.sms
            K = 2
            diag_tensor = torch.stack([torch.eye(K) for _ in range(self.n_classes)], dim=0).cuda()
            cps = torch.stack([torch.matmul(sms[:, :, _], torch.transpose(sms[:, :, _], 0, 1)) for _ in range(self.n_classes)], dim=0)
            if self.args.network.startswith('caffenet'):
              orth_loss = torch.mean((1 - diag_tensor)*(cps - diag_tensor)**2)
            else:
              orth_loss = torch.mean((cps - diag_tensor)**2)
            
            loss = class_loss + specific_loss + orth_loss 
            
            loss.backward()
            self.optimizer.step()

            self.logger.log(it, len(self.source_loader),
                            {"specific": specific_loss.item(), "class": class_loss.item() 
                             },
                            {"class": torch.sum(cls_pred == class_l.data).item()
                            },
                            data.shape[0])
            del loss, class_loss, specific_loss, specific_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                specific_correct, class_correct = self.do_test(loader)
                specific_acc = float(specific_correct) / total
                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"specific": specific_acc, "class": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        specific_correct = 0
        class_correct = 0
        for it, ((data, jig_l, class_l), _) in enumerate(loader):
            data, jig_l, class_l = data.to(self.device), jig_l.to(self.device), class_l.to(self.device)
            dummy_ids = one_hot(np.zeros(len(data), dtype=np.int32), 3)
            specific_logit, class_logit = self.model(data, torch.tensor(dummy_ids, dtype=torch.float, device='cuda'))
            _, cls_pred = class_logit.max(dim=1)
            _, specific_pred = specific_logit.max(dim=1)
            class_correct += torch.sum(cls_pred == class_l.data)
            specific_correct += torch.sum(specific_pred == class_l.data)
        print (self.model.embs, self.model.cs_wt)
        return specific_correct, class_correct

    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)  # , "domain", "lambda"
        self.results = {"val": torch.zeros(self.args.epochs), "test": torch.zeros(self.args.epochs)}
        k = 512
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 13
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=True, classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=True, classes=args.n_classes)
        else:
            model = resnet18(pretrained=True, classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        #source_loader의 length
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        #optimizer : SGD
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

    def _do_epoch(self, epoch=None):
        #=============================
        #train
        #=============================
        criterion = nn.CrossEntropyLoss()
        #train 모드
        self.model.train()
        #it : batch 몇번째인지
        #data : input image
        #jig_l : ? (일단 무조건 0으로 고정되어 있음)
        #class_l = class label index
        #d_idx= ?
        for it, ((data, jig_l, class_l),
                 d_idx) in enumerate(self.source_loader):
            #data들 cuda에 올리기
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(
                self.device), class_l.to(self.device), d_idx.to(self.device)
            #gradient descent 직전에 초기화 해주기
            self.optimizer.zero_grad()

            #3 dimension axis에 대해 flip
            #detach를 통해 original data tensor 대한 연산들이 추적되는 것을 방지
            #clone을 통해 autograd relationship이 없는 tensor를 생성
            #어쨌든 data_flip은 computational graph에서 빠져있기 때문에 data tensor에 영향을 미치지 못함
            print(data.shape)
            data_flip = torch.flip(data, (3, )).detach().clone()

            #밑의 2가지 concatenate를 하면서 batch가 64+64=128이 됨
            #data와 data_flip을 concatenate
            #data.shape = (128,3,222,222)
            data = torch.cat((data, data_flip))
            #class label을 concatenate
            class_l = torch.cat((class_l, class_l))

            #class score vector 구하기
            print(0, data.shape)
            class_logit = self.model(data, class_l, True, epoch)
            #loss구하기
            class_loss = criterion(class_logit, class_l)
            #class prediction
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss

            #구한 loss로부터 back propagation을 통해 각 변수마다 loss에 대한 gradient 를 구해주기
            loss.backward()
            #model의 paramater update
            self.optimizer.step()

            self.logger.log(
                it, len(self.source_loader), {"class": class_loss.item()}, {
                    "class": torch.sum(cls_pred == class_l.data).item(),
                }, data.shape[0])
            del loss, class_loss, class_logit

        #=============================
        #test
        #=============================
        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)

                class_correct = self.do_test(loader)

                class_acc = float(class_correct) / total
                self.logger.log_test(phase, {"class": class_acc})
                self.results[phase][self.current_epoch] = class_acc

    def do_test(self, loader):
        class_correct = 0
        for it, ((data, nouse, class_l), _) in enumerate(loader):
            data, nouse, class_l = data.to(self.device), nouse.to(
                self.device), class_l.to(self.device)

            class_logit = self.model(data, class_l, False)
            _, cls_pred = class_logit.max(dim=1)

            class_correct += torch.sum(cls_pred == class_l.data)

        return class_correct

    #train 함수
    def do_training(self):
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
        #epoch만큼 train
        for self.current_epoch in range(self.args.epochs):
            #scheduler에 따른 learning rate 갱신
            self.scheduler.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            #실제 rsc algortihm이 포함된 코드 수행
            self._do_epoch(self.current_epoch)
        val_res = self.results["val"]
        test_res = self.results["test"]
        idx_best = val_res.argmax()
        print(
            "Best val %g, corresponding test %g - best test: %g, best epoch: %g"
            % (val_res.max(), test_res[idx_best], test_res.max(), idx_best))
        self.logger.save_best(test_res[idx_best], test_res.max())
        return self.logger, self.model
Ejemplo n.º 14
0
class Trainer:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler, self.optimizer_par, self.scheduler_par = get_optim_and_scheduler_PAR(
            model,
            args.epochs,
            args.learning_rate,
            args.par_learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.par_weight = args.par_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
            # import ipdb;ipdb.set_trace()

    def accuracy(self, output, target, topk=(1, )):
        """Computes the accuracy over the k top predictions for the specified values of k"""
        with torch.no_grad():
            maxk = max(topk)
            batch_size = target.size(0)

            _, pred = output.topk(maxk, 1, True, True)
            pred = pred.t()
            correct = pred.eq(target.view(1, -1).expand_as(pred))

            res = []
            for k in topk:
                correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
                res.append(correct_k)
            return res

    def _do_epoch(self):
        criterion = nn.CrossEntropyLoss()
        self.model.train()
        for it, ((data, _, class_l), _) in enumerate(self.source_loader):
            data, class_l = data.to(self.device), class_l.to(self.device)

            # update par classifier
            self.optimizer_par.zero_grad()
            class_logit, par_logit = self.model(data)
            m, n = par_logit.shape[1], par_logit.shape[2]
            par_class_l = class_l.view(-1, 1, 1, 1).repeat(1, m, n, 1).view(-1)
            par_loss = criterion(par_logit.view(-1, self.n_classes),
                                 par_class_l)
            _, par_pred = par_logit.view(-1, self.n_classes).max(dim=1)
            par_loss.backward()
            self.optimizer_par.step()

            # update main classifier
            self.optimizer.zero_grad()
            class_logit, par_logit = self.model(data)
            class_loss = criterion(class_logit, class_l)
            # import ipdb;ipdb.set_trace()
            par_loss2 = criterion(par_logit.view(-1, self.n_classes),
                                  par_class_l)
            # top1_correct_pred, top5_correct_pred = self.accuracy(class_logit, class_l, topk=[1,5])
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss - par_loss2 * self.par_weight
            # loss = class_loss
            loss.backward()
            self.optimizer.step()

            self.logger.log(
                it,
                len(self.source_loader),
                {
                    "par": par_loss.item(),
                    "class": class_loss.item()
                },
                # ,"lambda": lambda_val},
                {
                    "par":
                    torch.sum(par_pred == par_class_l.data).type(
                        torch.FloatTensor) / (m * n),
                    "class":
                    torch.sum(cls_pred == class_l.data).item(),
                    # "top5 class": top5_correct_pred.item(),
                },
                data.shape[0])
            # print(time()-begin)
            del loss, class_loss, par_loss, par_logit, class_logit

        self.model.eval()
        with torch.no_grad():
            for phase, loader in self.test_loaders.items():
                total = len(loader.dataset)
                par_correct, top1_correct_pred, top5_correct_pred = self.do_test(
                    loader)
                par_acc = float(par_correct) / total
                class_top1_acc = float(top1_correct_pred) / total
                class_top5_acc = float(top5_correct_pred) / total
                self.logger.log_test(
                    phase, {
                        "par": par_acc,
                        "class top1": class_top1_acc,
                        "class top5": class_top5_acc
                    })
                self.results[phase +
                             'top1'][self.current_epoch] = class_top1_acc
                self.results[phase +
                             'top5'][self.current_epoch] = class_top5_acc

    def do_test(self, loader):
        par_correct = 0
        # class_correct = 0
        class_correct_top1 = 0
        class_correct_top5 = 0
        domain_correct = 0
        for it, ((data, _, class_l), _) in enumerate(loader):
            data, class_l = data.to(self.device), class_l.to(self.device)
            class_logit, par_logit = self.model(data)
            m, n = par_logit.shape[1], par_logit.shape[2]
            par_class_l = class_l.view(-1, 1, 1, 1).repeat(1, m, n, 1).view(-1)
            _, cls_pred = class_logit.max(dim=1)
            _, par_pred = par_logit.view(-1, self.n_classes).max(dim=1)
            top1_correct_pred, top5_correct_pred = self.accuracy(class_logit,
                                                                 class_l,
                                                                 topk=[1, 5])
            # class_correct += torch.sum(cls_pred == class_l.data)
            class_correct_top1 += top1_correct_pred
            class_correct_top5 += top5_correct_pred
            # import ipdb;ipdb.set_trace()
            par_correct += torch.sum(par_pred == par_class_l.data).type(
                torch.FloatTensor) / (m * n)
        return par_correct, class_correct_top1, class_correct_top5

    def do_training(self):
        self.logger = Logger(self.args,
                             update_frequency=30)  # , "domain", "lambda"
        self.results = {
            "valtop1": torch.zeros(self.args.epochs),
            "valtop5": torch.zeros(self.args.epochs),
            "testtop1": torch.zeros(self.args.epochs),
            "testtop5": torch.zeros(self.args.epochs)
        }
        for self.current_epoch in range(self.args.epochs):
            self.scheduler.step()
            self.scheduler_par.step()
            self.logger.new_epoch(self.scheduler.get_lr())
            self._do_epoch()
        val_res = self.results["valtop1"]
        testtop1_res = self.results["testtop1"]
        testtop5_res = self.results["testtop5"]
        idx_best = val_res.argmax()
        print(
            "Best val %g, corresponding test top1 acc %g top5 acc %g - best test top1: %g, top5: %g"
            % (val_res.max(), testtop1_res[idx_best], testtop5_res[idx_best],
               testtop1_res.max(), testtop5_res.max()))
        self.logger.save_best(testtop1_res[idx_best], testtop1_res.max())
        return self.logger, self.model