Пример #1
0
def train(args):

    from time import time

    data_train, data_val, _, src_vocab, targ_vocab, inv_src_vocab, inv_targ_vocab = get_data(
        'TN')
    print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab)

    attention_fc_weight = mx.sym.Variable('attention_fc_weight')
    attention_fc_bias = mx.sym.Variable('attention_fc_bias')

    fc_weight = mx.sym.Variable('fc_weight')
    fc_bias = mx.sym.Variable('fc_bias')
    targ_em_weight = mx.sym.Variable('targ_embed_weight')

    encoder = SequentialRNNCell()

    if args.use_cudnn_cells:
        encoder.add(
            mx.rnn.FusedRNNCell(args.num_hidden,
                                num_layers=args.num_layers,
                                dropout=args.dropout,
                                mode='lstm',
                                prefix='lstm_encoder',
                                bidirectional=args.bidirectional,
                                get_next_state=True))
    else:
        for i in range(args.num_layers):
            if args.bidirectional:
                encoder.add(
                    BidirectionalCell(
                        LSTMCell(args.num_hidden // 2,
                                 prefix='rnn_encoder_f%d_' % i),
                        LSTMCell(args.num_hidden // 2,
                                 prefix='rnn_encoder_b%d_' % i)))
                if i < args.num_layers - 1 and args.dropout > 0.0:
                    encoder.add(
                        mx.rnn.DropoutCell(args.dropout,
                                           prefix='rnn_encoder%d_' % i))
            else:
                encoder.add(
                    LSTMCell(args.num_hidden, prefix='rnn_encoder%d_' % i))
                if i < args.num_layers - 1 and args.dropout > 0.0:
                    encoder.add(
                        mx.rnn.DropoutCell(args.dropout,
                                           prefix='rnn_encoder%d_' % i))

    decoder = mx.rnn.SequentialRNNCell()

    if args.use_cudnn_cells:
        decoder.add(
            mx.rnn.FusedRNNCell(args.num_hidden,
                                num_layers=args.num_layers,
                                mode='lstm',
                                prefix='lstm_decoder',
                                bidirectional=args.bidirectional,
                                get_next_state=True))
    else:
        for i in range(args.num_layers):
            decoder.add(
                LSTMCell(args.num_hidden, prefix=('rnn_decoder%d_' % i)))
            if i < args.num_layers - 1 and args.dropout > 0.0:
                decoder.add(
                    mx.rnn.DropoutCell(args.dropout,
                                       prefix='rnn_decoder%d_' % i))

    def sym_gen(seq_len):
        src_data = mx.sym.Variable('src_data')
        targ_data = mx.sym.Variable('targ_data')
        label = mx.sym.Variable('softmax_label')

        src_embed = mx.sym.Embedding(data=src_data,
                                     input_dim=len(src_vocab),
                                     output_dim=args.num_embed,
                                     name='src_embed')
        targ_embed = mx.sym.Embedding(
            data=targ_data,
            weight=targ_em_weight,
            input_dim=len(targ_vocab),  # data=data
            output_dim=args.num_embed,
            name='targ_embed')

        encoder.reset()
        decoder.reset()

        enc_seq_len, dec_seq_len = seq_len

        layout = 'TNC'
        encoder_outputs, encoder_states = encoder.unroll(enc_seq_len,
                                                         inputs=src_embed,
                                                         layout=layout)

        if args.bidirectional:
            encoder_states = [
                mx.sym.concat(encoder_states[0][0], encoder_states[0][1]),
                mx.sym.concat(encoder_states[0][1], encoder_states[1][1])
            ]

        if args.remove_state_feed:
            encoder_states = None

        # This should be based on EOS or max seq len for inference, but here we unroll to the target length
        # TODO: fix <GO> symbol
        if args.inference_unrolling_for_training:
            outputs, _ = infer_decoder_unroll(decoder,
                                              encoder_outputs,
                                              targ_embed,
                                              targ_vocab,
                                              dec_seq_len,
                                              0,
                                              fc_weight,
                                              fc_bias,
                                              attention_fc_weight,
                                              attention_fc_bias,
                                              targ_em_weight,
                                              begin_state=encoder_states,
                                              layout='TNC',
                                              merge_outputs=True)
        else:
            outputs, _ = train_decoder_unroll(decoder,
                                              encoder_outputs,
                                              targ_embed,
                                              targ_vocab,
                                              dec_seq_len,
                                              0,
                                              fc_weight,
                                              fc_bias,
                                              attention_fc_weight,
                                              attention_fc_bias,
                                              targ_em_weight,
                                              begin_state=encoder_states,
                                              layout='TNC',
                                              merge_outputs=True)

        # NEW
        rs = mx.sym.Reshape(outputs,
                            shape=(-1, args.num_hidden),
                            name='sym_gen_reshape1')
        fc = mx.sym.FullyConnected(data=rs,
                                   weight=fc_weight,
                                   bias=fc_bias,
                                   num_hidden=len(targ_vocab),
                                   name='sym_gen_fc')
        label_rs = mx.sym.Reshape(data=label,
                                  shape=(-1, ),
                                  name='sym_gen_reshape2')
        pred = mx.sym.SoftmaxOutput(data=fc,
                                    label=label_rs,
                                    name='sym_gen_softmax')

        return pred, (
            'src_data',
            'targ_data',
        ), ('softmax_label', )


