Пример #1
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        network = R2D2Embedding().cuda()
    elif options.network == 'ResNet':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            network = resnet12(avg_pool=False, drop_rate=0.1,
                               dropblock_size=5).cuda()
            network = torch.nn.DataParallel(network)  #, device_ids=[1, 2])
        else:
            network = resnet12(avg_pool=False, drop_rate=0.1,
                               dropblock_size=2).cuda()
    else:
        print("Cannot recognize the network type")
        assert (False)

    # Choose the classification head
    if options.head == 'Subspace':
        cls_head = ClassificationHead(base_learner='Subspace').cuda()
    elif options.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif options.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif options.head == 'R2D2':
        cls_head = ClassificationHead(base_learner='R2D2').cuda()
    elif options.head == 'SVM':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    else:
        print("Cannot recognize the dataset type")
        assert (False)

    return (network, cls_head)
Пример #2
0
def get_model(options):
    # Choose the embedding network
    # if options.network == 'ProtoNet':
    #     network = ProtoNetEmbedding().cuda()
    # elif options.network == 'R2D2':
    #     network = R2D2Embedding().cuda()
    if options.network == 'ResNet':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            network = resnet12(avg_pool=options.avg,
                               drop_rate=options.drop_rate,
                               dropblock_size=5).cuda()
            # network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3])
            # print(os.environ['CUDA_VISIBLE_DEVICES'])
            # print(list(range(len(options.gpu_ids))))
            network = torch.nn.DataParallel(network,
                                            device_ids=list(
                                                range(len(options.gpu_ids))))
        else:
            network = resnet12(avg_pool=options.avg,
                               drop_rate=options.drop_rate,
                               dropblock_size=2).cuda()
            network = torch.nn.DataParallel(network,
                                            device_ids=list(
                                                range(len(options.gpu_ids))))
    else:
        print("Cannot recognize the network type")

    # Choose the classification head
    try:
        if options.scale_const:
            cls_head = ScheduledClassificationHead(
                base_learner=options.head,
                enable_scale=options.scale,
                scale=options.scale_const,
                fn=lambda x: options.scale_const,
            ).cuda()
            # norm=options.norm, power=options.power).cuda()
            print('Use const scale {}.'.format(options.scale_const))
        elif options.scale_schedule:
            cls_head = ScheduledClassificationHead(
                base_learner=options.head,
                enable_scale=options.scale,
                scale=5.0,
                fn=lambda e: 10. if e < 14000 else (20. if e < 16000 else 30.
                                                    if e < 18000 else (50.)),
            ).cuda()
            # norm=options.norm, power=options.power).cuda()
            print('Use scheduled scale.')
        else:
            cls_head = ClassificationHead(
                base_learner=options.head,
                enable_scale=options.scale,
            ).cuda()
            # norm=options.norm, power=options.power).cuda()
            print('Use learnable scale')

    except Exception:
        print("Cannot recognize the classification head")

    return (network, cls_head)
Пример #3
0
def get_model(options):
    # Choose the embedding network & corresponding linear head
    if options.dataset == 'miniImageNet':
        if options.network == 'ProtoNet':
            network = ProtoNetEmbedding().cuda()
            cls_head = torch.nn.Linear(1600, 64).cuda()
        elif options.network == 'R2D2':
            network = R2D2Embedding().cuda()
            cls_head = torch.nn.Linear(51200, 64).cuda()
        elif options.network == 'ResNet':
            if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
                network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
                network = torch.nn.DataParallel(network)
            else:
                network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
            cls_head = torch.nn.Linear(16000, 64).cuda()
        elif options.network == 'WideResNet':
            if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
                network = wrn_28_10().cuda()
                network = torch.nn.DataParallel(network)
            else:
                network = wrn_28_10().cuda()
            cls_head = torch.nn.Linear(2560, 64).cuda()
        elif options.network == 'MAML':
            network = MAML_Embedding().cuda()
            cls_head = torch.nn.Linear(800,64).cuda()
        else:
            print ("Cannot recognize the network type")
            assert(False)
    elif options.dataset == 'CIFAR_FS':
        if options.network == 'ProtoNet':
            network = ProtoNetEmbedding().cuda()
            cls_head = torch.nn.Linear(256, 64).cuda()
        elif options.network == 'R2D2':
            network = R2D2Embedding().cuda()
            cls_head = torch.nn.Linear(8192, 64).cuda()
        elif options.network == 'ResNet':
            if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
                network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
                network = torch.nn.DataParallel(network)
            else:
                network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
            cls_head = torch.nn.Linear(2560, 64).cuda()
        elif options.network == 'WideResNet':
            if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
                network = wrn_28_10().cuda()
                network = torch.nn.DataParallel(network)
            else:
                network = wrn_28_10().cuda()
            cls_head = torch.nn.Linear(640, 64).cuda()
        elif options.network == 'MAML':
            network = MAML_Embedding().cuda()
            cls_head = torch.nn.Linear(800,64).cuda()
        else:
            print ("Cannot recognize the network type")
            assert(False)
    return (network, cls_head)
