def infer(args): system_info.print_system_info() # Prepare model in_size = feature.get_input_dim(args.frame_size, args.context_size, args.input_transform) if args.model_type == 'Transformer': model = TransformerDiarization( in_size, n_units=args.hidden_size, n_heads=args.transformer_encoder_n_heads, n_layers=args.transformer_encoder_n_layers, dropout=0, alpha=0) else: raise ValueError('Unknown model type.') serializers.load_npz(args.model_file, model) if args.gpu >= 0: gpuid = use_single_gpu() model.to_gpu() kaldi_obj = kaldi_data.KaldiData(args.data_dir) for recid in kaldi_obj.wavs: data, rate = kaldi_obj.load_wav(recid) Y = feature.stft(data, args.frame_size, args.frame_shift) Y = feature.transform(Y, transform_type=args.input_transform) Y = feature.splice(Y, context_size=args.context_size) Y = Y[::args.subsampling] out_chunks = [] with chainer.no_backprop_mode(), chainer.using_config('train', False): hs = None for start, end in _gen_chunk_indices(len(Y), args.chunk_size): Y_chunked = Variable(Y[start:end]) if args.gpu >= 0: Y_chunked.to_gpu(gpuid) hs, ys = model.estimate_sequential(hs, [Y_chunked]) if args.gpu >= 0: ys[0].to_cpu() out_chunks.append(ys[0].data) if args.save_attention_weight == 1: att_fname = f"{recid}_{start}_{end}.att.npy" att_path = os.path.join(args.out_dir, att_fname) model.save_attention_weight(att_path) outfname = recid + '.h5' outpath = os.path.join(args.out_dir, outfname) if hasattr(model, 'label_delay'): outdata = shift(np.vstack(out_chunks), (-model.label_delay, 0)) else: outdata = np.vstack(out_chunks) with h5py.File(outpath, 'w') as wf: wf.create_dataset('T_hat', data=outdata)
def train(args): """ Training model with chainer backend. This function is called from eend/bin/train.py with parsed command-line arguments. """ np.random.seed(args.seed) os.environ['CHAINER_SEED'] = str(args.seed) chainer.global_config.cudnn_deterministic = True train_set = KaldiDiarizationDataset( args.train_data_dir, chunk_size=args.num_frames, context_size=args.context_size, input_transform=args.input_transform, frame_size=args.frame_size, frame_shift=args.frame_shift, subsampling=args.subsampling, rate=args.sampling_rate, use_last_samples=True, label_delay=args.label_delay, n_speakers=args.num_speakers, ) dev_set = KaldiDiarizationDataset( args.valid_data_dir, chunk_size=args.num_frames, context_size=args.context_size, input_transform=args.input_transform, frame_size=args.frame_size, frame_shift=args.frame_shift, subsampling=args.subsampling, rate=args.sampling_rate, use_last_samples=True, label_delay=args.label_delay, n_speakers=args.num_speakers, ) # Prepare model Y, T = train_set.get_example(0) if args.model_type == 'BLSTM': model = BLSTMDiarization( in_size=Y.shape[1], n_speakers=args.num_speakers, hidden_size=args.hidden_size, n_layers=args.num_lstm_layers, embedding_layers=args.embedding_layers, embedding_size=args.embedding_size, dc_loss_ratio=args.dc_loss_ratio, ) elif args.model_type == 'Transformer': model = TransformerDiarization( args.num_speakers, Y.shape[1], n_units=args.hidden_size, n_heads=args.transformer_encoder_n_heads, n_layers=args.transformer_encoder_n_layers, dropout=args.transformer_encoder_dropout) else: raise ValueError('Possible model_type are "Transformer" and "BLSTM"') if args.gpu >= 0: gpuid = use_single_gpu() print('GPU device {} is used'.format(gpuid)) model.to_gpu() else: gpuid = -1 print('Prepared model') # Setup optimizer if args.optimizer == 'adam': optimizer = optimizers.Adam(alpha=args.lr) elif args.optimizer == 'sgd': optimizer = optimizers.SGD(lr=args.lr) elif args.optimizer == 'noam': optimizer = optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9) else: raise ValueError(args.optimizer) optimizer.setup(model) if args.gradclip > 0: optimizer.add_hook( chainer.optimizer_hooks.GradientClipping(args.gradclip)) # Init/Resume if args.initmodel: print('Load model from', args.initmodel) serializers.load_npz(args.initmodel, model) train_iter = iterators.MultiprocessIterator( train_set, batch_size=args.batchsize, repeat=True, shuffle=True, # shared_mem=64000000, shared_mem=None, n_processes=4, n_prefetch=2) dev_iter = iterators.MultiprocessIterator( dev_set, batch_size=args.batchsize, repeat=False, shuffle=False, # shared_mem=64000000, shared_mem=None, n_processes=4, n_prefetch=2) if args.gradient_accumulation_steps > 1: updater = GradientAccumulationUpdater( train_iter, optimizer, converter=_convert, device=gpuid) else: updater = training.StandardUpdater( train_iter, optimizer, converter=_convert, device=gpuid) trainer = training.Trainer( updater, (args.max_epochs, 'epoch'), out=os.path.join(args.model_save_dir)) evaluator = extensions.Evaluator( dev_iter, model, converter=_convert, device=gpuid) trainer.extend(evaluator) if args.optimizer == 'noam': trainer.extend( NoamScheduler(args.hidden_size, warmup_steps=args.noam_warmup_steps, scale=args.noam_scale), trigger=(1, 'iteration')) if args.resume: chainer.serializers.load_npz(args.resume, trainer) # MICRO AVERAGE metrics = [ ('diarization_error', 'speaker_scored', 'DER'), ('speech_miss', 'speech_scored', 'SAD_MR'), ('speech_falarm', 'speech_scored', 'SAD_FR'), ('speaker_miss', 'speaker_scored', 'MI'), ('speaker_falarm', 'speaker_scored', 'FA'), ('speaker_error', 'speaker_scored', 'CF'), ('correct', 'frames', 'accuracy')] for num, den, name in metrics: trainer.extend(extensions.MicroAverage( 'main/{}'.format(num), 'main/{}'.format(den), 'main/{}'.format(name))) trainer.extend(extensions.MicroAverage( 'validation/main/{}'.format(num), 'validation/main/{}'.format(den), 'validation/main/{}'.format(name))) trainer.extend(extensions.LogReport(log_name='log_iter', trigger=(1000, 'iteration'))) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/diarization_error_rate', 'validation/main/diarization_error_rate', 'elapsed_time'])) trainer.extend(extensions.PlotReport( ['main/loss', 'validation/main/loss'], x_key='epoch', file_name='loss.png')) trainer.extend(extensions.PlotReport( ['main/diarization_error_rate', 'validation/main/diarization_error_rate'], x_key='epoch', file_name='DER.png')) trainer.extend(extensions.ProgressBar(update_interval=100)) trainer.extend(extensions.snapshot( filename='snapshot_epoch-{.updater.epoch}')) trainer.extend(extensions.dump_graph('main/loss', out_name="cg.dot")) trainer.run() print('Finished!')