#    foo, _, _ = sym_gen((1, 1))
#    print(type(foo))
#    mx.viz.plot_network(symbol=foo).save('./seq2seq.dot')

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen=sym_gen,
        default_bucket_key=data_train.default_bucket_key,
        context=contexts)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            [encoder, decoder], args.model_prefix, args.load_epoch)
    else:
        arg_params = None
        aux_params = None

    opt_params = {'learning_rate': args.lr, 'wd': args.wd}

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    opt_params['clip_gradient'] = args.max_grad_norm

    start = time()

    model.fit(train_data=data_train,
              eval_data=data_val,
              eval_metric=mx.metric.Perplexity(invalid_label),
              kvstore=args.kv_store,
              optimizer=args.optimizer,
              optimizer_params=opt_params,
              initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
              arg_params=arg_params,
              aux_params=aux_params,
              begin_epoch=args.load_epoch,
              num_epoch=args.num_epochs,
              batch_end_callback=mx.callback.Speedometer(
                  batch_size=args.batch_size,
                  frequent=args.disp_batches,
                  auto_reset=True),
              epoch_end_callback=mx.rnn.do_rnn_checkpoint([encoder, decoder],
                                                          args.model_prefix, 1)
              if args.model_prefix else None)

    train_duration = time() - start
    time_per_epoch = train_duration / args.num_epochs
    print("\n\nTime per epoch: %.2f seconds\n\n" % time_per_epoch)
