def sym_gen(seq_len): sym = lstm_unroll(num_lstm_layer, seq_len, feat_dim, num_hidden=num_hidden, num_label=label_dim, num_hidden_proj=num_hidden_proj) data_names = ['data'] + state_names label_names = ['softmax_label'] return (sym, data_names, label_names)
else: truncate_len = 20 data_test = TruncatedSentenceIter(test_sets, batch_size, init_states, truncate_len, feat_dim=feat_dim, do_shuffling=False, pad_zeros=True, has_label=True) sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden, num_label=label_dim, output_states=True, num_hidden_proj=num_hidden_proj) data_names = [x[0] for x in data_test.provide_data] label_names = ['softmax_label'] module = mx.mod.Module(sym, context=contexts, data_names=data_names, label_names=label_names) # set the parameters module.bind(data_shapes=data_test.provide_data, label_shapes=None, for_training=False) module.set_params(arg_params=arg_params, aux_params=aux_params)
num_label=label_dim, num_hidden_proj=num_hidden_proj) data_names = ['data'] + state_names label_names = ['softmax_label'] return (sym, data_names, label_names) module = mx.mod.BucketingModule(sym_gen, default_bucket_key=data_train.default_bucket_key, context=contexts) do_training(training_method, args, module, data_train, data_val) elif training_method == METHOD_TBPTT: truncate_len = args.config.getint('train', 'truncate_len') data_train = TruncatedSentenceIter(train_sets, batch_size, init_states, truncate_len=truncate_len, feat_dim=feat_dim) data_val = TruncatedSentenceIter(dev_sets, batch_size, init_states, truncate_len=truncate_len, feat_dim=feat_dim, do_shuffling=False, pad_zeros=True) sym = lstm_unroll(num_lstm_layer, truncate_len, feat_dim, num_hidden=num_hidden, num_label=label_dim, output_states=True, num_hidden_proj=num_hidden_proj) data_names = [x[0] for x in data_train.provide_data] label_names = [x[0] for x in data_train.provide_label] module = mx.mod.Module(sym, context=contexts, data_names=data_names, label_names=label_names) do_training(training_method, args, module, data_train, data_val) else: raise RuntimeError('Unknown training method: %s' % training_method) print("="*80) print("Finished Training") print("="*80) args.config.write(sys.stdout)