def sym_gen(seq_len): # [batch_size, seq_len] data = mx.sym.Variable('data') # map input to a embeding vector embedIn = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed, name='input_embed') # pass embedding vector to lstm # [batch_size, seq_len, num_hidden] output, _ = cell.unroll(seq_len, inputs=embedIn, layout='NTC', merge_outputs=True) #output = output.reshape(-1, num_embed) # map label to embeding label = mx.sym.Variable('label') labwgt = mx.sym.Variable('label_weight') # define output embeding matrix # # TODO: change to adapter binding embedwgt = mx.sym.Variable(name='output_embed_weight', shape=(len(vocab), args.num_hidden)) pred = nce_loss(output, label, labwgt, embedwgt, len(vocab), args.num_hidden, args.num_label, seq_len) return pred, ('data', ), ('label', 'label_weight')
def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden): param_cells = [] last_states = [] for i in range(num_lstm_layer): param_cells.append( LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)) last_states.append(state) data = mx.sym.Variable('data') label = mx.sym.Variable('label') label_weight = mx.sym.Variable('label_weight') embed_weight = mx.sym.Variable('embed_weight') label_embed_weight = mx.sym.Variable('label_embed_weight') data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size, weight=embed_weight, output_dim=100, name='data_embed') datavec = mx.sym.SliceChannel(data=data_embed, num_outputs=seq_len, squeeze_axis=True, name='data_slice') labelvec = mx.sym.SliceChannel(data=label, num_outputs=seq_len, squeeze_axis=True, name='label_slice') labelweightvec = mx.sym.SliceChannel(data=label_weight, num_outputs=seq_len, squeeze_axis=True, name='label_weight_slice') probs = [] for seqidx in range(seq_len): hidden = datavec[seqidx] for i in range(num_lstm_layer): next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=seqidx, layeridx=i) hidden = next_state.h last_states[i] = next_state probs.append( nce_loss(data=hidden, label=labelvec[seqidx], label_weight=labelweightvec[seqidx], embed_weight=label_embed_weight, vocab_size=vocab_size, num_hidden=100)) return mx.sym.Group(probs)
def get_net(num_vocab): data = mx.sym.Variable('data') label = mx.sym.Variable('label') label_weight = mx.sym.Variable('label_weight') embed_weight = mx.sym.Variable('embed_weight') pred = mx.sym.FullyConnected(data=data, num_hidden=100) ret = nce_loss(data=pred, label=label, label_weight=label_weight, embed_weight=embed_weight, vocab_size=num_vocab, num_hidden=100) return ret
def get_net(num_vocab): data = mx.sym.Variable('data') label = mx.sym.Variable('label') label_weight = mx.sym.Variable('label_weight') embed_weight = mx.sym.Variable('embed_weight') pred = mx.sym.FullyConnected(data=data, num_hidden=100) ret = nce_loss( data=pred, label=label, label_weight=label_weight, embed_weight=embed_weight, vocab_size=num_vocab, num_hidden=100) return ret
def get_lstm_net(vocab_size, seq_len, num_lstm_layer, num_hidden): param_cells = [] last_states = [] for i in range(num_lstm_layer): param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), h=mx.sym.Variable("l%d_init_h" % i)) last_states.append(state) data = mx.sym.Variable('data') label = mx.sym.Variable('label') label_weight = mx.sym.Variable('label_weight') embed_weight = mx.sym.Variable('embed_weight') label_embed_weight = mx.sym.Variable('label_embed_weight') data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size, weight=embed_weight, output_dim=100, name='data_embed') datavec = mx.sym.SliceChannel(data=data_embed, num_outputs=seq_len, squeeze_axis=True, name='data_slice') labelvec = mx.sym.SliceChannel(data=label, num_outputs=seq_len, squeeze_axis=True, name='label_slice') labelweightvec = mx.sym.SliceChannel(data=label_weight, num_outputs=seq_len, squeeze_axis=True, name='label_weight_slice') probs = [] for seqidx in range(seq_len): hidden = datavec[seqidx] for i in range(num_lstm_layer): next_state = _lstm(num_hidden, indata=hidden, prev_state=last_states[i], param=param_cells[i], seqidx=seqidx, layeridx=i) hidden = next_state.h last_states[i] = next_state probs.append(nce_loss(data=hidden, label=labelvec[seqidx], label_weight=labelweightvec[seqidx], embed_weight=label_embed_weight, vocab_size=vocab_size, num_hidden=100)) return mx.sym.Group(probs)
def get_word_net(vocab_size, num_input): data = mx.sym.Variable('data') label = mx.sym.Variable('label') label_weight = mx.sym.Variable('label_weight') embed_weight = mx.sym.Variable('embed_weight') data_embed = mx.sym.Embedding(data=data, input_dim=vocab_size, weight=embed_weight, output_dim=100, name='data_embed') datavec = mx.sym.SliceChannel(data=data_embed, num_outputs=num_input, squeeze_axis=1, name='data_slice') pred = datavec[0] for i in range(1, num_input): pred = pred + datavec[i] return nce_loss(data=pred, label=label, label_weight=label_weight, embed_weight=embed_weight, vocab_size=vocab_size, num_hidden=100)
# Build the model ############################################################################### ntokens = len(dictionary) if args.loss == 'nce': model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied, args.loss, corpus.dictionary.unigram, args.noise_ratio, reset=args.reset) criterion = nce_loss() criterion_test = nn.CrossEntropyLoss() else: model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied, reset=args.reset) criterion = nn.CrossEntropyLoss() interpCrit = nn.CrossEntropyLoss(reduction='none') if args.cuda: model.cuda()