Пример #1
0
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=True, classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=True, classes=args.n_classes)
        else:
            model = resnet18(pretrained=True, classes=args.n_classes)
        self.model = model.to(device)
        self.D_model = DeepInfoMaxLoss(alpha=args.alpha,
                                       beta=args.beta,
                                       gamma=args.gamma).to(device)
        # print(self.model)
        # print(self.D_model)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            [self.model, self.D_model.global_d, self.D_model.local_d],
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.dis_optimizer, self.dis_scheduler = get_optim_and_scheduler(
            [self.D_model.prior_d],
            args.epochs,
            args.learning_rate * 1e-3,
            args.train_all,
            nesterov=args.nesterov)
        # args.learning_rate*1e-3
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
        self.max_test_acc = 0.0
        self.logger = Logger(self.args, update_frequency=30)
        self.results = {
            "val": torch.zeros(self.args.epochs),
            "test": torch.zeros(self.args.epochs)
        }
Пример #2
0
 def __init__(self, args, device):
     self.args = args
     self.device = device
     model = model_factory.get_network(args.network)(
         jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
     self.model = model.to(device)
     # print(self.model)
     self.source_loader, self.val_loader = data_helper.get_train_dataloader(
         args, patches=model.is_patch_based())
     self.target_loader = data_helper.get_val_dataloader(
         args, patches=model.is_patch_based())
     self.test_loaders = {
         "val": self.val_loader,
         "test": self.target_loader
     }
     self.len_dataloader = len(self.source_loader)
     print("Dataset size: train %d, val %d, test %d" %
           (len(self.source_loader.dataset), len(
               self.val_loader.dataset), len(self.target_loader.dataset)))
     self.optimizer, self.scheduler = get_optim_and_scheduler(
         model,
         args.epochs,
         args.learning_rate,
         args.train_all,
         nesterov=args.nesterov)
     self.jig_weight = args.jig_weight
     self.only_non_scrambled = args.classify_only_sane
     self.n_classes = args.n_classes
     if args.target in args.source:
         self.target_id = args.source.index(args.target)
         print("Target in source: %d" % self.target_id)
         print(args.source)
     else:
         self.target_id = None
    def __init__(self, args, device):
        self.args = args
        self.device = device

        # Logger
        self.log_frequency = 10

        model = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)

        # The training dataset get divided into two parts (Train & Validation)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args)
        self.target_loader = data_helper.get_target_dataloader(args)
        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.init_train_dataset_size = len(self.source_loader.dataset)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, train_all=True, nesterov=args.nesterov, adam = args.adam)
        self.n_classes = args.n_classes

        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
Пример #4
0
    def __init__(self, args, device):
        self.args = args
        self.device = device

        model = model_factory.get_network(args.network)(classes=args.n_classes, jigsaw_classes=31, rotation_classes=4, odd_classes=9)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args)
        self.target_loader = data_helper.get_val_dataloader(args)

        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes

        self.nTasks = 2
        if args.rotation == True:
            self.nTasks += 1
        if args.odd_one_out == True:
            self.nTasks += 1

        print("N of tasks: " + str(self.nTasks))
