コード例 #1
0
def load_mask_detector():
    ############### Change path to path of model ###################
    # path_model='/content/drive/MyDrive/frinks/models/fastai_resnet101'
    path_model='/content/drive/MyDrive/frinks/fewShots/CloserLookFewShot/checkpoints_masks_Conv4_baseline_aug/20.tar'
    path_data='/content/drive/MyDrive/frinks/Faces/data'
    if flag is 'torch':
      if not flag_fewShots:
        model = torch.load(path_model)
      if flag_fewShots:
        # import pdb; pdb.set_trace()
        model = BaselineTrain(backbone.Conv4, 4)
        model_dict = torch.load(path_model)
        model.load_state_dict(model_dict['state'])
        model=model.cuda()
    elif flag is 'fastai':
      data = ImageDataBunch.from_folder(path_data, valid_pct=0.2, size = 120)
      model = cnn_learner(data, models.resnet101, metrics=error_rate)
      model.load(path_model)
    else:
      model = LoadModel(path_model)
    return model
コード例 #2
0
            model           = RelationNet( feature_model, loss_type = loss_type , **train_few_shot_params )
        elif params.method in ['maml' , 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model           = MAML(  model_dict[params.model], approx = (params.method == 'maml_approx') , **train_few_shot_params )
            if params.dataset in ['omniglot', 'cross_char']: #maml use different parameter in omniglot
                model.n_task     = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
       raise ValueError('Unknown method')

    model = model.cuda()

    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' %(configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method  in ['baseline', 'baseline++']: 
        params.checkpoint_dir += '_%dway_%dshot' %( params.train_n_way, params.n_shot)

    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)

    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx' :
        stop_epoch = params.stop_epoch * model.n_task #maml use multiple tasks in one update 
