def get_config(): c = Config() # cli flags using trixi, overwrite using e.g. --learning_rate=0.001 c.txt_file = 'assets/001ssb.txt' # Path to a .txt file to train on c.seq_length = 30 # Length of an input sequence c.gen_length = 250 # Length of the generated sequence c.lstm_num_hidden = 128 # Number of hidden units in the LSTM c.lstm_num_layers = 2 # Number of LSTM layers in the model # Training params c.batch_size = 64 # Number of examples to process in a batch c.learning_rate = 2e-3 # Learning rate # It is not necessary to implement the following three params, but it may help training. c.learning_rate_decay = 0.96 # Learning rate decay fraction c.learning_rate_step = 5000 # Learning rate step c.dropout_keep_prob = 1.0 # Dropout keep probability c.train_steps = 1e6 # Number of training steps c.max_norm = 5.0 # Misc params c.summary_path = './summaries/' # Output path for summaries c.print_every = 5 # How often to print training progress c.sample_every = 100 # How often to sample from the model c.device = 'cuda:0' # Training device 'cpu' or 'cuda:0' c.temperature = 0.5 # balances the sampling strategy between fully-greedy (near 0) and fully-random (higher). e.g. 0.5, 1.0, 2.0. return c
def get_config(): c = Config() # cli flags using trixi, overwrite using e.g. --learning_rate=0.001 c.model_type = 'RNN' # Model type, should be 'RNN' or 'LSTM' c.input_length = 10 # Length of an input sequence c.input_dim = 1 # Dimensionality of input sequence c.num_classes = 10 # Dimensionality of output sequence c.num_hidden = 128 # Number of hidden units in the model c.batch_size = 128 # Number of examples to process in a batch c.learning_rate = 0.001 # Learning rate c.train_steps = 10000 # Number of training steps c.max_norm = 10.0 c.device = 'cuda:0' # Training device 'cpu' or 'cuda:0' return c