示例#1
0
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        type=str, help="directory or list of wav files")
    parser.add_argument("--waveforms_eval",
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--feats", required=True,
                        type=str, help="directory or list of wav files")
    parser.add_argument("--feats_eval", required=True,
                        type=str, help="directory or list of evaluation feat files")
    parser.add_argument("--stats", required=True,
                        type=str, help="directory or list of evaluation wav files")
    parser.add_argument("--expdir", required=True,
                        type=str, help="directory to save the model")
    # network structure setting
    parser.add_argument("--upsampling_factor", default=120,
                        type=int, help="number of dimension of aux feats")
    parser.add_argument("--hidden_units_wave", default=384,
                        type=int, help="depth of dilation")
    parser.add_argument("--hidden_units_wave_2", default=16,
                        type=int, help="depth of dilation")
    parser.add_argument("--kernel_size_wave", default=7,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--dilation_size_wave", default=1,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--lpc", default=12,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--mcep_dim", default=50,
                        type=int, help="kernel size of dilated causal convolution")
    parser.add_argument("--right_size", default=0,
                        type=int, help="kernel size of dilated causal convolution")
    # network training setting
    parser.add_argument("--lr", default=1e-4,
                        type=float, help="learning rate")
    parser.add_argument("--batch_size", default=15,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count", default=4000,
                        type=int, help="number of training epochs")
    parser.add_argument("--do_prob", default=0,
                        type=float, help="dropout probability")
    parser.add_argument("--batch_size_utt", default=5,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--batch_size_utt_eval", default=5,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_workers", default=2,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_quantize", default=256,
                        type=int, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--causal_conv_wave", default=False,
                        type=strtobool, help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--n_stage", default=4,
                        type=int, help="number of sparsification stages")
    parser.add_argument("--t_start", default=20000,
                        type=int, help="iter idx to start sparsify")
    parser.add_argument("--t_end", default=4500000,
                        type=int, help="iter idx to finish densitiy sparsify")
    parser.add_argument("--interval", default=100,
                        type=int, help="interval in finishing densitiy sparsify")
    parser.add_argument("--densities", default="0.05-0.05-0.2",
                        type=str, help="final densitiy of reset, update, new hidden gate matrices")
    # other setting
    parser.add_argument("--pad_len", default=3000,
                        type=int, help="seed number")
    parser.add_argument("--save_interval_iter", default=5000,
                        type=int, help="interval steps to logr")
    parser.add_argument("--save_interval_epoch", default=10,
                        type=int, help="interval steps to logr")
    parser.add_argument("--log_interval_steps", default=50,
                        type=int, help="interval steps to logr")
    parser.add_argument("--seed", default=1,
                        type=int, help="seed number")
    parser.add_argument("--resume", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--pretrained", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--string_path", default=None,
                        type=str, help="model path to restart training")
    parser.add_argument("--GPU_device", default=None,
                        type=int, help="selection of GPU device")
    parser.add_argument("--verbose", default=1,
                        type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"]     = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"]  = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(level=logging.INFO,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(level=logging.DEBUG,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(level=logging.WARN,
                            format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
                            datefmt='%m/%d/%Y %I:%M:%S',
                            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if str(device) == "cpu":
        raise ValueError('ERROR: Training by CPU is not acceptable.')

    torch.backends.cudnn.benchmark = True #faster

    #if args.pretrained is None:
    if 'mel' in args.string_path:
        mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_melsp"))
        scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_melsp"))
        args.excit_dim = 0
        #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/mean_melsp")])
        #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/scale_melsp")])
        #args.excit_dim = 2
        #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/mean_melsp")])
        #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/scale_melsp")])
        #args.excit_dim = 6
    else:
        mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_"+args.string_path.replace("/","")))
        scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_"+args.string_path.replace("/","")))
        if mean_stats.shape[0] > args.mcep_dim+2:
            if 'feat_org_lf0' in args.string_path:
                args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2)
                args.excit_dim = 2+args.cap_dim
            else:
                args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+3)
                args.excit_dim = 2+1+args.cap_dim
            #args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2)
            #args.excit_dim = 2+args.cap_dim
        else:
            args.cap_dim = None
            args.excit_dim = 2
    #else:
    #    if 'mel' in args.string_path:
    #        args.excit_dim = 0
    #    else:
    #        args.cap_dim = 3
    #        if 'legacy' not in args.string_path:
    #            args.excit_dim = 6
    #        else:
    #            args.excit_dim = 5

    # save args as conf
    # 14/15-8 or 14/15/16-6/7/8 [5ms]
    # 7-8 or 8-6/7/8 [10ms]
    #args.batch_size = 7
    #args.batch_size_utt = 8
    #args.batch_size = 8
    #args.batch_size_utt = 6
    #args.codeap_dim = 3
    torch.save(args, args.expdir + "/model.conf")
    #args.batch_size = 10
    #batch_sizes = [None]*3
    #batch_sizes[0] = int(args.batch_size*0.5)
    #batch_sizes[1] = int(args.batch_size)
    #batch_sizes[2] = int(args.batch_size*1.5)
    #logging.info(batch_sizes)

    # define network
    model_waveform = GRU_WAVE_DECODER_DUALGRU_COMPACT(
        feat_dim=args.mcep_dim+args.excit_dim,
        upsampling_factor=args.upsampling_factor,
        hidden_units=args.hidden_units_wave,
        hidden_units_2=args.hidden_units_wave_2,
        kernel_size=args.kernel_size_wave,
        dilation_size=args.dilation_size_wave,
        n_quantize=args.n_quantize,
        causal_conv=args.causal_conv_wave,
        lpc=args.lpc,
        right_size=args.right_size,
        do_prob=args.do_prob)
    logging.info(model_waveform)
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    criterion_l1 = torch.nn.L1Loss(reduction='none')

    # send to gpu
    if torch.cuda.is_available():
        model_waveform.cuda()
        criterion_ce.cuda()
        criterion_l1.cuda()
        if args.pretrained is None:
            mean_stats = mean_stats.cuda()
            scale_stats = scale_stats.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model_waveform.train()

    if args.pretrained is None:
        model_waveform.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/scale_stats.data),2))
        model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data/scale_stats.data))

    for param in model_waveform.parameters():
        param.requires_grad = True
    for param in model_waveform.scale_in.parameters():
        param.requires_grad = False
    if args.lpc > 0:
        for param in model_waveform.logits.parameters():
            param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model_waveform.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters (waveform): %.3f million' % parameters)

    module_list = list(model_waveform.conv.parameters())
    module_list += list(model_waveform.conv_s_c.parameters()) + list(model_waveform.embed_wav.parameters())
    module_list += list(model_waveform.gru.parameters()) + list(model_waveform.gru_2.parameters())
    module_list += list(model_waveform.out.parameters())

    optimizer = RAdam(module_list, lr=args.lr)
    #optimizer = torch.optim.Adam(module_list, lr=args.lr)
    #if args.pretrained is None:
    #    optimizer = RAdam(module_list, lr=args.lr)
    #else:
    #    #optimizer = RAdam(module_list, lr=args.lr)
    #    optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None and args.resume is None:
        checkpoint = torch.load(args.pretrained)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
    #    optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        checkpoint = torch.load(args.resume)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    else:
        epoch_idx = 0

    def zero_wav_pad(x): return padding(x, args.pad_len*args.upsampling_factor, value=0.0)  # noqa: E704
    def zero_feat_pad(x): return padding(x, args.pad_len, value=0.0)  # noqa: E704
    pad_wav_transform = transforms.Compose([zero_wav_pad])
    pad_feat_transform = transforms.Compose([zero_feat_pad])

    wav_transform = transforms.Compose([lambda x: encode_mu_law(x, args.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats):
        feat_list = [args.feats + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list = read_txt(args.feats)
    else:
        logging.error("--feats should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(feat_list))
    dataset = FeatureDatasetNeuVoco(wav_list, feat_list, pad_wav_transform, pad_feat_transform, args.upsampling_factor, 
                    args.string_path, wav_transform=wav_transform)
                    #args.string_path, wav_transform=wav_transform, with_excit=True)
                    #args.string_path, wav_transform=wav_transform, with_excit=False)
                    #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim)
    dataloader = DataLoader(dataset, batch_size=args.batch_size_utt, shuffle=True, num_workers=args.n_workers)
    #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None)
    #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, batch_sizes=batch_sizes)

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats_eval):
        feat_list_eval = [args.feats_eval + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--feats_eval should be directory or list.")
        sys.exit(1)
    assert len(wav_list_eval) == len(feat_list_eval)
    logging.info("number of evaluation data = %d." % len(feat_list_eval))
    dataset_eval = FeatureDatasetNeuVoco(wav_list_eval, feat_list_eval, pad_wav_transform, pad_feat_transform, args.upsampling_factor, 
                    args.string_path, wav_transform=wav_transform)
                    #args.string_path, wav_transform=wav_transform, with_excit=False)
                    #args.string_path, wav_transform=wav_transform, with_excit=True)
                    #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim)
    dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size_utt_eval, shuffle=False, num_workers=args.n_workers)
    #generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None)

    writer = SummaryWriter(args.expdir)
    total_train_loss = defaultdict(list)
    total_eval_loss = defaultdict(list)

    #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
    density_deltas_ = args.densities.split('-')
    density_deltas = [None]*len(density_deltas_)
    for i in range(len(density_deltas_)):
        density_deltas[i] = (1-float(density_deltas_[i]))/args.n_stage
    t_deltas = [None]*args.n_stage
    t_starts = [None]*args.n_stage
    t_ends = [None]*args.n_stage
    densities = [None]*args.n_stage
    t_delta = args.t_end - args.t_start + 1
    #t_deltas[0] = round((1/(args.n_stage-1))*0.6*t_delta)
    if args.n_stage > 3:
        t_deltas[0] = round((1/2)*0.2*t_delta)
    else:
        t_deltas[0] = round(0.2*t_delta)
    t_starts[0] = args.t_start
    t_ends[0] = args.t_start + t_deltas[0] - 1
    densities[0] = [None]*len(density_deltas)
    for j in range(len(density_deltas)):
        densities[0][j] = 1-density_deltas[j]
    for i in range(1,args.n_stage):
        if i < args.n_stage-1:
            #t_deltas[i] = round((1/(args.n_stage-1))*0.6*t_delta)
            if args.n_stage > 3:
                if i < 2:
                    t_deltas[i] = round((1/2)*0.2*t_delta)
                else:
                    if args.n_stage > 4:
                        t_deltas[i] = round((1/2)*0.3*t_delta)
                    else:
                        t_deltas[i] = round(0.3*t_delta)
            else:
                t_deltas[i] = round(0.3*t_delta)
        else:
            #t_deltas[i] = round(0.4*t_delta)
            t_deltas[i] = round(0.5*t_delta)
        t_starts[i] = t_ends[i-1] + 1
        t_ends[i] = t_starts[i] + t_deltas[i] - 1
        densities[i] = [None]*len(density_deltas)
        if i < args.n_stage-1:
            for j in range(len(density_deltas)):
                densities[i][j] = densities[i-1][j]-density_deltas[j]
        else:
            for j in range(len(density_deltas)):
                densities[i][j] = float(density_deltas_[j])
    logging.info(t_delta)
    logging.info(t_deltas)
    logging.info(t_starts)
    logging.info(t_ends)
    logging.info(args.interval)
    logging.info(densities)
    idx_stage = 0

    # train
    total = 0
    iter_count = 0
    loss_ce = []
    loss_err = []
    min_eval_loss_ce = 99999999.99
    min_eval_loss_ce_std = 99999999.99
    min_eval_loss_err = 99999999.99
    min_eval_loss_err_std = 99999999.99
    iter_idx = 0
    min_idx = -1
    #min_eval_loss_ce = 2.007181
    #min_eval_loss_ce_std = 0.801412
    #iter_idx = 70350
    #min_idx = 6 #resume7
    while idx_stage < args.n_stage-1 and iter_idx + 1 >= t_starts[idx_stage+1]:
        idx_stage += 1
        logging.info(idx_stage)
    change_min_flag = False
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx+1))
    logging.info("Training data")
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
            del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        if c_idx < 0: # summarize epoch
            # save current epoch model
            numpy_random_state = np.random.get_state()
            torch_random_state = torch.get_rng_state()
            # report current epoch
            logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\
                "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \
                    np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
            logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
            "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            model_waveform.eval()
            for param in model_waveform.parameters():
                param.requires_grad = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                        del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval)
                    if c_idx < 0:
                        break

                    x_es = x_ss+x_bs
                    f_es = f_ss+f_bs
                    logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
                    if x_ss > 0:
                        if x_es <= max_slen:
                            batch_x_prev = batch_x[:,x_ss-1:x_es-1]
                            if args.lpc > 0:
                                if x_ss-args.lpc >= 0:
                                    batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1]
                                else:
                                    batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                            batch_feat = batch_feat[:,f_ss:f_es]
                            batch_x = batch_x[:,x_ss:x_es]
                        else:
                            batch_x_prev = batch_x[:,x_ss-1:-1]
                            if args.lpc > 0:
                                if x_ss-args.lpc >= 0:
                                    batch_x_lpc = batch_x[:,x_ss-args.lpc:-1]
                                else:
                                    batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                            batch_feat = batch_feat[:,f_ss:]
                            batch_x = batch_x[:,x_ss:]
                    else:
                        batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2)
                        if args.lpc > 0:
                            batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2)
                        batch_feat = batch_feat[:,:f_es]
                        batch_x = batch_x[:,:x_es]
                    #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all())

                    if f_ss > 0:
                        if len(del_index_utt) > 0:
                            h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                            h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                        if args.lpc > 0:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc)
                        else:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2)
                    else:
                        if args.lpc > 0:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc)
                        else:
                            batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev)

                    # samples check
                    i = np.random.randint(0, batch_x_output.shape[0])
                    logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
                    #check_samples = batch_x[i,5:10].long()
                    #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
                    #logging.info(check_samples)

                    # handle short ending
                    if len(idx_select) > 0:
                        logging.info('len_idx_select: '+str(len(idx_select)))
                        batch_loss_ce_select = 0
                        batch_loss_err_select = 0
                        for j in range(len(idx_select)):
                            k = idx_select[j]
                            slens_utt = slens_acc[k]
                            logging.info('%s %d' % (featfile[k], slens_utt))
                            batch_x_output_ = batch_x_output[k,:slens_utt]
                            batch_x_ = batch_x[k,:slens_utt]
                            batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_))
                            batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1))
                        batch_loss += batch_loss_ce_select
                        batch_loss_ce_select /= len(idx_select)
                        batch_loss_err_select /= len(idx_select)
                        total_eval_loss["eval/loss_ce"].append(batch_loss_ce_select.item())
                        total_eval_loss["eval/loss_err"].append(batch_loss_err_select.item())
                        loss_ce.append(batch_loss_ce_select.item())
                        loss_err.append(batch_loss_err_select.item())
                        if len(idx_select_full) > 0:
                            logging.info('len_idx_select_full: '+str(len(idx_select_full)))
                            batch_x = torch.index_select(batch_x,0,idx_select_full)
                            batch_x_output = torch.index_select(batch_x_output,0,idx_select_full)
                        else:
                            logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \
                                batch_loss_err_select.item(), time.time() - start))
                            iter_count += 1
                            total += time.time() - start
                            continue

                    # loss
                    batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1)
                    batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1)

                    batch_loss_ce = batch_loss_ce_.mean()
                    batch_loss_err = batch_loss_err_.mean()
                    total_eval_loss["eval/loss_ce"].append(batch_loss_ce.item())
                    total_eval_loss["eval/loss_err"].append(batch_loss_err.item())
                    loss_ce.append(batch_loss_ce.item())
                    loss_err.append(batch_loss_err.item())

                    logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                        f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            logging.info('sme')
            for key in total_eval_loss.keys():
                total_eval_loss[key] = np.mean(total_eval_loss[key])
                logging.info(f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.")
            write_to_tensorboard(writer, iter_idx, total_eval_loss)
            total_eval_loss = defaultdict(list)
            eval_loss_ce = np.mean(loss_ce)
            eval_loss_ce_std = np.std(loss_ce)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\
                "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \
                    eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count))
            if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \
                or (eval_loss_ce <= min_eval_loss_ce):
                min_eval_loss_ce = eval_loss_ce
                min_eval_loss_ce_std = eval_loss_ce_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_idx = epoch_idx
                change_min_flag = True
            if change_min_flag:
                logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \
                    min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag):
            #    logging.info('save epoch:%d' % (epoch_idx+1))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            logging.info('save epoch:%d' % (epoch_idx+1))
            save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model_waveform.train()
            for param in model_waveform.parameters():
                param.requires_grad = True
            for param in model_waveform.scale_in.parameters():
                param.requires_grad = False
            if args.lpc > 0:
                for param in model_waveform.logits.parameters():
                    param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx+1))
                logging.info("Training data")
                batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                    del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1))

            x_es = x_ss+x_bs
            f_es = f_ss+f_bs
            logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
            if x_ss > 0:
                if x_es <= max_slen:
                    batch_x_prev = batch_x[:,x_ss-1:x_es-1]
                    if args.lpc > 0:
                        if x_ss-args.lpc >= 0:
                            batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1]
                        else:
                            batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                    batch_feat = batch_feat[:,f_ss:f_es]
                    batch_x = batch_x[:,x_ss:x_es]
                else:
                    batch_x_prev = batch_x[:,x_ss-1:-1]
                    if args.lpc > 0:
                        if x_ss-args.lpc >= 0:
                            batch_x_lpc = batch_x[:,x_ss-args.lpc:-1]
                        else:
                            batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2)
                    batch_feat = batch_feat[:,f_ss:]
                    batch_x = batch_x[:,x_ss:]
            else:
                batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2)
                if args.lpc > 0:
                    batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2)
                batch_feat = batch_feat[:,:f_es]
                batch_x = batch_x[:,:x_es]
            #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all())

            if f_ss > 0:
                if len(del_index_utt) > 0:
                    h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                    h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device)
                if args.lpc > 0:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc, do=True)
                else:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, do=True)
            else:
                if args.lpc > 0:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc, do=True)
                else:
                    batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, do=True)

            # samples check
            #with torch.no_grad():
            i = np.random.randint(0, batch_x_output.shape[0])
            logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
            #    check_samples = batch_x[i,5:10].long()
            #    logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
            #    logging.info(check_samples)

            # handle short ending
            batch_loss = 0
            if len(idx_select) > 0:
                logging.info('len_idx_select: '+str(len(idx_select)))
                batch_loss_ce_select = 0
                batch_loss_err_select = 0
                for j in range(len(idx_select)):
                    k = idx_select[j]
                    slens_utt = slens_acc[k]
                    logging.info('%s %d' % (featfile[k], slens_utt))
                    batch_x_output_ = batch_x_output[k,:slens_utt]
                    batch_x_ = batch_x[k,:slens_utt]
                    batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_))
                    batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1))
                batch_loss += batch_loss_ce_select
                batch_loss_ce_select /= len(idx_select)
                batch_loss_err_select /= len(idx_select)
                total_train_loss["train/loss_ce"].append(batch_loss_ce_select.item())
                total_train_loss["train/loss_err"].append(batch_loss_err_select.item())
                loss_ce.append(batch_loss_ce_select.item())
                loss_err.append(batch_loss_err_select.item())
                if len(idx_select_full) > 0:
                    logging.info('len_idx_select_full: '+str(len(idx_select_full)))
                    batch_x = torch.index_select(batch_x,0,idx_select_full)
                    batch_x_output = torch.index_select(batch_x_output,0,idx_select_full)
                #elif len(idx_select) > 1:
                else:
                    optimizer.zero_grad()
                    batch_loss.backward()
                    #for name, param in model_waveform.named_parameters():
                    #    if param.requires_grad:
                    #        logging.info(f"{name} {param.grad.norm()}")
                    flag = False
                    for name, param in model_waveform.named_parameters():
                        if param.requires_grad:
                            grad_norm = param.grad.norm()
                    #        logging.info(f"{name} {grad_norm}")
                            #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm):
                            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                                flag = True
                    if flag:
                        logging.info("explode grad")
                        optimizer.zero_grad()
                        continue
                    torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
                    #for name, param in model_waveform.named_parameters():
                    #    if param.requires_grad:
                    #        logging.info(f"{name} {param.grad.norm()}")
                    optimizer.step()

                    with torch.no_grad():
                        #test = model_waveform.gru.weight_hh_l0.data.clone()
                        #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
                        #t_start, t_end, interval, densities
                        if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]:
                            idx_stage += 1
                        if idx_stage > 0:
                            sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1])
                        else:
                            sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage])
                        #logging.info((test==model_waveform.gru.weight_hh_l0).all())

                    logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \
                        batch_loss_err_select.item(), time.time() - start))
                    iter_idx += 1
                    if iter_idx % args.save_interval_iter == 0:
                        logging.info('save iter:%d' % (iter_idx))
                        save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
                    iter_count += 1
                    if iter_idx % args.log_interval_steps == 0:
                        logging.info('smt')
                        for key in total_train_loss.keys():
                            total_train_loss[key] = np.mean(total_train_loss[key])
                            logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.")
                        write_to_tensorboard(writer, iter_idx, total_train_loss)
                        total_train_loss = defaultdict(list)
                    total += time.time() - start
                    continue
                #else:
                #    continue

            # loss
            batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1)
            batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1)

            batch_loss_ce = batch_loss_ce_.mean()
            batch_loss_err = batch_loss_err_.mean()
            total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
            total_train_loss["train/loss_err"].append(batch_loss_err.item())
            loss_ce.append(batch_loss_ce.item())
            loss_err.append(batch_loss_err.item())

            batch_loss += batch_loss_ce_.sum()

            optimizer.zero_grad()
            batch_loss.backward()
            #for name, param in model_waveform.named_parameters():
            #    if param.requires_grad:
            #        logging.info(f"{name} {param.grad.norm()}")
            flag = False
            for name, param in model_waveform.named_parameters():
                if param.requires_grad:
                    grad_norm = param.grad.norm()
            #        logging.info(f"{name} {grad_norm}")
                    #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm):
                    if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                        flag = True
            if flag:
                logging.info("explode grad")
                optimizer.zero_grad()
                continue
            torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
            #for name, param in model_waveform.named_parameters():
            #    if param.requires_grad:
            #        logging.info(f"{name} {param.grad.norm()}")
            optimizer.step()

            with torch.no_grad():
                #test = model_waveform.gru.weight_hh_l0.data.clone()
                #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2))
                #t_start, t_end, interval, densities
                if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]:
                    idx_stage += 1
                if idx_stage > 0:
                    sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1])
                else:
                    sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage])
                #logging.info((test==model_waveform.gru.weight_hh_l0).all())

            logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            if iter_idx % args.save_interval_iter == 0:
                logging.info('save iter:%d' % (iter_idx))
                save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
            iter_count += 1
            if iter_idx % args.log_interval_steps == 0:
                logging.info('smt')
                for key in total_train_loss.keys():
                    total_train_loss[key] = np.mean(total_train_loss[key])
                    logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.")
                write_to_tensorboard(writer, iter_idx, total_train_loss)
                total_train_loss = defaultdict(list)
            total += time.time() - start


    # save final model
    model_waveform.cpu()
    torch.save({"model_waveform": model_waveform.state_dict()}, args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
示例#2
0
    optimizer = RAdam(model.parameters(), args.lr)

if args.loss == 'mse':
    loss_fn = torch.nn.MSELoss()
else:
    loss_fn = torch.nn.BCELoss()

# laoding checkpoint
if args.load_path:
    files = os.listdir(args.load_path)
    files = sorted(files, key=lambda x: int(os.path.splitext(x)[0]))
    last_path = os.path.join(args.load_path, files[-1])

    checkpoint = torch.load(last_path)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    loss = checkpoint['loss']

    
if args.grid_latent:
    walk_grid(model)
    os._exit(0)



dataset = CustomDataset(args)
dataset_loader = torch.utils.data.DataLoader(dataset=dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=args.shuffle, collate_fn=collate_fn)

for epoch in range(1, args.epoch):
    epoch_loss_rec = []
    epoch_loss_kl = []
示例#3
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('data', metavar='DIR', help='path to dataset')
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        metavar='N',
                        help='input batch size for training (default: 128)')
    parser.add_argument('--batch-size-val',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--epochs',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='number of epochs to train (default: 1000)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-4,
                        metavar='LR',
                        help='learning rate (default: 1e-4)')
    parser.add_argument('--image-size',
                        type=float,
                        default=80,
                        metavar='IMSIZE',
                        help='input image size (default: 80)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--multi-gpu',
                        action='store_true',
                        default=False,
                        help='parallel training on multiple GPUs')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    parser.add_argument('--model-save-path',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='For Saving the current Model')
    parser.add_argument('--resume',
                        default='',
                        type=str,
                        metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    args = parser.parse_args()

    torch.set_default_tensor_type('torch.FloatTensor')

    device = torch.device("cpu" if args.no_cuda else "cuda")

    train_data = dataset(args.data,
                         "train",
                         args.image_size,
                         transform=transforms.Compose([ToTensor()]),
                         shuffle=True)
    valid_data = dataset(args.data,
                         "val",
                         args.image_size,
                         transform=transforms.Compose([ToTensor()]))

    trainloader = DataLoader(train_data,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=8)
    validloader = DataLoader(valid_data,
                             batch_size=args.batch_size_val,
                             shuffle=False,
                             num_workers=8)

    model = Model(args.image_size, args.image_size)

    optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=1e-8)

    if not args.no_cuda:
        model.cuda()

    if torch.cuda.device_count() > 1 and args.multi_gpu:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    if args.resume:
        model.load_state_dict(
            torch.load(os.path.join(args.resume, model_save_name)))
        optimizer.load_state_dict(
            torch.load(os.path.join(args.resume, optimizer_save_name)))

    train(model, optimizer, trainloader, validloader, device, args)
