コード例 #1
0
ファイル: train.py プロジェクト: zzhang68/nn-gev
    prepare_training_data(args.chime_dir, args.data_dir)

flists = dict()
for stage in ['tr', 'dt']:
    with open(
            os.path.join(args.data_dir, 'flist_{}.json'.format(stage))) as fid:
        flists[stage] = json.load(fid)
log.debug('Loaded file lists')

# Prepare model
if args.model_type == 'BLSTM':
    model = BLSTMMaskEstimator()
    model_save_dir = os.path.join(args.data_dir, 'BLSTM_model')
    mkdir_p(model_save_dir)
elif args.model_type == 'FW':
    model = SimpleFWMaskEstimator()
    model_save_dir = os.path.join(args.data_dir, 'FW_model')
    mkdir_p(model_save_dir)
else:
    raise ValueError('Unknown model type. Possible are "BLSTM" and "FW"')

if args.gpu >= 0:
    cuda.get_device(args.gpu).use()
    model.to_gpu()
xp = np if args.gpu < 0 else cuda.cupy
log.debug('Prepared model')

# Setup optimizer
optimizer = optimizers.Adam()
optimizer.setup(model)