def forward(net, input_data, net_config, phase='train', deploy=False): """Defines and creates the ReInspect network given the net, input data and configurations.""" net.clear_forward() batch_ws_i = input_data["ws_i"] batch_stop_i = [net_config['max_len']] * net_config['batch_size'] wordvec_layer = input_data["wordvec_layer"] # 128*38*100*1 net.f(NumpyData("target_words", data=np.array(input_data["target_words"]))) # 128*100*1*1 tops = [] slice_point = [] for i in range(net_config['max_len']): tops.append('label%d' % i) if i != 0: slice_point.append(i) net.f( Slice("label_slice_layer", slice_dim=1, bottoms=["target_words"], tops=tops, slice_point=slice_point)) net.f(NumpyData("target_wordvec%d" % 0, data=wordvec_layer[:, :, 0, 0])) # start symbol, 128*38 filler = Filler("uniform", net_config["init_range"]) for i in range(net_config['max_len']): if i == 0: net.f( NumpyData( "dummy_layer", np.zeros((net_config["batch_size"], net_config["lstm_num_cells"])))) net.f( NumpyData( "dummy_mem_cell", np.zeros((net_config["batch_size"], net_config["lstm_num_cells"])))) for j in range(net_config['lstm_num_stacks']): bottoms = [] if j == 0: bottoms.append('target_wordvec%d' % i) if j >= 1: bottoms.append('dropout%d_%d' % (j - 1, i)) if i == 0: bottoms.append("dummy_layer") else: bottoms.append('lstm%d_hidden%d' % (j, i - 1)) net.f(Concat('concat%d_layer%d' % (j, i), bottoms=bottoms)) param_names = [] for k in range(4): param_names.append('lstm%d_param_%d' % (j, k)) bottoms = ['concat%d_layer%d' % (j, i)] if i == 0: bottoms.append('dummy_mem_cell') else: bottoms.append('lstm%d_mem_cell%d' % (j, i - 1)) net.f( LstmUnit('lstm%d_layer%d' % (j, i), net_config["lstm_num_cells"], weight_filler=filler, param_names=param_names, bottoms=bottoms, tops=[ 'lstm%d_hidden%d' % (j, i), 'lstm%d_mem_cell%d' % (j, i) ])) net.f( Dropout('dropout%d_%d' % (j, i), net_config["dropout_ratio"], bottoms=['lstm%d_hidden%d' % (j, i)])) net.f( InnerProduct("ip%d" % i, net_config['vocab_size'], bottoms=[ 'dropout%d_%d' % (net_config['lstm_num_stacks'] - 1, i) ], weight_filler=filler)) if i < net_config['max_len'] - 1: tar_wordvec = np.array(wordvec_layer[:, :, i + 1, 0]) # 128*38 if phase == 'test': net.f(Softmax("word_probs%d" % i, bottoms=["ip%d" % i])) probs = net.blobs["word_probs%d" % i].data for bi in range(net_config['batch_size']): if i >= batch_ws_i[bi] and i < batch_stop_i[bi]: vec = [0] * net_config["vocab_size"] peakIndex = np.argmax(probs[bi, :]) if peakIndex == net_config['whitespace_symbol']: batch_stop_i[bi] = i + 1 vec[peakIndex] = 1 tar_wordvec[bi, :] = vec net.f(NumpyData("target_wordvec%d" % (i + 1), data=tar_wordvec)) bottoms = [] for i in range(net_config['max_len']): bottoms.append("ip%d" % i) net.f(Concat('ip_concat', bottoms=bottoms, concat_dim=0)) bottoms = [] for i in range(net_config['max_len']): bottoms.append('label%d' % i) net.f(Concat('label_concat', bottoms=bottoms, concat_dim=0)) if deploy: net.f(Softmax("word_probs", bottoms=["ip_concat"])) net.f( SoftmaxWithLoss("word_loss", bottoms=["ip_concat", "label_concat"], ignore_label=net_config['zero_symbol']))