예제 #1
0
def main():
    print('Initializing Training Process..')

    parser = argparse.ArgumentParser()

    parser.add_argument('--group_name', default=None)
    parser.add_argument('--input_wavs_dir',
                        default='LJSpeech-1.1/wavs',
                        help='')
    parser.add_argument('--input_mels_dir', default='ft_dataset', help='')
    parser.add_argument('--input_training_file',
                        default='LJSpeech-1.1/training.txt',
                        help='')
    parser.add_argument('--input_validation_file',
                        default='LJSpeech-1.1/validation.txt',
                        help='')
    parser.add_argument('--checkpoint_path', default='cp_hifigan')
    parser.add_argument('--config', default='')
    parser.add_argument('--training_epochs', default=3100, type=int)
    parser.add_argument('--stdout_interval', default=5, type=int)
    parser.add_argument('--checkpoint_interval', default=5000, type=int)
    parser.add_argument('--summary_interval', default=100, type=int)
    parser.add_argument('--validation_interval', default=1000, type=int)
    parser.add_argument('--fine_tuning', default=False, type=bool)

    a = parser.parse_args()

    with open(a.config) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)
    build_env(a.config, 'config.json', a.checkpoint_path)

    torch.manual_seed(h.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        h.num_gpus = torch.cuda.device_count()
        h.batch_size = int(h.batch_size / h.num_gpus)
        print('Batch size per GPU :', h.batch_size)
    else:
        pass

    if h.num_gpus > 1:
        mp.spawn(train, nprocs=h.num_gpus, args=(
            a,
            h,
        ))
    else:
        train(0, a, h)
예제 #2
0
def main():
    print('Initializing Inference Process..')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_mels_dir', default='test_mel_files')
    parser.add_argument('--output_dir', default='generated_files_from_mel')
    parser.add_argument('--checkpoint_file', required=True)
    a = parser.parse_args()

    config_file = os.path.join(
        os.path.split(a.checkpoint_file)[0], 'config.json')
    with open(config_file) as f:
        data = f.read()

    global h
    json_config = json.loads(data)
    h = AttrDict(json_config)

    torch.manual_seed(h.seed)
    global device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    inference(a)
예제 #3
0
def main():
    print("Initializing Inference Process..")

    parser = argparse.ArgumentParser()
    parser.add_argument("--input_wavs_dir", default="test_files")
    parser.add_argument("--output_dir", default="generated_files")
    parser.add_argument("--checkpoint_file", required=True)
    a = parser.parse_args()

    config_file = os.path.join(os.path.split(a.checkpoint_file)[0], "config.json")
    with open(config_file) as f:
        data = f.read()

    global h
    json_config = json.loads(data)
    h = AttrDict(json_config)

    torch.manual_seed(h.seed)
    global device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    inference(a)
예제 #4
0
def main():
    print('Initializing Training Process..')

    parser = argparse.ArgumentParser()

    parser.add_argument('--group_name', default=None)
    parser.add_argument('--checkpoint_path', default='cp_hifigan')
    parser.add_argument('--config', default='config_8k.json')
    parser.add_argument('--training_epochs', default=3100, type=int)
    parser.add_argument('--stdout_interval', default=5, type=int)
    parser.add_argument('--checkpoint_interval', default=5000, type=int)
    parser.add_argument('--summary_interval', default=100, type=int)
    parser.add_argument('--validation_interval', default=1000, type=int)
    parser.add_argument('--fine_tuning', default=False, type=bool)

    a = parser.parse_args()

    with open(a.config) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)
    build_env(a.config, 'config.json', a.checkpoint_path)

    model = Generator(h)

    inputs = torch.randn(10, 80, 80)
    output = model(inputs)
    print(output.shape)
예제 #5
0
def main(args):
    if args.config is not None:
        with open(args.config) as f:
            data = f.read()
        global h

        json_config = json.loads(data)
        h = AttrDict(json_config)

    torch.manual_seed(h.seed)

    model = Generator(h).cuda()
    state_dict_g = load_checkpoint(args.checkpoint_path, 'cuda')
    model.load_state_dict(state_dict_g['generator'])

    model.eval()
    model.remove_weight_norm()

    with torch.no_grad():
        mel = torch.from_numpy(np.load(args.input))
        if len(mel.shape) == 2:
            mel = mel.unsqueeze(0)
        mel = mel.cuda()
        #zero = torch.full((1, 80, 10), -11.5129).to(mel.device)
        #mel = torch.cat((mel, zero), dim=2)
        hifigan_trace = torch.jit.trace(model, mel)
        print(state_dict_g.keys())
        hifigan_trace.save("{}/hifigan_{}.pt".format(args.out, args.name))
예제 #6
0
def main():
    print('Initializing Inference Process..')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_wavs_dir', default='test_files')
    parser.add_argument('--output_dir', default='test_files_generated')
    parser.add_argument('--checkpoint_file', required=True)
    a = parser.parse_args()

    config_file = os.path.join(
        os.path.split(a.checkpoint_file)[0], 'config.json')
    with open(config_file) as f:
        data = f.read()

    global h
    json_config = json.loads(data)
    h = AttrDict(json_config)

    torch.manual_seed(h.seed)
    global device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    global STFT
    STFT = STFT_Class(h.sampling_rate, h.num_mels, h.n_fft, h.win_size,
                      h.hop_size, h.fmin, h.fmax)

    inference(a, STFT)
