Exemple #1
0
def load_network(code_path, babi_train_raw, babi_test_raw, word2vec,
                 word_vector_size, model_json):
    if model_json in loaded_models:
        print "!> model %s is already loaded" % model_json
        return loaded_models[model_json]

    print "!> loading model %s..." % model_json

    model_file = open(model_json)
    args_dict = json.load(model_file)

    assert word_vector_size == args_dict['word_vector_size']

    args_dict['babi_train_raw'] = babi_train_raw
    args_dict['babi_test_raw'] = babi_test_raw
    args_dict['word2vec'] = word2vec

    # init class
    if args_dict['network'] == 'dmn_batch':
        raise Exception("dmn_batch did not implement predict()")
        sys.path.insert(0, code_path)
        import dmn_batch
        dmn = dmn_batch.DMN_batch(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_basic':
        raise Exception("dmn_batch did not implement predict()")
        sys.path.insert(0, code_path)
        import dmn_basic
        dmn = dmn_basic.DMN_basic(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_smooth':
        sys.path.insert(0, code_path)
        import dmn_smooth
        dmn = dmn_smooth.DMN_smooth(**args_dict)
        sys.path.insert(0, current_dir)

    elif args_dict['network'] == 'dmn_qa':
        raise Exception("dmn_batch did not implement predict()")

        sys.path.insert(0, code_path)
        import dmn_qa_draft
        dmn = dmn_qa_draft.DMN_qa(**args_dict)
        sys.path.insert(0, current_dir)

    else:
        raise Exception("No such network known: " + args_dict['network'])

    print "!> loading state %s..." % args_dict['load_state']

    dmn.load_state(args_dict['load_state'])

    loaded_models[model_json] = dmn

    return dmn
Exemple #2
0
def get_dmn(network, batch_size, args_dict):
    # init class
    if network == 'dmn_batch':
        import dmn_batch
        dmn = dmn_batch.DMN_batch(**args_dict)
    else:
        if (batch_size != 1):
            print "==> no minibatch training, argument batch_size is useless"
            batch_size = 1

        if network == 'dmn_basic':
            import dmn_basic
            dmn = dmn_basic.DMN_basic(**args_dict)
        elif network == 'dmn_smooth':
            import dmn_smooth
            dmn = dmn_smooth.DMN_smooth(**args_dict)
        elif network == 'dmn_spv':
            import dmn_spv
            dmn = dmn_spv.DMN(**args_dict)
        else:
            raise Exception("No such network known: " + network)
    return dmn, batch_size
Exemple #3
0
    import dmn_batch
    dmn = dmn_batch.DMN_batch(**args_dict)

elif args.network == 'dmn_basic':
    import dmn_basic
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_basic.DMN_basic(**args_dict)

elif args.network == 'dmn_smooth':
    import dmn_smooth
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_smooth.DMN_smooth(**args_dict)

elif args.network == 'dmn_qa':
    import dmn_qa_draft
    if (args.batch_size != 1):
        print "==> no minibatch training, argument batch_size is useless"
        args.batch_size = 1
    dmn = dmn_qa_draft.DMN_qa(**args_dict)

else: 
    raise Exception("No such network known: " + args.network)
    

if args.load_state != "":
    dmn.load_state(args.load_state)