Пример #1
0
    def __init__(self, args):
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        pre_base_dir = osp.join(log_base_dir, 'pre')
        if not osp.exists(pre_base_dir):
            os.mkdir(pre_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type])
        save_path2 = 'batchsize' + str(args.pre_batch_size) + '_lr' + str(args.pre_lr) + '_gamma' + str(args.pre_gamma) + '_step' + \
            str(args.pre_step_size) + '_maxepoch' + str(args.pre_max_epoch)
        args.save_path = pre_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        self.args = args

        if self.args.dataset == 'MiniImageNet':
            from dataloader.mini_imagenet import MiniImageNet as Dataset
        elif self.args.dataset == 'TieredImageNet':
            from dataloader.tiered_imagenet import TieredImageNet as Dataset
        elif self.args.dataset == 'FC100':
            from dataloader.fewshotcifar import FewshotCifar as Dataset
        else:
            raise ValueError('Please set correct dataset.')

        self.trainset = Dataset('train', self.args, train_aug=True)
        self.train_loader = DataLoader(dataset=self.trainset,
                                       batch_size=args.pre_batch_size,
                                       shuffle=True,
                                       num_workers=8,
                                       pin_memory=True)

        self.valset = Dataset('test', self.args)
        self.val_sampler = CategoriesSampler(
            self.valset.label, 600, self.args.way,
            self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset,
                                     batch_sampler=self.val_sampler,
                                     num_workers=8,
                                     pin_memory=True)

        num_class_pretrain = self.trainset.num_class

        self.model = MtlLearner(self.args,
                                mode='pre',
                                num_cls=num_class_pretrain)

        self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': self.args.pre_lr}, \
            {'params': self.model.pre_fc.parameters(), 'lr': self.args.pre_lr}], \
                momentum=self.args.pre_custom_momentum, nesterov=True, weight_decay=self.args.pre_custom_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.pre_step_size, \
            gamma=self.args.pre_gamma)

        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()
