Пример #1
0
def train(config):

    print(json.dumps(config, indent=2), end='\n')
    # read basic config
    global_conf = config["global"]
    optimizer = global_conf['optimizer']
    optimizer = optimizers.get(optimizer)
    K.set_value(optimizer.lr, global_conf['learning_rate'])
    weights_file = str(global_conf['weights_file']) + '.%d'
    display_interval = int(global_conf['display_interval'])
    num_iters = int(global_conf['num_iters'])
    save_weights_iters = int(global_conf['save_weights_iters'])

    # read input config
    input_conf = config['inputs']
    share_input_conf = input_conf['share']

    # collect embedding
    if 'embed_path' in share_input_conf:
        embed_dict = read_embedding(filename=share_input_conf['embed_path'])
        _PAD_ = share_input_conf['vocab_size'] - 1
        embed_dict[_PAD_] = np.zeros((share_input_conf['embed_size'], ),
                                     dtype=np.float32)
        embed = np.float32(
            np.random.uniform(-0.2, 0.2, [
                share_input_conf['vocab_size'], share_input_conf['embed_size']
            ]))
        share_input_conf['embed'] = convert_embed_2_numpy(embed_dict,
                                                          embed=embed)
    else:
        embed = np.float32(
            np.random.uniform(-0.2, 0.2, [
                share_input_conf['vocab_size'], share_input_conf['embed_size']
            ]))
        share_input_conf['embed'] = embed
    print('[Embedding] Embedding Load Done.', end='\n')

    # list all input tags and construct tags config
    input_train_conf = OrderedDict()
    input_eval_conf = OrderedDict()
    for tag in input_conf.keys():
        if 'phase' not in input_conf[tag]:
            continue
        if input_conf[tag]['phase'] == 'TRAIN':
            input_train_conf[tag] = {}
            input_train_conf[tag].update(share_input_conf)
            input_train_conf[tag].update(input_conf[tag])
        elif input_conf[tag]['phase'] == 'EVAL':
            input_eval_conf[tag] = {}
            input_eval_conf[tag].update(share_input_conf)
            input_eval_conf[tag].update(input_conf[tag])
    print('[Input] Process Input Tags. %s in TRAIN, %s in EVAL.' %
          (input_train_conf.keys(), input_eval_conf.keys()),
          end='\n')

    # collect dataset identification
    dataset = {}
    for tag in input_conf:
        if tag != 'share' and input_conf[tag]['phase'] == 'PREDICT':
            continue
        if 'text1_corpus' in input_conf[tag]:
            datapath = input_conf[tag]['text1_corpus']
            if datapath not in dataset:
                dataset[datapath], _ = read_data(datapath)
        if 'text2_corpus' in input_conf[tag]:
            datapath = input_conf[tag]['text2_corpus']
            if datapath not in dataset:
                dataset[datapath], _ = read_data(datapath)
    print('[Dataset] %s Dataset Load Done.' % len(dataset), end='\n')

    # initial data generator
    train_gen = OrderedDict()
    eval_gen = OrderedDict()

    for tag, conf in input_train_conf.items():
        print(conf, end='\n')
        conf['data1'] = dataset[conf['text1_corpus']]
        conf['data2'] = dataset[conf['text2_corpus']]
        generator = inputs.get(conf['input_type'])
        train_gen[tag] = generator(config=conf)

    for tag, conf in input_eval_conf.items():
        print(conf, end='\n')
        conf['data1'] = dataset[conf['text1_corpus']]
        conf['data2'] = dataset[conf['text2_corpus']]
        generator = inputs.get(conf['input_type'])
        eval_gen[tag] = generator(config=conf)

    ######### Load Model #########
    model = load_model(config)

    loss = []
    for lobj in config['losses']:
        if lobj['object_name'] in mz_specialized_losses:
            loss.append(
                rank_losses.get(lobj['object_name'])(lobj['object_params']))
        else:
            loss.append(rank_losses.get(lobj['object_name']))
    eval_metrics = OrderedDict()
    for mobj in config['metrics']:
        mobj = mobj.lower()
        if '@' in mobj:
            mt_key, mt_val = mobj.split('@', 1)
            eval_metrics[mobj] = metric.get(mt_key)(int(mt_val))
        else:
            eval_metrics[mobj] = metric.get(mobj)
    model.compile(optimizer=optimizer, loss=loss)
    print('[Model] Model Compile Done.', end='\n')

    for i_e in range(num_iters):
        for tag, generator in train_gen.items():
            genfun = generator.get_batch_generator()
            print('[%s]\t[Train:%s] ' % (time.strftime(
                '%m-%d-%Y %H:%M:%S', time.localtime(time.time())), tag),
                  end='')
            history = model.fit_generator(genfun,
                                          steps_per_epoch=display_interval,
                                          epochs=1,
                                          shuffle=False,
                                          verbose=0)  #callbacks=[eval_map])
            print('Iter:%d\tloss=%.6f' % (i_e, history.history['loss'][0]),
                  end='\n')

        for tag, generator in eval_gen.items():
            genfun = generator.get_batch_generator()
            print('[%s]\t[Eval:%s] ' % (time.strftime(
                '%m-%d-%Y %H:%M:%S', time.localtime(time.time())), tag),
                  end='')
            res = dict([[k, 0.] for k in eval_metrics.keys()])
            num_valid = 0
            for input_data, y_true in genfun:
                y_pred = model.predict(input_data, batch_size=len(y_true))
                if issubclass(type(generator),
                              inputs.list_generator.ListBasicGenerator):
                    list_counts = input_data['list_counts']
                    for k, eval_func in eval_metrics.items():
                        for lc_idx in range(len(list_counts) - 1):
                            pre = list_counts[lc_idx]
                            suf = list_counts[lc_idx + 1]
                            res[k] += eval_func(y_true=y_true[pre:suf],
                                                y_pred=y_pred[pre:suf])
                    num_valid += len(list_counts) - 1
                else:
                    for k, eval_func in eval_metrics.items():
                        res[k] += eval_func(y_true=y_true, y_pred=y_pred)
                    num_valid += 1
            generator.reset()
            print('Iter:%d\t%s' % (i_e, '\t'.join(
                ['%s=%f' % (k, v / num_valid) for k, v in res.items()])),
                  end='\n')
            sys.stdout.flush()
        if (i_e + 1) % save_weights_iters == 0:
            model.save_weights(weights_file % (i_e + 1))