示例#4
0
def main():

    global args
    best_prec1, best_epoch = 0.0, 0

    if not os.path.exists(args.save):
        os.makedirs(args.save)

    if args.data.startswith('cifar'):
        IM_SIZE = 32
    else:
        IM_SIZE = 224

    model = getattr(models, args.arch)(args)
    n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE)    
    torch.save(n_flops, os.path.join(args.save, 'flops.pth'))
    del(model)
        
        
    model = getattr(models, args.arch)(args)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

    criterion = nn.CrossEntropyLoss().cuda()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'radam':
        from radam import RAdam
        optimizer = RAdam(model.parameters(), args.lr,
                          weight_decay=args.weight_decay)
    else:
        raise NotImplementedError("Wrong optimizer.")
    

    if args.resume:
        checkpoint = load_checkpoint(args)
        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

    cudnn.benchmark = True

    train_loader, val_loader, test_loader = get_dataloaders(args)

    if args.evalmode is not None:
        state_dict = torch.load(args.evaluate_from)['state_dict']
        model.load_state_dict(state_dict)

        if args.evalmode == 'anytime':
            validate(test_loader, model, criterion)
        else:
            dynamic_evaluate(model, test_loader, val_loader, args)
        return

    scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1'
              '\tval_prec1\ttrain_prec5\tval_prec5']

    for epoch in range(args.start_epoch, args.epochs):

        train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch)

        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion)

        scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6)
                      .format(epoch, lr, train_loss, val_loss,
                              train_prec1, val_prec1, train_prec5, val_prec5))

        is_best = val_prec1 > best_prec1
        if is_best:
            best_prec1 = val_prec1
            best_epoch = epoch
            print('Best var_prec1 {}'.format(best_prec1))

        model_filename = 'checkpoint_%03d.pth.tar' % epoch
        save_checkpoint({
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, args, is_best, model_filename, scores)

    print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch))

    ### Test the final model

    print('********** Final prediction results **********')
    validate(test_loader, model, criterion)

    return 
