def train(args): from time import time data_train, data_val, _, src_vocab, targ_vocab, inv_src_vocab, inv_targ_vocab = get_data( 'TN') print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab) attention_fc_weight = mx.sym.Variable('attention_fc_weight') attention_fc_bias = mx.sym.Variable('attention_fc_bias') fc_weight = mx.sym.Variable('fc_weight') fc_bias = mx.sym.Variable('fc_bias') targ_em_weight = mx.sym.Variable('targ_embed_weight') encoder = SequentialRNNCell() if args.use_cudnn_cells: encoder.add( mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout, mode='lstm', prefix='lstm_encoder', bidirectional=args.bidirectional, get_next_state=True)) else: for i in range(args.num_layers): if args.bidirectional: encoder.add( BidirectionalCell( LSTMCell(args.num_hidden // 2, prefix='rnn_encoder_f%d_' % i), LSTMCell(args.num_hidden // 2, prefix='rnn_encoder_b%d_' % i))) if i < args.num_layers - 1 and args.dropout > 0.0: encoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_encoder%d_' % i)) else: encoder.add( LSTMCell(args.num_hidden, prefix='rnn_encoder%d_' % i)) if i < args.num_layers - 1 and args.dropout > 0.0: encoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_encoder%d_' % i)) decoder = mx.rnn.SequentialRNNCell() if args.use_cudnn_cells: decoder.add( mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm', prefix='lstm_decoder', bidirectional=args.bidirectional, get_next_state=True)) else: for i in range(args.num_layers): decoder.add( LSTMCell(args.num_hidden, prefix=('rnn_decoder%d_' % i))) if i < args.num_layers - 1 and args.dropout > 0.0: decoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_decoder%d_' % i)) def sym_gen(seq_len): src_data = mx.sym.Variable('src_data') targ_data = mx.sym.Variable('targ_data') label = mx.sym.Variable('softmax_label') src_embed = mx.sym.Embedding(data=src_data, input_dim=len(src_vocab), output_dim=args.num_embed, name='src_embed') targ_embed = mx.sym.Embedding( data=targ_data, weight=targ_em_weight, input_dim=len(targ_vocab), # data=data output_dim=args.num_embed, name='targ_embed') encoder.reset() decoder.reset() enc_seq_len, dec_seq_len = seq_len layout = 'TNC' encoder_outputs, encoder_states = encoder.unroll(enc_seq_len, inputs=src_embed, layout=layout) if args.bidirectional: encoder_states = [ mx.sym.concat(encoder_states[0][0], encoder_states[0][1]), mx.sym.concat(encoder_states[0][1], encoder_states[1][1]) ] if args.remove_state_feed: encoder_states = None # This should be based on EOS or max seq len for inference, but here we unroll to the target length # TODO: fix <GO> symbol if args.inference_unrolling_for_training: outputs, _ = infer_decoder_unroll(decoder, encoder_outputs, targ_embed, targ_vocab, dec_seq_len, 0, fc_weight, fc_bias, attention_fc_weight, attention_fc_bias, targ_em_weight, begin_state=encoder_states, layout='TNC', merge_outputs=True) else: outputs, _ = train_decoder_unroll(decoder, encoder_outputs, targ_embed, targ_vocab, dec_seq_len, 0, fc_weight, fc_bias, attention_fc_weight, attention_fc_bias, targ_em_weight, begin_state=encoder_states, layout='TNC', merge_outputs=True) # NEW rs = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden), name='sym_gen_reshape1') fc = mx.sym.FullyConnected(data=rs, weight=fc_weight, bias=fc_bias, num_hidden=len(targ_vocab), name='sym_gen_fc') label_rs = mx.sym.Reshape(data=label, shape=(-1, ), name='sym_gen_reshape2') pred = mx.sym.SoftmaxOutput(data=fc, label=label_rs, name='sym_gen_softmax') return pred, ( 'src_data', 'targ_data', ), ('softmax_label', ) # foo, _, _ = sym_gen((1, 1)) # print(type(foo)) # mx.viz.plot_network(symbol=foo).save('./seq2seq.dot') if args.gpus: contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')] else: contexts = mx.cpu(0) model = mx.mod.BucketingModule( sym_gen=sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts) if args.load_epoch: _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint( [encoder, decoder], args.model_prefix, args.load_epoch) else: arg_params = None aux_params = None opt_params = {'learning_rate': args.lr, 'wd': args.wd} if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']: opt_params['momentum'] = args.mom opt_params['clip_gradient'] = args.max_grad_norm start = time() model.fit(train_data=data_train, eval_data=data_val, eval_metric=mx.metric.Perplexity(invalid_label), kvstore=args.kv_store, optimizer=args.optimizer, optimizer_params=opt_params, initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=arg_params, aux_params=aux_params, begin_epoch=args.load_epoch, num_epoch=args.num_epochs, batch_end_callback=mx.callback.Speedometer( batch_size=args.batch_size, frequent=args.disp_batches, auto_reset=True), epoch_end_callback=mx.rnn.do_rnn_checkpoint([encoder, decoder], args.model_prefix, 1) if args.model_prefix else None) train_duration = time() - start time_per_epoch = train_duration / args.num_epochs print("\n\nTime per epoch: %.2f seconds\n\n" % time_per_epoch)
def infer(args): assert args.model_prefix, "Must specifiy path to load from" data_test, src_vocab, inv_src_vocab, targ_vocab, inv_targ_vocab = get_data( 'TN', infer=True) print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab) attention_fc_weight = mx.sym.Variable('attention_fc_weight') attention_fc_bias = mx.sym.Variable('attention_fc_bias') fc_weight = mx.sym.Variable('fc_weight') fc_bias = mx.sym.Variable('fc_bias') targ_em_weight = mx.sym.Variable('targ_embed_weight') if args.use_cudnn_cells: encoder = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout, mode='lstm', prefix='lstm_encoder', bidirectional=args.bidirectional, get_next_state=True).unfuse() else: encoder = SequentialRNNCell() for i in range(args.num_layers): if args.bidirectional: encoder.add( BidirectionalCell( LSTMCell(args.num_hidden // 2, prefix='rnn_encoder_f%d_' % i), LSTMCell(args.num_hidden // 2, prefix='rnn_encoder_b%d_' % i))) if i < args.num_layers - 1 and args.dropout > 0.0: encoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_encoder%d_' % i)) else: encoder.add( LSTMCell(args.num_hidden, prefix='rnn_encoder%d_' % i)) if i < args.num_layers - 1 and args.dropout > 0.0: encoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_encoder%d_' % i)) if args.use_cudnn_cells: decoder = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm', prefix='lstm_decoder', bidirectional=args.bidirectional, get_next_state=True).unfuse() else: decoder = mx.rnn.SequentialRNNCell() for i in range(args.num_layers): decoder.add( LSTMCell(args.num_hidden, prefix=('rnn_decoder%d_' % i))) if i < args.num_layers - 1 and args.dropout > 0.0: decoder.add( mx.rnn.DropoutCell(args.dropout, prefix='rnn_decoder%d_' % i)) def sym_gen(seq_len): src_data = mx.sym.Variable('src_data') targ_data = mx.sym.Variable('targ_data') label = mx.sym.Variable('softmax_label') src_embed = mx.sym.Embedding(data=src_data, input_dim=len(src_vocab), output_dim=args.num_embed, name='src_embed') targ_embed = mx.sym.Embedding( data=targ_data, input_dim=len(targ_vocab), weight=targ_em_weight, # data=data output_dim=args.num_embed, name='targ_embed') encoder.reset() decoder.reset() enc_seq_len, dec_seq_len = seq_len layout = 'TNC' encoder_outputs, encoder_states = encoder.unroll(enc_seq_len, inputs=src_embed, layout=layout) if args.bidirectional: encoder_states = [ mx.sym.concat(encoder_states[0][0], encoder_states[0][1]), mx.sym.concat(encoder_states[0][1], encoder_states[1][1]) ] # This should be based on EOS or max seq len for inference, but here we unroll to the target length # TODO: fix <GO> symbol # outputs, _ = decoder.unroll(dec_seq_len, targ_embed, begin_state=states, layout=layout, merge_outputs=True) outputs, _ = infer_decoder_unroll(decoder, encoder_outputs, targ_embed, targ_vocab, dec_seq_len, 0, fc_weight, fc_bias, attention_fc_weight, attention_fc_bias, targ_em_weight, begin_state=encoder_states, layout='TNC', merge_outputs=True) # NEW rs = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden), name='sym_gen_reshape1') fc = mx.sym.FullyConnected(data=rs, weight=fc_weight, bias=fc_bias, num_hidden=len(targ_vocab), name='sym_gen_fc') label_rs = mx.sym.Reshape(data=label, shape=(-1, ), name='sym_gen_reshape2') pred = mx.sym.SoftmaxOutput(data=fc, label=label_rs, name='sym_gen_softmax') # rs = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden), name='sym_gen_reshape1') # fc = mx.sym.FullyConnected(data=rs, num_hidden=len(targ_vocab), name='sym_gen_fc') # label_rs = mx.sym.Reshape(data=label, shape=(-1,), name='sym_gen_reshape2') # pred = mx.sym.SoftmaxOutput(data=fc, label=label_rs, name='sym_gen_softmax') return pred, ( 'src_data', 'targ_data', ), ('softmax_label', ) if args.gpus: contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')] else: contexts = mx.cpu(0) model = mx.mod.BucketingModule( sym_gen=sym_gen, default_bucket_key=data_test.default_bucket_key, context=contexts) model.bind(data_test.provide_data, data_test.provide_label, for_training=False) if args.load_epoch: _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint( [encoder, decoder], args.model_prefix, args.load_epoch) # print(arg_params) model.set_params(arg_params, aux_params) else: arg_params = None aux_params = None opt_params = {'learning_rate': args.lr, 'wd': args.wd} if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']: opt_params['momentum'] = args.mom opt_params['clip_gradient'] = args.max_grad_norm start = time() # mx.metric.Perplexity # model.score(data_test, BleuScore(invalid_label), #mx.metric.Perplexity(invalid_label), # batch_end_callback=mx.callback.Speedometer(batch_size=args.batch_size, frequent=1, auto_reset=True)) examples = [] bleu_acc = 0.0 num_inst = 0 try: data_test.reset() smoothing_fn = nltk.translate.bleu_score.SmoothingFunction().method3 while True: data_batch = data_test.next() model.forward(data_batch, is_train=None) source = data_batch.data[0] preds = model.get_outputs()[0] labels = data_batch.label[0] maxed = mx.ndarray.argmax(data=preds, axis=1) pred_nparr = maxed.asnumpy() src_nparr = source.asnumpy() label_nparr = labels.asnumpy().astype(np.int32) sent_len, batch_size = np.shape(label_nparr) pred_nparr = pred_nparr.reshape(sent_len, batch_size).astype(np.int32) for i in range(batch_size): src_lst = list( reversed(drop_sentinels(src_nparr[:, i].tolist()))) exp_lst = drop_sentinels(label_nparr[:, i].tolist()) act_lst = drop_sentinels(pred_nparr[:, i].tolist()) expected = exp_lst actual = act_lst bleu = nltk.translate.bleu_score.sentence_bleu( references=[expected], hypothesis=actual, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing_fn) bleu_acc += bleu num_inst += 1 examples.append((src_lst, exp_lst, act_lst, bleu)) except StopIteration as se: pass bleu_acc /= num_inst # Find the top K best translations examples = sorted(examples, key=itemgetter(3), reverse=True) num_examples = 20 print("\nSample translations:\n") for i in range(min(num_examples, len(examples))): src_lst, exp_lst, act_lst, bleu = examples[i] src_txt = array_to_text(src_lst, data_test.inv_src_vocab) exp_txt = array_to_text(exp_lst, data_test.inv_targ_vocab) act_txt = array_to_text(act_lst, data_test.inv_targ_vocab) print("\n") print("Source text: %s" % src_txt) print("Expected translation: %s" % exp_txt) print("Actual translation: %s" % act_txt) print("\nTest set BLEU score (averaged over all examples): %.3f\n" % bleu_acc)
def infer(args): assert args.model_prefix, "Must specifiy path to load from" data_train, data_val, src_vocab, targ_vocab = get_data('TN') print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab) encoder = SequentialRNNCell() if args.use_cudnn_cells: encoder.add( mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout, mode='lstm', prefix='lstm_encoder_', bidirectional=args.bidirectional, get_next_state=True).unfuse()) else: for i in range(args.num_layers): if args.bidirectional: encoder.add( mx.rnn.BidirectionalCell( LSTMCell(args.num_hidden, prefix='lstm_encoder_l%d_' % i), LSTMCell(args.num_hidden, prefix='lstm_encoder_r%d_' % i), output_prefix='lstm_encoder_bi_l%d_' % i)) else: encoder.add( LSTMCell(args.num_hidden, prefix='lstm_encoder_l%d_' % i)) if i < args.num_layers - 1 and args.dropout > 0.0: encoder.add( mx.rnn.DropoutCell(args.dropout, prefix='lstm_encoder__dropout%d_' % i)) encoder.add(AttentionEncoderCell()) decoder = mx.rnn.SequentialRNNCell() if args.use_cudnn_cells: decoder.add( mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm', prefix='lstm_decoder_', bidirectional=False, get_next_state=True)).unfuse() else: for i in range(args.num_layers): decoder.add( LSTMCell(args.num_hidden, prefix=('lstm_decoder_l%d_' % i))) if i < args.num_layers - 1 and args.dropout > 0.0: decoder.add( mx.rnn.DropoutCell(args.dropout, prefix='lstm_decoder_l%d_' % i)) decoder.add(DotAttentionCell()) def sym_gen(seq_len): src_data = mx.sym.Variable('src_data') targ_data = mx.sym.Variable('targ_data') label = mx.sym.Variable('softmax_label') src_embed = mx.sym.Embedding(data=src_data, input_dim=len(src_vocab), output_dim=args.num_embed, name='src_embed') targ_embed = mx.sym.Embedding( data=targ_data, input_dim=len(targ_vocab), # data=data output_dim=args.num_embed, name='targ_embed') encoder.reset() decoder.reset() enc_seq_len, dec_seq_len = seq_len layout = 'TNC' _, states = encoder.unroll(enc_seq_len, inputs=src_embed, layout=layout) # This should be based on EOS or max seq len for inference, but here we unroll to the target length # TODO: fix <GO> symbol # outputs, _ = decoder.unroll(dec_seq_len, targ_embed, begin_state=states, layout=layout, merge_outputs=True) outputs, _ = decoder_unroll(decoder, targ_embed, targ_vocab, dec_seq_len, 0, begin_state=states, layout='TNC', merge_outputs=True) # NEW rs = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden), name='sym_gen_reshape1') fc = mx.sym.FullyConnected(data=rs, num_hidden=len(targ_vocab), name='sym_gen_fc') label_rs = mx.sym.Reshape(data=label, shape=(-1, ), name='sym_gen_reshape2') pred = mx.sym.SoftmaxOutput(data=fc, label=label_rs, name='sym_gen_softmax') return pred, ( 'src_data', 'targ_data', ), ('softmax_label', ) if args.gpus: contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')] else: contexts = mx.cpu(0) model = mx.mod.BucketingModule( sym_gen=sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts) model.bind(data_val.provide_data, data_val.provide_label, for_training=False) if args.load_epoch: _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint( decoder, args.model_prefix, args.load_epoch) model.set_params(arg_params, aux_params) else: arg_params = None aux_params = None opt_params = {'learning_rate': args.lr, 'wd': args.wd} if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']: opt_params['momentum'] = args.mom opt_params['clip_gradient'] = args.max_grad_norm start = time() # mx.metric.Perplexity model.score( data_val, BleuScore(invalid_label), #PPL(invalid_label), batch_end_callback=mx.callback.Speedometer(batch_size=args.batch_size, frequent=5, auto_reset=True)) infer_duration = time() - start time_per_epoch = infer_duration / args.num_epochs print("\n\nTime per epoch: %.2f seconds\n\n" % time_per_epoch)