def gpu_decode(feat_list, gpu): # set default gpu and do not track gradient torch.cuda.set_device(gpu) torch.set_grad_enabled(False) # define model and load parameters if config.use_upsampling_layer: upsampling_factor = config.upsampling_factor else: upsampling_factor = 0 model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage)["model"]) model.eval() model.cuda() # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, feature_type=config.feature_type, wav_transform=wav_transform, feat_transform=feat_transform, upsampling_factor=config.upsampling_factor, use_upsampling_layer=config.use_upsampling_layer, use_speaker_code=config.use_speaker_code) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel weights_dir = os.path.join(cfg.workdir, 'weights') check_and_mkdir(weights_dir) print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir + '/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=40, dilations=[1, 2, 4, 8, 16]) model = nn.DataParallel(model) model.cuda() model.train() # build loss loss_fn = nn.CTCLoss(blank=27) if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True) print("loading", cfg.workdir + '/weights/best.pth') cfg.load_from = cfg.workdir + '/weights/best.pth' scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def gpu_decode(feat_list, gpu): with torch.cuda.device(gpu): # define model and load parameters model = WaveNet(n_quantize=config.n_quantize, n_aux=config.n_aux, n_resch=config.n_resch, n_skipch=config.n_skipch, dilation_depth=config.dilation_depth, dilation_repeat=config.dilation_repeat, kernel_size=config.kernel_size, upsampling_factor=config.upsampling_factor) model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage.cuda(gpu)) ["model"]) model.eval() model.cuda() torch.backends.cudnn.benchmark = True # define generator generator = decode_generator( feat_list, batch_size=args.batch_size, wav_transform=wav_transform, feat_transform=feat_transform, use_speaker_code=config.use_speaker_code, upsampling_factor=config.upsampling_factor) # decode if args.batch_size > 1: for feat_ids, (batch_x, batch_h, n_samples_list) in generator: logging.info("decoding start") samples_list = model.batch_fast_generate( batch_x, batch_h, n_samples_list, args.intervals) for feat_id, samples in zip(feat_ids, samples_list): wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir)) else: for feat_id, (x, h, n_samples) in generator: logging.info("decoding %s (length = %d)" % (feat_id, n_samples)) samples = model.fast_generate(x, h, n_samples, args.intervals) wav = decode_mu_law(samples, config.n_quantize) sf.write(args.outdir + "/" + feat_id + ".wav", wav, args.fs, "PCM_16") logging.info("wrote %s.wav in %s." % (feat_id, args.outdir))
def main(args): print('Starting') matplotlib.use('agg') os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu checkpoints = args.checkpoint.parent.glob(args.checkpoint.name + '_*.pth') checkpoints = [c for c in checkpoints if extract_id(c) in args.decoders] assert len(checkpoints) >= 1, "No checkpoints found." model_args = torch.load(args.model.parent / 'args.pth')[0] encoder = wavenet_models.Encoder(model_args) encoder.load_state_dict(torch.load(checkpoints[0])['encoder_state']) encoder.eval() encoder = encoder.cuda() decoders = [] decoder_ids = [] for checkpoint in checkpoints: decoder = WaveNet(model_args) decoder.load_state_dict(torch.load(checkpoint)['decoder_state']) decoder.eval() decoder = decoder.cuda() if args.py: decoder = WavenetGenerator(decoder, args.batch_size, wav_freq=args.rate) else: decoder = NVWavenetGenerator(decoder, args.rate * (args.split_size // 20), args.batch_size, 3) decoders += [decoder] decoder_ids += [extract_id(checkpoint)] xs = [] assert args.output_next_to_orig ^ (args.output is not None) if len(args.files) == 1 and args.files[0].is_dir(): top = args.files[0] file_paths = list(top.glob('**/*.wav')) + list(top.glob('**/*.h5')) else: file_paths = args.files if not args.skip_filter: file_paths = [f for f in file_paths if not '_' in str(f.name)] for file_path in file_paths: if file_path.suffix == '.wav': data, rate = librosa.load(file_path, sr=16000) assert rate == 16000 data = utils.mu_law(data) elif file_path.suffix == '.h5': data = utils.mu_law(h5py.File(file_path, 'r')['wav'][:] / (2**15)) if data.shape[-1] % args.rate != 0: data = data[:-(data.shape[-1] % args.rate)] assert data.shape[-1] % args.rate == 0 else: raise Exception(f'Unsupported filetype {file_path}') if args.sample_len: data = data[:args.sample_len] else: args.sample_len = len(data) xs.append(torch.tensor(data).unsqueeze(0).float().cuda()) xs = torch.stack(xs).contiguous() print(f'xs size: {xs.size()}') def save(x, decoder_ix, filepath): wav = utils.inv_mu_law(x.cpu().numpy()) print(f'X size: {x.shape}') print(f'X min: {x.min()}, max: {x.max()}') if args.output_next_to_orig: save_audio(wav.squeeze(), filepath.parent / f'{filepath.stem}_{decoder_ix}.wav', rate=args.rate) else: save_audio(wav.squeeze(), args.output / str(extract_id(args.model)) / str(args.update) / filepath.with_suffix('.wav').name, rate=args.rate) yy = {} with torch.no_grad(): zz = [] for xs_batch in torch.split(xs, args.batch_size): zz += [encoder(xs_batch)] zz = torch.cat(zz, dim=0) with utils.timeit("Generation timer"): for i, decoder_id in enumerate(decoder_ids): yy[decoder_id] = [] decoder = decoders[i] for zz_batch in torch.split(zz, args.batch_size): print(zz_batch.shape) splits = torch.split(zz_batch, args.split_size, -1) audio_data = [] decoder.reset() for cond in tqdm.tqdm(splits): audio_data += [decoder.generate(cond).cpu()] audio_data = torch.cat(audio_data, -1) yy[decoder_id] += [audio_data] yy[decoder_id] = torch.cat(yy[decoder_id], dim=0) del decoder for decoder_ix, decoder_result in yy.items(): for sample_result, filepath in zip(decoder_result, file_paths): save(sample_result, decoder_ix, filepath)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--n_quantize", default=256, type=int, help="number of quantization") parser.add_argument("--n_aux", default=28, type=int, help="number of dimension of aux feats") parser.add_argument("--n_resch", default=512, type=int, help="number of channels of residual output") parser.add_argument("--n_skipch", default=256, type=int, help="number of channels of skip output") parser.add_argument("--dilation_depth", default=10, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=1, type=int, help="number of repeating of dilation") parser.add_argument("--kernel_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=0, type=int, help="upsampling factor of aux features" "(if set 0, do not apply)") parser.add_argument("--use_speaker_code", default=False, type=strtobool, help="flag to use speaker code") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay coefficient") parser.add_argument( "--batch_size", default=20000, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--iters", default=200000, type=int, help="number of iterations") # other setting parser.add_argument("--checkpoints", default=10000, type=int, help="how frequent saving model") parser.add_argument("--intervals", default=100, type=int, help="log interval") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig( level=logging.INFO, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig( level=logging.WARN, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # save args as conf torch.save(args, args.expdir + "/model.conf") # # define network model = WaveNet(n_quantize=args.n_quantize, n_aux=args.n_aux, n_resch=args.n_resch, n_skipch=args.n_skipch, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, upsampling_factor=args.upsampling_factor) logging.info(model) model.apply(initialize) model.train() # define loss and optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() # define transforms scaler = StandardScaler() scaler.mean_ = read_hdf5(args.stats, "/mean") scaler.scale_ = read_hdf5(args.stats, "/scale") wav_transform = transforms.Compose( [lambda x: encode_mu_law(x, args.n_quantize)]) feat_transform = transforms.Compose([lambda x: scaler.transform(x)]) # define generator if os.path.isdir(args.waveforms): filenames = sorted( find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [ args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames ] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) generator = train_generator(wav_list, feat_list, receptive_field=model.receptive_field, batch_size=args.batch_size, wav_transform=wav_transform, feat_transform=feat_transform, shuffle=True, upsampling_factor=args.upsampling_factor, use_speaker_code=args.use_speaker_code) while not generator.queue.full(): time.sleep(0.1) # resume if args.resume is not None: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) iterations = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % iterations) else: iterations = 0 # send to gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) # train loss = 0 total = 0 for i in six.moves.range(iterations, args.iters): start = time.time() (batch_x, batch_h), batch_t = generator.next() batch_output = model(batch_x, batch_h)[0] batch_loss = criterion(batch_output[model.receptive_field:], batch_t[model.receptive_field:]) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss += batch_loss.data[0] total += time.time() - start logging.debug("batch loss = %.3f (%.3f sec / batch)" % (batch_loss.data[0], time.time() - start)) # report progress if (i + 1) % args.intervals == 0: logging.info( "(iter:%d) average loss = %.6f (%.3f sec / batch)" % (i + 1, loss / args.intervals, total / args.intervals)) logging.info( "estimated required time = " "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}". format( relativedelta(seconds=int((args.iters - (i + 1)) * (total / args.intervals))))) loss = 0 total = 0 # save intermidiate model if (i + 1) % args.checkpoints == 0: save_checkpoint(args.expdir, model, optimizer, i + 1) # save final model model.cpu() torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel if args.find_pattern == True: cfg.find_pattern_num = 16 cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold: exit() if args.skip_exist == True: if os.path.exists(cfg.workdir): exit() print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True) # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True) vctk_val = VCTK(cfg, 'val') if args.test_acc_cmodel == True: val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True) else: val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] nn.init.xavier_normal_(raw_w, gain=1.0) a[name] = raw_w model.load_state_dict(a) weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) if args.vis_pattern == True or args.vis_mask == True: cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True) print("loading", cfg.workdir + '/weights/best.pth') cfg.load_from = cfg.workdir + '/weights/best.pth' if args.test_acc == True: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) else: print("Error: model file not exists, ", cfg.load_from) exit() else: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) # Export the model print("exporting onnx ...") model.eval() batch_size = 1 x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda() torch.onnx.export(model.module, # model being run x, # model input (or a tuple for multiple inputs) "wavenet.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}}) if os.path.exists(args.load_from_h5): # model.load_state_dict(torch.load(args.load_from_h5), strict=True) print("loading", args.load_from_h5) model.train() model_dict = model.state_dict() print(model_dict.keys()) #先将参数值numpy转换为tensor形式 pretrained_dict = dd.io.load(args.load_from_h5) print(pretrained_dict.keys()) new_pre_dict = {} for k,v in pretrained_dict.items(): new_pre_dict[k] = torch.Tensor(v) #更新 model_dict.update(new_pre_dict) #加载 model.load_state_dict(model_dict) if args.find_pattern == True: # cfg.find_pattern_num = 16 # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold: # exit() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \ = find_pattern_by_similarity(raw_w , cfg.find_pattern_num , cfg.find_pattern_shape , cfg.find_zero_threshold , cfg.find_score_threshold) pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \ = pattern_curve_analyse(raw_w.shape , cfg.find_pattern_shape , patterns , pattern_match_num , pattern_coo_nnz , pattern_nnz , pattern_inner_nnz) write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel) , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz , pattern_inner_nnz , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict) # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel) # , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para # , all_nnzs.values(), all_patterns.values()) exit() if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') elif cfg.sparse_mode == 'find_retrain': cfg.pattern_num = int(args.find_retrain_para.split('_')[0]) cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])] cfg.pattern_nnz = int(args.find_retrain_para.split('_')[3]) cfg.coo_num = float(args.find_retrain_para.split('_')[4]) cfg.layer_or_model_wise = str(args.find_retrain_para.split('_')[5]) # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set( # cfg.pattern_shape, cfg.pattern_nnz) print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}') elif cfg.sparse_mode == 'hcgs_pruning': print(args.pattern_para) cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])] cfg.reserve_num1 = int(args.hcgs_para.split('_')[2]) cfg.reserve_num2 = int(args.hcgs_para.split('_')[3]) print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16') cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2) if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [8,8]) patterns = list(pattern_count_dict.keys()) counts = list(pattern_count_dict.values()) print(len(patterns)) print(counts) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=27) # loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) if args.test_acc == True: f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() if args.test_acc_cmodel == True: f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel if args.skip_exist == True: if os.path.exists(cfg.workdir): exit() print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True) # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] nn.init.xavier_normal_(raw_w, gain=1.0) a[name] = raw_w model.load_state_dict(a) weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) if args.vis_pattern == True or args.vis_mask == True: cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=False) print("loading", cfg.workdir + '/weights/best.pth') if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=False) print("loading", cfg.load_from) if os.path.exists(args.load_from_h5): # model.load_state_dict(torch.load(args.load_from_h5), strict=True) print("loading", args.load_from_h5) model.train() model_dict = model.state_dict() print(model_dict.keys()) #先将参数值numpy转换为tensor形式 pretrained_dict = dd.io.load(args.load_from_h5) print(pretrained_dict.keys()) new_pre_dict = {} for k,v in pretrained_dict.items(): new_pre_dict[k] = torch.Tensor(v) #更新 model_dict.update(new_pre_dict) #加载 model.load_state_dict(model_dict) if args.find_pattern == True: cfg.find_pattern_num = 16 cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \ = find_pattern_by_similarity(raw_w , cfg.find_pattern_num , cfg.find_pattern_shape , cfg.find_zero_threshold , cfg.find_score_threshold) pattern_num_memory_dict, pattern_num_coo_nnz_dict \ = pattern_curve_analyse(raw_w.shape , cfg.find_pattern_shape , patterns , pattern_match_num , pattern_coo_nnz , pattern_nnz , pattern_inner_nnz) write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel) , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz , pattern_num_memory_dict, pattern_num_coo_nnz_dict) # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel) # , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para # , all_nnzs.values(), all_patterns.values()) exit() if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [8,8]) patterns = list(pattern_count_dict.keys()) counts = list(pattern_count_dict.values()) print(len(patterns)) print(counts) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=27) # loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) if args.test_acc == True: f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", required=True, type=str, help="directory or list of wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of aux feat files") parser.add_argument("--stats", required=True, type=str, help="hdf5 file including statistics") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") parser.add_argument("--feature_type", default="world", choices=["world", "melspc"], type=str, help="feature type") # network structure setting parser.add_argument("--n_quantize", default=256, type=int, help="number of quantization") parser.add_argument("--n_aux", default=28, type=int, help="number of dimension of aux feats") parser.add_argument("--n_resch", default=512, type=int, help="number of channels of residual output") parser.add_argument("--n_skipch", default=256, type=int, help="number of channels of skip output") parser.add_argument("--dilation_depth", default=10, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=1, type=int, help="number of repeating of dilation") parser.add_argument("--kernel_size", default=2, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--upsampling_factor", default=80, type=int, help="upsampling factor of aux features") parser.add_argument("--use_upsampling_layer", default=True, type=strtobool, help="flag to use upsampling layer") parser.add_argument("--use_speaker_code", default=False, type=strtobool, help="flag to use speaker code") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay coefficient") parser.add_argument( "--batch_length", default=20000, type=int, help="batch length (if set 0, utterance batch will be used)") parser.add_argument( "--batch_size", default=1, type=int, help="batch size (if use utterance batch, batch_size will be 1.") parser.add_argument("--iters", default=200000, type=int, help="number of iterations") # other setting parser.add_argument("--checkpoints", default=10000, type=int, help="how frequent saving model") parser.add_argument("--intervals", default=100, type=int, help="log interval") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, nargs="?", type=str, help="model path to restart training") parser.add_argument("--n_gpus", default=1, type=int, help="number of gpus") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() # set log level if args.verbose == 1: logging.basicConfig( level=logging.INFO, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S') elif args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S') else: logging.basicConfig( level=logging.WARNING, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S') logging.warning("logging is disabled.") # show argmument for key, value in vars(args).items(): logging.info("%s = %s" % (key, str(value))) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) # fix slow computation of dilated conv # https://github.com/pytorch/pytorch/issues/15054#issuecomment-450191923 torch.backends.cudnn.benchmark = True # save args as conf torch.save(args, args.expdir + "/model.conf") # define network if args.use_upsampling_layer: upsampling_factor = args.upsampling_factor else: upsampling_factor = 0 model = WaveNet(n_quantize=args.n_quantize, n_aux=args.n_aux, n_resch=args.n_resch, n_skipch=args.n_skipch, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, kernel_size=args.kernel_size, upsampling_factor=upsampling_factor) logging.info(model) model.apply(initialize) model.train() if args.n_gpus > 1: device_ids = range(args.n_gpus) model = torch.nn.DataParallel(model, device_ids) model.receptive_field = model.module.receptive_field if args.n_gpus > args.batch_size: logging.warning("batch size is less than number of gpus.") # define optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = nn.CrossEntropyLoss() # define transforms scaler = StandardScaler() scaler.mean_ = read_hdf5(args.stats, "/" + args.feature_type + "/mean") scaler.scale_ = read_hdf5(args.stats, "/" + args.feature_type + "/scale") wav_transform = transforms.Compose( [lambda x: encode_mu_law(x, args.n_quantize)]) feat_transform = transforms.Compose([lambda x: scaler.transform(x)]) # define generator if os.path.isdir(args.waveforms): filenames = sorted( find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] feat_list = [ args.feats + "/" + filename.replace(".wav", ".h5") for filename in filenames ] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) feat_list = read_txt(args.feats) else: logging.error("--waveforms should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(wav_list)) generator = train_generator(wav_list, feat_list, receptive_field=model.receptive_field, batch_length=args.batch_length, batch_size=args.batch_size, feature_type=args.feature_type, wav_transform=wav_transform, feat_transform=feat_transform, shuffle=True, upsampling_factor=args.upsampling_factor, use_upsampling_layer=args.use_upsampling_layer, use_speaker_code=args.use_speaker_code) # charge minibatch in queue while not generator.queue.full(): time.sleep(0.1) # resume model and optimizer if args.resume is not None and len(args.resume) != 0: checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) iterations = checkpoint["iterations"] if args.n_gpus > 1: model.module.load_state_dict(checkpoint["model"]) else: model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) logging.info("restored from %d-iter checkpoint." % iterations) else: iterations = 0 # check gpu and then send to gpu if torch.cuda.is_available(): model.cuda() criterion.cuda() for state in optimizer.state.values(): for key, value in state.items(): if torch.is_tensor(value): state[key] = value.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) # train loss = 0 total = 0 for i in six.moves.range(iterations, args.iters): start = time.time() (batch_x, batch_h), batch_t = generator.next() batch_output = model(batch_x, batch_h) batch_loss = criterion( batch_output[:, model.receptive_field:].contiguous().view( -1, args.n_quantize), batch_t[:, model.receptive_field:].contiguous().view(-1)) optimizer.zero_grad() batch_loss.backward() optimizer.step() loss += batch_loss.item() total += time.time() - start logging.debug("batch loss = %.3f (%.3f sec / batch)" % (batch_loss.item(), time.time() - start)) # report progress if (i + 1) % args.intervals == 0: logging.info( "(iter:%d) average loss = %.6f (%.3f sec / batch)" % (i + 1, loss / args.intervals, total / args.intervals)) logging.info( "estimated required time = " "{0.days:02}:{0.hours:02}:{0.minutes:02}:{0.seconds:02}". format( relativedelta(seconds=int((args.iters - (i + 1)) * (total / args.intervals))))) loss = 0 total = 0 # save intermidiate model if (i + 1) % args.checkpoints == 0: if args.n_gpus > 1: save_checkpoint(args.expdir, model.module, optimizer, i + 1) else: save_checkpoint(args.expdir, model, optimizer, i + 1) # save final model if args.n_gpus > 1: torch.save({"model": model.module.state_dict()}, args.expdir + "/checkpoint-final.pkl") else: torch.save({"model": model.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.workdir = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth')) print("loading", cfg.workdir + '/weights/best.pth') if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from)) print("loading", cfg.load_from) if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [16,16]) patterns = list(pattern_count_dict.keys()) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=0, reduction='none') # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)