示例#5
0
class Trainer():
    def __init__(self, log_dir, cfg):

        self.path = log_dir
        self.cfg = cfg

        if cfg.TRAIN.FLAG:
            self.model_dir = os.path.join(self.path, 'Model')
            self.log_dir = os.path.join(self.path, 'Log')
            mkdir_p(self.model_dir)
            mkdir_p(self.log_dir)
            self.writer = SummaryWriter(log_dir=self.log_dir)
            self.logfile = os.path.join(self.path, "logfile.log")
            sys.stdout = Logger(logfile=self.logfile)

        self.data_dir = cfg.DATASET.DATA_DIR
        self.max_epochs = cfg.TRAIN.MAX_EPOCHS
        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL

        s_gpus = cfg.GPU_ID.split(',')
        self.gpus = [int(ix) for ix in s_gpus]
        self.num_gpus = len(self.gpus)

        self.batch_size = cfg.TRAIN.BATCH_SIZE
        self.lr = cfg.TRAIN.LEARNING_RATE

        torch.cuda.set_device(self.gpus[0])
        cudnn.benchmark = True

        sample = cfg.SAMPLE
        self.dataset = []
        self.dataloader = []
        self.use_feats = cfg.model.use_feats
        eval_split = cfg.EVAL if cfg.EVAL else 'val'
        train_split = cfg.DATASET.train_split
        if cfg.DATASET.DATASET == 'clevr':
            clevr_collate_fn = collate_fn
            cogent = cfg.DATASET.COGENT
            if cogent:
                print(f'Using CoGenT {cogent.upper()}')

            if cfg.TRAIN.FLAG:
                self.dataset = ClevrDataset(data_dir=self.data_dir,
                                            split=train_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=clevr_collate_fn)

            self.dataset_val = ClevrDataset(data_dir=self.data_dir,
                                            split=eval_split + cogent,
                                            sample=sample,
                                            **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             drop_last=False,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             collate_fn=clevr_collate_fn)

        elif cfg.DATASET.DATASET == 'gqa':
            if self.use_feats == 'spatial':
                gqa_collate_fn = collate_fn_gqa
            elif self.use_feats == 'objects':
                gqa_collate_fn = collate_fn_gqa_objs
            if cfg.TRAIN.FLAG:
                self.dataset = GQADataset(data_dir=self.data_dir,
                                          split=train_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
                self.dataloader = DataLoader(dataset=self.dataset,
                                             batch_size=cfg.TRAIN.BATCH_SIZE,
                                             shuffle=True,
                                             num_workers=cfg.WORKERS,
                                             drop_last=True,
                                             collate_fn=gqa_collate_fn)

            self.dataset_val = GQADataset(data_dir=self.data_dir,
                                          split=eval_split,
                                          sample=sample,
                                          use_feats=self.use_feats,
                                          **cfg.DATASET.params)
            self.dataloader_val = DataLoader(dataset=self.dataset_val,
                                             batch_size=cfg.TEST_BATCH_SIZE,
                                             shuffle=False,
                                             num_workers=cfg.WORKERS,
                                             drop_last=False,
                                             collate_fn=gqa_collate_fn)

        # load model
        self.vocab = load_vocab(cfg)
        self.model, self.model_ema = mac.load_MAC(cfg, self.vocab)

        self.weight_moving_average(alpha=0)
        if cfg.TRAIN.RADAM:
            self.optimizer = RAdam(self.model.parameters(), lr=self.lr)
        else:
            self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

        self.start_epoch = 0
        if cfg.resume_model:
            location = 'cuda' if cfg.CUDA else 'cpu'
            state = torch.load(cfg.resume_model, map_location=location)
            self.model.load_state_dict(state['model'])
            self.optimizer.load_state_dict(state['optim'])
            self.start_epoch = state['iter'] + 1
            state = torch.load(cfg.resume_model_ema, map_location=location)
            self.model_ema.load_state_dict(state['model'])

        if cfg.start_epoch is not None:
            self.start_epoch = cfg.start_epoch

        self.previous_best_acc = 0.0
        self.previous_best_epoch = 0
        self.previous_best_loss = 100
        self.previous_best_loss_epoch = 0

        self.total_epoch_loss = 0
        self.prior_epoch_loss = 10

        self.print_info()
        self.loss_fn = torch.nn.CrossEntropyLoss().cuda()

        self.comet_exp = Experiment(
            project_name=cfg.COMET_PROJECT_NAME,
            api_key=os.getenv('COMET_API_KEY'),
            workspace=os.getenv('COMET_WORKSPACE'),
            disabled=cfg.logcomet is False,
        )
        if cfg.logcomet:
            exp_name = cfg_to_exp_name(cfg)
            print(exp_name)
            self.comet_exp.set_name(exp_name)
            self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg))
            self.comet_exp.log_asset(self.logfile)
            self.comet_exp.log_asset_data(json.dumps(cfg, indent=4),
                                          file_name='cfg.json')
            self.comet_exp.set_model_graph(str(self.model))
            if cfg.cfg_file:
                self.comet_exp.log_asset(cfg.cfg_file)

        with open(os.path.join(self.path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=4)

    def print_info(self):
        print('Using config:')
        pprint.pprint(self.cfg)
        print("\n")

        pprint.pprint("Size of train dataset: {}".format(len(self.dataset)))
        # print("\n")
        pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val)))
        print("\n")

        print("Using MAC-Model:")
        pprint.pprint(self.model)
        print("\n")

    def weight_moving_average(self, alpha=0.999):
        for param1, param2 in zip(self.model_ema.parameters(),
                                  self.model.parameters()):
            param1.data *= alpha
            param1.data += (1.0 - alpha) * param2.data

    def set_mode(self, mode="train"):
        if mode == "train":
            self.model.train()
            self.model_ema.train()
        else:
            self.model.eval()
            self.model_ema.eval()

    def reduce_lr(self):
        epoch_loss = self.total_epoch_loss  # / float(len(self.dataset) // self.batch_size)
        lossDiff = self.prior_epoch_loss - epoch_loss
        if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \
            (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \
            (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)):
            self.lr *= 0.5
            print("Reduced learning rate to {}".format(self.lr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = self.lr
        self.prior_epoch_loss = epoch_loss
        self.total_epoch_loss = 0

    def save_models(self, iteration):
        save_model(self.model,
                   self.optimizer,
                   iteration,
                   self.model_dir,
                   model_name="model")
        save_model(self.model_ema,
                   None,
                   iteration,
                   self.model_dir,
                   model_name="model_ema")

    def train_epoch(self, epoch):
        cfg = self.cfg
        total_loss = 0.
        total_correct = 0
        total_samples = 0

        self.labeled_data = iter(self.dataloader)
        self.set_mode("train")

        dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20)

        for data in dataset:
            ######################################################
            # (1) Prepare training data
            ######################################################
            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()
            else:
                question = question
                image = image
                answer = answer.squeeze()

            ############################
            # (2) Train Model
            ############################
            self.optimizer.zero_grad()

            scores = self.model(image, question, question_len)
            loss = self.loss_fn(scores, answer)
            loss.backward()

            if self.cfg.TRAIN.CLIP_GRADS:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.cfg.TRAIN.CLIP)

            self.optimizer.step()
            self.weight_moving_average()

            ############################
            # (3) Log Progress
            ############################
            correct = scores.detach().argmax(1) == answer
            total_correct += correct.sum().cpu().item()
            total_loss += loss.item() * answer.size(0)
            total_samples += answer.size(0)

            avg_loss = total_loss / total_samples
            train_accuracy = total_correct / total_samples
            # accuracy = correct.sum().cpu().numpy() / answer.shape[0]

            # if avg_loss == 0:
            #     avg_loss = loss.item()
            #     train_accuracy = accuracy
            # else:
            #     avg_loss = 0.99 * avg_loss + 0.01 * loss.item()
            #     train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy
            # self.total_epoch_loss += loss.item() * answer.size(0)

            dataset.set_description(
                'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format(
                    epoch + 1, avg_loss, train_accuracy))

        self.total_epoch_loss = avg_loss

        dict = {
            "loss": avg_loss,
            "accuracy": train_accuracy,
            "avg_loss": avg_loss,  # For commet
            "avg_accuracy": train_accuracy,  # For commet
        }
        return dict

    def train(self):
        cfg = self.cfg
        print("Start Training")
        for epoch in range(self.start_epoch, self.max_epochs):

            with self.comet_exp.train():
                dict = self.train_epoch(epoch)
                self.reduce_lr()
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            with self.comet_exp.validate():
                dict = self.log_results(epoch, dict)
                dict['epoch'] = epoch + 1
                dict['lr'] = self.lr
                self.comet_exp.log_metrics(
                    dict,
                    epoch=epoch + 1,
                )

            if cfg.TRAIN.EALRY_STOPPING:
                if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch:
                    # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch:
                    print('Early stop')
                    break

        self.comet_exp.log_asset(self.logfile)
        self.save_models(self.max_epochs)
        self.writer.close()
        print("Finished Training")
        print(
            f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}"
        )

    def log_results(self, epoch, dict, max_eval_samples=None):
        epoch += 1
        self.writer.add_scalar("avg_loss", dict["loss"], epoch)
        self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch)

        metrics = self.calc_accuracy("validation",
                                     max_samples=max_eval_samples)
        self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch)
        self.writer.add_scalar("val_accuracy", metrics['acc'], epoch)
        self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch)
        self.writer.add_scalar("val_loss", metrics['loss'], epoch)

        print(
            "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}"
            .format(epoch=epoch, lr=self.lr, **metrics))

        if metrics['acc'] > self.previous_best_acc:
            self.previous_best_acc = metrics['acc']
            self.previous_best_epoch = epoch
        if metrics['loss'] < self.previous_best_loss:
            self.previous_best_loss = metrics['loss']
            self.previous_best_loss_epoch = epoch

        if epoch % self.snapshot_interval == 0:
            self.save_models(epoch)

        return metrics

    def calc_accuracy(self, mode="train", max_samples=None):
        self.set_mode("validation")

        if mode == "train":
            loader = self.dataloader
        # elif (mode == "validation") or (mode == 'test'):
        #     loader = self.dataloader_val
        else:
            loader = self.dataloader_val

        total_correct = 0
        total_correct_ema = 0
        total_samples = 0
        total_loss = 0.
        total_loss_ema = 0.
        pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20)
        for data in pbar:

            image, question, question_len, answer = data['image'], data[
                'question'], data['question_length'], data['answer']
            answer = answer.long()
            question = Variable(question)
            answer = Variable(answer)

            if self.cfg.CUDA:
                if self.use_feats == 'spatial':
                    image = image.cuda()
                elif self.use_feats == 'objects':
                    image = [e.cuda() for e in image]
                question = question.cuda()
                answer = answer.cuda().squeeze()

            with torch.no_grad():
                scores = self.model(image, question, question_len)
                scores_ema = self.model_ema(image, question, question_len)

                loss = self.loss_fn(scores, answer)
                loss_ema = self.loss_fn(scores_ema, answer)

            correct = scores.detach().argmax(1) == answer
            correct_ema = scores_ema.detach().argmax(1) == answer

            total_correct += correct.sum().cpu().item()
            total_correct_ema += correct_ema.sum().cpu().item()

            total_loss += loss.item() * answer.size(0)
            total_loss_ema += loss_ema.item() * answer.size(0)

            total_samples += answer.size(0)

            avg_acc = total_correct / total_samples
            avg_acc_ema = total_correct_ema / total_samples
            avg_loss = total_loss / total_samples
            avg_loss_ema = total_loss_ema / total_samples

            pbar.set_postfix({
                'Acc': f'{avg_acc:.5f}',
                'Acc Ema': f'{avg_acc_ema:.5f}',
                'Loss': f'{avg_loss:.5f}',
                'Loss Ema': f'{avg_loss_ema:.5f}',
            })

        return dict(acc=avg_acc,
                    acc_ema=avg_acc_ema,
                    loss=avg_loss,
                    loss_ema=avg_loss_ema)
