Exemplo n.º 1
0
def prepare_data(args):
    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')

    init_c = [('l%d_init_c' % l, (batch_size, num_hidden))
              for l in range(num_lstm_layer)]
    init_h = [('l%d_init_h' % l, (batch_size, num_hidden))
              for l in range(num_lstm_layer)]

    init_states = init_c + init_h

    file_test = args.config.get('data', 'train')

    file_format = args.config.get('data', 'format')
    feat_dim = args.config.getint('data', 'xdim')

    test_data_args = {
        "gpu_chunk": 32768,
        "lst_file": file_test,
        "file_format": file_format,
        "separate_lines": True,
        "has_labels": True
    }

    test_sets = DataReadStream(test_data_args, feat_dim)

    return (init_states, test_sets)
def prepare_data(args):
    batch_size = args.config.getint('train', 'batch_size')
    num_hidden = args.config.getint('arch', 'num_hidden')
    num_hidden_proj = args.config.getint('arch', 'num_hidden_proj')
    num_lstm_layer = args.config.getint('arch', 'num_lstm_layer')

    init_c = [('l%d_init_c' % l, (batch_size, num_hidden))
              for l in range(num_lstm_layer)]
    if num_hidden_proj > 0:
        init_h = [('l%d_init_h' % l, (batch_size, num_hidden_proj))
                  for l in range(num_lstm_layer)]
    else:
        init_h = [('l%d_init_h' % l, (batch_size, num_hidden))
                  for l in range(num_lstm_layer)]

    init_states = init_c + init_h

    file_train = args.config.get('data', 'train')
    file_dev = args.config.get('data', 'dev')
    file_format = args.config.get('data', 'format')
    feat_dim = args.config.getint('data', 'xdim')

    train_data_args = {
        "gpu_chunk": 32768,
        "lst_file": file_train,
        "file_format": file_format,
        "separate_lines": True
    }

    dev_data_args = {
        "gpu_chunk": 32768,
        "lst_file": file_dev,
        "file_format": file_format,
        "separate_lines": True
    }

    train_sets = DataReadStream(train_data_args, feat_dim)
    dev_sets = DataReadStream(dev_data_args, feat_dim)

    return (init_states, train_sets, dev_sets)