Пример #2
0
def infer(args):
    assert args.model_prefix, "Must specifiy path to load from"

    data_test, src_vocab, inv_src_vocab, targ_vocab, inv_targ_vocab = get_data(
        'TN', infer=True)

    print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab)

    attention_fc_weight = mx.sym.Variable('attention_fc_weight')
    attention_fc_bias = mx.sym.Variable('attention_fc_bias')

    fc_weight = mx.sym.Variable('fc_weight')
    fc_bias = mx.sym.Variable('fc_bias')
    targ_em_weight = mx.sym.Variable('targ_embed_weight')

    if args.use_cudnn_cells:
        encoder = mx.rnn.FusedRNNCell(args.num_hidden,
                                      num_layers=args.num_layers,
                                      dropout=args.dropout,
                                      mode='lstm',
                                      prefix='lstm_encoder',
                                      bidirectional=args.bidirectional,
                                      get_next_state=True).unfuse()

    else:
        encoder = SequentialRNNCell()

        for i in range(args.num_layers):
            if args.bidirectional:
                encoder.add(
                    BidirectionalCell(
                        LSTMCell(args.num_hidden // 2,
                                 prefix='rnn_encoder_f%d_' % i),
                        LSTMCell(args.num_hidden // 2,
                                 prefix='rnn_encoder_b%d_' % i)))
                if i < args.num_layers - 1 and args.dropout > 0.0:
                    encoder.add(
                        mx.rnn.DropoutCell(args.dropout,
                                           prefix='rnn_encoder%d_' % i))
            else:
                encoder.add(
                    LSTMCell(args.num_hidden, prefix='rnn_encoder%d_' % i))
                if i < args.num_layers - 1 and args.dropout > 0.0:
                    encoder.add(
                        mx.rnn.DropoutCell(args.dropout,
                                           prefix='rnn_encoder%d_' % i))

    if args.use_cudnn_cells:
        decoder = mx.rnn.FusedRNNCell(args.num_hidden,
                                      num_layers=args.num_layers,
                                      mode='lstm',
                                      prefix='lstm_decoder',
                                      bidirectional=args.bidirectional,
                                      get_next_state=True).unfuse()

    else:
        decoder = mx.rnn.SequentialRNNCell()

        for i in range(args.num_layers):
            decoder.add(
                LSTMCell(args.num_hidden, prefix=('rnn_decoder%d_' % i)))
            if i < args.num_layers - 1 and args.dropout > 0.0:
                decoder.add(
                    mx.rnn.DropoutCell(args.dropout,
                                       prefix='rnn_decoder%d_' % i))

    def sym_gen(seq_len):
        src_data = mx.sym.Variable('src_data')
        targ_data = mx.sym.Variable('targ_data')
        label = mx.sym.Variable('softmax_label')

        src_embed = mx.sym.Embedding(data=src_data,
                                     input_dim=len(src_vocab),
                                     output_dim=args.num_embed,
                                     name='src_embed')
        targ_embed = mx.sym.Embedding(
            data=targ_data,
            input_dim=len(targ_vocab),
            weight=targ_em_weight,  # data=data
            output_dim=args.num_embed,
            name='targ_embed')

        encoder.reset()
        decoder.reset()

        enc_seq_len, dec_seq_len = seq_len

        layout = 'TNC'
        encoder_outputs, encoder_states = encoder.unroll(enc_seq_len,
                                                         inputs=src_embed,
                                                         layout=layout)

        if args.bidirectional:
            encoder_states = [
                mx.sym.concat(encoder_states[0][0], encoder_states[0][1]),
                mx.sym.concat(encoder_states[0][1], encoder_states[1][1])
            ]

        # This should be based on EOS or max seq len for inference, but here we unroll to the target length
        # TODO: fix <GO> symbol