Пример #2
0
def predict(config):
    ######## Read input config ########

    print(json.dumps(config, indent=2), end='\n')
    input_conf = config['inputs']
    share_input_conf = input_conf['share']

    # collect embedding
    if 'embed_path' in share_input_conf:
        embed_dict = read_embedding(filename=share_input_conf['embed_path'])
        _PAD_ = share_input_conf['vocab_size'] - 1
        embed_dict[_PAD_] = np.zeros((share_input_conf['embed_size'], ),
                                     dtype=np.float32)
        embed = np.float32(
            np.random.uniform(-0.02, 0.02, [
                share_input_conf['vocab_size'], share_input_conf['embed_size']
            ]))
        share_input_conf['embed'] = convert_embed_2_numpy(embed_dict,
                                                          embed=embed)
    else:
        embed = np.float32(
            np.random.uniform(-0.2, 0.2, [
                share_input_conf['vocab_size'], share_input_conf['embed_size']
            ]))
        share_input_conf['embed'] = embed
    print('[Embedding] Embedding Load Done.', end='\n')

    # list all input tags and construct tags config
    input_predict_conf = OrderedDict()
    for tag in input_conf.keys():
        if 'phase' not in input_conf[tag]:
            continue
        if input_conf[tag]['phase'] == 'PREDICT':
            input_predict_conf[tag] = {}
            input_predict_conf[tag].update(share_input_conf)
            input_predict_conf[tag].update(input_conf[tag])
    print('[Input] Process Input Tags. %s in PREDICT.' %
          (input_predict_conf.keys()),
          end='\n')

    # collect dataset identification
    dataset = {}
    for tag in input_conf:
        if tag == 'share' or input_conf[tag]['phase'] == 'PREDICT':
            if 'text1_corpus' in input_conf[tag]:
                datapath = input_conf[tag]['text1_corpus']
                if datapath not in dataset:
                    dataset[datapath], _ = read_data(datapath)
            if 'text2_corpus' in input_conf[tag]:
                datapath = input_conf[tag]['text2_corpus']
                if datapath not in dataset:
                    dataset[datapath], _ = read_data(datapath)
    print('[Dataset] %s Dataset Load Done.' % len(dataset), end='\n')

    # initial data generator
    predict_gen = OrderedDict()

    for tag, conf in input_predict_conf.items():
        print(conf, end='\n')
        conf['data1'] = dataset[conf['text1_corpus']]
        conf['data2'] = dataset[conf['text2_corpus']]
        generator = inputs.get(conf['input_type'])
        predict_gen[tag] = generator(
            #data1 = dataset[conf['text1_corpus']],
            #data2 = dataset[conf['text2_corpus']],
            config=conf)

    ######## Read output config ########
    output_conf = config['outputs']

    ######## Load Model ########
    global_conf = config["global"]
    weights_file = str(global_conf['weights_file']) + '.' + str(
        global_conf['test_weights_iters'])

    model = load_model(config)
    model.load_weights(weights_file)

    eval_metrics = OrderedDict()
    for mobj in config['metrics']:
        mobj = mobj.lower()
        if '@' in mobj:
            mt_key, mt_val = mobj.split('@', 1)
            eval_metrics[mobj] = metric.get(mt_key)(int(mt_val))
        else:
            eval_metrics[mobj] = metric.get(mobj)
    res = dict([[k, 0.] for k in eval_metrics.keys()])

    for tag, generator in predict_gen.items():
        genfun = generator.get_batch_generator()
        print('[%s]\t[Predict] @ %s ' % (time.strftime(
            '%m-%d-%Y %H:%M:%S', time.localtime(time.time())), tag),
              end='')
        num_valid = 0
        res_scores = {}
        for input_data, y_true in genfun:
            y_pred = model.predict(input_data, batch_size=len(y_true))

            if issubclass(type(generator),
                          inputs.list_generator.ListBasicGenerator):
                list_counts = input_data['list_counts']
                for k, eval_func in eval_metrics.items():
                    for lc_idx in range(len(list_counts) - 1):
                        pre = list_counts[lc_idx]
                        suf = list_counts[lc_idx + 1]
                        res[k] += eval_func(y_true=y_true[pre:suf],
                                            y_pred=y_pred[pre:suf])

                y_pred = np.squeeze(y_pred)
                for lc_idx in range(len(list_counts) - 1):
                    pre = list_counts[lc_idx]
                    suf = list_counts[lc_idx + 1]
                    for p, y, t in zip(input_data['ID'][pre:suf],
                                       y_pred[pre:suf], y_true[pre:suf]):
                        if p[0] not in res_scores:
                            res_scores[p[0]] = {}
                        res_scores[p[0]][p[1]] = (y, t)

                num_valid += len(list_counts) - 1
            else:
                for k, eval_func in eval_metrics.items():
                    res[k] += eval_func(y_true=y_true, y_pred=y_pred)
                for p, y, t in zip(input_data['ID'], y_pred, y_true):
                    if p[0] not in res_scores:
                        res_scores[p[0]] = {}
                    res_scores[p[0]][p[1]] = (y[1], t[1])
                num_valid += 1
        generator.reset()

        if tag in output_conf:
            if output_conf[tag]['save_format'] == 'TREC':
                with open(output_conf[tag]['save_path'], 'w') as f:
                    for qid, dinfo in res_scores.items():
                        dinfo = sorted(dinfo.items(),
                                       key=lambda d: d[1][0],
                                       reverse=True)
                        for inum, (did, (score, gt)) in enumerate(dinfo):
                            f.write('%s\tQ0\t%s\t%d\t%f\t%s\t%s\n' %
                                    (qid, did, inum, score, config['net_name'],
                                     gt))
            elif output_conf[tag]['save_format'] == 'TEXTNET':
                with open(output_conf[tag]['save_path'], 'w') as f:
                    for qid, dinfo in res_scores.items():
                        dinfo = sorted(dinfo.items(),
                                       key=lambda d: d[1][0],
                                       reverse=True)
                        for inum, (did, (score, gt)) in enumerate(dinfo):
                            f.write('%s %s %s %s\n' % (gt, qid, did, score))

        print('[Predict] results: ',
              '\t'.join(['%s=%f' % (k, v / num_valid)
                         for k, v in res.items()]),
              end='\n')
        sys.stdout.flush()