Example #1
0
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()['model.network']

    if opt['name'] == 'vqa_net':
        net = VQANet(
            txt_enc=opt['txt_enc'],
            self_q_att=opt['self_q_att'],
            attention=opt['attention'],
            classif=opt['classif'],
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid)

    elif opt['name'] == 'vrd_net':
        net = VRDNet(opt)

    else:
        raise ValueError(opt['name'])

    if torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net
Example #2
0
def factory(engine):
    opt = Options()['model.network']
    if opt['name'] == "beef_hdd":
        net = BeefHDD(layers_to_fuse=opt['layers_to_fuse'],
                     label_fusion_opt=opt['label_fusion'],
                     blinkers_dim=opt['blinkers_dim'],
                     gru_opt=opt['gru_opt'],
                     n_future=opt['n_future'],
                     detach_pred=opt.get('detach_pred',False))
    elif opt['name'] == "driver_hdd":
        net = DriverHDD(blinkers_dim=opt['blinkers_dim'],
                                 gru_opt=opt['gru_opt'],
                                 n_future=opt['n_future'])
    elif opt['name'] == "baseline_multitask_hdd":
        net = BaselineMultitaskHDD(n_classes=opt['n_classes'],
                                    blinkers_dim=opt['blinkers_dim'],
                                    layer_to_extract=opt['layer_to_extract'],
                                    dim_features=opt['dim_features'],
                                    gru_opt=opt['gru_opt'],
                                    n_future=opt['n_future'],
                                    mlp_opt=opt.get('mlp_opt',None))
    else:
        raise ValueError(opt['name'])
    if torch.cuda.device_count()>1:
        net = DataParallel(net)
    return net
Example #3
0
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()['model.network']

    if opt['base'] == 'smrl':
        from .smrl_net import SMRLNet as BaselineNet
    elif opt['base'] == 'updn':
        from .updn_net import UpDnNet as BaselineNet
    elif opt['base'] == 'san':
        from .san_net import SANNet as BaselineNet
    else:
        raise ValueError(opt['base'])

    orig_net = BaselineNet(
        txt_enc=opt['txt_enc'],
        self_q_att=opt['self_q_att'],
        agg=opt['agg'],
        classif=opt['classif'],
        wid_to_word=dataset.wid_to_word,
        word_to_wid=dataset.word_to_wid,
        aid_to_ans=dataset.aid_to_ans,
        ans_to_aid=dataset.ans_to_aid,
        fusion=opt['fusion'],
        residual=opt['residual'],
        q_single=opt['q_single'],
    )

    if opt['name'] == 'baseline':
        net = orig_net

    elif opt['name'] == 'rubi':
        net = RUBiNet(model=orig_net,
                      output_size=len(dataset.aid_to_ans),
                      classif=opt['rubi_params']['mlp_q'])

    elif opt['name'] == 'cfvqa':
        net = CFVQA(model=orig_net,
                    output_size=len(dataset.aid_to_ans),
                    classif_q=opt['cfvqa_params']['mlp_q'],
                    classif_v=opt['cfvqa_params']['mlp_v'],
                    fusion_mode=opt['fusion_mode'],
                    is_va=True)

    elif opt['name'] == 'cfvqasimple':
        net = CFVQA(model=orig_net,
                    output_size=len(dataset.aid_to_ans),
                    classif_q=opt['cfvqa_params']['mlp_q'],
                    classif_v=None,
                    fusion_mode=opt['fusion_mode'],
                    is_va=False)

    else:
        raise ValueError(opt['name'])

    if Options()['misc.cuda'] and torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net
Example #4
0
def factory(engine):
    opt = Options()['model.network']

    if opt['name'] == '{PROJECT_NAME_LOWER}':
        net = {PROJECT_NAME}Network(**opt)  # noqa: E999
    else:
        raise ValueError(opt['name'])

    if torch.cuda.device_count() > 1:
        net = DataParallel(net)
    return net
Example #5
0
def factory(engine=None):

    Logger()('Creating MSEmbedding network...')

    if Options()['model']['network']['name'] == 'MSEmbeddingNet':
        network = MSEmbeddingNet()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    elif Options()['model']['network']['name'] == 'MSEmbedding_MLP_Net':
        network = MSEmbedding_MLP_Net()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    elif Options()['model']['network']['name'] == 'MSEmbeddingNormNet':
        network = MSEmbeddingNormNet()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    elif Options()['model']['network']['name'] == 'MSEmbeddingTransformerNet':
        network = MSEmbeddingTransformerNet()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    elif Options()['model']['network']['name'] == 'MSEmbeddingTransformer2Net':
        network = MSEmbeddingTransformer2Net()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    else:
        raise ValueError()

    return network