示例#6
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif opt.Prediction == 'None':
        converter = TransformerConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    # model = torch.nn.DataParallel(model).to(device)
    model = model.to(device)
    model.train()
    if opt.load_from_checkpoint:
        model.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'checkpoint.pth')))
        print(f'loaded checkpoint from {opt.load_from_checkpoint}...')
    elif opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.SequenceModeling == 'Transformer':
            fe_state = OrderedDict()
            state_dict = torch.load(opt.saved_model)
            for k, v in state_dict.items():
                if k.startswith('module.FeatureExtraction'):
                    new_k = re.sub('module.FeatureExtraction.', '', k)
                    fe_state[new_k] = state_dict[k]
            model.FeatureExtraction.load_state_dict(fe_state)
        else:
            if opt.FT:
                model.load_state_dict(torch.load(opt.saved_model), strict=False)
            else:
                model.load_state_dict(torch.load(opt.saved_model))
    if opt.freeze_fe:
        model.freeze(['FeatureExtraction'])
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    elif opt.Prediction == 'None':
        criterion = LabelSmoothingLoss(classes=converter.n_classes, padding_idx=converter.pad_idx, smoothing=0.1)
        # criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.pad_idx)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        assert opt.adam in ['Adam', 'AdamW', 'RAdam'], 'adam optimizer must be in Adam, AdamW or RAdam'
        if opt.adam == 'Adam':
            optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        elif opt.adam == "AdamW":
            optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        else:
            optimizer = RAdam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    if opt.load_from_checkpoint and opt.load_optimizer_state:
        optimizer.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'optimizer.pth')))
        print(f'loaded optimizer state from {os.path.join(opt.load_from_checkpoint, "optimizer.pth")}')

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    if opt.load_from_checkpoint:
        with open(os.path.join(opt.load_from_checkpoint, 'iter.json'), mode='r', encoding='utf8') as f:
            start_iter = json.load(f)
            print(f'continue to train, start_iter: {start_iter}')
            f.close()

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    # i = start_iter

    bar = tqdm(range(start_iter, opt.num_iter))
    # while(True):
    for i in bar:
        bar.set_description(f'Iter {i}: train_loss = {loss_avg.val():.5f}')
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        elif opt.Prediction == 'None':
            tgt_input = text['tgt_input']
            tgt_output = text['tgt_output']
            tgt_padding_mask = text['tgt_padding_mask']
            preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,)
            cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1))
        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (i + 1) % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')

                # checkpoint
                os.makedirs(f'./checkpoints/{opt.experiment_name}/', exist_ok=True)

                torch.save(model.state_dict(), f'./checkpoints/{opt.experiment_name}/checkpoint.pth')
                torch.save(optimizer.state_dict(), f'./checkpoints/{opt.experiment_name}/optimizer.pth')
                with open(f'./checkpoints/{opt.experiment_name}/iter.json', mode='w', encoding='utf8') as f:
                    json.dump(i + 1, f)
                    f.close()

                with open(f'./checkpoints/{opt.experiment_name}/checkpoint.log', mode='a', encoding='utf8') as f:
                    f.write(f'Saved checkpoint with iter={i}\n')
                    f.write(f'\tCheckpoint at: ./checkpoints/{opt.experiment_name}/checkpoint.pth')
                    f.write(f'\tOptimizer at: ./checkpoints/{opt.experiment_name}/optimizer.pth')

                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        # if i == opt.num_iter:
        #     print('end the training')
        #     sys.exit()
        # i += 1
        # if i == 1: break
    print('end training')
