num_label=label_dim, num_hidden_proj=num_hidden_proj) data_names = ['data'] + state_names label_names = ['softmax_label'] return (sym, data_names, label_names) module = mx.mod.BucketingModule( sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts) do_training(training_method, args, module, data_train, data_val) elif training_method == METHOD_TBPTT: truncate_len = args.config.getint('train', 'truncate_len') data_train = TruncatedSentenceIter(train_sets, batch_size, init_states, truncate_len=truncate_len, feat_dim=feat_dim) data_val = TruncatedSentenceIter(dev_sets, batch_size, init_states, truncate_len=truncate_len, feat_dim=feat_dim, do_shuffling=False, pad_zeros=True) sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden, num_label=label_dim, output_states=True,
num_hidden_proj=num_hidden_proj) data_names = ['data'] + state_names label_names = [] return (sym, data_names, label_names) module = mx.mod.BucketingModule( sym_gen, default_bucket_key=data_test.default_bucket_key, context=contexts) else: truncate_len = 20 data_test = TruncatedSentenceIter(test_sets, batch_size, init_states, truncate_len, feat_dim=feat_dim, do_shuffling=False, pad_zeros=True, has_label=True) sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden, num_label=label_dim, output_states=True, num_hidden_proj=num_hidden_proj) data_names = [x[0] for x in data_test.provide_data] label_names = ['softmax_label'] module = mx.mod.Module(sym, context=contexts,
init_h = [('l%d_init_h' % l, (batch_size, num_hidden_lstm)) for l in range(num_lstm_layer)] init_states = init_c + init_h state_names = [x[0] for x in init_states] data_names = [data_name] + state_names data_test = TruncatedSentenceIter(test_sets, batch_size, init_states, truncate_len=truncate_len, delay=10, feat_dim=feat_dim, label_dim=label_dim, data_name='data', label_name='linear_label', do_shuffling=False, pad_zeros=False) data_names = [x[0] for x in data_test.provide_data] label_names = [x[0] for x in data_test.provide_label] sym, arg_params, aux_params = mx.model.load_checkpoint( test_prefix, load_epoch_num) mod = mx.mod.Module(sym, context=contexts, data_names=data_names, label_names=label_names)