Esempio n. 1
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel        
    
    if args.find_pattern == True:
        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])
        if int(cfg.find_pattern_shape[0] * cfg.find_pattern_shape[1]) <= cfg.find_score_threshold:
            exit()

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=4, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    if args.test_acc_cmodel == True:
        val_loader = DataLoader(vctk_val, batch_size=1, num_workers=4, shuffle=False, pin_memory=True)
    else:
        val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=4, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=40, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" \
            and name.split(".")[-2] != "bn2" \
            and name.split(".")[-2] != "bn3" \
            and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=True)
        print("loading", cfg.workdir + '/weights/best.pth')
        cfg.load_from = cfg.workdir + '/weights/best.pth'

    if args.test_acc == True:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
        else:
            print("Error: model file not exists, ", cfg.load_from)
            exit()
    else:
        if os.path.exists(cfg.load_from):
            model.load_state_dict(torch.load(cfg.load_from), strict=True)
            print("loading", cfg.load_from)
            # Export the model
            print("exporting onnx ...")
            model.eval()
            batch_size = 1
            x = torch.randn(batch_size, 40, 720, requires_grad=True).cuda()
            torch.onnx.export(model.module,               # model being run
                            x,                         # model input (or a tuple for multiple inputs)
                            "wavenet.onnx",   # where to save the model (can be a file or file-like object)
                            export_params=True,        # store the trained parameter weights inside the model file
                            opset_version=10,          # the ONNX version to export the model to
                            do_constant_folding=True,  # whether to execute constant folding for optimization
                            input_names = ['input'],   # the model's input names
                            output_names = ['output'], # the model's output names
                            dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                            'output' : {0 : 'batch_size'}})

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        # cfg.find_pattern_num   = 16
        # cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        # cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        # cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        # if cfg.find_pattern_shape[0] * cfg.find_pattern_shape[0] <= cfg.find_score_threshold:
        #     exit()

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_inner_nnz
                                        , pattern_num_memory_dict, pattern_num_cal_num_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'find_retrain':
        cfg.pattern_num   = int(args.find_retrain_para.split('_')[0])
        cfg.pattern_shape = [int(args.find_retrain_para.split('_')[1]), int(args.find_retrain_para.split('_')[2])]
        cfg.pattern_nnz   = int(args.find_retrain_para.split('_')[3])
        cfg.coo_num       = float(args.find_retrain_para.split('_')[4])
        cfg.layer_or_model_wise   = str(args.find_retrain_para.split('_')[5])
        # cfg.fd_rtn_pattern_candidates = generate_complete_pattern_set(
        #                                 cfg.pattern_shape, cfg.pattern_nnz)
        print(f'find_retrain {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pattern_nnz} {cfg.coo_num} {cfg.layer_or_model_wise}')

    elif cfg.sparse_mode == 'hcgs_pruning':
        print(args.pattern_para)
        cfg.block_shape = [int(args.hcgs_para.split('_')[0]), int(args.hcgs_para.split('_')[1])]
        cfg.reserve_num1 = int(args.hcgs_para.split('_')[2])
        cfg.reserve_num2 = int(args.hcgs_para.split('_')[3])
        print(f'hcgs_pruning {cfg.reserve_num1}/8 {cfg.reserve_num2}/16')
        cfg.hcgs_mask = generate_hcgs_mask(model, cfg.block_shape, cfg.reserve_num1, cfg.reserve_num2)

    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" \
                and name.split(".")[-2] != "bn2" \
                and name.split(".")[-2] != "bn3" \
                and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()

    if args.test_acc_cmodel == True:
        f1, val_loss, tps, preds, poses = test_acc_cmodel(val_loader, model, loss_fn)
        # f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