示例#7
0
class face_learner(object):
    def __init__(self, conf):
        print(conf)
        self.model = ResNet()
        self.model.cuda()
        if conf.initial:
            self.model.load_state_dict(torch.load("models/"+conf.model))
            print('Load model_ir_se101.pth')
        self.milestones = conf.milestones
        self.loader, self.class_num = get_train_loader(conf)
        self.total_class = 16520
        self.data_num = 285356
        self.writer = SummaryWriter(conf.log_path)
        self.step = 0
        self.paras_only_bn, self.paras_wo_bn = separate_bn_paras(self.model)

        if conf.meta:
            self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.total_class)
            self.head.cuda()
            if conf.initial:
                self.head.load_state_dict(torch.load("models/head_op.pth"))
                print('Load head_op.pth')
            self.optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.meta_optimizer = RAdam([
                {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4},
                {'params': self.paras_only_bn}
            ], lr=conf.lr)
            self.head.train()
        else:
            self.head = dict()
            self.optimizer = dict()
            for race in races:
                self.head[race] = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[race])
                self.head[race].cuda()
                if conf.initial:
                    self.head[race].load_state_dict(torch.load("models/head_op_{}.pth".format(race)))
                    print('Load head_op_{}.pth'.format(race))
                self.optimizer[race] = RAdam([
                    {'params': self.paras_wo_bn + [self.head[race].kernel], 'weight_decay': 5e-4},
                    {'params': self.paras_only_bn}
                ], lr=conf.lr, betas=(0.5, 0.999))
                self.head[race].train()
            # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True)

        self.board_loss_every = min(len(self.loader[race]) for race in races) // 10
        self.evaluate_every = self.data_num // 5
        self.save_every = self.data_num // 2
        self.eval, self.eval_issame = get_val_data(conf)

    def save_state(self, conf, accuracy, extra=None, model_only=False, race='All'):
        save_path = 'models/'
        torch.save(
            self.model.state_dict(), save_path +
                                     'model_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                  extra, race))
        if not model_only:
            if conf.meta:
                torch.save(
                    self.head.state_dict(), save_path +
                                        'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step,
                                                                                    extra, race))
                #torch.save(
                #    self.optimizer.state_dict(), save_path +
                #                             'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                #                                                                              self.step, extra, race))
            else:
                torch.save(
                    self.head[race].state_dict(), save_path +
                                            'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy,
                                                                                           self.step,
                                                                                           extra, race))
                #torch.save(
                #    self.optimizer[race].state_dict(), save_path +
                 #                                'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(),
                #                                                                                     accuracy,
                #                                                                                     self.step, extra,
                #                                                                                     race))

    def load_state(self, conf, fixed_str, model_only=False):
        save_path = 'models/'
        self.model.load_state_dict(torch.load(save_path + conf.model))
        if not model_only:
            self.head.load_state_dict(torch.load(save_path + conf.head))
            self.optimizer.load_state_dict(torch.load(save_path + conf.optim))

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step)

        # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step)
        # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step)
        # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False):
        self.model.eval()
        idx = 0
        entry_num = carray.size()[0]
        embeddings = np.zeros([entry_num, conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= entry_num:
                batch = carray[idx:idx + conf.batch_size]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(batch.cuda()).cpu().detach().numpy()
                idx += conf.batch_size
            if idx < entry_num:
                batch = carray[idx:]
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda())
                    embeddings[idx:] = l2_norm(emb_batch).cpu().detach().numpy()
                else:
                    embeddings[idx:] = self.model(batch.cuda()).cpu().detach().numpy()
        tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = trans.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def train_finetuning(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))
            '''
            if e == self.milestones[0]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[1]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            if e == self.milestones[2]:
                for ra in races:
                    for params in self.optimizer[ra].param_groups:
                        params['lr'] /= 10
            '''
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                self.optimizer[race].zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head[race].parameters(), conf.max_grad_norm)
                self.optimizer[race].step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % (1 * len(self.loader[race])) == 0 and self.step != 0:
                    self.save_state(conf, 'None', race=race, model_only=True)

                self.step += 1

        self.save_state(conf, 'None', extra='final', race=race)
        torch.save(self.optimizer[race].state_dict(), 'models/optimizer_{}.pth'.format(race))

    def train_maml(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        loader_iter = dict()
        for race in races:
            loader_iter[race] = iter(self.loader[race])
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for i in tqdm(range(self.data_num // conf.batch_size)):
                ra1, ra2 = random.sample(races, 2)
                try:
                    imgs1, labels1 = loader_iter[ra1].next()
                except StopIteration:
                    loader_iter[ra1] = iter(self.loader[ra1])
                    imgs1, labels1 = loader_iter[ra1].next()

                try:
                    imgs2, labels2 = loader_iter[ra2].next()
                except StopIteration:
                    loader_iter[ra2] = iter(self.loader[ra2])
                    imgs2, labels2 = loader_iter[ra2].next()

                ## save original weights to make the update
                weights_original_model = deepcopy(self.model.state_dict())
                weights_original_head = deepcopy(self.head.state_dict())

                # base learn
                imgs1 = imgs1.cuda()
                labels1 = labels1.cuda()
                self.optimizer.zero_grad()
                embeddings1 = self.model(imgs1)
                thetas1 = self.head(embeddings1, labels1)
                loss1 = conf.ce_loss(thetas1, labels1)
                loss1.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.optimizer.step()

                # meta learn
                imgs2 = imgs2.cuda()
                labels2 = labels2.cuda()
                embeddings2 = self.model(imgs2)
                thetas2 = self.head(embeddings2, labels2)
                self.model.load_state_dict(weights_original_model)
                self.head.load_state_dict(weights_original_head)
                self.meta_optimizer.zero_grad()
                loss2 = conf.ce_loss(thetas2, labels2)
                loss2.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm)
                nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm)
                self.meta_optimizer.step()

                running_loss += loss2.item()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    for race in races:
                        accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.eval[race], self.eval_issame[race])
                        self.board_val(race, accuracy, best_threshold, roc_curve_tensor)
                    self.model.train()

                if self.step % (self.data_num // conf.batch_size // 2) == 0 and self.step != 0:
                    self.save_state(conf, e)

                self.step += 1

        self.save_state(conf, epochs, extra='final')

    def train_meta_head(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head.parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for race in races:
                for imgs, labels in tqdm(iter(self.loader[race])):
                    imgs = imgs.cuda()
                    labels = labels.cuda()
                    optimizer.zero_grad()
                    embeddings = self.model(imgs)
                    thetas = self.head(embeddings, labels)
                    loss = conf.ce_loss(thetas, labels)
                    loss.backward()
                    running_loss += loss.item()
                    optimizer.step()

                    if self.step % self.board_loss_every == 0 and self.step != 0:
                        loss_board = running_loss / self.board_loss_every
                        self.writer.add_scalar('train_loss', loss_board, self.step)
                        running_loss = 0.

                    self.step += 1

            torch.save(self.head.state_dict(), 'models/head_{}_meta_{}.pth'.format(get_time(), e))

    def train_race_head(self, conf, epochs, race):
        self.model.train()
        running_loss = 0.
        optimizer = optim.SGD(self.head[race].parameters(), lr=conf.lr, momentum=conf.momentum)
        for e in range(epochs):
            print('epoch {} started'.format(e))
            if e == self.milestones[0]:
                self.schedule_lr()
            if e == self.milestones[1]:
                self.schedule_lr()
            if e == self.milestones[2]:
                self.schedule_lr()
            for imgs, labels in tqdm(iter(self.loader[race])):
                imgs = imgs.cuda()
                labels = labels.cuda()
                optimizer.zero_grad()
                embeddings = self.model(imgs)
                thetas = self.head[race](embeddings, labels)
                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                self.step += 1

        torch.save(self.head[race].state_dict(), 'models/head_{}_{}_{}.pth'.format(get_time(), race, epochs))

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        for params in self.meta_optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer, self.meta_optimizer)
def main():
    parser = argparse.ArgumentParser()
    # path setting
    parser.add_argument("--waveforms",
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--waveforms_eval",
                        type=str,
                        help="directory or list of evaluation wav files")
    parser.add_argument("--feats",
                        required=True,
                        type=str,
                        help="directory or list of wav files")
    parser.add_argument("--feats_eval",
                        required=True,
                        type=str,
                        help="directory or list of evaluation feat files")
    parser.add_argument("--stats",
                        required=True,
                        type=str,
                        help="directory or list of evaluation wav files")
    parser.add_argument("--expdir",
                        required=True,
                        type=str,
                        help="directory to save the model")
    # network structure setting
    parser.add_argument("--upsampling_factor",
                        default=120,
                        type=int,
                        help="number of dimension of aux feats")
    parser.add_argument("--hid_chn",
                        default=256,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--skip_chn",
                        default=256,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_depth",
                        default=3,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--dilation_repeat",
                        default=2,
                        type=int,
                        help="depth of dilation")
    parser.add_argument("--kernel_size",
                        default=7,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--kernel_size_wave",
                        default=7,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--dilation_size_wave",
                        default=1,
                        type=int,
                        help="kernel size of dilated causal convolution")
    parser.add_argument("--mcep_dim",
                        default=50,
                        type=int,
                        help="kernel size of dilated causal convolution")
    # network training setting
    parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
    parser.add_argument(
        "--batch_size",
        default=30,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument("--epoch_count",
                        default=4000,
                        type=int,
                        help="number of training epochs")
    parser.add_argument("--do_prob",
                        default=0,
                        type=float,
                        help="dropout probability")
    parser.add_argument(
        "--batch_size_utt",
        default=5,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--batch_size_utt_eval",
        default=5,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--n_workers",
        default=2,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--n_quantize",
        default=256,
        type=int,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--bi_wave",
        default=True,
        type=strtobool,
        help="batch size (if set 0, utterance batch will be used)")
    parser.add_argument(
        "--causal_conv_wave",
        default=False,
        type=strtobool,
        help="batch size (if set 0, utterance batch will be used)")
    # other setting
    parser.add_argument("--init",
                        default=False,
                        type=strtobool,
                        help="seed number")
    parser.add_argument("--pad_len",
                        default=3000,
                        type=int,
                        help="seed number")
    ##parser.add_argument("--save_interval_iter", default=5000,
    #parser.add_argument("--save_interval_iter", default=3000,
    #                    type=int, help="interval steps to logr")
    parser.add_argument("--save_interval_epoch",
                        default=10,
                        type=int,
                        help="interval steps to logr")
    parser.add_argument("--log_interval_steps",
                        default=50,
                        type=int,
                        help="interval steps to logr")
    parser.add_argument("--seed", default=1, type=int, help="seed number")
    parser.add_argument("--resume",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--pretrained",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--preconf",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--string_path",
                        default=None,
                        type=str,
                        help="model path to restart training")
    parser.add_argument("--GPU_device",
                        default=None,
                        type=int,
                        help="selection of GPU device")
    parser.add_argument("--verbose", default=1, type=int, help="log level")
    args = parser.parse_args()

    if args.GPU_device is not None:
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device)

    # make experimental directory
    if not os.path.exists(args.expdir):
        os.makedirs(args.expdir)

    # set log level
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    elif args.verbose > 1:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s',
            datefmt='%m/%d/%Y %I:%M:%S',
            filename=args.expdir + "/train.log")
        logging.getLogger().addHandler(logging.StreamHandler())
        logging.warn("logging is disabled.")

    # fix seed
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if str(device) == "cpu":
        raise ValueError('ERROR: Training by CPU is not acceptable.')

    torch.backends.cudnn.benchmark = True  #faster

    if args.pretrained is None:
        if 'mel' in args.string_path:
            mean_stats = torch.FloatTensor(read_hdf5(args.stats,
                                                     "/mean_melsp"))
            scale_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/scale_melsp"))
            args.excit_dim = 0
        else:
            mean_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/mean_feat_mceplf0cap"))
            scale_stats = torch.FloatTensor(
                read_hdf5(args.stats, "/scale_feat_mceplf0cap"))
            args.cap_dim = mean_stats.shape[0] - (args.mcep_dim + 3)
            args.excit_dim = 2 + 1 + args.cap_dim
    else:
        config = torch.load(args.preconf)
        args.excit_dim = config.excit_dim
        args.cap_dim = config.cap_dim

    # save args as conf
    torch.save(args, args.expdir + "/model.conf")

    # define network
    model_waveform = DSWNV(n_aux=args.mcep_dim + args.excit_dim,
                           upsampling_factor=args.upsampling_factor,
                           hid_chn=args.hid_chn,
                           skip_chn=args.skip_chn,
                           kernel_size=args.kernel_size,
                           aux_kernel_size=args.kernel_size_wave,
                           aux_dilation_size=args.dilation_size_wave,
                           dilation_depth=args.dilation_depth,
                           dilation_repeat=args.dilation_repeat,
                           n_quantize=args.n_quantize,
                           do_prob=args.do_prob)
    logging.info(model_waveform)
    shift_rec_field = model_waveform.receptive_field
    logging.info(shift_rec_field)
    if shift_rec_field % args.upsampling_factor > 0:
        shift_rec_field_frm = shift_rec_field // args.upsampling_factor + 1
    else:
        shift_rec_field_frm = shift_rec_field // args.upsampling_factor
    shift_rec_field = shift_rec_field_frm * args.upsampling_factor
    logging.info(shift_rec_field)
    logging.info(shift_rec_field_frm)
    criterion_ce = torch.nn.CrossEntropyLoss(reduction='none')
    criterion_l1 = torch.nn.L1Loss(reduction='none')

    # send to gpu
    if torch.cuda.is_available():
        model_waveform.cuda()
        criterion_ce.cuda()
        criterion_l1.cuda()
        if args.pretrained is None:
            mean_stats = mean_stats.cuda()
            scale_stats = scale_stats.cuda()
    else:
        logging.error("gpu is not available. please check the setting.")
        sys.exit(1)

    model_waveform.train()

    if args.pretrained is None:
        model_waveform.scale_in.weight = torch.nn.Parameter(
            torch.unsqueeze(torch.diag(1.0 / scale_stats.data), 2))
        model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data /
                                                            scale_stats.data))

    #if args.pretrained is not None:
    #    checkpoint = torch.load(args.pretrained)
    #    #model_waveform.remove_weight_norm()
    #    #model_waveform.load_state_dict(checkpoint["model"])
    #    model_waveform.load_state_dict(checkpoint["model_waveform"])
    #    epoch_idx = checkpoint["iterations"]
    #    logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
    #    epoch_idx = 0
    #    #model_waveform.apply_weight_norm()
    #    #torch.nn.utils.remove_weight_norm(model_waveform.scale_in)

    for param in model_waveform.parameters():
        param.requires_grad = True
    for param in model_waveform.scale_in.parameters():
        param.requires_grad = False

    parameters = filter(lambda p: p.requires_grad, model_waveform.parameters())
    parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000
    logging.info('Trainable Parameters (waveform): %.3f million' % parameters)

    module_list = list(model_waveform.conv_aux.parameters()) + list(
        model_waveform.upsampling.parameters())
    if model_waveform.wav_conv_flag:
        module_list += list(model_waveform.wav_conv.parameters())
    module_list += list(model_waveform.causal.parameters())
    module_list += list(model_waveform.in_x.parameters()) + list(
        model_waveform.dil_h.parameters())
    module_list += list(model_waveform.out_skip.parameters())
    module_list += list(model_waveform.out_1.parameters()) + list(
        model_waveform.out_2.parameters())

    optimizer = RAdam(module_list, lr=args.lr)
    #optimizer = torch.optim.Adam(module_list, lr=args.lr)

    # resume
    if args.pretrained is not None and args.resume is None:
        checkpoint = torch.load(args.pretrained)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        #    optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("pretrained from %d-iter checkpoint." % epoch_idx)
        epoch_idx = 0
    elif args.resume is not None:
        #if args.resume is not None:
        checkpoint = torch.load(args.resume)
        model_waveform.load_state_dict(checkpoint["model_waveform"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        epoch_idx = checkpoint["iterations"]
        logging.info("restored from %d-iter checkpoint." % epoch_idx)
    #    epoch_idx = 2
    else:
        epoch_idx = 0

    def zero_wav_pad(x):
        return padding(x, args.pad_len * args.upsampling_factor,
                       value=0.0)  # noqa: E704

    def zero_feat_pad(x):
        return padding(x, args.pad_len, value=0.0)  # noqa: E704

    pad_wav_transform = transforms.Compose([zero_wav_pad])
    pad_feat_transform = transforms.Compose([zero_feat_pad])

    wav_transform = transforms.Compose(
        [lambda x: encode_mu_law(x, args.n_quantize)])

    # define generator training
    if os.path.isdir(args.waveforms):
        filenames = sorted(
            find_files(args.waveforms, "*.wav", use_dir_name=False))
        wav_list = [args.waveforms + "/" + filename for filename in filenames]
    elif os.path.isfile(args.waveforms):
        wav_list = read_txt(args.waveforms)
    else:
        logging.error("--waveforms should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats):
        feat_list = [args.feats + "/" + filename for filename in filenames]
    elif os.path.isfile(args.feats):
        feat_list = read_txt(args.feats)
    else:
        logging.error("--feats should be directory or list.")
        sys.exit(1)
    assert len(wav_list) == len(feat_list)
    logging.info("number of training data = %d." % len(feat_list))
    dataset = FeatureDatasetNeuVoco(wav_list,
                                    feat_list,
                                    pad_wav_transform,
                                    pad_feat_transform,
                                    args.upsampling_factor,
                                    args.string_path,
                                    wav_transform=wav_transform)
    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size_utt,
                            shuffle=True,
                            num_workers=args.n_workers)
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator = train_generator(dataloader,
                                device,
                                args.batch_size,
                                args.upsampling_factor,
                                limit_count=None)
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt))
    #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt))

    # define generator evaluation
    if os.path.isdir(args.waveforms_eval):
        filenames = sorted(
            find_files(args.waveforms_eval, "*.wav", use_dir_name=False))
        wav_list_eval = [
            args.waveforms + "/" + filename for filename in filenames
        ]
    elif os.path.isfile(args.waveforms_eval):
        wav_list_eval = read_txt(args.waveforms_eval)
    else:
        logging.error("--waveforms_eval should be directory or list.")
        sys.exit(1)
    if os.path.isdir(args.feats_eval):
        feat_list_eval = [
            args.feats_eval + "/" + filename for filename in filenames
        ]
    elif os.path.isfile(args.feats):
        feat_list_eval = read_txt(args.feats_eval)
    else:
        logging.error("--feats_eval should be directory or list.")
        sys.exit(1)
    assert len(wav_list_eval) == len(feat_list_eval)
    logging.info("number of evaluation data = %d." % len(feat_list_eval))
    dataset_eval = FeatureDatasetNeuVoco(wav_list_eval,
                                         feat_list_eval,
                                         pad_wav_transform,
                                         pad_feat_transform,
                                         args.upsampling_factor,
                                         args.string_path,
                                         wav_transform=wav_transform)
    dataloader_eval = DataLoader(dataset_eval,
                                 batch_size=args.batch_size_utt_eval,
                                 shuffle=False,
                                 num_workers=args.n_workers)
    ##generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    #generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None)
    #generator_eval = train_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1)
    generator_eval = train_generator(dataloader_eval,
                                     device,
                                     args.batch_size,
                                     args.upsampling_factor,
                                     limit_count=None)

    writer = SummaryWriter(args.expdir)
    total_train_loss = defaultdict(list)
    total_eval_loss = defaultdict(list)

    # train
    logging.info(args.string_path)
    total = 0
    iter_count = 0
    loss_ce = []
    loss_err = []
    min_eval_loss_err = 99999999.99
    min_eval_loss_err_std = 99999999.99
    min_eval_loss_ce = 99999999.99
    min_eval_loss_ce_std = 99999999.99
    iter_idx = 0
    min_idx = -1
    #min_eval_loss_ce = 1.575400
    #min_eval_loss_ce_std = 0.645726
    #iter_idx = 8098898
    #min_idx = 68 #resume70
    change_min_flag = False
    if args.resume is not None:
        np.random.set_state(checkpoint["numpy_random_state"])
        torch.set_rng_state(checkpoint["torch_random_state"])
    logging.info("==%d EPOCH==" % (epoch_idx + 1))
    logging.info("Training data")
    while epoch_idx < args.epoch_count:
        start = time.time()
        batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
            del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        if args.init:
            c_idx = -1
        if c_idx < 0:  # summarize epoch
            if not args.init:
                # save current epoch model
                numpy_random_state = np.random.get_state()
                torch_random_state = torch.get_rng_state()
                # report current epoch
                logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\
                    "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \
                        np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count))
                logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\
                "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total))))
            # compute loss in evaluation data
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            model_waveform.eval()
            for param in model_waveform.parameters():
                param.requires_grad = False
            pair_exist = False
            logging.info("Evaluation data")
            while True:
                with torch.no_grad():
                    start = time.time()
                    batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                        del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval)
                    if c_idx < 0:
                        break

                    x_es = x_ss + x_bs
                    f_es = f_ss + f_bs
                    logging.info(
                        f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}'
                    )
                    if x_ss > 0:
                        if x_es <= max_slen:
                            batch_x_prev = batch_x[:, x_ss - shift_rec_field -
                                                   1:x_es - 1]
                            batch_feat = batch_feat[:, f_ss -
                                                    shift_rec_field_frm:f_es]
                            batch_x = batch_x[:, x_ss:x_es]
                        else:
                            batch_x_prev = batch_x[:, x_ss - shift_rec_field -
                                                   1:-1]
                            batch_feat = batch_feat[:, f_ss -
                                                    shift_rec_field_frm:]
                            batch_x = batch_x[:, x_ss:]
                    #    assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all())
                    else:
                        batch_x_prev = F.pad(
                            batch_x[:, :x_es - 1],
                            (model_waveform.receptive_field + 1, 0),
                            "constant", args.n_quantize // 2)
                        batch_feat = batch_feat[:, :f_es]
                        batch_x = batch_x[:, :x_es]
                    #    assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all())

                    if x_ss > 0:
                        batch_x_output = model_waveform(
                            batch_feat, batch_x_prev)[:, shift_rec_field:]
                    else:
                        batch_x_output = model_waveform(
                            batch_feat, batch_x_prev,
                            first=True)[:, model_waveform.receptive_field:]

                    # samples check
                    i = np.random.randint(0, batch_x_output.shape[0])
                    logging.info("%s" % (os.path.join(
                        os.path.basename(os.path.dirname(featfile[i])),
                        os.path.basename(featfile[i]))))
                    #check_samples = batch_x[i,5:10].long()
                    #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
                    #logging.info(check_samples)

                    # handle short ending
                    batch_loss = 0
                    if len(idx_select) > 0:
                        logging.info('len_idx_select: ' + str(len(idx_select)))
                        batch_loss_ce = 0
                        batch_loss_err = 0
                        for j in range(len(idx_select)):
                            k = idx_select[j]
                            slens_utt = slens_acc[k]
                            logging.info('%s %d' % (featfile[k], slens_utt))
                            batch_x_output_k = batch_x_output[k, :slens_utt]
                            batch_x_k = batch_x[k, :slens_utt]
                            batch_loss_ce += torch.mean(
                                criterion_ce(batch_x_output_k, batch_x_k))
                            batch_loss_err += torch.mean(
                                torch.sum(
                                    100 * criterion_l1(
                                        F.softmax(batch_x_output_k, dim=-1),
                                        F.one_hot(batch_x_k,
                                                  num_classes=args.n_quantize).
                                        float()), -1))
                        batch_loss += batch_loss_ce
                        batch_loss_ce /= len(idx_select)
                        batch_loss_err /= len(idx_select)
                        total_eval_loss["eval/loss_ce"].append(
                            batch_loss_ce.item())
                        total_eval_loss["eval/loss_err"].append(
                            batch_loss_err.item())
                        loss_ce.append(batch_loss_ce.item())
                        loss_err.append(batch_loss_err.item())
                        if len(idx_select_full) > 0:
                            logging.info('len_idx_select_full: ' +
                                         str(len(idx_select_full)))
                            batch_x = torch.index_select(
                                batch_x, 0, idx_select_full)
                            batch_x_output = torch.index_select(
                                batch_x_output, 0, idx_select_full)
                        else:
                            logging.info(
                                "batch eval loss select %.3f %.3f (%.3f sec)" %
                                (batch_loss_ce.item(), batch_loss_err.item(),
                                 time.time() - start))
                            iter_count += 1
                            total += time.time() - start
                            continue

                    batch_loss_ce_ = torch.mean(
                        criterion_ce(
                            batch_x_output.reshape(-1, args.n_quantize),
                            batch_x.reshape(-1)).reshape(
                                batch_x_output.shape[0], -1), -1)
                    batch_loss_err_ = torch.mean(
                        torch.sum(
                            100 * criterion_l1(
                                F.softmax(batch_x_output, dim=-1),
                                F.one_hot(
                                    batch_x,
                                    num_classes=args.n_quantize).float()), -1),
                        -1)

                    batch_loss_ce = batch_loss_ce_.mean()
                    batch_loss_err = batch_loss_err_.mean()
                    total_eval_loss["eval/loss_ce"].append(
                        batch_loss_ce.item())
                    total_eval_loss["eval/loss_err"].append(
                        batch_loss_err.item())
                    loss_ce.append(batch_loss_ce.item())
                    loss_err.append(batch_loss_err.item())

                    logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, \
                        x_ss, x_bs, f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
                    iter_count += 1
                    total += time.time() - start
            logging.info('sme')
            for key in total_eval_loss.keys():
                total_eval_loss[key] = np.mean(total_eval_loss[key])
                logging.info(
                    f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.")
            write_to_tensorboard(writer, iter_idx, total_eval_loss)
            total_eval_loss = defaultdict(list)
            eval_loss_ce = np.mean(loss_ce)
            eval_loss_ce_std = np.std(loss_ce)
            eval_loss_err = np.mean(loss_err)
            eval_loss_err_std = np.std(loss_err)
            logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\
                "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \
                    eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count))
            if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \
                or (eval_loss_ce <= min_eval_loss_ce):
                min_eval_loss_ce = eval_loss_ce
                min_eval_loss_ce_std = eval_loss_ce_std
                min_eval_loss_err = eval_loss_err
                min_eval_loss_err_std = eval_loss_err_std
                min_idx = epoch_idx
                change_min_flag = True
            #else:
            #    epoch_min_flag = False
            if change_min_flag:
                logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \
                    min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1))
            #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag):
            #    logging.info('save epoch:%d' % (epoch_idx+1))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1)
            if args.init:
                exit()
            logging.info('save epoch:%d' % (epoch_idx + 1))
            save_checkpoint(args.expdir, model_waveform, optimizer,
                            numpy_random_state, torch_random_state,
                            epoch_idx + 1)
            total = 0
            iter_count = 0
            loss_ce = []
            loss_err = []
            epoch_idx += 1
            np.random.set_state(numpy_random_state)
            torch.set_rng_state(torch_random_state)
            model_waveform.train()
            for param in model_waveform.parameters():
                param.requires_grad = True
            for param in model_waveform.scale_in.parameters():
                param.requires_grad = False
            # start next epoch
            if epoch_idx < args.epoch_count:
                start = time.time()
                logging.info("==%d EPOCH==" % (epoch_idx + 1))
                logging.info("Training data")
                batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \
                    del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator)
        # feedforward and backpropagate current batch
        if epoch_idx < args.epoch_count:
            logging.info("%d iteration [%d]" % (iter_idx + 1, epoch_idx + 1))

            x_es = x_ss + x_bs
            f_es = f_ss + f_bs
            logging.info(
                f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}')
            if x_ss > 0:
                if x_es <= max_slen:
                    batch_x_prev = batch_x[:,
                                           x_ss - shift_rec_field - 1:x_es - 1]
                    batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:f_es]
                    batch_x = batch_x[:, x_ss:x_es]
                else:
                    batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:-1]
                    batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:]
                    batch_x = batch_x[:, x_ss:]
            #    assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all())
            else:
                batch_x_prev = F.pad(batch_x[:, :x_es - 1],
                                     (model_waveform.receptive_field + 1, 0),
                                     "constant", args.n_quantize // 2)
                batch_feat = batch_feat[:, :f_es]
                batch_x = batch_x[:, :x_es]
            #    assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all())

            if x_ss > 0:
                batch_x_output = model_waveform(batch_feat,
                                                batch_x_prev,
                                                do=True)[:, shift_rec_field:]
            else:
                batch_x_output = model_waveform(
                    batch_feat, batch_x_prev, first=True,
                    do=True)[:, model_waveform.receptive_field:]

            # samples check
            i = np.random.randint(0, batch_x_output.shape[0])
            logging.info(
                "%s" %
                (os.path.join(os.path.basename(os.path.dirname(featfile[i])),
                              os.path.basename(featfile[i]))))
            #with torch.no_grad():
            #    i = np.random.randint(0, batch_x_output.shape[0])
            #    logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i]))))
            #    check_samples = batch_x[i,5:10].long()
            #    logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples))
            #    logging.info(check_samples)

            # handle short ending
            batch_loss = 0
            if len(idx_select) > 0:
                logging.info('len_idx_select: ' + str(len(idx_select)))
                batch_loss_ce = 0
                batch_loss_err = 0
                for j in range(len(idx_select)):
                    k = idx_select[j]
                    slens_utt = slens_acc[k]
                    logging.info('%s %d' % (featfile[k], slens_utt))
                    batch_x_output_k = batch_x_output[k, :slens_utt]
                    batch_x_k = batch_x[k, :slens_utt]
                    batch_loss_ce += torch.mean(
                        criterion_ce(batch_x_output_k, batch_x_k))
                    batch_loss_err += torch.mean(
                        torch.sum(
                            100 * criterion_l1(
                                F.softmax(batch_x_output_k, dim=-1),
                                F.one_hot(
                                    batch_x_k,
                                    num_classes=args.n_quantize).float()), -1))
                batch_loss += batch_loss_ce
                batch_loss_ce /= len(idx_select)
                batch_loss_err /= len(idx_select)
                total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
                total_train_loss["train/loss_err"].append(
                    batch_loss_err.item())
                loss_ce.append(batch_loss_ce.item())
                loss_err.append(batch_loss_err.item())
                if len(idx_select_full) > 0:
                    logging.info('len_idx_select_full: ' +
                                 str(len(idx_select_full)))
                    batch_x = torch.index_select(batch_x, 0, idx_select_full)
                    batch_x_output = torch.index_select(
                        batch_x_output, 0, idx_select_full)
                else:
                    optimizer.zero_grad()
                    batch_loss.backward()
                    torch.nn.utils.clip_grad_norm_(model_waveform.parameters(),
                                                   10)
                    optimizer.step()

                    logging.info("batch loss select %.3f %.3f (%.3f sec)" %
                                 (batch_loss_ce.item(), batch_loss_err.item(),
                                  time.time() - start))
                    iter_idx += 1
                    #if iter_idx % args.save_interval_iter == 0:
                    #    logging.info('save iter:%d' % (iter_idx))
                    #    save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
                    iter_count += 1
                    if iter_idx % args.log_interval_steps == 0:
                        logging.info('smt')
                        for key in total_train_loss.keys():
                            total_train_loss[key] = np.mean(
                                total_train_loss[key])
                            logging.info(
                                f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}."
                            )
                        write_to_tensorboard(writer, iter_idx,
                                             total_train_loss)
                        total_train_loss = defaultdict(list)
                    total += time.time() - start
                    continue

            # loss
            batch_loss_ce_ = torch.mean(
                criterion_ce(batch_x_output.reshape(-1, args.n_quantize),
                             batch_x.reshape(-1)).reshape(
                                 batch_x_output.shape[0], -1), -1)
            batch_loss_err_ = torch.mean(
                torch.sum(
                    100 * criterion_l1(
                        F.softmax(batch_x_output, dim=-1),
                        F.one_hot(batch_x,
                                  num_classes=args.n_quantize).float()), -1),
                -1)

            batch_loss_ce = batch_loss_ce_.mean()
            batch_loss_err = batch_loss_err_.mean()
            total_train_loss["train/loss_ce"].append(batch_loss_ce.item())
            total_train_loss["train/loss_err"].append(batch_loss_err.item())
            loss_ce.append(batch_loss_ce.item())
            loss_err.append(batch_loss_err.item())

            batch_loss += batch_loss_ce_.sum()

            optimizer.zero_grad()
            batch_loss.backward()
            torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10)
            optimizer.step()

            logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \
                f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start))
            iter_idx += 1
            #if iter_idx % args.save_interval_iter == 0:
            #    logging.info('save iter:%d' % (iter_idx))
            #    save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx)
            iter_count += 1
            if iter_idx % args.log_interval_steps == 0:
                logging.info('smt')
                for key in total_train_loss.keys():
                    total_train_loss[key] = np.mean(total_train_loss[key])
                    logging.info(
                        f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}."
                    )
                write_to_tensorboard(writer, iter_idx, total_train_loss)
                total_train_loss = defaultdict(list)
            total += time.time() - start

    # save final model
    model_waveform.cpu()
    torch.save({"model_waveform": model_waveform.state_dict()},
               args.expdir + "/checkpoint-final.pkl")
    logging.info("final checkpoint created.")