예제 #7
0
def main():
    print('Initializing Inference Process..')

    parser = argparse.ArgumentParser()
    parser.add_argument('--input_wavs_dir', default='test_files')
    parser.add_argument('--onnx_filename', default='./generator.onnx')
    parser.add_argument('--output_dir', default='generated_files')
    parser.add_argument('--checkpoint_file', required=True)
    parser.add_argument('--config_file', required=True)
    a = parser.parse_args()

    config_file = a.config_file
    with open(config_file) as f:
        data = f.read()

    global h
    json_config = json.loads(data)
    h = AttrDict(json_config)
    #print(h)

    torch.manual_seed(h.seed)
    global device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    inference(a)
예제 #8
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--group_name', default=None)
    parser.add_argument('--input_wavs_dir', default='LJSpeech-1.1/wavs')
    parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
    parser.add_argument('--checkpoint_path', default='cp_hifigan')
    parser.add_argument('--config', default='')
    parser.add_argument('--fine_tuning', default=False, type=bool)
    parser.add_argument('--input_mels_dir', default='ft_dataset')
    parser.add_argument('--speakers_json', default=None, type=str)
    parser.add_argument('--batch_size', default=15, type=int)
    # for train script compatibility
    parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')

    

    a = parser.parse_args()

    with open(a.config) as f:
        data = f.read()
    
    if a.speakers_json:
      with open(a.speakers_json) as f:
            speaker_mapping = json.load(f)
    else:
        speaker_mapping = None

    json_config = json.loads(data)
    h = AttrDict(json_config)

    torch.manual_seed(h.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        h.num_gpus = torch.cuda.device_count()
        h.batch_size = int(h.batch_size / h.num_gpus)
        print('Batch size per GPU :', h.batch_size)
    else:
        pass

    h.segment_size = h.segment_size*8

    eval(0, a, h, speaker_mapping)
예제 #9
0
def main():
    print('Initializing the Training Process..')
    
    parser = argparse.ArgumentParser()
    
    
    parser.add_argument('--input_wavs_dir', default='data/recordings')
    parser.add_argument('--input_mels_dir', default='processed_spokenDigits_np')
    parser.add_argument('--config', default='processed_spokenDigits_np')
    parser.add_argument('--training_epochs', default='1000')
    
    a = parser.parse_args()
    
    with open(a.config) as f:
        data = f.read()
        
        
    json_config = json.loads(data)
    h = AttrDict(json_config)
    
    build_env(a.config, 'config.json', a.checkpoint_path)
    
    torch.manual_seed(h.seed):
    
    if torch.cuda.is_availale(h.seed):
        torch.cuda.manual_seeed(h.seed)
        
        h.batch_size = int(h.batch_size / h.num_gpu)
    else:
        print('\nRunning on cpu')
        
        
    # train now--    
        g_losses, d_losses, generated_mels = train(h) 
    
    # visualize the loss as the network trained
    plt.plot(g_losses, d_losses)
    plt.xlabel('100\'s of batches')
    plt.ylabel('loss')
    plt.grid(True)
    # plt.ylim(0, 2.5) # consistent scale
    plt.show()
예제 #10
0
def main():
    print('Initializing Training Process..')

    parser = argparse.ArgumentParser()

    parser.add_argument('--rank', default=0, type=int)
    parser.add_argument('--group_name', default=None)
    parser.add_argument('--input_wavs_dir', default='data/LJSpeech-1.1/wavs')
    parser.add_argument('--input_train_metafile', default='data/LJSpeech-1.1/metadata_ljspeech.csv')
    parser.add_argument('--input_valid_metafile', default='data/LJSpeech-1.1/metadata_test_ljspeech.csv')
    parser.add_argument('--inference', default=False, action='store_true')
    parser.add_argument('--cps', default='cp_melgan')
    parser.add_argument('--cp_g', default='') # ex) cp_mgt_01/g_100.pth
    parser.add_argument('--cp_d', default='') # ex) cp_mgt_01/d_100.pth
    parser.add_argument('--config', default='hparams.json')
    parser.add_argument('--training_epochs', default=5000, type=int)
    parser.add_argument('--stdout_interval', default=1, type=int)
    parser.add_argument('--checkpoint_interval', default=5000, type=int)
    parser.add_argument('--summary_interval', default=100, type=int)
    parser.add_argument('--validation_interval', default=1000, type=int)

    a = parser.parse_args()

    with open(a.config) as f:
        data = f.read()

    global h
    json_config = json.loads(data)
    h = AttrDict(json_config)
    build_env(a.config, 'config.json', a.cps)

    torch.manual_seed(h.seed)
    global device
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        device = torch.device('cuda')
        h.num_gpus = torch.cuda.device_count()
    else:
        device = torch.device('cpu')

    fit(a, a.training_epochs)
예제 #11
0
def get_model(path, device):
    config_file = os.path.join(os.path.split(path)[0], 'config.json')
    with open(config_file) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)

    def load_checkpoint(filepath, device):
        assert os.path.isfile(filepath)
        print("Loading '{}'".format(filepath))
        checkpoint_dict = torch.load(filepath, map_location=device)
        print("Complete.")
        return checkpoint_dict

    generator = Generator(h).to(device)

    state_dict_g = load_checkpoint(path, device)
    generator.load_state_dict(state_dict_g['generator'])
    generator.eval()
    generator.remove_weight_norm()

    return generator
예제 #12
0
파일: vocoder.py 프로젝트: thuhcsi/hifi-gan
 def load_config(self, config_file):
     with open(config_file) as f:
         data = f.read()
         json_config = json.loads(data)
     return AttrDict(json_config)