def build_model(hps, log): model = WaveNODE(hps) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print('number of parameters:', n_params) state = {} state['n_params'] = n_params log.write('%s\n' % json.dumps(state)) log.flush() return model
def build_model(hps): model = WaveNODE(hps) print('number of parameters:', sum(p.numel() for p in model.parameters() if p.requires_grad)) return model