示例#9
0
# Type = 'best'
save_folder = 'model_result_multi_layer'
Type = 'trainable'
model_check_point = '%s/model_%s_%d.pk' % (save_folder, Type, version_num)
optim_check_point = '%s/optim_%s_%d.pkl' % (save_folder, Type, version_num)
loss_check_point = '%s/loss_%s_%d.pkl' % (save_folder, Type, version_num)
epoch_check_point = '%s/epoch_%s_%d.pkl' % (save_folder, Type, version_num)
bleu_check_point = '%s/bleu_%s_%d.pkl' % (save_folder, Type, version_num)
loss_values = []
epoch_values = []
bleu_values = []
if os.path.isfile(model_check_point):
    print('Loading previous status (ver.%d)...' % version_num)
    model.load_state_dict(torch.load(model_check_point, map_location='cpu'))
    model = model.to(device)
    optimizer.load_state_dict(torch.load(optim_check_point))
    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     mode='min',
                                     factor=0.4,
                                     patience=2,
                                     min_lr=1e-7,
                                     verbose=True)
    loss_values = torch.load(loss_check_point)
    epoch_values = torch.load(epoch_check_point)
    bleu_values = torch.load(bleu_check_point)
    print('Load successfully')
else:
    print("ver.%d doesn't exist" % version_num)

