예제 #1
0
    def sym_gen(seq_len):
        # [batch_size, seq_len]
        data = mx.sym.Variable('data')

        # map input to a embeding vector
        embedIn = mx.sym.Embedding(data=data,
                                   input_dim=len(vocab),
                                   output_dim=args.num_embed,
                                   name='input_embed')

        # pass embedding vector to lstm
        # [batch_size, seq_len, num_hidden]
        output, _ = cell.unroll(seq_len,
                                inputs=embedIn,
                                layout='NTC',
                                merge_outputs=True)
        #output = output.reshape(-1, num_embed)

        # map label to embeding
        label = mx.sym.Variable('label')
        labwgt = mx.sym.Variable('label_weight')

        # define output embeding matrix
        #
        # TODO: change to adapter binding
        embedwgt = mx.sym.Variable(name='output_embed_weight',
                                   shape=(len(vocab), args.num_hidden))
        pred = nce_loss(output, label, labwgt, embedwgt, len(vocab),
                        args.num_hidden, args.num_label, seq_len)

        return pred, ('data', ), ('label', 'label_weight')
예제 #2
0
def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden):
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        param_cells.append(
            LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
                      i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
                      h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
                      h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)

    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    label_embed_weight = mx.sym.Variable('label_embed_weight')
    data_embed = mx.sym.Embedding(data=data,
                                  input_dim=vocab_size,
                                  weight=embed_weight,
                                  output_dim=100,
                                  name='data_embed')
    datavec = mx.sym.SliceChannel(data=data_embed,
                                  num_outputs=seq_len,
                                  squeeze_axis=True,
                                  name='data_slice')
    labelvec = mx.sym.SliceChannel(data=label,
                                   num_outputs=seq_len,
                                   squeeze_axis=True,
                                   name='label_slice')
    labelweightvec = mx.sym.SliceChannel(data=label_weight,
                                         num_outputs=seq_len,
                                         squeeze_axis=True,
                                         name='label_weight_slice')
    probs = []
    for seqidx in range(seq_len):
        hidden = datavec[seqidx]

        for i in range(num_lstm_layer):
            next_state = _lstm(num_hidden,
                               indata=hidden,
                               prev_state=last_states[i],
                               param=param_cells[i],
                               seqidx=seqidx,
                               layeridx=i)
            hidden = next_state.h
            last_states[i] = next_state

        probs.append(
            nce_loss(data=hidden,
                     label=labelvec[seqidx],
                     label_weight=labelweightvec[seqidx],
                     embed_weight=label_embed_weight,
                     vocab_size=vocab_size,
                     num_hidden=100))
    return mx.sym.Group(probs)
예제 #3
0
def get_net(num_vocab):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    pred = mx.sym.FullyConnected(data=data, num_hidden=100)
    ret = nce_loss(data=pred,
                   label=label,
                   label_weight=label_weight,
                   embed_weight=embed_weight,
                   vocab_size=num_vocab,
                   num_hidden=100)
    return ret
예제 #4
0
def get_net(num_vocab):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    pred = mx.sym.FullyConnected(data=data, num_hidden=100)
    ret = nce_loss(
        data=pred,
        label=label,
        label_weight=label_weight,
        embed_weight=embed_weight,
        vocab_size=num_vocab,
        num_hidden=100)
    return ret
예제 #5
0
def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden):
    param_cells = []
    last_states = []
    for i in range(num_lstm_layer):
        param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i),
                                     i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i),
                                     h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i),
                                     h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i)))
        state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i),
                          h=mx.sym.Variable("l%d_init_h" % i))
        last_states.append(state)

    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    label_embed_weight = mx.sym.Variable('label_embed_weight')
    data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
                                  weight=embed_weight,
                                  output_dim=100, name='data_embed')
    datavec = mx.sym.SliceChannel(data=data_embed,
                                  num_outputs=seq_len,
                                  squeeze_axis=True, name='data_slice')
    labelvec = mx.sym.SliceChannel(data=label,
                                   num_outputs=seq_len,
                                   squeeze_axis=True, name='label_slice')
    labelweightvec = mx.sym.SliceChannel(data=label_weight,
                                         num_outputs=seq_len,
                                         squeeze_axis=True, name='label_weight_slice')
    probs = []
    for seqidx in range(seq_len):
        hidden = datavec[seqidx]

        for i in range(num_lstm_layer):
            next_state = _lstm(num_hidden, indata=hidden,
                               prev_state=last_states[i],
                               param=param_cells[i],
                               seqidx=seqidx, layeridx=i)
            hidden = next_state.h
            last_states[i] = next_state

        probs.append(nce_loss(data=hidden,
                              label=labelvec[seqidx],
                              label_weight=labelweightvec[seqidx],
                              embed_weight=label_embed_weight,
                              vocab_size=vocab_size,
                              num_hidden=100))
    return mx.sym.Group(probs)
예제 #6
0
def get_word_net(vocab_size, num_input):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size,
                                  weight=embed_weight,
                                  output_dim=100, name='data_embed')
    datavec = mx.sym.SliceChannel(data=data_embed,
                                  num_outputs=num_input,
                                  squeeze_axis=1, name='data_slice')
    pred = datavec[0]
    for i in range(1, num_input):
        pred = pred + datavec[i]
    return nce_loss(data=pred,
                    label=label,
                    label_weight=label_weight,
                    embed_weight=embed_weight,
                    vocab_size=vocab_size,
                    num_hidden=100)
예제 #7
0
def get_word_net(vocab_size, num_input):
    data = mx.sym.Variable('data')
    label = mx.sym.Variable('label')
    label_weight = mx.sym.Variable('label_weight')
    embed_weight = mx.sym.Variable('embed_weight')
    data_embed = mx.sym.Embedding(data=data,
                                  input_dim=vocab_size,
                                  weight=embed_weight,
                                  output_dim=100,
                                  name='data_embed')
    datavec = mx.sym.SliceChannel(data=data_embed,
                                  num_outputs=num_input,
                                  squeeze_axis=1,
                                  name='data_slice')
    pred = datavec[0]
    for i in range(1, num_input):
        pred = pred + datavec[i]
    return nce_loss(data=pred,
                    label=label,
                    label_weight=label_weight,
                    embed_weight=embed_weight,
                    vocab_size=vocab_size,
                    num_hidden=100)
# Build the model
###############################################################################
ntokens = len(dictionary)
if args.loss == 'nce':
    model = model.RNNModel(args.model,
                           ntokens,
                           args.emsize,
                           args.nhid,
                           args.nlayers,
                           args.dropout,
                           args.tied,
                           args.loss,
                           corpus.dictionary.unigram,
                           args.noise_ratio,
                           reset=args.reset)
    criterion = nce_loss()
    criterion_test = nn.CrossEntropyLoss()
else:
    model = model.RNNModel(args.model,
                           ntokens,
                           args.emsize,
                           args.nhid,
                           args.nlayers,
                           args.dropout,
                           args.tied,
                           reset=args.reset)
    criterion = nn.CrossEntropyLoss()
    interpCrit = nn.CrossEntropyLoss(reduction='none')

if args.cuda:
    model.cuda()