Пример #2
0
    def __init__(self, args):
        log_base_dir = './logs/'
        if not osp.exists(log_base_dir):
            os.mkdir(log_base_dir)
        meta_base_dir = osp.join(log_base_dir, 'meta')
        if not osp.exists(meta_base_dir):
            os.mkdir(meta_base_dir)
        save_path1 = '_'.join([args.dataset, args.model_type, 'MTL'])
        save_path2 = 'shot' + str(args.shot) + '_way' + str(args.way) + '_query' + str(args.train_query) + '_step' + str(args.step_size) + '_gamma' + str(args.gamma) + '_lr' + str(args.lr) + '_lrbase' + str(args.lr_base) + '_lrc' + str(args.lr_combination) + '_lrch' + str(args.lr_combination_hyperprior) + '_lrbs' + str(args.lr_basestep) + '_lrbsh' + str(args.lr_basestep_hyperprior) + '_batch' + str(args.num_batch) + '_maxepoch' + str(args.max_epoch) + '_csw' + str(args.hyperprior_combination_softweight) + '_cbsw' + str(args.hyperprior_basestep_softweight) + '_baselr' + str(args.base_lr) + '_updatestep' + str(args.update_step) + '_stepsize' + str(args.step_size) + '_' + args.label
        args.save_path = meta_base_dir + '/' + save_path1 + '_' + save_path2
        ensure_path(args.save_path)

        self.args = args

        if self.args.dataset == 'MiniImageNet':
            from dataloader.mini_imagenet import MiniImageNet as Dataset
        elif self.args.dataset == 'TieredImageNet':
            from dataloader.tiered_imagenet import TieredImageNet as Dataset
        elif self.args.dataset == 'FC100':
            from dataloader.fewshotcifar import FewshotCifar as Dataset
        else:
            raise ValueError('Non-supported Dataset.')

        self.trainset = Dataset('train', self.args)
        self.train_sampler = CategoriesSampler(self.trainset.label, self.args.num_batch, self.args.way, self.args.shot + self.args.train_query)
        self.train_loader = DataLoader(dataset=self.trainset, batch_sampler=self.train_sampler, num_workers=8, pin_memory=True)

        self.valset = Dataset('val', self.args)
        self.val_sampler = CategoriesSampler(self.valset.label, 3000, self.args.way, self.args.shot + self.args.val_query)
        self.val_loader = DataLoader(dataset=self.valset, batch_sampler=self.val_sampler, num_workers=8, pin_memory=True)
        
        self.model = MtlLearner(self.args)

        new_para = filter(lambda p: p.requires_grad, self.model.encoder.parameters())
        self.optimizer = torch.optim.Adam([{'params': new_para}, {'params': self.model.base_learner.parameters(), 'lr': self.args.lr_base}, {'params': self.model.get_hyperprior_combination_initialization_vars(), 'lr': self.args.lr_combination}, {'params': self.model.get_hyperprior_combination_mapping_vars(), 'lr': self.args.lr_combination_hyperprior}, {'params': self.model.get_hyperprior_basestep_initialization_vars(), 'lr': self.args.lr_basestep}, {'params': self.model.get_hyperprior_stepsize_mapping_vars(), 'lr': self.args.lr_basestep_hyperprior}], lr=self.args.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.args.step_size, gamma=self.args.gamma)        
        
        self.model_dict = self.model.state_dict()
        if self.args.init_weights is not None:
            pretrained_dict = torch.load(self.args.init_weights)['params']
            pretrained_dict = {'encoder.'+k: v for k, v in pretrained_dict.items()}
            pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in self.model_dict}
            print(pretrained_dict.keys())
            self.model_dict.update(pretrained_dict) 

        self.model.load_state_dict(self.model_dict)    
        
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            self.model = self.model.cuda()
Пример #3
0
        ])
        args.save_path = save_path1 + '_' + save_path2 + '_' + args.exp_addendum
        ensure_path(args.save_path)
    else:
        ensure_path(args.save_path)

    if args.dataset == 'MiniImageNet':
        # Handle MiniImageNet
        from dataloader.mini_imagenet import MiniImageNet as Dataset
    elif args.dataset == 'CUB':
        from dataloader.cub import CUB as Dataset
    else:
        raise ValueError('Non-supported Dataset.')

    # train n_batch is 100 by default, val n_batch is 500 by default
    trainset = Dataset('train', args)
    train_sampler = CategoriesSampler(trainset.label, 100, args.way,
                                      args.shot + args.query)
    train_loader = DataLoader(dataset=trainset,
                              batch_sampler=train_sampler,
                              num_workers=8,
                              pin_memory=True)

    valset = Dataset('val', args)
    val_sampler = CategoriesSampler(valset.label, 500, args.validation_way,
                                    args.shot + args.query)
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)
            ]
        )
        args.save_path = save_path1 + "_" + save_path2
        ensure_path(args.save_path)
    else:
        ensure_path(args.save_path)

    if args.dataset == "MiniImageNet":
        # Handle MiniImageNet
        from dataloader.mini_imagenet import MiniImageNet as Dataset
    elif args.dataset == "CUB":
        from dataloader.cub import CUB as Dataset
    else:
        raise ValueError("Non-supported Dataset.")

    trainset = Dataset("train", args)
    train_sampler = CategoriesSampler(
        trainset.label, 100, args.way, args.shot + args.query
    )
    train_loader = DataLoader(
        dataset=trainset, batch_sampler=train_sampler, num_workers=8, pin_memory=True
    )

    valset = Dataset("val", args)
    val_sampler = CategoriesSampler(
        valset.label, 500, args.validation_way, args.shot + args.query
    )
    val_loader = DataLoader(
        dataset=valset, batch_sampler=val_sampler, num_workers=8, pin_memory=True
    )