#        outputs, _ = decoder.unroll(dec_seq_len, targ_embed, begin_state=states, layout=layout, merge_outputs=True)
        outputs, _ = infer_decoder_unroll(decoder,
                                          encoder_outputs,
                                          targ_embed,
                                          targ_vocab,
                                          dec_seq_len,
                                          0,
                                          fc_weight,
                                          fc_bias,
                                          attention_fc_weight,
                                          attention_fc_bias,
                                          targ_em_weight,
                                          begin_state=encoder_states,
                                          layout='TNC',
                                          merge_outputs=True)

        # NEW

        rs = mx.sym.Reshape(outputs,
                            shape=(-1, args.num_hidden),
                            name='sym_gen_reshape1')
        fc = mx.sym.FullyConnected(data=rs,
                                   weight=fc_weight,
                                   bias=fc_bias,
                                   num_hidden=len(targ_vocab),
                                   name='sym_gen_fc')
        label_rs = mx.sym.Reshape(data=label,
                                  shape=(-1, ),
                                  name='sym_gen_reshape2')
        pred = mx.sym.SoftmaxOutput(data=fc,
                                    label=label_rs,
                                    name='sym_gen_softmax')

        #        rs = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden), name='sym_gen_reshape1')
        #        fc = mx.sym.FullyConnected(data=rs, num_hidden=len(targ_vocab), name='sym_gen_fc')
        #        label_rs = mx.sym.Reshape(data=label, shape=(-1,), name='sym_gen_reshape2')
        #        pred = mx.sym.SoftmaxOutput(data=fc, label=label_rs, name='sym_gen_softmax')

        return pred, (
            'src_data',
            'targ_data',
        ), ('softmax_label', )

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen=sym_gen,
        default_bucket_key=data_test.default_bucket_key,
        context=contexts)

    model.bind(data_test.provide_data,
               data_test.provide_label,
               for_training=False)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            [encoder, decoder], args.model_prefix, args.load_epoch)
        #        print(arg_params)
        model.set_params(arg_params, aux_params)

    else:
        arg_params = None
        aux_params = None

    opt_params = {'learning_rate': args.lr, 'wd': args.wd}

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    opt_params['clip_gradient'] = args.max_grad_norm

    start = time()

    # mx.metric.Perplexity
    #    model.score(data_test, BleuScore(invalid_label), #mx.metric.Perplexity(invalid_label),
    #                batch_end_callback=mx.callback.Speedometer(batch_size=args.batch_size, frequent=1, auto_reset=True))

    examples = []
    bleu_acc = 0.0
    num_inst = 0

    try:
        data_test.reset()

        smoothing_fn = nltk.translate.bleu_score.SmoothingFunction().method3

        while True:

            data_batch = data_test.next()
            model.forward(data_batch, is_train=None)
            source = data_batch.data[0]
            preds = model.get_outputs()[0]
            labels = data_batch.label[0]

            maxed = mx.ndarray.argmax(data=preds, axis=1)
            pred_nparr = maxed.asnumpy()
            src_nparr = source.asnumpy()
            label_nparr = labels.asnumpy().astype(np.int32)
            sent_len, batch_size = np.shape(label_nparr)
            pred_nparr = pred_nparr.reshape(sent_len,
                                            batch_size).astype(np.int32)

            for i in range(batch_size):

                src_lst = list(
                    reversed(drop_sentinels(src_nparr[:, i].tolist())))
                exp_lst = drop_sentinels(label_nparr[:, i].tolist())
                act_lst = drop_sentinels(pred_nparr[:, i].tolist())

                expected = exp_lst
                actual = act_lst
                bleu = nltk.translate.bleu_score.sentence_bleu(
                    references=[expected],
                    hypothesis=actual,
                    weights=(0.25, 0.25, 0.25, 0.25),
                    smoothing_function=smoothing_fn)
                bleu_acc += bleu
                num_inst += 1
                examples.append((src_lst, exp_lst, act_lst, bleu))

    except StopIteration as se:
        pass

    bleu_acc /= num_inst

    # Find the top K best translations
    examples = sorted(examples, key=itemgetter(3), reverse=True)

    num_examples = 20

    print("\nSample translations:\n")
    for i in range(min(num_examples, len(examples))):
        src_lst, exp_lst, act_lst, bleu = examples[i]
        src_txt = array_to_text(src_lst, data_test.inv_src_vocab)
        exp_txt = array_to_text(exp_lst, data_test.inv_targ_vocab)
        act_txt = array_to_text(act_lst, data_test.inv_targ_vocab)
        print("\n")
        print("Source text: %s" % src_txt)
        print("Expected translation: %s" % exp_txt)
        print("Actual translation: %s" % act_txt)
    print("\nTest set BLEU score (averaged over all examples): %.3f\n" %
          bleu_acc)
