Ejemplo n.º 1
0
def forward(net, input_data, net_config, phase='train', deploy=False):
    """Defines and creates the ReInspect network given the net, input data
    and configurations."""

    net.clear_forward()

    batch_ws_i = input_data["ws_i"]
    batch_stop_i = [net_config['max_len']] * net_config['batch_size']
    wordvec_layer = input_data["wordvec_layer"]  # 128*38*100*1
    net.f(NumpyData("target_words",
                    data=np.array(input_data["target_words"])))  # 128*100*1*1

    tops = []
    slice_point = []
    for i in range(net_config['max_len']):
        tops.append('label%d' % i)
        if i != 0:
            slice_point.append(i)
    net.f(
        Slice("label_slice_layer",
              slice_dim=1,
              bottoms=["target_words"],
              tops=tops,
              slice_point=slice_point))

    net.f(NumpyData("target_wordvec%d" % 0,
                    data=wordvec_layer[:, :, 0, 0]))  # start symbol, 128*38

    filler = Filler("uniform", net_config["init_range"])
    for i in range(net_config['max_len']):
        if i == 0:
            net.f(
                NumpyData(
                    "dummy_layer",
                    np.zeros((net_config["batch_size"],
                              net_config["lstm_num_cells"]))))
            net.f(
                NumpyData(
                    "dummy_mem_cell",
                    np.zeros((net_config["batch_size"],
                              net_config["lstm_num_cells"]))))

        for j in range(net_config['lstm_num_stacks']):
            bottoms = []
            if j == 0:
                bottoms.append('target_wordvec%d' % i)
            if j >= 1:
                bottoms.append('dropout%d_%d' % (j - 1, i))
            if i == 0:
                bottoms.append("dummy_layer")
            else:
                bottoms.append('lstm%d_hidden%d' % (j, i - 1))
            net.f(Concat('concat%d_layer%d' % (j, i), bottoms=bottoms))

            param_names = []
            for k in range(4):
                param_names.append('lstm%d_param_%d' % (j, k))
            bottoms = ['concat%d_layer%d' % (j, i)]
            if i == 0:
                bottoms.append('dummy_mem_cell')
            else:
                bottoms.append('lstm%d_mem_cell%d' % (j, i - 1))
            net.f(
                LstmUnit('lstm%d_layer%d' % (j, i),
                         net_config["lstm_num_cells"],
                         weight_filler=filler,
                         param_names=param_names,
                         bottoms=bottoms,
                         tops=[
                             'lstm%d_hidden%d' % (j, i),
                             'lstm%d_mem_cell%d' % (j, i)
                         ]))

            net.f(
                Dropout('dropout%d_%d' % (j, i),
                        net_config["dropout_ratio"],
                        bottoms=['lstm%d_hidden%d' % (j, i)]))

        net.f(
            InnerProduct("ip%d" % i,
                         net_config['vocab_size'],
                         bottoms=[
                             'dropout%d_%d' %
                             (net_config['lstm_num_stacks'] - 1, i)
                         ],
                         weight_filler=filler))

        if i < net_config['max_len'] - 1:
            tar_wordvec = np.array(wordvec_layer[:, :, i + 1, 0])  # 128*38
            if phase == 'test':
                net.f(Softmax("word_probs%d" % i, bottoms=["ip%d" % i]))
                probs = net.blobs["word_probs%d" % i].data
                for bi in range(net_config['batch_size']):
                    if i >= batch_ws_i[bi] and i < batch_stop_i[bi]:
                        vec = [0] * net_config["vocab_size"]
                        peakIndex = np.argmax(probs[bi, :])
                        if peakIndex == net_config['whitespace_symbol']:
                            batch_stop_i[bi] = i + 1
                        vec[peakIndex] = 1
                        tar_wordvec[bi, :] = vec
            net.f(NumpyData("target_wordvec%d" % (i + 1), data=tar_wordvec))

    bottoms = []
    for i in range(net_config['max_len']):
        bottoms.append("ip%d" % i)
    net.f(Concat('ip_concat', bottoms=bottoms, concat_dim=0))

    bottoms = []
    for i in range(net_config['max_len']):
        bottoms.append('label%d' % i)
    net.f(Concat('label_concat', bottoms=bottoms, concat_dim=0))

    if deploy:
        net.f(Softmax("word_probs", bottoms=["ip_concat"]))

    net.f(
        SoftmaxWithLoss("word_loss",
                        bottoms=["ip_concat", "label_concat"],
                        ignore_label=net_config['zero_symbol']))