def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel if args.find_pattern == True: cfg.find_pattern_num = 16 cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold: exit() if args.skip_exist == True: if os.path.exists(cfg.workdir): exit() print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True) # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True) vctk_val = VCTK(cfg, 'val') if args.test_acc_cmodel == True: val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True) else: val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] nn.init.xavier_normal_(raw_w, gain=1.0) a[name] = raw_w model.load_state_dict(a) weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) if args.vis_pattern == True or args.vis_mask == True: cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True) print("loading", cfg.workdir + '/weights/best.pth') cfg.load_from = cfg.workdir + '/weights/best.pth' if args.test_acc == True: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) else: print("Error: model file not exists, ", cfg.load_from) exit() else: if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=True) print("loading", cfg.load_from) # Export the model print("exporting onnx ...") model.eval() batch_size = 1 x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda() torch.onnx.export(model.module, # model being run x, # model input (or a tuple for multiple inputs) "wavenet.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}}) if os.path.exists(args.load_from_h5): # model.load_state_dict(torch.load(args.load_from_h5), strict=True) print("loading", args.load_from_h5) model.train() model_dict = model.state_dict() print(model_dict.keys()) #先将参数值numpy转换为tensor形式 pretrained_dict = dd.io.load(args.load_from_h5) print(pretrained_dict.keys()) new_pre_dict = {} for k,v in pretrained_dict.items(): new_pre_dict[k] = torch.Tensor(v) #更新 model_dict.update(new_pre_dict) #加载 model.load_state_dict(model_dict) if args.find_pattern == True: # cfg.find_pattern_num = 16 # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold: # exit() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \ = find_pattern_by_similarity(raw_w , cfg.find_pattern_num , cfg.find_pattern_shape , cfg.find_zero_threshold , cfg.find_score_threshold) pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \ = pattern_curve_analyse(raw_w.shape , cfg.find_pattern_shape , patterns , pattern_match_num , pattern_coo_nnz , pattern_nnz , pattern_inner_nnz) write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel) , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz , pattern_inner_nnz , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict) # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel) # , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para # , all_nnzs.values(), all_patterns.values()) exit() if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') elif cfg.sparse_mode == 'find_retrain': cfg.pattern_num = int(args.find_retrain_para.split('_')[0]) cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])] cfg.pattern_nnz = int(args.find_retrain_para.split('_')[3]) cfg.coo_num = float(args.find_retrain_para.split('_')[4]) cfg.layer_or_model_wise = str(args.find_retrain_para.split('_')[5]) # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set( # cfg.pattern_shape, cfg.pattern_nnz) print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}') elif cfg.sparse_mode == 'hcgs_pruning': print(args.pattern_para) cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])] cfg.reserve_num1 = int(args.hcgs_para.split('_')[2]) cfg.reserve_num2 = int(args.hcgs_para.split('_')[3]) print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16') cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2) if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" \ and name.split(".")[-2] != "bn2" \ and name.split(".")[-2] != "bn3" \ and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [8,8]) patterns = list(pattern_count_dict.keys()) counts = list(pattern_count_dict.values()) print(len(patterns)) print(counts) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=27) # loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) if args.test_acc == True: f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() if args.test_acc_cmodel == True: f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn) # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.work_root = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' cfg.workdir = cfg.work_root + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from cfg.save_excel = args.save_excel if args.skip_exist == True: if os.path.exists(cfg.workdir): exit() print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True) # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] nn.init.xavier_normal_(raw_w, gain=1.0) a[name] = raw_w model.load_state_dict(a) weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) if args.vis_pattern == True or args.vis_mask == True: cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=False) print("loading", cfg.workdir + '/weights/best.pth') if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from), strict=False) print("loading", cfg.load_from) if os.path.exists(args.load_from_h5): # model.load_state_dict(torch.load(args.load_from_h5), strict=True) print("loading", args.load_from_h5) model.train() model_dict = model.state_dict() print(model_dict.keys()) #先将参数值numpy转换为tensor形式 pretrained_dict = dd.io.load(args.load_from_h5) print(pretrained_dict.keys()) new_pre_dict = {} for k,v in pretrained_dict.items(): new_pre_dict[k] = torch.Tensor(v) #更新 model_dict.update(new_pre_dict) #加载 model.load_state_dict(model_dict) if args.find_pattern == True: cfg.find_pattern_num = 16 cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])] cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0]) cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1]) name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) a = model.state_dict() for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] if raw_w.size(0) == 128 and raw_w.size(1) == 128: patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \ = find_pattern_by_similarity(raw_w , cfg.find_pattern_num , cfg.find_pattern_shape , cfg.find_zero_threshold , cfg.find_score_threshold) pattern_num_memory_dict, pattern_num_coo_nnz_dict \ = pattern_curve_analyse(raw_w.shape , cfg.find_pattern_shape , patterns , pattern_match_num , pattern_coo_nnz , pattern_nnz , pattern_inner_nnz) write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel) , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz , pattern_num_memory_dict, pattern_num_coo_nnz_dict) # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel) # , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para # , all_nnzs.values(), all_patterns.values()) exit() if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [8,8]) patterns = list(pattern_count_dict.keys()) counts = list(pattern_count_dict.values()) print(len(patterns)) print(counts) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=27) # loss_fn = nn.CTCLoss() # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) if args.test_acc == True: f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn) write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), cfg.exp_name, f1, val_loss, tps, preds, poses) exit() # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate, iters_per_checkpoint, iters_per_eval, batch_size, seed, checkpoint_path, log_dir, ema_decay=0.9999): torch.manual_seed(seed) torch.cuda.manual_seed(seed) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: init_distributed(rank, num_gpus, group_name, **dist_config) #=====END: ADDED FOR DISTRIBUTED====== if train_data_config["no_chunks"]: criterion = MaskedCrossEntropyLoss() else: criterion = CrossEntropyLoss() model = WaveNet(**wavenet_config).cuda() ema = ExponentialMovingAverage(ema_decay) for name, param in model.named_parameters(): if param.requires_grad: ema.register(name, param.data) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: model = apply_gradient_allreduce(model) #=====END: ADDED FOR DISTRIBUTED====== optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) scheduler = StepLR(optimizer, step_size=200000, gamma=0.5) # Load checkpoint if one exists iteration = 0 if checkpoint_path != "": model, optimizer, scheduler, iteration, ema = load_checkpoint(checkpoint_path, model, optimizer, scheduler, ema) iteration += 1 # next iteration is iteration + 1 trainset = Mel2SampOnehot(audio_config=audio_config, verbose=True, **train_data_config) validset = Mel2SampOnehot(audio_config=audio_config, verbose=False, **valid_data_config) # =====START: ADDED FOR DISTRIBUTED====== train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None valid_sampler = DistributedSampler(validset) if num_gpus > 1 else None # =====END: ADDED FOR DISTRIBUTED====== print(train_data_config) if train_data_config["no_chunks"]: collate_fn = utils.collate_fn else: collate_fn = torch.utils.data.dataloader.default_collate train_loader = DataLoader(trainset, num_workers=1, shuffle=False, collate_fn=collate_fn, sampler=train_sampler, batch_size=batch_size, pin_memory=True, drop_last=True) valid_loader = DataLoader(validset, num_workers=1, shuffle=False, sampler=valid_sampler, batch_size=1, pin_memory=True) # Get shared output_directory ready if rank == 0: if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) model.train() epoch_offset = max(0, int(iteration / len(train_loader))) writer = SummaryWriter(log_dir) print("Checkpoints writing to: {}".format(log_dir)) # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, epochs): print("Epoch: {}".format(epoch)) for i, batch in enumerate(train_loader): if low_memory: torch.cuda.empty_cache() scheduler.step() model.zero_grad() if train_data_config["no_chunks"]: x, y, seq_lens = batch seq_lens = to_gpu(seq_lens) else: x, y = batch x = to_gpu(x).float() y = to_gpu(y) x = (x, y) # auto-regressive takes outputs as inputs y_pred = model(x) if train_data_config["no_chunks"]: loss = criterion(y_pred, y, seq_lens) else: loss = criterion(y_pred, y) if num_gpus > 1: reduced_loss = reduce_tensor(loss.data, num_gpus)[0] else: reduced_loss = loss.data[0] loss.backward() optimizer.step() for name, param in model.named_parameters(): if name in ema.shadow: ema.update(name, param.data) print("{}:\t{:.9f}".format(iteration, reduced_loss)) if rank == 0: writer.add_scalar('loss', reduced_loss, iteration) if (iteration % iters_per_checkpoint == 0 and iteration): if rank == 0: checkpoint_path = "{}/wavenet_{}".format( output_directory, iteration) save_checkpoint(model, optimizer, scheduler, learning_rate, iteration, checkpoint_path, ema, wavenet_config) if (iteration % iters_per_eval == 0 and iteration > 0 and not config["no_validation"]): if low_memory: torch.cuda.empty_cache() if rank == 0: model_eval = nv_wavenet.NVWaveNet(**(model.export_weights())) for j, valid_batch in enumerate(valid_loader): mel, audio = valid_batch mel = to_gpu(mel).float() cond_input = model.get_cond_input(mel) predicted_audio = model_eval.infer(cond_input, nv_wavenet.Impl.AUTO) predicted_audio = utils.mu_law_decode_numpy(predicted_audio[0, :].cpu().numpy(), 256) writer.add_audio("valid/predicted_audio_{}".format(j), predicted_audio, iteration, 22050) audio = utils.mu_law_decode_numpy(audio[0, :].cpu().numpy(), 256) writer.add_audio("valid_true/audio_{}".format(j), audio, iteration, 22050) if low_memory: torch.cuda.empty_cache() iteration += 1
def main(): args = parse_args() cfg.resume = args.resume cfg.exp_name = args.exp cfg.workdir = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' + args.exp + '/debug' cfg.sparse_mode = args.sparse_mode cfg.batch_size = args.batch_size cfg.lr = args.lr cfg.load_from = args.load_from print('initial training...') print(f'work_dir:{cfg.workdir}, \n\ pretrained: {cfg.load_from}, \n\ batch_size: {cfg.batch_size}, \n\ lr : {cfg.lr}, \n\ epochs : {cfg.epochs}, \n\ sparse : {cfg.sparse_mode}') writer = SummaryWriter(log_dir=cfg.workdir+'/runs') # build train data vctk_train = VCTK(cfg, 'train') train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True) vctk_val = VCTK(cfg, 'val') val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True) # build model model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16]) model = nn.DataParallel(model) model.cuda() weights_dir = os.path.join(cfg.workdir, 'weights') if not os.path.exists(weights_dir): os.mkdir(weights_dir) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name) if not os.path.exists(cfg.vis_dir): os.mkdir(cfg.vis_dir) model.train() if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'): model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth')) print("loading", cfg.workdir + '/weights/best.pth') if os.path.exists(cfg.load_from): model.load_state_dict(torch.load(cfg.load_from)) print("loading", cfg.load_from) if cfg.sparse_mode == 'sparse_pruning': cfg.sparsity = args.sparsity print(f'sparse_pruning {cfg.sparsity}') elif cfg.sparse_mode == 'pattern_pruning': print(args.pattern_para) pattern_num = int(args.pattern_para.split('_')[0]) pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])] pattern_nnz = int(args.pattern_para.split('_')[3]) print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}') cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) elif cfg.sparse_mode == 'coo_pruning': cfg.coo_shape = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])] cfg.coo_nnz = int(args.coo_para.split('_')[2]) # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz) print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}') elif cfg.sparse_mode == 'ptcoo_pruning': cfg.pattern_num = int(args.pattern_para.split('_')[0]) cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])] cfg.pt_nnz = int(args.ptcoo_para.split('_')[3]) cfg.coo_nnz = int(args.ptcoo_para.split('_')[4]) cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz) cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns) print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}') if args.vis_mask == True: name_list = list() para_list = list() for name, para in model.named_parameters(): name_list.append(name) para_list.append(para) for i, name in enumerate(name_list): if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias": raw_w = para_list[i] zero = torch.zeros_like(raw_w) one = torch.ones_like(raw_w) mask = torch.where(raw_w == 0, zero, one) vis.save_visualized_mask(mask, name) exit() if args.vis_pattern == True: pattern_count_dict = find_pattern_model(model, [16,16]) patterns = list(pattern_count_dict.keys()) vis.save_visualized_pattern(patterns) exit() # build loss loss_fn = nn.CTCLoss(blank=0, reduction='none') # scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4) # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5) # train train(train_loader, scheduler, model, loss_fn, val_loader, writer)