# evaluateAndShowAttention(['現在', '未來', '夢想', '科學', '文化'], method='beam_search', is_sample=True)
示例#10
0
class Trainer:
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='channelwise',
                 thresh=.99,
                 half_precision=False,
                 downsampling=None):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()
        print('Checking for optimizer for {}'.format(optimizer))
        #optimizer = str(optimizer)
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == "adam_lr":
            print("Using adam with higher learning rate")
            self.optimizer = optim.Adam(model.parameters(), lr=0.01)
        elif optimizer == 'adam_lr2':
            print('Using adam with to large learning rate')
            self.optimizer = optim.Adam(model.parameters(), lr=0.0001)
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       momentum=0.9,
                                       weight_decay=5e-4)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.1,
                                       momentum=0.9,
                                       weight_decay=5e-4)
            self.lr_scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, self.epochs // 3)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir,
            f'{model.name}_bs{batch_size}_e{epochs}_dspl{downsampling}_t{int(thresh*1000)}_id{run_id}.csv'
        )
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment with identical run_id was detected, training will be skipped, consider using another run_id'
                )
        if os.path.exists((self.savepath.replace('.csv', '.pt'))):
            self.model.load_state_dict(
                torch.load(self.savepath.replace('.csv',
                                                 '.pt'))['model_state_dict'])
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
            if half_precision:
                self.model = self.model.half()
            self.optimizer.load_state_dict(
                torch.load(self.savepath.replace('.csv', '.pt'))['optimizer'])
            self.start_epoch = torch.load(self.savepath.replace(
                '.csv', '.pt'))['epoch'] + 1
            initial_epoch = self._infer_initial_epoch(self.savepath)
            print('Resuming existing run, starting at epoch', self.start_epoch,
                  'from', self.savepath.replace('.csv', '.pt'))
        else:
            if half_precision:
                self.model = self.model.half()
            self.start_epoch = 0
            initial_epoch = 0
            self.parallel = data_prallel
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_accuracy')
        writer2 = NPYWriter(self.savepath.replace('.csv', ''))
        self.pooling_strat = conv_method
        print('Settomg Satiraton recording threshold to', thresh)
        self.half = half_precision

        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''), [writer],
                                   model,
                                   ignore_layer_names='convolution',
                                   stats=['lsat', 'idim'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None,
                                   initial_epoch=initial_epoch,
                                   interpolation_strategy='nearest'
                                   if downsampling is not None else None,
                                   interpolation_downsampling=4)

    def _infer_initial_epoch(self, savepath):
        if not os.path.exists(savepath):
            return 0
        else:
            df = pd.read_csv(savepath, sep=';', index_col=0)
            print(len(df) + 1)
            return len(df)

    def train(self):
        if self.experiment_done:
            return
        for epoch in range(self.start_epoch, self.epochs):
            #self.test(epoch=epoch)

            print('Start training epoch', epoch)
            print(
                "{} Epoch {}, training loss: {}, training accuracy: {}".format(
                    now(), epoch, *self.train_epoch()))
            self.test(epoch=epoch)
            if self.opt_name == "LRS":
                print('LRS step')
                self.lr_scheduler.step()
            self.stats.add_saturations()
            #self.stats.save()
            #if self.plot:
            #    plot_saturation_level_from_results(self.savepath, epoch)
        self.stats.close()
        return self.savepath + '.csv'

    def train_epoch(self):
        self.model.train()
        correct = 0
        total = 0
        running_loss = 0
        old_time = time()
        top5_accumulator = 0
        for batch, data in enumerate(self.train_loader):
            if batch % 10 == 0 and batch != 0:
                print(
                    batch, 'of', len(self.train_loader), 'processing time',
                    time() - old_time,
                    "top5_acc:" if self.compute_top_k else 'acc:',
                    round(top5_accumulator /
                          (batch), 3) if self.compute_top_k else correct /
                    total)
                old_time = time()
            inputs, labels = data
            if self.half:
                inputs, labels = inputs.to(self.device).half(), labels.to(
                    self.device)
            else:
                inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            if self.compute_top_k:
                top5_accumulator += accuracy(outputs, labels, (5, ))[0]
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)

            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()
            correct += (predicted == labels.long()).sum().item()

            running_loss += loss.item()
        self.stats.add_scalar('training_loss', running_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('training_accuracy',
                                  (top5_accumulator / (batch + 1)))
        else:
            self.stats.add_scalar('training_accuracy', correct / total)
        return running_loss / total, correct / total

    def test(self, epoch, save=True):
        self.model.eval()
        correct = 0
        total = 0
        test_loss = 0
        top5_accumulator = 0
        with torch.no_grad():
            for batch, data in enumerate(self.test_loader):
                if batch % 10 == 0:
                    print('Processing eval batch', batch, 'of',
                          len(self.test_loader))
                inputs, labels = data
                if self.half:
                    inputs, labels = inputs.to(self.device).half(), labels.to(
                        self.device)
                else:
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels.long()).sum().item()
                if self.compute_top_k:
                    top5_accumulator += accuracy(outputs, labels, (5, ))[0]
                test_loss += loss.item()

        self.stats.add_scalar('test_loss', test_loss / total)
        if self.compute_top_k:
            self.stats.add_scalar('test_accuracy',
                                  top5_accumulator / (batch + 1))
            print('{} Test Top5-Accuracy on {} images: {:.4f}'.format(
                now(), total, top5_accumulator / (batch + 1)))

        else:
            self.stats.add_scalar('test_accuracy', correct / total)
            print('{} Test Accuracy on {} images: {:.4f}'.format(
                now(), total, correct / total))
        if save:
            torch.save(
                {
                    'model_state_dict': self.model.state_dict(),
                    'optimizer': self.optimizer.state_dict(),
                    'epoch': epoch,
                    'test_loss': test_loss / total
                }, self.savepath.replace('.csv', '.pt'))
        return correct / total, test_loss / total