Esempio n. 2
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.work_root   = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/'
    cfg.workdir     = cfg.work_root + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from
    cfg.save_excel   = args.save_excel

    if args.skip_exist == True:
        if os.path.exists(cfg.workdir):
            exit()

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train, batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    # train_loader = dataset.create("data/v28/train.record", cfg.batch_size, repeat=True)
    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()


    name_list = list()
    para_list = list()
    for name, para in model.named_parameters():
        name_list.append(name)
        para_list.append(para)

    a = model.state_dict()
    for i, name in enumerate(name_list):
        if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
            raw_w = para_list[i]
            nn.init.xavier_normal_(raw_w, gain=1.0)
            a[name] = raw_w
    model.load_state_dict(a)
    

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    if args.vis_pattern == True or args.vis_mask == True:
        cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'), strict=False)
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from), strict=False)
        print("loading", cfg.load_from)

    if os.path.exists(args.load_from_h5):
        # model.load_state_dict(torch.load(args.load_from_h5), strict=True)
        print("loading", args.load_from_h5)
        model.train()
        model_dict = model.state_dict()
        print(model_dict.keys())
        #先将参数值numpy转换为tensor形式
        pretrained_dict = dd.io.load(args.load_from_h5)
        print(pretrained_dict.keys())
        new_pre_dict = {}
        for k,v in pretrained_dict.items():
            new_pre_dict[k] = torch.Tensor(v)
        #更新
        model_dict.update(new_pre_dict)
        #加载
        model.load_state_dict(model_dict)

    if args.find_pattern == True:

        cfg.find_pattern_num   = 16
        cfg.find_pattern_shape = [int(args.find_pattern_shape.split('_')[0]), int(args.find_pattern_shape.split('_')[1])]
        cfg.find_zero_threshold = float(args.find_pattern_para.split('_')[0])
        cfg.find_score_threshold = int(args.find_pattern_para.split('_')[1])

        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        a = model.state_dict()
        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]
                if raw_w.size(0) == 128 and raw_w.size(1) == 128:
                    patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz, pattern_inner_nnz \
                                    = find_pattern_by_similarity(raw_w
                                        , cfg.find_pattern_num
                                        , cfg.find_pattern_shape
                                        , cfg.find_zero_threshold
                                        , cfg.find_score_threshold)

                    pattern_num_memory_dict, pattern_num_coo_nnz_dict \
                                    = pattern_curve_analyse(raw_w.shape
                                        , cfg.find_pattern_shape
                                        , patterns
                                        , pattern_match_num
                                        , pattern_coo_nnz
                                        , pattern_nnz
                                        , pattern_inner_nnz)
                                        
                    write_pattern_curve_analyse(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                                        , cfg.exp_name + " " + args.find_pattern_shape + " " + args.find_pattern_para
                                        , patterns, pattern_match_num, pattern_coo_nnz, pattern_nnz
                                        , pattern_num_memory_dict, pattern_num_coo_nnz_dict)

                    # write_pattern_count(os.path.join(cfg.work_root, args.save_pattern_count_excel)
                    #                     , cfg.exp_name + " " + args.find_pattern_shape +" " + args.find_pattern_para
                    #                     , all_nnzs.values(), all_patterns.values())
                    exit()



    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')

    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)

    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
        
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [8,8])
        patterns = list(pattern_count_dict.keys())
        counts = list(pattern_count_dict.values())
        print(len(patterns))
        print(counts)
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=27)
    # loss_fn = nn.CTCLoss()

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)
    if args.test_acc == True:
        f1, val_loss, tps, preds, poses = test_acc(val_loader, model, loss_fn)
        write_test_acc(os.path.join(cfg.work_root, args.test_acc_excel), 
                        cfg.exp_name, f1, val_loss, tps, preds, poses)
        exit()
    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)