コード例 #3
0
def main_train(params):
    _set_seed(params)

    results_logger = ResultsLogger(params)

    if params.dataset == 'cross':
        base_file = configs.data_dir['miniImagenet'] + 'all.json'
        val_file = configs.data_dir['CUB'] + 'val.json'
    elif params.dataset == 'cross_char':
        base_file = configs.data_dir['omniglot'] + 'noLatin.json'
        val_file = configs.data_dir['emnist'] + 'val.json'
    else:
        base_file = configs.data_dir[params.dataset] + 'base.json'
        val_file = configs.data_dir[params.dataset] + 'val.json'
    if 'Conv' in params.model:
        if params.dataset in ['omniglot', 'cross_char']:
            image_size = 28
        else:
            image_size = 84
    else:
        image_size = 224
    if params.dataset in ['omniglot', 'cross_char']:
        assert params.model == 'Conv4' and not params.train_aug, 'omniglot only support Conv4 without augmentation'
        params.model = 'Conv4S'
    optimization = 'Adam'
    if params.stop_epoch == -1:
        if params.method in ['baseline', 'baseline++']:
            if params.dataset in ['omniglot', 'cross_char']:
                params.stop_epoch = 5
            elif params.dataset in ['CUB']:
                params.stop_epoch = 200  # This is different as stated in the open-review paper. However, using 400 epoch in baseline actually lead to over-fitting
            elif params.dataset in ['miniImagenet', 'cross']:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 400  # default
        else:  # meta-learning methods
            if params.n_shot == 1:
                params.stop_epoch = 600
            elif params.n_shot == 5:
                params.stop_epoch = 400
            else:
                params.stop_epoch = 600  # default
    if params.method in ['baseline', 'baseline++']:
        base_datamgr = SimpleDataManager(image_size, batch_size=16)
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)
        val_datamgr = SimpleDataManager(image_size, batch_size=64)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)

        if params.dataset == 'omniglot':
            assert params.num_classes >= 4112, 'class number need to be larger than max label id in base class'
        if params.dataset == 'cross_char':
            assert params.num_classes >= 1597, 'class number need to be larger than max label id in base class'

        if params.method == 'baseline':
            model = BaselineTrain(model_dict[params.model], params.num_classes)
        elif params.method == 'baseline++':
            model = BaselineTrain(model_dict[params.model],
                                  params.num_classes,
                                  loss_type='dist')

    elif params.method in [
            'DKT', 'protonet', 'matchingnet', 'relationnet',
            'relationnet_softmax', 'maml', 'maml_approx'
    ]:
        n_query = max(
            1, int(16 * params.test_n_way / params.train_n_way)
        )  # if test_n_way is smaller than train_n_way, reduce n_query to keep batch size small

        train_few_shot_params = dict(n_way=params.train_n_way,
                                     n_support=params.n_shot)
        base_datamgr = SetDataManager(image_size,
                                      n_query=n_query,
                                      **train_few_shot_params)  # n_eposide=100
        base_loader = base_datamgr.get_data_loader(base_file,
                                                   aug=params.train_aug)

        test_few_shot_params = dict(n_way=params.test_n_way,
                                    n_support=params.n_shot)
        val_datamgr = SetDataManager(image_size,
                                     n_query=n_query,
                                     **test_few_shot_params)
        val_loader = val_datamgr.get_data_loader(val_file, aug=False)
        # a batch for SetDataManager: a [n_way, n_support + n_query, dim, w, h] tensor

        if (params.method == 'DKT'):
            model = DKT(model_dict[params.model], **train_few_shot_params)
            model.init_summary()
        elif params.method == 'protonet':
            model = ProtoNet(model_dict[params.model], **train_few_shot_params)
        elif params.method == 'matchingnet':
            model = MatchingNet(model_dict[params.model],
                                **train_few_shot_params)
        elif params.method in ['relationnet', 'relationnet_softmax']:
            if params.model == 'Conv4':
                feature_model = backbone.Conv4NP
            elif params.model == 'Conv6':
                feature_model = backbone.Conv6NP
            elif params.model == 'Conv4S':
                feature_model = backbone.Conv4SNP
            else:
                feature_model = lambda: model_dict[params.model](flatten=False)
            loss_type = 'mse' if params.method == 'relationnet' else 'softmax'

            model = RelationNet(feature_model,
                                loss_type=loss_type,
                                **train_few_shot_params)
        elif params.method in ['maml', 'maml_approx']:
            backbone.ConvBlock.maml = True
            backbone.SimpleBlock.maml = True
            backbone.BottleneckBlock.maml = True
            backbone.ResNet.maml = True
            model = MAML(model_dict[params.model],
                         approx=(params.method == 'maml_approx'),
                         **train_few_shot_params)
            if params.dataset in [
                    'omniglot', 'cross_char'
            ]:  # maml use different parameter in omniglot
                model.n_task = 32
                model.task_update_num = 1
                model.train_lr = 0.1
    else:
        raise ValueError('Unknown method')
    model = model.cuda()
    params.checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
        configs.save_dir, params.dataset, params.model, params.method)
    if params.train_aug:
        params.checkpoint_dir += '_aug'
    if not params.method in ['baseline', 'baseline++']:
        params.checkpoint_dir += '_%dway_%dshot' % (params.train_n_way,
                                                    params.n_shot)
    if not os.path.isdir(params.checkpoint_dir):
        os.makedirs(params.checkpoint_dir)
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method == 'maml' or params.method == 'maml_approx':
        stop_epoch = params.stop_epoch * model.n_task  # maml use multiple tasks in one update
    if params.resume:
        resume_file = get_resume_file(params.checkpoint_dir)
        if resume_file is not None:
            tmp = torch.load(resume_file)
            start_epoch = tmp['epoch'] + 1
            model.load_state_dict(tmp['state'])
    elif params.warmup:  # We also support warmup from pretrained baseline feature, but we never used in our paper
        baseline_checkpoint_dir = '%s/checkpoints/%s/%s_%s' % (
            configs.save_dir, params.dataset, params.model, 'baseline')
        if params.train_aug:
            baseline_checkpoint_dir += '_aug'
        warmup_resume_file = get_resume_file(baseline_checkpoint_dir)
        tmp = torch.load(warmup_resume_file)
        if tmp is not None:
            state = tmp['state']
            state_keys = list(state.keys())
            for i, key in enumerate(state_keys):
                if "feature." in key:
                    newkey = key.replace(
                        "feature.", ""
                    )  # an architecture model has attribute 'feature', load architecture feature to backbone by casting name from 'feature.trunk.xx' to 'trunk.xx'
                    state[newkey] = state.pop(key)
                else:
                    state.pop(key)
            model.feature.load_state_dict(state)
        else:
            raise ValueError('No warm_up file')

    model = train(base_loader, val_loader, model, optimization, start_epoch,
                  stop_epoch, params, results_logger)
    results_logger.save()