Example #6
0
def factory(engine=None):

    Logger()('Creating mnist network...')

    if Options()['model']['network']['name'] == 'net':
        network = Net()

        if Options()['misc']['cuda'] and len(utils.available_gpu_ids()) >= 2:
            network = DataParallel(network)

    else:
        raise ValueError()

    return network
Example #7
0
def factory(engine):
    logger = Logger()
    net_opt = Options()["model"]["network"]
    logger("Creating Network...")

    if net_opt["name"] == "{PROJECT_NAME_LOWER}network":
        # You can use any param to create your network
        # You just have to write them in your option file from options/ folder
        net = {PROJECT_NAME}Network(net_opt["param1"], net_opt["param2"])
    else:
        raise ValueError(opt["name"])
    logger("Network was created")
    if torch.cuda.device_count() > 1:
        net = DataParallel(net)
    return net
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()['model.network']

    if opt['name'] == 'baseline':
        net = BaselineNet(
            txt_enc=opt['txt_enc'],
            self_q_att=opt['self_q_att'],
            agg=opt['agg'],
            classif=opt['classif'],
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid,
            fusion=opt['fusion'],
            residual=opt['residual'],
        )

    elif opt['name'] == 'rubi':
        orig_net = BaselineNet(
            txt_enc=opt['txt_enc'],
            self_q_att=opt['self_q_att'],
            agg=opt['agg'],
            classif=opt['classif'],
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid,
            fusion=opt['fusion'],
            residual=opt['residual'],
        )
        net = RUBiNet(
            model=orig_net,
            output_size=len(dataset.aid_to_ans),
            classif=opt['rubi_params']['mlp_q']
        )
    else:
        raise ValueError(opt['name'])

    if Options()['misc.cuda'] and torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()["model.network"]

    if True:
        module, class_name = opt["name"].rsplit(".", 1)
        try:
            cls = getattr(
                import_module("." + module, "counting.models.networks"),
                class_name)
        except:
            traceback.print_exc()
            Logger()(f"Error importing class {module}, {class_name}")
            sys.exit(1)
        print("Network parameters", opt["parameters"])
        # check if @ in parameters
        print("checking if @ in parameters")
        parameters = opt.get("parameters", {}) or {}
        for key, value in parameters.items():  # TODO intégrer ça à bootstrap
            if value == "@dataset":
                print("loading dataset")
            elif value == "@engine":
                opt["parameters"][key] = engine
            elif value == "@aid_to_ans":
                opt["parameters"][key] = dataset.aid_to_ans
            elif value == "@ans_to_aid":
                opt["parameters"][key] = dataset.ans_to_aid
        net = cls(
            **parameters,
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid,
        )

    if Options()["misc.cuda"] and torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net
Example #10
0
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()['model.network']

    if opt['name'] == 'ban_net':
        net = BanNet(txt_enc=opt['txt_enc'],
                     glimpse=opt['glimpse'],
                     objects=opt['objects'],
                     feat_dims=opt['feat_dims'],
                     q_max_length=opt['q_max_length'],
                     wid_to_word=dataset.wid_to_word,
                     word_to_wid=dataset.word_to_wid,
                     aid_to_ans=dataset.aid_to_ans,
                     ans_to_aid=dataset.ans_to_aid)

    else:
        raise ValueError(opt['name'])

    if torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net
def factory(engine):
    mode = list(engine.dataset.keys())[0]
    dataset = engine.dataset[mode]
    opt = Options()['model.network']

    if opt['name'] == 'attention_net':
        net = AttentionNet(
            txt_enc=opt['txt_enc'],
            self_q_att=opt['self_q_att'],
            attention=opt['attention'],
            classif=opt['classif'],
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid)

    elif opt['name'] == 'murel_net':
        net = MuRelNet(
            txt_enc=opt['txt_enc'],
            self_q_att=opt['self_q_att'],
            n_step=opt['n_step'],
            shared=opt['shared'],
            cell=opt['cell'],
            agg=opt['agg'],
            classif=opt['classif'],
            wid_to_word=dataset.wid_to_word,
            word_to_wid=dataset.word_to_wid,
            aid_to_ans=dataset.aid_to_ans,
            ans_to_aid=dataset.ans_to_aid)

    else:
        raise ValueError(opt['name'])

    if torch.cuda.device_count() > 1:
        net = DataParallel(net)

    return net