Пример #5
0
    def __init__(self, args):
        self.args = args

        if args.dataset == 'miniimagenet':
            from dataloader.mini_imagenet import MiniImageNet as Dataset
            args.num_class = 64
            print('Using dataset: miniImageNet, base class num:', args.num_class)
        elif args.dataset == 'cub':
            from dataloader.cub import CUB as Dataset
            args.num_class = 100
            print('Using dataset: CUB, base class num:', args.num_class)
        elif args.dataset == 'tieredimagenet':
            from dataloader.tiered_imagenet import tieredImageNet as Dataset
            args.num_class = 351
            print('Using dataset: tieredImageNet, base class num:', args.num_class)
        elif args.dataset == 'fc100':
            from dataloader.fc100 import DatasetLoader as Dataset
            args.num_class = 60
            print('Using dataset: FC100, base class num:', args.num_class)
        elif args.dataset == 'cifar_fs':
            from dataloader.cifar_fs import DatasetLoader as Dataset
            args.num_class = 64
            print('Using dataset: CIFAR-FS, base class num:', args.num_class)
        else:
            raise ValueError('Please set the correct dataset.')

        self.Dataset = Dataset

        if args.mode == 'pre_train':
            print('Building pre-train model.')
            self.model = importlib.import_module('model.meta_model').MetaModel(args, dropout=args.dropout, mode='pre')
        else:
            print('Building meta model.')
            self.model = importlib.import_module('model.meta_model').MetaModel(args, dropout=args.dropout, mode='meta')

        if args.mode == 'pre_train':
            print('Initialize the model for pre-train phase.')
        else:
            args.dir = 'pretrain_model/%s/%s/max_acc.pth' % (args.dataset, args.backbone)
            if not os.path.exists(args.dir):
                os.system('sh scripts/download_pretrain_model.sh')
            print('Loading pre-trainrd model from:\n', args.dir)
            model_dict = self.model.state_dict()
            pretrained_dict = torch.load(args.dir)['params']
            pretrained_dict = {'encoder.' + k: v for k, v in pretrained_dict.items()}
            for k, v in pretrained_dict.items():
                model_dict[k] = pretrained_dict[k]
            self.model.load_state_dict(model_dict)

        if self.args.num_gpu > 1:
            self.model = nn.DataParallel(self.model, list(range(args.num_gpu)))
        self.model = self.model.cuda()
        print('Building model finished.')

        if args.mode == 'pre_train':
            args.save_path = 'pre_train/%s-%s' % \
                             (args.dataset, args.backbone)
        else:
            args.save_path = 'meta_train/%s-%s-%s-%dway-%dshot' % \
                             (args.dataset, args.backbone, args.meta_update, args.way, args.shot)

        args.save_path = osp.join('logs', args.save_path)

        ensure_path(args.save_path)

        trainset = Dataset('train', args)
        if args.mode == 'pre_train':
            self.train_loader = DataLoader(dataset=trainset, batch_size=args.bs, shuffle=True,
                                           num_workers=args.num_workers, pin_memory=True)
        else:
            train_sampler = CategoriesSampler(trainset.label, args.val_frequency * args.bs, args.way,
                                              args.shot + args.query)
            self.train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, num_workers=args.num_workers,
                                           pin_memory=True)

        valset = Dataset(args.set, args)
        val_sampler = CategoriesSampler(valset.label, args.val_episode, args.way, args.shot + args.query)
        self.val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=args.num_workers,
                                     pin_memory=True)

        val_loader = [x for x in self.val_loader]

        if args.mode == 'pre_train':
            self.optimizer = torch.optim.SGD([{'params': self.model.encoder.parameters(), 'lr': args.lr},
                                              {'params': self.model.fc.parameters(), 'lr': args.lr}],
                                             momentum=0.9, nesterov=True, weight_decay=0.0005)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=args.step_size,
                                                                gamma=args.gamma)
        else:
            if args.meta_update == 'mtl':
                new_para = filter(lambda p: p.requires_grad, self.model.encoder.parameters())
            else:
                new_para = self.model.encoder.parameters()

            self.optimizer = torch.optim.SGD([{'params': new_para, 'lr': args.lr},
                                              {'params': self.model.base_learner.parameters(), 'lr': self.args.lr},
                                              {'params': self.model.get_hyperprior_combination_initialization_vars(),
                                               'lr': self.args.lr_combination},
                                              {'params': self.model.get_hyperprior_combination_mapping_vars(),
                                               'lr': self.args.lr_combination_hyperprior},
                                              {'params': self.model.get_hyperprior_basestep_initialization_vars(),
                                               'lr': self.args.lr_basestep},
                                              {'params': self.model.get_hyperprior_stepsize_mapping_vars(),
                                               'lr': self.args.lr_basestep_hyperprior}],
                                             lr=args.lr, momentum=0.9, nesterov=True, weight_decay=0.0005)

            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=args.step_size,
                                                                gamma=args.gamma)
Пример #6
0
    save_path1 = '-'.join([args.dataset, args.backbone_class, 'Pre'])
    save_path2 = '_'.join([str(args.lr), str(args.gamma), str(args.schedule)])
    args.save_path = osp.join(save_path1, save_path2)
    if not osp.exists(save_path1):
        os.mkdir(save_path1)
    ensure_path(args.save_path)

    if args.dataset == 'MiniImageNet':
        # Handle MiniImageNet
        from dataloader.mini_imagenet import MiniImageNet as Dataset
    elif args.dataset == 'TieredImagenet':
        from dataloader.tiered_imagenet import tieredImageNet as Dataset
    else:
        raise ValueError('Non-supported Dataset.')

    trainset = Dataset('train', args, augment=True)
    train_loader = DataLoader(dataset=trainset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    args.num_class = trainset.num_class
    valset = Dataset('val', args)
    val_sampler = CategoriesSampler(valset.label, 200, valset.num_class,
                                    1 + args.query)  # test on 16-way 1-shot
    val_loader = DataLoader(dataset=valset,
                            batch_sampler=val_sampler,
                            num_workers=8,
                            pin_memory=True)
    args.way = valset.num_class
    args.shot = 1