コード例 #4
0
        elif params.method == 'S2M2_R' or 'rotation':
            if params.model == 'WideResNet28_10':
                model = wrn_mixup_model.wrn28_10(
                    num_classes=params.num_classes,
                    dct_status=params.dct_status)
            elif params.model == 'ResNet18':
                model = res_mixup_model.resnet18(
                    num_classes=params.num_classes)

        if params.method == 'baseline++':
            if use_gpu:
                if torch.cuda.device_count() > 1:
                    model = torch.nn.DataParallel(
                        model, device_ids=range(torch.cuda.device_count()))
                model.cuda()

            if params.resume:
                resume_file = get_resume_file(params.checkpoint_dir)
                tmp = torch.load(resume_file)
                start_epoch = tmp['epoch'] + 1
                state = tmp['state']
                model.load_state_dict(state)
            model = torch.nn.DataParallel(model).cuda()
            cudnn.benchmark = True
            optimization = 'Adam'
            model = train_baseline(base_loader, base_loader_test, val_loader,
                                   model, start_epoch,
                                   start_epoch + stop_epoch, params, {})

        elif params.method == 'S2M2_R':
コード例 #5
0
def main():
    timer = Timer()
    args, writer = init()

    train_file = args.dataset_dir + 'train.json'
    val_file = args.dataset_dir + 'val.json'

    few_shot_params = dict(n_way=args.n_way, n_support=args.n_shot, n_query=args.n_query)
    n_episode = 10 if args.debug else 100
    if args.method_type is Method_type.baseline:
        train_datamgr = SimpleDataManager(train_file, args.dataset_dir, args.image_size, batch_size=64)
        train_loader = train_datamgr.get_data_loader(aug = True)
    else:
        train_datamgr = SetDataManager(train_file, args.dataset_dir, args.image_size,
                                       n_episode=n_episode, mode='train', **few_shot_params)
        train_loader = train_datamgr.get_data_loader(aug=True)

    val_datamgr = SetDataManager(val_file, args.dataset_dir, args.image_size,
                                     n_episode=n_episode, mode='val', **few_shot_params)
    val_loader = val_datamgr.get_data_loader(aug=False)

    if args.model_type is Model_type.ConvNet:
        pass
    elif args.model_type is Model_type.ResNet12:
        from methods.backbone import ResNet12
        encoder = ResNet12()
    else:
        raise ValueError('')

    if args.method_type is Method_type.baseline:
        from methods.baselinetrain import BaselineTrain
        model = BaselineTrain(encoder, args)
    elif args.method_type is Method_type.protonet:
        from methods.protonet import ProtoNet
        model = ProtoNet(encoder, args)
    else:
        raise ValueError('')

    from torch.optim import SGD,lr_scheduler
    if args.method_type is Method_type.baseline:
        optimizer = SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_epoch, eta_min=0, last_epoch=-1)
    else:
        optimizer = torch.optim.SGD(model.encoder.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4,
                                    nesterov=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)

    args.ngpu = torch.cuda.device_count()
    torch.backends.cudnn.benchmark = True
    model = model.cuda()

    label = torch.from_numpy(np.repeat(range(args.n_way), args.n_query))
    label = label.cuda()

    if args.test:
        test(model, label, args, few_shot_params)
        return

    if args.resume:
        resume_OK =  resume_model(model, optimizer, args, scheduler)
    else:
        resume_OK = False
    if (not resume_OK) and  (args.warmup is not None):
        load_pretrained_weights(model, args)

    if args.debug:
        args.max_epoch = args.start_epoch + 1

    for epoch in range(args.start_epoch, args.max_epoch):
        train_one_epoch(model, optimizer, args, train_loader, label, writer, epoch)
        scheduler.step()

        vl, va = val(model, args, val_loader, label)
        if writer is not None:
            writer.add_scalar('data/val_acc', float(va), epoch)
        print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va))

        if va >= args.max_acc:
            args.max_acc = va
            print('saving the best model! acc={:.4f}'.format(va))
            save_model(model, optimizer, args, epoch, args.max_acc, 'max_acc', scheduler)
        save_model(model, optimizer, args, epoch, args.max_acc, 'epoch-last', scheduler)
        if epoch != 0:
            print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch)))
    if writer is not None:
        writer.close()
    test(model, label, args, few_shot_params)