Пример #5
0
 def __init__(self, args, device):
     self.args = args
     self.device = device
     model = resnet18(pretrained=True, classes=args.n_classes)
     self.model = model.to(device)
     # print(self.model)
     self.source_loader, self.val_loader = data_helper.get_train_dataloader(
         args, patches=model.is_patch_based())
     self.target_loader = data_helper.get_val_dataloader(
         args, patches=model.is_patch_based())
     self.test_loaders = {
         "val": self.val_loader,
         "test": self.target_loader
     }
     self.len_dataloader = len(self.source_loader)
     print("Dataset size: train %d, val %d, test %d" %
           (len(self.source_loader.dataset), len(
               self.val_loader.dataset), len(self.target_loader.dataset)))
     self.optimizer, self.scheduler = get_optim_and_scheduler(
         model,
         args.epochs,
         args.learning_rate,
         args.train_all,
         nesterov=args.nesterov)
     self.n_classes = args.n_classes
     if args.target in args.source:
         self.target_id = args.source.index(args.target)
         print("Target in source: %d" % self.target_id)
         print(args.source)
     else:
         self.target_id = None
    def __init__(self, args, device):
        
        self.args = args
        self.device = device
        
        if args.scrambled == 0:
            args.jig_weight_source = 0
            args.jig_weight_target = 0
          
        if args.rotated == 0:
            args.rot_weight_source = 0
            args.rot_weight_target = 0
            
        if args.odd == 0:
            args.odd_weight_source = 0
            args.odd_weight_target = 0


        model = model_factory.get_network(args.network)(classes=args.n_classes, odd_classes = args.grid_size**2)
        self.model = model.to(device)
        
        # Source Loaders
        self.source_loader, self.val_loader = data_helper.get_train_dataloader_JiGen(args, device, "DA")
        
        # Target Loaders
        self.target_train_loader, self.target_loader = data_helper.get_target_loader(args, device, "DA")

        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        
        size = len(self.source_loader.dataset) + len(self.target_train_loader.dataset)
        print("Dataset size: train %d, val %d, test %d" % (size, len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes
        
        # Set JiGen parameters
        self.jig_alpha_t = args.jig_weight_target
        self.jig_alpha_s = args.jig_weight_source
        self.permutation_number = args.permutation_number
        
        
        # Set Routate parameters
        self.rot_alpha_t = args.rot_weight_target
        self.rot_alpha_s = args.rot_weight_source
        
        # Set target loss weight
        self.target_loss_weight = args.target_entropy_weight

        # Set Odd parameters
        self.odd_alpha_t = args.odd_weight_target
        self.odd_alpha_s = args.odd_weight_source
        
        self.epoch_count = 0
        self.tot_epoch = args.epochs
Пример #7
0
    def __init__(self, args, device):
        self.args = args
        self.device = device
        if args.network == 'resnet18':
            model = resnet18(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        elif args.network == 'resnet50':
            model = resnet50(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        else:
            model = resnet18(pretrained=self.args.pretrained,
                             classes=args.n_classes)
        self.model = model.to(device)

        if args.resume:
            if isfile(args.resume):
                print(f"=> loading checkpoint '{args.resume}'")
                checkpoint = torch.load(args.resume)
                self.args.start_epoch = checkpoint['epoch']
                self.model.load_state_dict(checkpoint['model'])
                print(
                    f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})"
                )
            else:
                raise ValueError(f"Failed to find checkpoint {args.resume}")

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        # self.target_loader = data_helper.get_val_dataloader(args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_tgt_dataloader(
            self.args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None
        self.topk = [0 for _ in range(3)]
Пример #8
0
    def __init__(self, args, device):
        self.alpha_jigsaw_weight = 0.5
        self.alpha_jigsaw_weight_target = 0.5
        self.alpha_rotation_weight = 0.5
        self.alpha_rotation_weight_target = 0.5
        self.alpha_odd_weight = 0.5
        self.alpha_odd_weight_target = 0.5

        self.entropi_ni = 0.1
        self.args = args
        self.device = device
        self.betaJigen = args.betaJigen
        #if args.rotation == True:
        model = model_factory.get_network(args.network)(classes=args.n_classes,
                                                        jigsaw_classes=31,
                                                        odd_classes=10,
                                                        rotation_classes=4)
        # elif args.oddOneOut == True:
        #     model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=10)
        #else:
        #    model = model_factory.get_network(args.network)(classes=args.n_classes,jigsaw_classes=31)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args)
        self.target_loader = data_helper.get_val_dataloader(args)
        self.targetAsSource_loader = data_helper.get_trainTargetAsSource_dataloader(
            args)

        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print(
            "Dataset size: train %d, val %d, test %d" %
            (len(self.source_loader.dataset) + len(self.targetAsSource_loader),
             len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes
        if args.oddOneOut == True and args.rotation == True:
            self.nTasks = 4
        elif args.oddOneOut == True or args.rotation == True:
            self.nTasks = 3
        else:
            self.nTasks = 2
Пример #9
0
    def __init__(self, args, device):
        self.args = args
        self.device = device

        model = model_factory.get_network(args.network)(classes=args.n_classes)
        self.model = model.to(device)

        self.source_loader, self.val_loader = data_helper.get_train_dataloader(args)
        self.target_loader = data_helper.get_val_dataloader(args)


        self.test_loaders = {"val": self.val_loader, "test": self.target_loader}
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" % (len(self.source_loader.dataset), len(self.val_loader.dataset), len(self.target_loader.dataset)))

        self.optimizer, self.scheduler = get_optim_and_scheduler(model, args.epochs, args.learning_rate, args.train_all)

        self.n_classes = args.n_classes
Пример #10
0
    def __init__(self, args, device):
        self.args = args
        self.device = device
        model = model_factory.get_network(args.network)(
            pretrained=args.imagenet,
            jigsaw_classes=args.jigsaw_n_classes + 1,
            classes=args.n_classes)
        self.model = model.to(device)
        # print(self.model)
        self.source_loader, self.val_loader = data_helper.get_train_dataloader(
            args, patches=model.is_patch_based())
        self.target_loader = data_helper.get_jigsaw_val_dataloader(
            args, patches=model.is_patch_based())
        self.test_loaders = {
            "val": self.val_loader,
            "test": self.target_loader
        }
        self.len_dataloader = len(self.source_loader)
        print("Dataset size: train %d, val %d, test %d" %
              (len(self.source_loader.dataset), len(
                  self.val_loader.dataset), len(self.target_loader.dataset)))
        self.optimizer, self.scheduler = get_optim_and_scheduler(
            model,
            args.epochs,
            args.learning_rate,
            args.train_all,
            nesterov=args.nesterov)
        self.jig_weight = args.jig_weight
        self.only_non_scrambled = args.classify_only_sane
        self.n_classes = args.n_classes
        if args.target in args.source:
            self.target_id = args.source.index(args.target)
            print("Target in source: %d" % self.target_id)
            print(args.source)
        else:
            self.target_id = None

        self.best_val_jigsaw = 0.0
        self.best_jigsaw_acc = 0.0
        _, logname = Logger.get_name_from_args(args)

        self.folder_name = "%s/%s_to_%s/%s" % (args.folder_name, "-".join(
            sorted(args.source)), args.target, logname)
Пример #11
0
 def __init__(self, args, device):
     self.args = args
     self.device = device
     model = model_factory.get_network(args.network)(
         jigsaw_classes=args.jigsaw_n_classes + 1, classes=args.n_classes)
     self.model = model.to(device)
     # print(self.model)
     if args.target in args.source:
         print(
             "No need to include target in source, it is automatically done by this script"
         )
         k = args.source.index(args.target)
         args.source = args.source[:k] + args.source[k + 1:]
         print("Source: %s" % args.source)
     self.source_loader, self.val_loader = data_helper.get_train_dataloader(
         args, patches=model.is_patch_based())
     self.target_jig_loader = data_helper.get_target_jigsaw_loader(args)
     self.target_loader = data_helper.get_val_dataloader(
         args, patches=model.is_patch_based())
     self.test_loaders = {
         "val": self.val_loader,
         "test": self.target_loader
     }
     self.len_dataloader = len(self.source_loader)
     print("Dataset size: train %d, target jig: %d, val %d, test %d" %
           (len(self.source_loader.dataset),
            len(self.target_jig_loader.dataset), len(
                self.val_loader.dataset), len(self.target_loader.dataset)))
     self.optimizer, self.scheduler = get_optim_and_scheduler(
         model,
         args.epochs,
         args.learning_rate,
         args.train_all,
         nesterov=args.nesterov)
     self.jig_weight = args.jig_weight
     self.target_weight = args.target_weight
     self.target_entropy = args.entropy_weight
     self.only_non_scrambled = args.classify_only_sane
     self.n_classes = args.n_classes