Пример #4
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        network = R2D2Embedding().cuda()
    elif options.network == 'R2D2_mixup':
        network = R2D2Embedding_mixup().cuda()
    elif options.network == 'ResNet_mixup':
        network = resnet12_mixup(avg_pool=False,
                                 drop_rate=0.1,
                                 dropblock_size=2).cuda()
    elif options.network == 'ResNet':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            network = resnet12(avg_pool=False, drop_rate=0.1,
                               dropblock_size=5).cuda()
            network = torch.nn.DataParallel(network)
        else:
            network = resnet12(avg_pool=False, drop_rate=0.1,
                               dropblock_size=2).cuda()
            network = torch.nn.DataParallel(network)
    else:
        print("Cannot recognize the network type")
        assert (False)

    # Choose the classification head
    if options.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif options.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif options.head == 'R2D2':
        cls_head = R2D2Head().cuda()
    elif options.head == 'SVM':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    else:
        print("Cannot recognize the dataset type")
        assert (False)

    if options.support_aug and 'mix' in options.support_aug:
        if options.head == 'R2D2':
            cls_head_mixup = R2D2Head_Mixup().cuda()
        elif options.head == 'SVM':
            cls_head_mixup = ClassificationHead_Mixup(
                base_learner='SVM-CS').cuda()
        else:
            print("Cannot recognize the dataset type")

        return (network, cls_head, cls_head_mixup)

    else:
        return (network, cls_head)
Пример #5
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        network = R2D2Embedding().cuda()
    elif options.network == 'ResNet':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5).cuda()
            # device_ids = list(range(len(options.gpu.split(','))))
            # network = torch.nn.DataParallel(network, device_ids=device_ids)
            # network = torch.nn.DataParallel(network, device_ids=[0, 1, 2, 3])
        else:
            network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2).cuda()
        # device_ids = list(range(len(options.gpu.split(','))))
        # network = torch.nn.DataParallel(network, device_ids=device_ids)
    elif options.network == 'ResNetRFS':
        if 'imagenet' in opt.dataset.lower():
            network = resnet12_rfs(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=5).cuda()
        else:
            network = resnet12_rfs(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=2).cuda()
        # device_ids = list(range(len(options.gpu.split(','))))
        # network = torch.nn.DataParallel(network, device_ids=device_ids)
    else:
        print ("Cannot recognize the network type")
        assert(False)

    # Choose the classification head
    if opt.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif opt.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif opt.head == 'R2D2':
        cls_head = ClassificationHead(base_learner='R2D2').cuda()
    elif opt.head == 'SVM' or 'LR':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    elif options.head == 'SVM-BiP':
        cls_head = ClassificationHead(base_learner='SVM-CS-BiP').cuda()
    else:
        print ("Cannot recognize the classification head type")
        assert(False)

    return (network, cls_head)
Пример #6
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().to(options.device)
    elif options.network == 'ResNet12':
        network = torch.nn.DataParallel(
            resnet12(options.device,
                     avg_pool=False,
                     drop_rate=.1,
                     dropblock_size=2).to(options.device))
    elif options.network == 'ResNet18':
        network = torch.nn.DataParallel(
            resnet18(pretrained=False).to(options.device))
    else:
        print("Cannot recognize the network type")
        assert (False)

    # Set the classification head
    cls_head = ClassificationHead(options.device).to(options.device)

    return (network, cls_head)
Пример #7
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        # 256 dimensional
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        # 8192 dimensional
        network = R2D2Embedding().cuda()
    elif options.network == 'WideResNet_DC':
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            if options.aws == 1:
                network = WideResNet(28,
                                     widen_factor=10,
                                     dropRate=0.1,
                                     avgpool_param=21).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
            else:
                network = WideResNet(28,
                                     widen_factor=10,
                                     dropRate=0.1,
                                     avgpool_param=21).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
        else:
            if options.aws == 1:
                network = WideResNet(28,
                                     widen_factor=10,
                                     dropRate=0.1,
                                     avgpool_param=8).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
            else:
                network = WideResNet(28,
                                     widen_factor=10,
                                     dropRate=0.1,
                                     avgpool_param=8).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
    elif options.network == 'ResNet_DC':
        # 2560 dimensional
        if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
            if options.aws == 1:
                network = resnet12(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=5,
                                   avgpool_param=5).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
            else:
                network = resnet12(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=5,
                                   avgpool_param=5).cuda()
                network = torch.nn.DataParallel(network)
        else:
            if options.aws == 1:
                network = resnet12(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=2,
                                   avgpool_param=2).cuda()
                network = torch.nn.DataParallel(network,
                                                device_ids=[0, 1, 2, 3])
            else:
                network = resnet12(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=2,
                                   avgpool_param=2).cuda()
                network = torch.nn.DataParallel(network)
    else:
        print("Cannot recognize the network type")
        assert (False)

    # Choose the classification head
    if options.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif options.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif options.head == 'R2D2':
        cls_head = ClassificationHead(base_learner='R2D2').cuda()
    elif options.head == 'SVM':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    elif options.head == 'FIML':
        cls_head = dimp_norm_init_shannon_hingeL2Loss(
            num_iter=options.steepest_descent_iter,
            norm_feat=options.norm_feat,
            entropy_weight=options.entropy_weight,
            entropy_temp=options.entropy_temp,
            learn_entropy_weights=options.learn_entropy_weights,
            learn_entropy_temp=options.learn_entropy_temp,
            learn_weights=options.learn_weights,
            pos_weight=options.pos_weight,
            neg_weight=options.neg_weight,
            learn_slope=options.learn_slope,
            pos_lrelu_slope=options.pos_lrelu_slope,
            neg_lrelu_slope=options.neg_lrelu_slope,
            learn_spatial_weight=options.learn_inner_spatial_weight,
            dc_factor=options.dc_factor)
    else:
        print("Cannot recognize the dataset type")
        assert (False)

    return (network, cls_head)