Пример #3
0
def infer(args):
    assert args.model_prefix, "Must specifiy path to load from"

    data_train, data_val, src_vocab, targ_vocab = get_data('TN')

    print "len(src_vocab) len(targ_vocab)", len(src_vocab), len(targ_vocab)

    encoder = SequentialRNNCell()
    if args.use_cudnn_cells:
        encoder.add(
            mx.rnn.FusedRNNCell(args.num_hidden,
                                num_layers=args.num_layers,
                                dropout=args.dropout,
                                mode='lstm',
                                prefix='lstm_encoder_',
                                bidirectional=args.bidirectional,
                                get_next_state=True).unfuse())

    else:
        for i in range(args.num_layers):
            if args.bidirectional:
                encoder.add(
                    mx.rnn.BidirectionalCell(
                        LSTMCell(args.num_hidden,
                                 prefix='lstm_encoder_l%d_' % i),
                        LSTMCell(args.num_hidden,
                                 prefix='lstm_encoder_r%d_' % i),
                        output_prefix='lstm_encoder_bi_l%d_' % i))
            else:
                encoder.add(
                    LSTMCell(args.num_hidden, prefix='lstm_encoder_l%d_' % i))
            if i < args.num_layers - 1 and args.dropout > 0.0:
                encoder.add(
                    mx.rnn.DropoutCell(args.dropout,
                                       prefix='lstm_encoder__dropout%d_' % i))

    encoder.add(AttentionEncoderCell())

    decoder = mx.rnn.SequentialRNNCell()

    if args.use_cudnn_cells:
        decoder.add(
            mx.rnn.FusedRNNCell(args.num_hidden,
                                num_layers=args.num_layers,
                                mode='lstm',
                                prefix='lstm_decoder_',
                                bidirectional=False,
                                get_next_state=True)).unfuse()
    else:
        for i in range(args.num_layers):
            decoder.add(
                LSTMCell(args.num_hidden, prefix=('lstm_decoder_l%d_' % i)))
            if i < args.num_layers - 1 and args.dropout > 0.0:
                decoder.add(
                    mx.rnn.DropoutCell(args.dropout,
                                       prefix='lstm_decoder_l%d_' % i))
    decoder.add(DotAttentionCell())

    def sym_gen(seq_len):
        src_data = mx.sym.Variable('src_data')
        targ_data = mx.sym.Variable('targ_data')
        label = mx.sym.Variable('softmax_label')

        src_embed = mx.sym.Embedding(data=src_data,
                                     input_dim=len(src_vocab),
                                     output_dim=args.num_embed,
                                     name='src_embed')
        targ_embed = mx.sym.Embedding(
            data=targ_data,
            input_dim=len(targ_vocab),  # data=data
            output_dim=args.num_embed,
            name='targ_embed')

        encoder.reset()
        decoder.reset()

        enc_seq_len, dec_seq_len = seq_len

        layout = 'TNC'
        _, states = encoder.unroll(enc_seq_len,
                                   inputs=src_embed,
                                   layout=layout)

        # This should be based on EOS or max seq len for inference, but here we unroll to the target length
        # TODO: fix <GO> symbol
        #        outputs, _ = decoder.unroll(dec_seq_len, targ_embed, begin_state=states, layout=layout, merge_outputs=True)
        outputs, _ = decoder_unroll(decoder,
                                    targ_embed,
                                    targ_vocab,
                                    dec_seq_len,
                                    0,
                                    begin_state=states,
                                    layout='TNC',
                                    merge_outputs=True)

        # NEW
        rs = mx.sym.Reshape(outputs,
                            shape=(-1, args.num_hidden),
                            name='sym_gen_reshape1')
        fc = mx.sym.FullyConnected(data=rs,
                                   num_hidden=len(targ_vocab),
                                   name='sym_gen_fc')
        label_rs = mx.sym.Reshape(data=label,
                                  shape=(-1, ),
                                  name='sym_gen_reshape2')
        pred = mx.sym.SoftmaxOutput(data=fc,
                                    label=label_rs,
                                    name='sym_gen_softmax')

        return pred, (
            'src_data',
            'targ_data',
        ), ('softmax_label', )

    if args.gpus:
        contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
    else:
        contexts = mx.cpu(0)

    model = mx.mod.BucketingModule(
        sym_gen=sym_gen,
        default_bucket_key=data_train.default_bucket_key,
        context=contexts)

    model.bind(data_val.provide_data,
               data_val.provide_label,
               for_training=False)

    if args.load_epoch:
        _, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
            decoder, args.model_prefix, args.load_epoch)
        model.set_params(arg_params, aux_params)

    else:
        arg_params = None
        aux_params = None

    opt_params = {'learning_rate': args.lr, 'wd': args.wd}

    if args.optimizer not in ['adadelta', 'adagrad', 'adam', 'rmsprop']:
        opt_params['momentum'] = args.mom

    opt_params['clip_gradient'] = args.max_grad_norm

    start = time()

    # mx.metric.Perplexity
    model.score(
        data_val,
        BleuScore(invalid_label),  #PPL(invalid_label),
        batch_end_callback=mx.callback.Speedometer(batch_size=args.batch_size,
                                                   frequent=5,
                                                   auto_reset=True))

    infer_duration = time() - start
    time_per_epoch = infer_duration / args.num_epochs
    print("\n\nTime per epoch: %.2f seconds\n\n" % time_per_epoch)