Esempio n. 3
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, iters_per_eval, batch_size, seed, checkpoint_path, log_dir, ema_decay=0.9999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    if train_data_config["no_chunks"]:
        criterion = MaskedCrossEntropyLoss()
    else:
        criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()
    ema = ExponentialMovingAverage(ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=200000, gamma=0.5)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, scheduler, iteration, ema = load_checkpoint(checkpoint_path, model,
                                                                      optimizer, scheduler, ema)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2SampOnehot(audio_config=audio_config, verbose=True, **train_data_config)
    validset = Mel2SampOnehot(audio_config=audio_config, verbose=False, **valid_data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    valid_sampler = DistributedSampler(validset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    print(train_data_config)
    if train_data_config["no_chunks"]:
        collate_fn = utils.collate_fn
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
                              collate_fn=collate_fn,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(validset, num_workers=1, shuffle=False,
                              sampler=valid_sampler, batch_size=1, pin_memory=True)
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    writer = SummaryWriter(log_dir)
    print("Checkpoints writing to: {}".format(log_dir))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            if low_memory:
                torch.cuda.empty_cache()
            scheduler.step()
            model.zero_grad()

            if train_data_config["no_chunks"]:
                x, y, seq_lens = batch
                seq_lens = to_gpu(seq_lens)
            else:
                x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            if train_data_config["no_chunks"]:
                loss = criterion(y_pred, y, seq_lens)
            else:
                loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if rank == 0:
                writer.add_scalar('loss', reduced_loss, iteration)
            if (iteration % iters_per_checkpoint == 0 and iteration):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                                    checkpoint_path, ema, wavenet_config)
            if (iteration % iters_per_eval == 0 and iteration > 0 and not config["no_validation"]):
                if low_memory:
                    torch.cuda.empty_cache()
                if rank == 0:
                    model_eval = nv_wavenet.NVWaveNet(**(model.export_weights()))
                    for j, valid_batch in enumerate(valid_loader):
                        mel, audio = valid_batch
                        mel = to_gpu(mel).float()
                        cond_input = model.get_cond_input(mel)
                        predicted_audio = model_eval.infer(cond_input, nv_wavenet.Impl.AUTO)
                        predicted_audio = utils.mu_law_decode_numpy(predicted_audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid/predicted_audio_{}".format(j),
                                         predicted_audio,
                                         iteration,
                                         22050)
                        audio = utils.mu_law_decode_numpy(audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid_true/audio_{}".format(j),
                                         audio,
                                         iteration,
                                         22050)
                        if low_memory:
                            torch.cuda.empty_cache()
            iteration += 1
Esempio n. 4
0
def main():
    args = parse_args()
    cfg.resume      = args.resume
    cfg.exp_name    = args.exp
    cfg.workdir     = '/zhzhao/code/wavenet_torch/torch_lyuan/exp_result/' + args.exp + '/debug'
    cfg.sparse_mode = args.sparse_mode
    cfg.batch_size  = args.batch_size
    cfg.lr          = args.lr
    cfg.load_from   = args.load_from

    print('initial training...')
    print(f'work_dir:{cfg.workdir}, \n\
            pretrained: {cfg.load_from},  \n\
            batch_size: {cfg.batch_size}, \n\
            lr        : {cfg.lr},         \n\
            epochs    : {cfg.epochs},     \n\
            sparse    : {cfg.sparse_mode}')
    writer = SummaryWriter(log_dir=cfg.workdir+'/runs')

    # build train data
    vctk_train = VCTK(cfg, 'train')
    train_loader = DataLoader(vctk_train,batch_size=cfg.batch_size, num_workers=8, shuffle=True, pin_memory=True)

    vctk_val = VCTK(cfg, 'val')
    val_loader = DataLoader(vctk_val, batch_size=cfg.batch_size, num_workers=8, shuffle=False, pin_memory=True)

    # build model
    model = WaveNet(num_classes=28, channels_in=20, dilations=[1,2,4,8,16])
    model = nn.DataParallel(model)
    model.cuda()

    weights_dir = os.path.join(cfg.workdir, 'weights')
    if not os.path.exists(weights_dir):
        os.mkdir(weights_dir)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    cfg.vis_dir = os.path.join(cfg.vis_dir, cfg.exp_name)
    if not os.path.exists(cfg.vis_dir):
        os.mkdir(cfg.vis_dir)
    model.train()

    if cfg.resume and os.path.exists(cfg.workdir + '/weights/best.pth'):
        model.load_state_dict(torch.load(cfg.workdir + '/weights/best.pth'))
        print("loading", cfg.workdir + '/weights/best.pth')

    if os.path.exists(cfg.load_from):
        model.load_state_dict(torch.load(cfg.load_from))
        print("loading", cfg.load_from)


    if cfg.sparse_mode == 'sparse_pruning':
        cfg.sparsity = args.sparsity
        print(f'sparse_pruning {cfg.sparsity}')
    elif cfg.sparse_mode == 'pattern_pruning':
        print(args.pattern_para)
        pattern_num   = int(args.pattern_para.split('_')[0])
        pattern_shape = [int(args.pattern_para.split('_')[1]), int(args.pattern_para.split('_')[2])]
        pattern_nnz   = int(args.pattern_para.split('_')[3])
        print(f'pattern_pruning {pattern_num} [{pattern_shape[0]}, {pattern_shape[1]}] {pattern_nnz}')
        cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
    elif cfg.sparse_mode == 'coo_pruning':
        cfg.coo_shape   = [int(args.coo_para.split('_')[0]), int(args.coo_para.split('_')[1])]
        cfg.coo_nnz   = int(args.coo_para.split('_')[2])
        # cfg.patterns = generate_pattern(pattern_num, pattern_shape, pattern_nnz)
        print(f'coo_pruning [{cfg.coo_shape[0]}, {cfg.coo_shape[1]}] {cfg.coo_nnz}')
    elif cfg.sparse_mode == 'ptcoo_pruning':
        cfg.pattern_num   = int(args.pattern_para.split('_')[0])
        cfg.pattern_shape = [int(args.ptcoo_para.split('_')[1]), int(args.ptcoo_para.split('_')[2])]
        cfg.pt_nnz   = int(args.ptcoo_para.split('_')[3])
        cfg.coo_nnz   = int(args.ptcoo_para.split('_')[4])
        cfg.patterns = generate_pattern(cfg.pattern_num, cfg.pattern_shape, cfg.pt_nnz)
        cfg.pattern_mask = generate_pattern_mask(model, cfg.patterns)
        print(f'ptcoo_pruning {cfg.pattern_num} [{cfg.pattern_shape[0]}, {cfg.pattern_shape[1]}] {cfg.pt_nnz} {cfg.coo_nnz}')


    if args.vis_mask == True:
        name_list = list()
        para_list = list()
        for name, para in model.named_parameters():
            name_list.append(name)
            para_list.append(para)

        for i, name in enumerate(name_list):
            if name.split(".")[-2] != "bn" and name.split(".")[-1] != "bias":
                raw_w = para_list[i]

                zero = torch.zeros_like(raw_w)
                one = torch.ones_like(raw_w)

                mask = torch.where(raw_w == 0, zero, one)
                vis.save_visualized_mask(mask, name)
        exit()

    if args.vis_pattern == True:
        pattern_count_dict = find_pattern_model(model, [16,16])
        patterns = list(pattern_count_dict.keys())
        vis.save_visualized_pattern(patterns)
        exit()
    # build loss
    loss_fn = nn.CTCLoss(blank=0, reduction='none')

    #
    scheduler = optim.Adam(model.parameters(), lr=cfg.lr, eps=1e-4)
    # scheduler = optim.lr_scheduler.MultiStepLR(train_step, milestones=[50, 150, 250], gamma=0.5)

    # train
    train(train_loader, scheduler, model, loss_fn, val_loader, writer)