Пример #8
0
def get_model(options):
    # Choose the embedding network
    if options.network == 'ProtoNet':
        network = ProtoNetEmbedding().cuda()
    elif options.network == 'R2D2':
        network = R2D2Embedding().cuda()
    elif options.network == 'ResNet':
        if 'imagenet' in options.dataset.lower():
            network = resnet12(avg_pool=False,
                               drop_rate=0.1,
                               dropblock_size=5).cuda()
        else:
            network = resnet12(avg_pool=False,
                               drop_rate=0.1,
                               dropblock_size=2).cuda()
        device_ids = list(range(len(options.gpu.split(','))))
        network = torch.nn.DataParallel(network, device_ids=device_ids)
    elif options.network == 'ResNetRFS':
        if 'imagenet' in opt.dataset.lower():
            network = resnet12_rfs(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=5).cuda()
        else:
            network = resnet12_rfs(avg_pool=True,
                                   drop_rate=0.1,
                                   dropblock_size=2).cuda()
        device_ids = list(range(len(options.gpu.split(','))))
        network = torch.nn.DataParallel(network, device_ids=device_ids)
    elif options.network == 'ResNet_FiLM':
        film_act = None if options.no_film_activation else F.leaky_relu
        if 'imagenet' in options.dataset.lower():
            network = resnet12_film(
                avg_pool=False, drop_rate=0.1, dropblock_size=5,
                film_indim=opt.film_indim, film_alpha=1.0, film_act=film_act,
                final_relu=(not opt.no_final_relu),
                film_normalize=opt.film_normalize,
                dual_BN=options.dual_BN).cuda()
            options.film_preprocess_input_dim = 16000
        else:
            network = resnet12_film(
                avg_pool=False, drop_rate=0.1, dropblock_size=2,
                film_indim=options.film_indim, film_alpha=1.0, film_act=film_act,
                final_relu=(not opt.no_final_relu),
                film_normalize=opt.film_normalize,
                dual_BN=options.dual_BN).cuda()
        device_ids = list(range(len(options.gpu.split(','))))
        network = torch.nn.DataParallel(network, device_ids=device_ids)
    elif options.network == 'ResNetRFS_FiLM':
        film_act = None if options.no_film_activation else F.leaky_relu
        if 'imagenet' in opt.dataset.lower():
            network = resnet12_rfs_film(
                avg_pool=True, drop_rate=0.1, dropblock_size=5,
                film_indim=640, film_alpha=1.0, film_act=film_act,
                final_relu=(not options.no_final_relu),
                film_normalize=options.film_normalize,
                dual_BN=options.dual_BN).cuda()
        else:
            network = resnet12_rfs_film(
                avg_pool=True, drop_rate=0.1, dropblock_size=2,
                film_indim=640, film_alpha=1.0, film_act=film_act,
                final_relu=(not options.no_final_relu),
                film_normalize=options.film_normalize,
                dual_BN=options.dual_BN).cuda()
        device_ids = list(range(len(options.gpu.split(','))))
        network = torch.nn.DataParallel(network, device_ids=device_ids)
    else:
        print("Cannot recognize the network type")
        assert False

    # Choose the classification head
    if opt.head == 'ProtoNet':
        cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
    elif opt.head == 'Ridge':
        cls_head = ClassificationHead(base_learner='Ridge').cuda()
    elif opt.head == 'R2D2':
        cls_head = ClassificationHead(base_learner='R2D2').cuda()
    elif opt.head == 'SVM' or opt.head == 'LR':
        cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
    elif options.head == 'SVM-BiP':
        cls_head = ClassificationHead(base_learner='SVM-CS-BiP').cuda()
    else:
        print("Cannot recognize the classification head type")
        assert False

    return network, cls_head