def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of wav files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation feat files") parser.add_argument("--stats", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--upsampling_factor", default=120, type=int, help="number of dimension of aux feats") parser.add_argument("--hidden_units_wave", default=384, type=int, help="depth of dilation") parser.add_argument("--hidden_units_wave_2", default=16, type=int, help="depth of dilation") parser.add_argument("--kernel_size_wave", default=7, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--dilation_size_wave", default=1, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--lpc", default=12, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--mcep_dim", default=50, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--right_size", default=0, type=int, help="kernel size of dilated causal convolution") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument("--batch_size", default=15, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--epoch_count", default=4000, type=int, help="number of training epochs") parser.add_argument("--do_prob", default=0, type=float, help="dropout probability") parser.add_argument("--batch_size_utt", default=5, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--batch_size_utt_eval", default=5, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--n_workers", default=2, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--n_quantize", default=256, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--causal_conv_wave", default=False, type=strtobool, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--n_stage", default=4, type=int, help="number of sparsification stages") parser.add_argument("--t_start", default=20000, type=int, help="iter idx to start sparsify") parser.add_argument("--t_end", default=4500000, type=int, help="iter idx to finish densitiy sparsify") parser.add_argument("--interval", default=100, type=int, help="interval in finishing densitiy sparsify") parser.add_argument("--densities", default="0.05-0.05-0.2", type=str, help="final densitiy of reset, update, new hidden gate matrices") # other setting parser.add_argument("--pad_len", default=3000, type=int, help="seed number") parser.add_argument("--save_interval_iter", default=5000, type=int, help="interval steps to logr") parser.add_argument("--save_interval_epoch", default=10, type=int, help="interval steps to logr") parser.add_argument("--log_interval_steps", default=50, type=int, help="interval steps to logr") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--pretrained", default=None, type=str, help="model path to restart training") parser.add_argument("--string_path", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig(level=logging.DEBUG, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig(level=logging.WARN, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cpu": raise ValueError('ERROR: Training by CPU is not acceptable.') torch.backends.cudnn.benchmark = True #faster #if args.pretrained is None: if 'mel' in args.string_path: mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_melsp")) scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_melsp")) args.excit_dim = 0 #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/mean_melsp")]) #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:2], read_hdf5(args.stats, "/scale_melsp")]) #args.excit_dim = 2 #mean_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/mean_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/mean_melsp")]) #scale_stats = torch.FloatTensor(np.r_[read_hdf5(args.stats, "/scale_feat_mceplf0cap")[:6], read_hdf5(args.stats, "/scale_melsp")]) #args.excit_dim = 6 else: mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_"+args.string_path.replace("/",""))) scale_stats = torch.FloatTensor(read_hdf5(args.stats, "/scale_"+args.string_path.replace("/",""))) if mean_stats.shape[0] > args.mcep_dim+2: if 'feat_org_lf0' in args.string_path: args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2) args.excit_dim = 2+args.cap_dim else: args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+3) args.excit_dim = 2+1+args.cap_dim #args.cap_dim = mean_stats.shape[0]-(args.mcep_dim+2) #args.excit_dim = 2+args.cap_dim else: args.cap_dim = None args.excit_dim = 2 #else: # if 'mel' in args.string_path: # args.excit_dim = 0 # else: # args.cap_dim = 3 # if 'legacy' not in args.string_path: # args.excit_dim = 6 # else: # args.excit_dim = 5 # save args as conf # 14/15-8 or 14/15/16-6/7/8 [5ms] # 7-8 or 8-6/7/8 [10ms] #args.batch_size = 7 #args.batch_size_utt = 8 #args.batch_size = 8 #args.batch_size_utt = 6 #args.codeap_dim = 3 torch.save(args, args.expdir + "/model.conf") #args.batch_size = 10 #batch_sizes = [None]*3 #batch_sizes[0] = int(args.batch_size*0.5) #batch_sizes[1] = int(args.batch_size) #batch_sizes[2] = int(args.batch_size*1.5) #logging.info(batch_sizes) # define network model_waveform = GRU_WAVE_DECODER_DUALGRU_COMPACT( feat_dim=args.mcep_dim+args.excit_dim, upsampling_factor=args.upsampling_factor, hidden_units=args.hidden_units_wave, hidden_units_2=args.hidden_units_wave_2, kernel_size=args.kernel_size_wave, dilation_size=args.dilation_size_wave, n_quantize=args.n_quantize, causal_conv=args.causal_conv_wave, lpc=args.lpc, right_size=args.right_size, do_prob=args.do_prob) logging.info(model_waveform) criterion_ce = torch.nn.CrossEntropyLoss(reduction='none') criterion_l1 = torch.nn.L1Loss(reduction='none') # send to gpu if torch.cuda.is_available(): model_waveform.cuda() criterion_ce.cuda() criterion_l1.cuda() if args.pretrained is None: mean_stats = mean_stats.cuda() scale_stats = scale_stats.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) model_waveform.train() if args.pretrained is None: model_waveform.scale_in.weight = torch.nn.Parameter(torch.unsqueeze(torch.diag(1.0/scale_stats.data),2)) model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data/scale_stats.data)) for param in model_waveform.parameters(): param.requires_grad = True for param in model_waveform.scale_in.parameters(): param.requires_grad = False if args.lpc > 0: for param in model_waveform.logits.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model_waveform.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters (waveform): %.3f million' % parameters) module_list = list(model_waveform.conv.parameters()) module_list += list(model_waveform.conv_s_c.parameters()) + list(model_waveform.embed_wav.parameters()) module_list += list(model_waveform.gru.parameters()) + list(model_waveform.gru_2.parameters()) module_list += list(model_waveform.out.parameters()) optimizer = RAdam(module_list, lr=args.lr) #optimizer = torch.optim.Adam(module_list, lr=args.lr) #if args.pretrained is None: # optimizer = RAdam(module_list, lr=args.lr) #else: # #optimizer = RAdam(module_list, lr=args.lr) # optimizer = torch.optim.Adam(module_list, lr=args.lr) # resume if args.pretrained is not None and args.resume is None: checkpoint = torch.load(args.pretrained) model_waveform.load_state_dict(checkpoint["model_waveform"]) # optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 elif args.resume is not None: checkpoint = torch.load(args.resume) model_waveform.load_state_dict(checkpoint["model_waveform"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) else: epoch_idx = 0 def zero_wav_pad(x): return padding(x, args.pad_len*args.upsampling_factor, value=0.0) # noqa: E704 def zero_feat_pad(x): return padding(x, args.pad_len, value=0.0) # noqa: E704 pad_wav_transform = transforms.Compose([zero_wav_pad]) pad_feat_transform = transforms.Compose([zero_feat_pad]) wav_transform = transforms.Compose([lambda x: encode_mu_law(x, args.n_quantize)]) # define generator training if os.path.isdir(args.waveforms): filenames = sorted(find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) else: logging.error("--waveforms should be directory or list.") sys.exit(1) if os.path.isdir(args.feats): feat_list = [args.feats + "/" + filename for filename in filenames] elif os.path.isfile(args.feats): feat_list = read_txt(args.feats) else: logging.error("--feats should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(feat_list)) dataset = FeatureDatasetNeuVoco(wav_list, feat_list, pad_wav_transform, pad_feat_transform, args.upsampling_factor, args.string_path, wav_transform=wav_transform) #args.string_path, wav_transform=wav_transform, with_excit=True) #args.string_path, wav_transform=wav_transform, with_excit=False) #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim) dataloader = DataLoader(dataset, batch_size=args.batch_size_utt, shuffle=True, num_workers=args.n_workers) #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1) generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None) #generator = data_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, batch_sizes=batch_sizes) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames = sorted(find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [args.waveforms + "/" + filename for filename in filenames] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) if os.path.isdir(args.feats_eval): feat_list_eval = [args.feats_eval + "/" + filename for filename in filenames] elif os.path.isfile(args.feats): feat_list_eval = read_txt(args.feats_eval) else: logging.error("--feats_eval should be directory or list.") sys.exit(1) assert len(wav_list_eval) == len(feat_list_eval) logging.info("number of evaluation data = %d." % len(feat_list_eval)) dataset_eval = FeatureDatasetNeuVoco(wav_list_eval, feat_list_eval, pad_wav_transform, pad_feat_transform, args.upsampling_factor, args.string_path, wav_transform=wav_transform) #args.string_path, wav_transform=wav_transform, with_excit=False) #args.string_path, wav_transform=wav_transform, with_excit=True) #args.string_path, wav_transform=wav_transform, with_excit=True, codeap_dim=args.codeap_dim) dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size_utt_eval, shuffle=False, num_workers=args.n_workers) #generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1) generator_eval = data_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None) writer = SummaryWriter(args.expdir) total_train_loss = defaultdict(list) total_eval_loss = defaultdict(list) #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2)) density_deltas_ = args.densities.split('-') density_deltas = [None]*len(density_deltas_) for i in range(len(density_deltas_)): density_deltas[i] = (1-float(density_deltas_[i]))/args.n_stage t_deltas = [None]*args.n_stage t_starts = [None]*args.n_stage t_ends = [None]*args.n_stage densities = [None]*args.n_stage t_delta = args.t_end - args.t_start + 1 #t_deltas[0] = round((1/(args.n_stage-1))*0.6*t_delta) if args.n_stage > 3: t_deltas[0] = round((1/2)*0.2*t_delta) else: t_deltas[0] = round(0.2*t_delta) t_starts[0] = args.t_start t_ends[0] = args.t_start + t_deltas[0] - 1 densities[0] = [None]*len(density_deltas) for j in range(len(density_deltas)): densities[0][j] = 1-density_deltas[j] for i in range(1,args.n_stage): if i < args.n_stage-1: #t_deltas[i] = round((1/(args.n_stage-1))*0.6*t_delta) if args.n_stage > 3: if i < 2: t_deltas[i] = round((1/2)*0.2*t_delta) else: if args.n_stage > 4: t_deltas[i] = round((1/2)*0.3*t_delta) else: t_deltas[i] = round(0.3*t_delta) else: t_deltas[i] = round(0.3*t_delta) else: #t_deltas[i] = round(0.4*t_delta) t_deltas[i] = round(0.5*t_delta) t_starts[i] = t_ends[i-1] + 1 t_ends[i] = t_starts[i] + t_deltas[i] - 1 densities[i] = [None]*len(density_deltas) if i < args.n_stage-1: for j in range(len(density_deltas)): densities[i][j] = densities[i-1][j]-density_deltas[j] else: for j in range(len(density_deltas)): densities[i][j] = float(density_deltas_[j]) logging.info(t_delta) logging.info(t_deltas) logging.info(t_starts) logging.info(t_ends) logging.info(args.interval) logging.info(densities) idx_stage = 0 # train total = 0 iter_count = 0 loss_ce = [] loss_err = [] min_eval_loss_ce = 99999999.99 min_eval_loss_ce_std = 99999999.99 min_eval_loss_err = 99999999.99 min_eval_loss_err_std = 99999999.99 iter_idx = 0 min_idx = -1 #min_eval_loss_ce = 2.007181 #min_eval_loss_ce_std = 0.801412 #iter_idx = 70350 #min_idx = 6 #resume7 while idx_stage < args.n_stage-1 and iter_idx + 1 >= t_starts[idx_stage+1]: idx_stage += 1 logging.info(idx_stage) change_min_flag = False if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") while epoch_idx < args.epoch_count: start = time.time() batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator) if c_idx < 0: # summarize epoch # save current epoch model numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() # report current epoch logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\ "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \ np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count)) logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total)))) # compute loss in evaluation data total = 0 iter_count = 0 loss_ce = [] loss_err = [] model_waveform.eval() for param in model_waveform.parameters(): param.requires_grad = False logging.info("Evaluation data") while True: with torch.no_grad(): start = time.time() batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval) if c_idx < 0: break x_es = x_ss+x_bs f_es = f_ss+f_bs logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}') if x_ss > 0: if x_es <= max_slen: batch_x_prev = batch_x[:,x_ss-1:x_es-1] if args.lpc > 0: if x_ss-args.lpc >= 0: batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1] else: batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,f_ss:f_es] batch_x = batch_x[:,x_ss:x_es] else: batch_x_prev = batch_x[:,x_ss-1:-1] if args.lpc > 0: if x_ss-args.lpc >= 0: batch_x_lpc = batch_x[:,x_ss-args.lpc:-1] else: batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,f_ss:] batch_x = batch_x[:,x_ss:] else: batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2) if args.lpc > 0: batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,:f_es] batch_x = batch_x[:,:x_es] #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all()) if f_ss > 0: if len(del_index_utt) > 0: h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device) h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device) if args.lpc > 0: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc) else: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2) else: if args.lpc > 0: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc) else: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev) # samples check i = np.random.randint(0, batch_x_output.shape[0]) logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i])))) #check_samples = batch_x[i,5:10].long() #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples)) #logging.info(check_samples) # handle short ending if len(idx_select) > 0: logging.info('len_idx_select: '+str(len(idx_select))) batch_loss_ce_select = 0 batch_loss_err_select = 0 for j in range(len(idx_select)): k = idx_select[j] slens_utt = slens_acc[k] logging.info('%s %d' % (featfile[k], slens_utt)) batch_x_output_ = batch_x_output[k,:slens_utt] batch_x_ = batch_x[k,:slens_utt] batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_)) batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1)) batch_loss += batch_loss_ce_select batch_loss_ce_select /= len(idx_select) batch_loss_err_select /= len(idx_select) total_eval_loss["eval/loss_ce"].append(batch_loss_ce_select.item()) total_eval_loss["eval/loss_err"].append(batch_loss_err_select.item()) loss_ce.append(batch_loss_ce_select.item()) loss_err.append(batch_loss_err_select.item()) if len(idx_select_full) > 0: logging.info('len_idx_select_full: '+str(len(idx_select_full))) batch_x = torch.index_select(batch_x,0,idx_select_full) batch_x_output = torch.index_select(batch_x_output,0,idx_select_full) else: logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \ batch_loss_err_select.item(), time.time() - start)) iter_count += 1 total += time.time() - start continue # loss batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1) batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1) batch_loss_ce = batch_loss_ce_.mean() batch_loss_err = batch_loss_err_.mean() total_eval_loss["eval/loss_ce"].append(batch_loss_ce.item()) total_eval_loss["eval/loss_err"].append(batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \ f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_count += 1 total += time.time() - start logging.info('sme') for key in total_eval_loss.keys(): total_eval_loss[key] = np.mean(total_eval_loss[key]) logging.info(f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.") write_to_tensorboard(writer, iter_idx, total_eval_loss) total_eval_loss = defaultdict(list) eval_loss_ce = np.mean(loss_ce) eval_loss_ce_std = np.std(loss_ce) eval_loss_err = np.mean(loss_err) eval_loss_err_std = np.std(loss_err) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\ "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \ eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count)) if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \ or (eval_loss_ce <= min_eval_loss_ce): min_eval_loss_ce = eval_loss_ce min_eval_loss_ce_std = eval_loss_ce_std min_eval_loss_err = eval_loss_err min_eval_loss_err_std = eval_loss_err_std min_idx = epoch_idx change_min_flag = True if change_min_flag: logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \ min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1)) #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag): # logging.info('save epoch:%d' % (epoch_idx+1)) # save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1) logging.info('save epoch:%d' % (epoch_idx+1)) save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1) total = 0 iter_count = 0 loss_ce = [] loss_err = [] epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model_waveform.train() for param in model_waveform.parameters(): param.requires_grad = True for param in model_waveform.scale_in.parameters(): param.requires_grad = False if args.lpc > 0: for param in model_waveform.logits.parameters(): param.requires_grad = False # start next epoch if epoch_idx < args.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx+1)) logging.info("Training data") batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator) # feedforward and backpropagate current batch if epoch_idx < args.epoch_count: logging.info("%d iteration [%d]" % (iter_idx+1, epoch_idx+1)) x_es = x_ss+x_bs f_es = f_ss+f_bs logging.info(f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}') if x_ss > 0: if x_es <= max_slen: batch_x_prev = batch_x[:,x_ss-1:x_es-1] if args.lpc > 0: if x_ss-args.lpc >= 0: batch_x_lpc = batch_x[:,x_ss-args.lpc:x_es-1] else: batch_x_lpc = F.pad(batch_x[:,:x_es-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,f_ss:f_es] batch_x = batch_x[:,x_ss:x_es] else: batch_x_prev = batch_x[:,x_ss-1:-1] if args.lpc > 0: if x_ss-args.lpc >= 0: batch_x_lpc = batch_x[:,x_ss-args.lpc:-1] else: batch_x_lpc = F.pad(batch_x[:,:-1], (-(x_ss-args.lpc), 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,f_ss:] batch_x = batch_x[:,x_ss:] else: batch_x_prev = F.pad(batch_x[:,:x_es-1], (1, 0), "constant", args.n_quantize // 2) if args.lpc > 0: batch_x_lpc = F.pad(batch_x[:,:x_es-1], (args.lpc, 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:,:f_es] batch_x = batch_x[:,:x_es] #assert((batch_x_prev[:,1:] == batch_x[:,:-1]).all()) if f_ss > 0: if len(del_index_utt) > 0: h_x = torch.FloatTensor(np.delete(h_x.cpu().data.numpy(), del_index_utt, axis=1)).to(device) h_x_2 = torch.FloatTensor(np.delete(h_x_2.cpu().data.numpy(), del_index_utt, axis=1)).to(device) if args.lpc > 0: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, x_lpc=batch_x_lpc, do=True) else: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, h=h_x, h_2=h_x_2, do=True) else: if args.lpc > 0: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, x_lpc=batch_x_lpc, do=True) else: batch_x_output, h_x, h_x_2 = model_waveform(batch_feat, batch_x_prev, do=True) # samples check #with torch.no_grad(): i = np.random.randint(0, batch_x_output.shape[0]) logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i])))) # check_samples = batch_x[i,5:10].long() # logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples)) # logging.info(check_samples) # handle short ending batch_loss = 0 if len(idx_select) > 0: logging.info('len_idx_select: '+str(len(idx_select))) batch_loss_ce_select = 0 batch_loss_err_select = 0 for j in range(len(idx_select)): k = idx_select[j] slens_utt = slens_acc[k] logging.info('%s %d' % (featfile[k], slens_utt)) batch_x_output_ = batch_x_output[k,:slens_utt] batch_x_ = batch_x[k,:slens_utt] batch_loss_ce_select += torch.mean(criterion_ce(batch_x_output_, batch_x_)) batch_loss_err_select += torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output_, dim=-1), F.one_hot(batch_x_, num_classes=args.n_quantize).float()), -1)) batch_loss += batch_loss_ce_select batch_loss_ce_select /= len(idx_select) batch_loss_err_select /= len(idx_select) total_train_loss["train/loss_ce"].append(batch_loss_ce_select.item()) total_train_loss["train/loss_err"].append(batch_loss_err_select.item()) loss_ce.append(batch_loss_ce_select.item()) loss_err.append(batch_loss_err_select.item()) if len(idx_select_full) > 0: logging.info('len_idx_select_full: '+str(len(idx_select_full))) batch_x = torch.index_select(batch_x,0,idx_select_full) batch_x_output = torch.index_select(batch_x_output,0,idx_select_full) #elif len(idx_select) > 1: else: optimizer.zero_grad() batch_loss.backward() #for name, param in model_waveform.named_parameters(): # if param.requires_grad: # logging.info(f"{name} {param.grad.norm()}") flag = False for name, param in model_waveform.named_parameters(): if param.requires_grad: grad_norm = param.grad.norm() # logging.info(f"{name} {grad_norm}") #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm): if torch.isnan(grad_norm) or torch.isinf(grad_norm): flag = True if flag: logging.info("explode grad") optimizer.zero_grad() continue torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10) #for name, param in model_waveform.named_parameters(): # if param.requires_grad: # logging.info(f"{name} {param.grad.norm()}") optimizer.step() with torch.no_grad(): #test = model_waveform.gru.weight_hh_l0.data.clone() #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2)) #t_start, t_end, interval, densities if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]: idx_stage += 1 if idx_stage > 0: sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1]) else: sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage]) #logging.info((test==model_waveform.gru.weight_hh_l0).all()) logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce_select.item(), \ batch_loss_err_select.item(), time.time() - start)) iter_idx += 1 if iter_idx % args.save_interval_iter == 0: logging.info('save iter:%d' % (iter_idx)) save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx) iter_count += 1 if iter_idx % args.log_interval_steps == 0: logging.info('smt') for key in total_train_loss.keys(): total_train_loss[key] = np.mean(total_train_loss[key]) logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.") write_to_tensorboard(writer, iter_idx, total_train_loss) total_train_loss = defaultdict(list) total += time.time() - start continue #else: # continue # loss batch_loss_ce_ = torch.mean(criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape(batch_x_output.shape[0], -1), -1) batch_loss_err_ = torch.mean(torch.sum(100*criterion_l1(F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1) batch_loss_ce = batch_loss_ce_.mean() batch_loss_err = batch_loss_err_.mean() total_train_loss["train/loss_ce"].append(batch_loss_ce.item()) total_train_loss["train/loss_err"].append(batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) batch_loss += batch_loss_ce_.sum() optimizer.zero_grad() batch_loss.backward() #for name, param in model_waveform.named_parameters(): # if param.requires_grad: # logging.info(f"{name} {param.grad.norm()}") flag = False for name, param in model_waveform.named_parameters(): if param.requires_grad: grad_norm = param.grad.norm() # logging.info(f"{name} {grad_norm}") #if grad_norm >= 1e4 or torch.isnan(grad_norm) or torch.isinf(grad_norm): if torch.isnan(grad_norm) or torch.isinf(grad_norm): flag = True if flag: logging.info("explode grad") optimizer.zero_grad() continue torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10) #for name, param in model_waveform.named_parameters(): # if param.requires_grad: # logging.info(f"{name} {param.grad.norm()}") optimizer.step() with torch.no_grad(): #test = model_waveform.gru.weight_hh_l0.data.clone() #sparsify = lpcnet.Sparsify(2000, 40000, 400, (0.05, 0.05, 0.2)) #t_start, t_end, interval, densities if idx_stage < args.n_stage-1 and iter_idx + 1 == t_starts[idx_stage+1]: idx_stage += 1 if idx_stage > 0: sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage], densities_p=densities[idx_stage-1]) else: sparsify(model_waveform, iter_idx + 1, t_starts[idx_stage], t_ends[idx_stage], args.interval, densities[idx_stage]) #logging.info((test==model_waveform.gru.weight_hh_l0).all()) logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \ f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_idx += 1 if iter_idx % args.save_interval_iter == 0: logging.info('save iter:%d' % (iter_idx)) save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx) iter_count += 1 if iter_idx % args.log_interval_steps == 0: logging.info('smt') for key in total_train_loss.keys(): total_train_loss[key] = np.mean(total_train_loss[key]) logging.info(f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}.") write_to_tensorboard(writer, iter_idx, total_train_loss) total_train_loss = defaultdict(list) total += time.time() - start # save final model model_waveform.cpu() torch.save({"model_waveform": model_waveform.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
optimizer = RAdam(model.parameters(), args.lr) if args.loss == 'mse': loss_fn = torch.nn.MSELoss() else: loss_fn = torch.nn.BCELoss() # laoding checkpoint if args.load_path: files = os.listdir(args.load_path) files = sorted(files, key=lambda x: int(os.path.splitext(x)[0])) last_path = os.path.join(args.load_path, files[-1]) checkpoint = torch.load(last_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) loss = checkpoint['loss'] if args.grid_latent: walk_grid(model) os._exit(0) dataset = CustomDataset(args) dataset_loader = torch.utils.data.DataLoader(dataset=dataset, num_workers=args.num_workers, batch_size=args.batch_size, shuffle=args.shuffle, collate_fn=collate_fn) for epoch in range(1, args.epoch): epoch_loss_rec = [] epoch_loss_kl = []
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('--batch-size-val', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--epochs', type=int, default=1000, metavar='N', help='number of epochs to train (default: 1000)') parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', help='learning rate (default: 1e-4)') parser.add_argument('--image-size', type=float, default=80, metavar='IMSIZE', help='input image size (default: 80)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--multi-gpu', action='store_true', default=False, help='parallel training on multiple GPUs') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--model-save-path', default='', type=str, metavar='PATH', help='For Saving the current Model') parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') args = parser.parse_args() torch.set_default_tensor_type('torch.FloatTensor') device = torch.device("cpu" if args.no_cuda else "cuda") train_data = dataset(args.data, "train", args.image_size, transform=transforms.Compose([ToTensor()]), shuffle=True) valid_data = dataset(args.data, "val", args.image_size, transform=transforms.Compose([ToTensor()])) trainloader = DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=8) validloader = DataLoader(valid_data, batch_size=args.batch_size_val, shuffle=False, num_workers=8) model = Model(args.image_size, args.image_size) optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=1e-8) if not args.no_cuda: model.cuda() if torch.cuda.device_count() > 1 and args.multi_gpu: print("Let's use", torch.cuda.device_count(), "GPUs!") model = torch.nn.DataParallel(model) if args.resume: model.load_state_dict( torch.load(os.path.join(args.resume, model_save_name))) optimizer.load_state_dict( torch.load(os.path.join(args.resume, optimizer_save_name))) train(model, optimizer, trainloader, validloader, device, args)
def main(): global args best_prec1, best_epoch = 0.0, 0 if not os.path.exists(args.save): os.makedirs(args.save) if args.data.startswith('cifar'): IM_SIZE = 32 else: IM_SIZE = 224 model = getattr(models, args.arch)(args) n_flops, n_params = measure_model(model, IM_SIZE, IM_SIZE) torch.save(n_flops, os.path.join(args.save, 'flops.pth')) del(model) model = getattr(models, args.arch)(args) if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() criterion = nn.CrossEntropyLoss().cuda() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'radam': from radam import RAdam optimizer = RAdam(model.parameters(), args.lr, weight_decay=args.weight_decay) else: raise NotImplementedError("Wrong optimizer.") if args.resume: checkpoint = load_checkpoint(args) if checkpoint is not None: args.start_epoch = checkpoint['epoch'] + 1 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) cudnn.benchmark = True train_loader, val_loader, test_loader = get_dataloaders(args) if args.evalmode is not None: state_dict = torch.load(args.evaluate_from)['state_dict'] model.load_state_dict(state_dict) if args.evalmode == 'anytime': validate(test_loader, model, criterion) else: dynamic_evaluate(model, test_loader, val_loader, args) return scores = ['epoch\tlr\ttrain_loss\tval_loss\ttrain_prec1' '\tval_prec1\ttrain_prec5\tval_prec5'] for epoch in range(args.start_epoch, args.epochs): train_loss, train_prec1, train_prec5, lr = train(train_loader, model, criterion, optimizer, epoch) val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion) scores.append(('{}\t{:.3f}' + '\t{:.4f}' * 6) .format(epoch, lr, train_loss, val_loss, train_prec1, val_prec1, train_prec5, val_prec5)) is_best = val_prec1 > best_prec1 if is_best: best_prec1 = val_prec1 best_epoch = epoch print('Best var_prec1 {}'.format(best_prec1)) model_filename = 'checkpoint_%03d.pth.tar' % epoch save_checkpoint({ 'epoch': epoch, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'optimizer': optimizer.state_dict(), }, args, is_best, model_filename, scores) print('Best val_prec1: {:.4f} at epoch {}'.format(best_prec1, best_epoch)) ### Test the final model print('********** Final prediction results **********') validate(test_loader, model, criterion) return
class Trainer(): def __init__(self, log_dir, cfg): self.path = log_dir self.cfg = cfg if cfg.TRAIN.FLAG: self.model_dir = os.path.join(self.path, 'Model') self.log_dir = os.path.join(self.path, 'Log') mkdir_p(self.model_dir) mkdir_p(self.log_dir) self.writer = SummaryWriter(log_dir=self.log_dir) self.logfile = os.path.join(self.path, "logfile.log") sys.stdout = Logger(logfile=self.logfile) self.data_dir = cfg.DATASET.DATA_DIR self.max_epochs = cfg.TRAIN.MAX_EPOCHS self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL s_gpus = cfg.GPU_ID.split(',') self.gpus = [int(ix) for ix in s_gpus] self.num_gpus = len(self.gpus) self.batch_size = cfg.TRAIN.BATCH_SIZE self.lr = cfg.TRAIN.LEARNING_RATE torch.cuda.set_device(self.gpus[0]) cudnn.benchmark = True sample = cfg.SAMPLE self.dataset = [] self.dataloader = [] self.use_feats = cfg.model.use_feats eval_split = cfg.EVAL if cfg.EVAL else 'val' train_split = cfg.DATASET.train_split if cfg.DATASET.DATASET == 'clevr': clevr_collate_fn = collate_fn cogent = cfg.DATASET.COGENT if cogent: print(f'Using CoGenT {cogent.upper()}') if cfg.TRAIN.FLAG: self.dataset = ClevrDataset(data_dir=self.data_dir, split=train_split + cogent, sample=sample, **cfg.DATASET.params) self.dataloader = DataLoader(dataset=self.dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.WORKERS, drop_last=True, collate_fn=clevr_collate_fn) self.dataset_val = ClevrDataset(data_dir=self.data_dir, split=eval_split + cogent, sample=sample, **cfg.DATASET.params) self.dataloader_val = DataLoader(dataset=self.dataset_val, batch_size=cfg.TEST_BATCH_SIZE, drop_last=False, shuffle=False, num_workers=cfg.WORKERS, collate_fn=clevr_collate_fn) elif cfg.DATASET.DATASET == 'gqa': if self.use_feats == 'spatial': gqa_collate_fn = collate_fn_gqa elif self.use_feats == 'objects': gqa_collate_fn = collate_fn_gqa_objs if cfg.TRAIN.FLAG: self.dataset = GQADataset(data_dir=self.data_dir, split=train_split, sample=sample, use_feats=self.use_feats, **cfg.DATASET.params) self.dataloader = DataLoader(dataset=self.dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, num_workers=cfg.WORKERS, drop_last=True, collate_fn=gqa_collate_fn) self.dataset_val = GQADataset(data_dir=self.data_dir, split=eval_split, sample=sample, use_feats=self.use_feats, **cfg.DATASET.params) self.dataloader_val = DataLoader(dataset=self.dataset_val, batch_size=cfg.TEST_BATCH_SIZE, shuffle=False, num_workers=cfg.WORKERS, drop_last=False, collate_fn=gqa_collate_fn) # load model self.vocab = load_vocab(cfg) self.model, self.model_ema = mac.load_MAC(cfg, self.vocab) self.weight_moving_average(alpha=0) if cfg.TRAIN.RADAM: self.optimizer = RAdam(self.model.parameters(), lr=self.lr) else: self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.start_epoch = 0 if cfg.resume_model: location = 'cuda' if cfg.CUDA else 'cpu' state = torch.load(cfg.resume_model, map_location=location) self.model.load_state_dict(state['model']) self.optimizer.load_state_dict(state['optim']) self.start_epoch = state['iter'] + 1 state = torch.load(cfg.resume_model_ema, map_location=location) self.model_ema.load_state_dict(state['model']) if cfg.start_epoch is not None: self.start_epoch = cfg.start_epoch self.previous_best_acc = 0.0 self.previous_best_epoch = 0 self.previous_best_loss = 100 self.previous_best_loss_epoch = 0 self.total_epoch_loss = 0 self.prior_epoch_loss = 10 self.print_info() self.loss_fn = torch.nn.CrossEntropyLoss().cuda() self.comet_exp = Experiment( project_name=cfg.COMET_PROJECT_NAME, api_key=os.getenv('COMET_API_KEY'), workspace=os.getenv('COMET_WORKSPACE'), disabled=cfg.logcomet is False, ) if cfg.logcomet: exp_name = cfg_to_exp_name(cfg) print(exp_name) self.comet_exp.set_name(exp_name) self.comet_exp.log_parameters(flatten_json_iterative_solution(cfg)) self.comet_exp.log_asset(self.logfile) self.comet_exp.log_asset_data(json.dumps(cfg, indent=4), file_name='cfg.json') self.comet_exp.set_model_graph(str(self.model)) if cfg.cfg_file: self.comet_exp.log_asset(cfg.cfg_file) with open(os.path.join(self.path, 'cfg.json'), 'w') as f: json.dump(cfg, f, indent=4) def print_info(self): print('Using config:') pprint.pprint(self.cfg) print("\n") pprint.pprint("Size of train dataset: {}".format(len(self.dataset))) # print("\n") pprint.pprint("Size of val dataset: {}".format(len(self.dataset_val))) print("\n") print("Using MAC-Model:") pprint.pprint(self.model) print("\n") def weight_moving_average(self, alpha=0.999): for param1, param2 in zip(self.model_ema.parameters(), self.model.parameters()): param1.data *= alpha param1.data += (1.0 - alpha) * param2.data def set_mode(self, mode="train"): if mode == "train": self.model.train() self.model_ema.train() else: self.model.eval() self.model_ema.eval() def reduce_lr(self): epoch_loss = self.total_epoch_loss # / float(len(self.dataset) // self.batch_size) lossDiff = self.prior_epoch_loss - epoch_loss if ((lossDiff < 0.015 and self.prior_epoch_loss < 0.5 and self.lr > 0.00002) or \ (lossDiff < 0.008 and self.prior_epoch_loss < 0.15 and self.lr > 0.00001) or \ (lossDiff < 0.003 and self.prior_epoch_loss < 0.10 and self.lr > 0.000005)): self.lr *= 0.5 print("Reduced learning rate to {}".format(self.lr)) for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr self.prior_epoch_loss = epoch_loss self.total_epoch_loss = 0 def save_models(self, iteration): save_model(self.model, self.optimizer, iteration, self.model_dir, model_name="model") save_model(self.model_ema, None, iteration, self.model_dir, model_name="model_ema") def train_epoch(self, epoch): cfg = self.cfg total_loss = 0. total_correct = 0 total_samples = 0 self.labeled_data = iter(self.dataloader) self.set_mode("train") dataset = tqdm(self.labeled_data, total=len(self.dataloader), ncols=20) for data in dataset: ###################################################### # (1) Prepare training data ###################################################### image, question, question_len, answer = data['image'], data[ 'question'], data['question_length'], data['answer'] answer = answer.long() question = Variable(question) answer = Variable(answer) if cfg.CUDA: if self.use_feats == 'spatial': image = image.cuda() elif self.use_feats == 'objects': image = [e.cuda() for e in image] question = question.cuda() answer = answer.cuda().squeeze() else: question = question image = image answer = answer.squeeze() ############################ # (2) Train Model ############################ self.optimizer.zero_grad() scores = self.model(image, question, question_len) loss = self.loss_fn(scores, answer) loss.backward() if self.cfg.TRAIN.CLIP_GRADS: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.TRAIN.CLIP) self.optimizer.step() self.weight_moving_average() ############################ # (3) Log Progress ############################ correct = scores.detach().argmax(1) == answer total_correct += correct.sum().cpu().item() total_loss += loss.item() * answer.size(0) total_samples += answer.size(0) avg_loss = total_loss / total_samples train_accuracy = total_correct / total_samples # accuracy = correct.sum().cpu().numpy() / answer.shape[0] # if avg_loss == 0: # avg_loss = loss.item() # train_accuracy = accuracy # else: # avg_loss = 0.99 * avg_loss + 0.01 * loss.item() # train_accuracy = 0.99 * train_accuracy + 0.01 * accuracy # self.total_epoch_loss += loss.item() * answer.size(0) dataset.set_description( 'Epoch: {}; Avg Loss: {:.5f}; Avg Train Acc: {:.5f}'.format( epoch + 1, avg_loss, train_accuracy)) self.total_epoch_loss = avg_loss dict = { "loss": avg_loss, "accuracy": train_accuracy, "avg_loss": avg_loss, # For commet "avg_accuracy": train_accuracy, # For commet } return dict def train(self): cfg = self.cfg print("Start Training") for epoch in range(self.start_epoch, self.max_epochs): with self.comet_exp.train(): dict = self.train_epoch(epoch) self.reduce_lr() dict['epoch'] = epoch + 1 dict['lr'] = self.lr self.comet_exp.log_metrics( dict, epoch=epoch + 1, ) with self.comet_exp.validate(): dict = self.log_results(epoch, dict) dict['epoch'] = epoch + 1 dict['lr'] = self.lr self.comet_exp.log_metrics( dict, epoch=epoch + 1, ) if cfg.TRAIN.EALRY_STOPPING: if epoch - cfg.TRAIN.PATIENCE == self.previous_best_epoch: # if epoch - cfg.TRAIN.PATIENCE == self.previous_best_loss_epoch: print('Early stop') break self.comet_exp.log_asset(self.logfile) self.save_models(self.max_epochs) self.writer.close() print("Finished Training") print( f"Highest validation accuracy: {self.previous_best_acc} at epoch {self.previous_best_epoch}" ) def log_results(self, epoch, dict, max_eval_samples=None): epoch += 1 self.writer.add_scalar("avg_loss", dict["loss"], epoch) self.writer.add_scalar("train_accuracy", dict["accuracy"], epoch) metrics = self.calc_accuracy("validation", max_samples=max_eval_samples) self.writer.add_scalar("val_accuracy_ema", metrics['acc_ema'], epoch) self.writer.add_scalar("val_accuracy", metrics['acc'], epoch) self.writer.add_scalar("val_loss_ema", metrics['loss_ema'], epoch) self.writer.add_scalar("val_loss", metrics['loss'], epoch) print( "Epoch: {epoch}\tVal Acc: {acc},\tVal Acc EMA: {acc_ema},\tAvg Loss: {loss},\tAvg Loss EMA: {loss_ema},\tLR: {lr}" .format(epoch=epoch, lr=self.lr, **metrics)) if metrics['acc'] > self.previous_best_acc: self.previous_best_acc = metrics['acc'] self.previous_best_epoch = epoch if metrics['loss'] < self.previous_best_loss: self.previous_best_loss = metrics['loss'] self.previous_best_loss_epoch = epoch if epoch % self.snapshot_interval == 0: self.save_models(epoch) return metrics def calc_accuracy(self, mode="train", max_samples=None): self.set_mode("validation") if mode == "train": loader = self.dataloader # elif (mode == "validation") or (mode == 'test'): # loader = self.dataloader_val else: loader = self.dataloader_val total_correct = 0 total_correct_ema = 0 total_samples = 0 total_loss = 0. total_loss_ema = 0. pbar = tqdm(loader, total=len(loader), desc=mode.upper(), ncols=20) for data in pbar: image, question, question_len, answer = data['image'], data[ 'question'], data['question_length'], data['answer'] answer = answer.long() question = Variable(question) answer = Variable(answer) if self.cfg.CUDA: if self.use_feats == 'spatial': image = image.cuda() elif self.use_feats == 'objects': image = [e.cuda() for e in image] question = question.cuda() answer = answer.cuda().squeeze() with torch.no_grad(): scores = self.model(image, question, question_len) scores_ema = self.model_ema(image, question, question_len) loss = self.loss_fn(scores, answer) loss_ema = self.loss_fn(scores_ema, answer) correct = scores.detach().argmax(1) == answer correct_ema = scores_ema.detach().argmax(1) == answer total_correct += correct.sum().cpu().item() total_correct_ema += correct_ema.sum().cpu().item() total_loss += loss.item() * answer.size(0) total_loss_ema += loss_ema.item() * answer.size(0) total_samples += answer.size(0) avg_acc = total_correct / total_samples avg_acc_ema = total_correct_ema / total_samples avg_loss = total_loss / total_samples avg_loss_ema = total_loss_ema / total_samples pbar.set_postfix({ 'Acc': f'{avg_acc:.5f}', 'Acc Ema': f'{avg_acc_ema:.5f}', 'Loss': f'{avg_loss:.5f}', 'Loss Ema': f'{avg_loss_ema:.5f}', }) return dict(acc=avg_acc, acc_ema=avg_acc_ema, loss=avg_loss, loss_ema=avg_loss_ema)
def train(opt): """ dataset preparation """ if not opt.data_filtering_off: print('Filtering the images containing characters which are not in opt.character') print('Filtering the images whose label is longer than opt.batch_max_length') # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 opt.select_data = opt.select_data.split('-') opt.batch_ratio = opt.batch_ratio.split('-') train_dataset = Batch_Balanced_Dataset(opt) log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a') AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) valid_loader = torch.utils.data.DataLoader( valid_dataset, batch_size=opt.batch_size, shuffle=True, # 'True' to check training progress with validation function. num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True) log.write(valid_dataset_log) print('-' * 80) log.write('-' * 80 + '\n') log.close() """ model configuration """ if 'CTC' in opt.Prediction: converter = CTCLabelConverter(opt.character) elif opt.Prediction == 'None': converter = TransformerConverter(opt.character) else: converter = AttnLabelConverter(opt.character) opt.num_class = len(converter.character) if opt.rgb: opt.input_channel = 3 model = Model(opt) print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction) # weight initialization for name, param in model.named_parameters(): if 'localization_fc2' in name: print(f'Skip {name} as it is already initialized') continue try: if 'bias' in name: init.constant_(param, 0.0) elif 'weight' in name: init.kaiming_normal_(param) except Exception as e: # for batchnorm. if 'weight' in name: param.data.fill_(1) continue # data parallel for multi-GPU # model = torch.nn.DataParallel(model).to(device) model = model.to(device) model.train() if opt.load_from_checkpoint: model.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'checkpoint.pth'))) print(f'loaded checkpoint from {opt.load_from_checkpoint}...') elif opt.saved_model != '': print(f'loading pretrained model from {opt.saved_model}') if opt.SequenceModeling == 'Transformer': fe_state = OrderedDict() state_dict = torch.load(opt.saved_model) for k, v in state_dict.items(): if k.startswith('module.FeatureExtraction'): new_k = re.sub('module.FeatureExtraction.', '', k) fe_state[new_k] = state_dict[k] model.FeatureExtraction.load_state_dict(fe_state) else: if opt.FT: model.load_state_dict(torch.load(opt.saved_model), strict=False) else: model.load_state_dict(torch.load(opt.saved_model)) if opt.freeze_fe: model.freeze(['FeatureExtraction']) print("Model:") print(model) """ setup loss """ if 'CTC' in opt.Prediction: criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) elif opt.Prediction == 'None': criterion = LabelSmoothingLoss(classes=converter.n_classes, padding_idx=converter.pad_idx, smoothing=0.1) # criterion = torch.nn.CrossEntropyLoss(ignore_index=converter.pad_idx) else: criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 # loss averager loss_avg = Averager() # filter that only require gradient decent filtered_parameters = [] params_num = [] for p in filter(lambda p: p.requires_grad, model.parameters()): filtered_parameters.append(p) params_num.append(np.prod(p.size())) print('Trainable params num : ', sum(params_num)) # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] # setup optimizer if opt.adam: assert opt.adam in ['Adam', 'AdamW', 'RAdam'], 'adam optimizer must be in Adam, AdamW or RAdam' if opt.adam == 'Adam': optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) elif opt.adam == "AdamW": optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = RAdam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) else: optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) print("Optimizer:") print(optimizer) if opt.load_from_checkpoint and opt.load_optimizer_state: optimizer.load_state_dict(torch.load(os.path.join(opt.load_from_checkpoint, 'optimizer.pth'))) print(f'loaded optimizer state from {os.path.join(opt.load_from_checkpoint, "optimizer.pth")}') """ final options """ # print(opt) with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: opt_log = '------------ Options -------------\n' args = vars(opt) for k, v in args.items(): opt_log += f'{str(k)}: {str(v)}\n' opt_log += '---------------------------------------\n' print(opt_log) opt_file.write(opt_log) """ start training """ start_iter = 0 if opt.saved_model != '': try: start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) print(f'continue to train, start_iter: {start_iter}') except: pass if opt.load_from_checkpoint: with open(os.path.join(opt.load_from_checkpoint, 'iter.json'), mode='r', encoding='utf8') as f: start_iter = json.load(f) print(f'continue to train, start_iter: {start_iter}') f.close() start_time = time.time() best_accuracy = -1 best_norm_ED = -1 # i = start_iter bar = tqdm(range(start_iter, opt.num_iter)) # while(True): for i in bar: bar.set_description(f'Iter {i}: train_loss = {loss_avg.val():.5f}') # train part image_tensors, labels = train_dataset.get_batch() image = image_tensors.to(device) text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) batch_size = image.size(0) if 'CTC' in opt.Prediction: preds = model(image, text).log_softmax(2) preds_size = torch.IntTensor([preds.size(1)] * batch_size) preds = preds.permute(1, 0, 2) # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss # https://github.com/jpuigcerver/PyLaia/issues/16 torch.backends.cudnn.enabled = False cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) torch.backends.cudnn.enabled = True # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a). # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0. # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0. # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707 # cost = criterion(preds, text, preds_size, length) elif opt.Prediction == 'None': tgt_input = text['tgt_input'] tgt_output = text['tgt_output'] tgt_padding_mask = text['tgt_padding_mask'] preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,) cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1)) else: preds = model(image, text[:, :-1]) # align with Attention.forward target = text[:, 1:] # without [GO] Symbol cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) model.zero_grad() cost.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) optimizer.step() loss_avg.add(cost) # validation part if (i + 1) % opt.valInterval == 0: elapsed_time = time.time() - start_time # for log with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: model.eval() with torch.no_grad(): valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( model, criterion, valid_loader, converter, opt) model.train() # training loss and validation loss loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' loss_avg.reset() current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' # keep best accuracy model (on valid dataset) if current_accuracy > best_accuracy: best_accuracy = current_accuracy torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') if current_norm_ED > best_norm_ED: best_norm_ED = current_norm_ED torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') # checkpoint os.makedirs(f'./checkpoints/{opt.experiment_name}/', exist_ok=True) torch.save(model.state_dict(), f'./checkpoints/{opt.experiment_name}/checkpoint.pth') torch.save(optimizer.state_dict(), f'./checkpoints/{opt.experiment_name}/optimizer.pth') with open(f'./checkpoints/{opt.experiment_name}/iter.json', mode='w', encoding='utf8') as f: json.dump(i + 1, f) f.close() with open(f'./checkpoints/{opt.experiment_name}/checkpoint.log', mode='a', encoding='utf8') as f: f.write(f'Saved checkpoint with iter={i}\n') f.write(f'\tCheckpoint at: ./checkpoints/{opt.experiment_name}/checkpoint.pth') f.write(f'\tOptimizer at: ./checkpoints/{opt.experiment_name}/optimizer.pth') best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' print(loss_model_log) log.write(loss_model_log + '\n') # show some predicted results dashed_line = '-' * 80 head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): if 'Attn' in opt.Prediction: gt = gt[:gt.find('[s]')] pred = pred[:pred.find('[s]')] predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' predicted_result_log += f'{dashed_line}' print(predicted_result_log) log.write(predicted_result_log + '\n') # save model per 1e+5 iter. if (i + 1) % 1e+5 == 0: torch.save( model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') # if i == opt.num_iter: # print('end the training') # sys.exit() # i += 1 # if i == 1: break print('end training')
class face_learner(object): def __init__(self, conf): print(conf) self.model = ResNet() self.model.cuda() if conf.initial: self.model.load_state_dict(torch.load("models/"+conf.model)) print('Load model_ir_se101.pth') self.milestones = conf.milestones self.loader, self.class_num = get_train_loader(conf) self.total_class = 16520 self.data_num = 285356 self.writer = SummaryWriter(conf.log_path) self.step = 0 self.paras_only_bn, self.paras_wo_bn = separate_bn_paras(self.model) if conf.meta: self.head = Arcface(embedding_size=conf.embedding_size, classnum=self.total_class) self.head.cuda() if conf.initial: self.head.load_state_dict(torch.load("models/head_op.pth")) print('Load head_op.pth') self.optimizer = RAdam([ {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr) self.meta_optimizer = RAdam([ {'params': self.paras_wo_bn + [self.head.kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr) self.head.train() else: self.head = dict() self.optimizer = dict() for race in races: self.head[race] = Arcface(embedding_size=conf.embedding_size, classnum=self.class_num[race]) self.head[race].cuda() if conf.initial: self.head[race].load_state_dict(torch.load("models/head_op_{}.pth".format(race))) print('Load head_op_{}.pth'.format(race)) self.optimizer[race] = RAdam([ {'params': self.paras_wo_bn + [self.head[race].kernel], 'weight_decay': 5e-4}, {'params': self.paras_only_bn} ], lr=conf.lr, betas=(0.5, 0.999)) self.head[race].train() # self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=40, verbose=True) self.board_loss_every = min(len(self.loader[race]) for race in races) // 10 self.evaluate_every = self.data_num // 5 self.save_every = self.data_num // 2 self.eval, self.eval_issame = get_val_data(conf) def save_state(self, conf, accuracy, extra=None, model_only=False, race='All'): save_path = 'models/' torch.save( self.model.state_dict(), save_path + 'model_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) if not model_only: if conf.meta: torch.save( self.head.state_dict(), save_path + 'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) #torch.save( # self.optimizer.state_dict(), save_path + # 'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, # self.step, extra, race)) else: torch.save( self.head[race].state_dict(), save_path + 'head_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), accuracy, self.step, extra, race)) #torch.save( # self.optimizer[race].state_dict(), save_path + # 'optimizer_{}_accuracy-{}_step-{}_{}_{}.pth'.format(get_time(), # accuracy, # self.step, extra, # race)) def load_state(self, conf, fixed_str, model_only=False): save_path = 'models/' self.model.load_state_dict(torch.load(save_path + conf.model)) if not model_only: self.head.load_state_dict(torch.load(save_path + conf.head)) self.optimizer.load_state_dict(torch.load(save_path + conf.optim)) def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor): self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy, self.step) self.writer.add_scalar('{}_best_threshold'.format(db_name), best_threshold, self.step) self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor, self.step) # self.writer.add_scalar('{}_val:true accept ratio'.format(db_name), val, self.step) # self.writer.add_scalar('{}_val_std'.format(db_name), val_std, self.step) # self.writer.add_scalar('{}_far:False Acceptance Ratio'.format(db_name), far, self.step) def evaluate(self, conf, carray, issame, nrof_folds=5, tta=False): self.model.eval() idx = 0 entry_num = carray.size()[0] embeddings = np.zeros([entry_num, conf.embedding_size]) with torch.no_grad(): while idx + conf.batch_size <= entry_num: batch = carray[idx:idx + conf.batch_size] if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda()) embeddings[idx:idx + conf.batch_size] = l2_norm(emb_batch).cpu().detach().numpy() else: embeddings[idx:idx + conf.batch_size] = self.model(batch.cuda()).cpu().detach().numpy() idx += conf.batch_size if idx < entry_num: batch = carray[idx:] if tta: fliped = hflip_batch(batch) emb_batch = self.model(batch.cuda()) + self.model(fliped.cuda()) embeddings[idx:] = l2_norm(emb_batch).cpu().detach().numpy() else: embeddings[idx:] = self.model(batch.cuda()).cpu().detach().numpy() tpr, fpr, accuracy, best_thresholds = evaluate(embeddings, issame, nrof_folds) buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) roc_curve_tensor = trans.ToTensor()(roc_curve) return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor def train_finetuning(self, conf, epochs, race): self.model.train() running_loss = 0. for e in range(epochs): print('epoch {} started'.format(e)) ''' if e == self.milestones[0]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 if e == self.milestones[1]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 if e == self.milestones[2]: for ra in races: for params in self.optimizer[ra].param_groups: params['lr'] /= 10 ''' for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() self.optimizer[race].zero_grad() embeddings = self.model(imgs) thetas = self.head[race](embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head[race].parameters(), conf.max_grad_norm) self.optimizer[race].step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % (1 * len(self.loader[race])) == 0 and self.step != 0: self.save_state(conf, 'None', race=race, model_only=True) self.step += 1 self.save_state(conf, 'None', extra='final', race=race) torch.save(self.optimizer[race].state_dict(), 'models/optimizer_{}.pth'.format(race)) def train_maml(self, conf, epochs): self.model.train() running_loss = 0. loader_iter = dict() for race in races: loader_iter[race] = iter(self.loader[race]) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for i in tqdm(range(self.data_num // conf.batch_size)): ra1, ra2 = random.sample(races, 2) try: imgs1, labels1 = loader_iter[ra1].next() except StopIteration: loader_iter[ra1] = iter(self.loader[ra1]) imgs1, labels1 = loader_iter[ra1].next() try: imgs2, labels2 = loader_iter[ra2].next() except StopIteration: loader_iter[ra2] = iter(self.loader[ra2]) imgs2, labels2 = loader_iter[ra2].next() ## save original weights to make the update weights_original_model = deepcopy(self.model.state_dict()) weights_original_head = deepcopy(self.head.state_dict()) # base learn imgs1 = imgs1.cuda() labels1 = labels1.cuda() self.optimizer.zero_grad() embeddings1 = self.model(imgs1) thetas1 = self.head(embeddings1, labels1) loss1 = conf.ce_loss(thetas1, labels1) loss1.backward() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm) self.optimizer.step() # meta learn imgs2 = imgs2.cuda() labels2 = labels2.cuda() embeddings2 = self.model(imgs2) thetas2 = self.head(embeddings2, labels2) self.model.load_state_dict(weights_original_model) self.head.load_state_dict(weights_original_head) self.meta_optimizer.zero_grad() loss2 = conf.ce_loss(thetas2, labels2) loss2.backward() nn.utils.clip_grad_norm_(self.model.parameters(), conf.max_grad_norm) nn.utils.clip_grad_norm_(self.head.parameters(), conf.max_grad_norm) self.meta_optimizer.step() running_loss += loss2.item() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. if self.step % self.evaluate_every == 0 and self.step != 0: for race in races: accuracy, best_threshold, roc_curve_tensor = self.evaluate(conf, self.eval[race], self.eval_issame[race]) self.board_val(race, accuracy, best_threshold, roc_curve_tensor) self.model.train() if self.step % (self.data_num // conf.batch_size // 2) == 0 and self.step != 0: self.save_state(conf, e) self.step += 1 self.save_state(conf, epochs, extra='final') def train_meta_head(self, conf, epochs): self.model.train() running_loss = 0. optimizer = optim.SGD(self.head.parameters(), lr=conf.lr, momentum=conf.momentum) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for race in races: for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head(embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. self.step += 1 torch.save(self.head.state_dict(), 'models/head_{}_meta_{}.pth'.format(get_time(), e)) def train_race_head(self, conf, epochs, race): self.model.train() running_loss = 0. optimizer = optim.SGD(self.head[race].parameters(), lr=conf.lr, momentum=conf.momentum) for e in range(epochs): print('epoch {} started'.format(e)) if e == self.milestones[0]: self.schedule_lr() if e == self.milestones[1]: self.schedule_lr() if e == self.milestones[2]: self.schedule_lr() for imgs, labels in tqdm(iter(self.loader[race])): imgs = imgs.cuda() labels = labels.cuda() optimizer.zero_grad() embeddings = self.model(imgs) thetas = self.head[race](embeddings, labels) loss = conf.ce_loss(thetas, labels) loss.backward() running_loss += loss.item() optimizer.step() if self.step % self.board_loss_every == 0 and self.step != 0: loss_board = running_loss / self.board_loss_every self.writer.add_scalar('train_loss', loss_board, self.step) running_loss = 0. self.step += 1 torch.save(self.head[race].state_dict(), 'models/head_{}_{}_{}.pth'.format(get_time(), race, epochs)) def schedule_lr(self): for params in self.optimizer.param_groups: params['lr'] /= 10 for params in self.meta_optimizer.param_groups: params['lr'] /= 10 print(self.optimizer, self.meta_optimizer)
def main(): parser = argparse.ArgumentParser() # path setting parser.add_argument("--waveforms", type=str, help="directory or list of wav files") parser.add_argument("--waveforms_eval", type=str, help="directory or list of evaluation wav files") parser.add_argument("--feats", required=True, type=str, help="directory or list of wav files") parser.add_argument("--feats_eval", required=True, type=str, help="directory or list of evaluation feat files") parser.add_argument("--stats", required=True, type=str, help="directory or list of evaluation wav files") parser.add_argument("--expdir", required=True, type=str, help="directory to save the model") # network structure setting parser.add_argument("--upsampling_factor", default=120, type=int, help="number of dimension of aux feats") parser.add_argument("--hid_chn", default=256, type=int, help="depth of dilation") parser.add_argument("--skip_chn", default=256, type=int, help="depth of dilation") parser.add_argument("--dilation_depth", default=3, type=int, help="depth of dilation") parser.add_argument("--dilation_repeat", default=2, type=int, help="depth of dilation") parser.add_argument("--kernel_size", default=7, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--kernel_size_wave", default=7, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--dilation_size_wave", default=1, type=int, help="kernel size of dilated causal convolution") parser.add_argument("--mcep_dim", default=50, type=int, help="kernel size of dilated causal convolution") # network training setting parser.add_argument("--lr", default=1e-4, type=float, help="learning rate") parser.add_argument( "--batch_size", default=30, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument("--epoch_count", default=4000, type=int, help="number of training epochs") parser.add_argument("--do_prob", default=0, type=float, help="dropout probability") parser.add_argument( "--batch_size_utt", default=5, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument( "--batch_size_utt_eval", default=5, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument( "--n_workers", default=2, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument( "--n_quantize", default=256, type=int, help="batch size (if set 0, utterance batch will be used)") parser.add_argument( "--bi_wave", default=True, type=strtobool, help="batch size (if set 0, utterance batch will be used)") parser.add_argument( "--causal_conv_wave", default=False, type=strtobool, help="batch size (if set 0, utterance batch will be used)") # other setting parser.add_argument("--init", default=False, type=strtobool, help="seed number") parser.add_argument("--pad_len", default=3000, type=int, help="seed number") ##parser.add_argument("--save_interval_iter", default=5000, #parser.add_argument("--save_interval_iter", default=3000, # type=int, help="interval steps to logr") parser.add_argument("--save_interval_epoch", default=10, type=int, help="interval steps to logr") parser.add_argument("--log_interval_steps", default=50, type=int, help="interval steps to logr") parser.add_argument("--seed", default=1, type=int, help="seed number") parser.add_argument("--resume", default=None, type=str, help="model path to restart training") parser.add_argument("--pretrained", default=None, type=str, help="model path to restart training") parser.add_argument("--preconf", default=None, type=str, help="model path to restart training") parser.add_argument("--string_path", default=None, type=str, help="model path to restart training") parser.add_argument("--GPU_device", default=None, type=int, help="selection of GPU device") parser.add_argument("--verbose", default=1, type=int, help="log level") args = parser.parse_args() if args.GPU_device is not None: os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_device) # make experimental directory if not os.path.exists(args.expdir): os.makedirs(args.expdir) # set log level if args.verbose == 1: logging.basicConfig( level=logging.INFO, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) elif args.verbose > 1: logging.basicConfig( level=logging.DEBUG, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) else: logging.basicConfig( level=logging.WARN, format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s', datefmt='%m/%d/%Y %I:%M:%S', filename=args.expdir + "/train.log") logging.getLogger().addHandler(logging.StreamHandler()) logging.warn("logging is disabled.") # fix seed os.environ['PYTHONHASHSEED'] = str(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if str(device) == "cpu": raise ValueError('ERROR: Training by CPU is not acceptable.') torch.backends.cudnn.benchmark = True #faster if args.pretrained is None: if 'mel' in args.string_path: mean_stats = torch.FloatTensor(read_hdf5(args.stats, "/mean_melsp")) scale_stats = torch.FloatTensor( read_hdf5(args.stats, "/scale_melsp")) args.excit_dim = 0 else: mean_stats = torch.FloatTensor( read_hdf5(args.stats, "/mean_feat_mceplf0cap")) scale_stats = torch.FloatTensor( read_hdf5(args.stats, "/scale_feat_mceplf0cap")) args.cap_dim = mean_stats.shape[0] - (args.mcep_dim + 3) args.excit_dim = 2 + 1 + args.cap_dim else: config = torch.load(args.preconf) args.excit_dim = config.excit_dim args.cap_dim = config.cap_dim # save args as conf torch.save(args, args.expdir + "/model.conf") # define network model_waveform = DSWNV(n_aux=args.mcep_dim + args.excit_dim, upsampling_factor=args.upsampling_factor, hid_chn=args.hid_chn, skip_chn=args.skip_chn, kernel_size=args.kernel_size, aux_kernel_size=args.kernel_size_wave, aux_dilation_size=args.dilation_size_wave, dilation_depth=args.dilation_depth, dilation_repeat=args.dilation_repeat, n_quantize=args.n_quantize, do_prob=args.do_prob) logging.info(model_waveform) shift_rec_field = model_waveform.receptive_field logging.info(shift_rec_field) if shift_rec_field % args.upsampling_factor > 0: shift_rec_field_frm = shift_rec_field // args.upsampling_factor + 1 else: shift_rec_field_frm = shift_rec_field // args.upsampling_factor shift_rec_field = shift_rec_field_frm * args.upsampling_factor logging.info(shift_rec_field) logging.info(shift_rec_field_frm) criterion_ce = torch.nn.CrossEntropyLoss(reduction='none') criterion_l1 = torch.nn.L1Loss(reduction='none') # send to gpu if torch.cuda.is_available(): model_waveform.cuda() criterion_ce.cuda() criterion_l1.cuda() if args.pretrained is None: mean_stats = mean_stats.cuda() scale_stats = scale_stats.cuda() else: logging.error("gpu is not available. please check the setting.") sys.exit(1) model_waveform.train() if args.pretrained is None: model_waveform.scale_in.weight = torch.nn.Parameter( torch.unsqueeze(torch.diag(1.0 / scale_stats.data), 2)) model_waveform.scale_in.bias = torch.nn.Parameter(-(mean_stats.data / scale_stats.data)) #if args.pretrained is not None: # checkpoint = torch.load(args.pretrained) # #model_waveform.remove_weight_norm() # #model_waveform.load_state_dict(checkpoint["model"]) # model_waveform.load_state_dict(checkpoint["model_waveform"]) # epoch_idx = checkpoint["iterations"] # logging.info("pretrained from %d-iter checkpoint." % epoch_idx) # epoch_idx = 0 # #model_waveform.apply_weight_norm() # #torch.nn.utils.remove_weight_norm(model_waveform.scale_in) for param in model_waveform.parameters(): param.requires_grad = True for param in model_waveform.scale_in.parameters(): param.requires_grad = False parameters = filter(lambda p: p.requires_grad, model_waveform.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1000000 logging.info('Trainable Parameters (waveform): %.3f million' % parameters) module_list = list(model_waveform.conv_aux.parameters()) + list( model_waveform.upsampling.parameters()) if model_waveform.wav_conv_flag: module_list += list(model_waveform.wav_conv.parameters()) module_list += list(model_waveform.causal.parameters()) module_list += list(model_waveform.in_x.parameters()) + list( model_waveform.dil_h.parameters()) module_list += list(model_waveform.out_skip.parameters()) module_list += list(model_waveform.out_1.parameters()) + list( model_waveform.out_2.parameters()) optimizer = RAdam(module_list, lr=args.lr) #optimizer = torch.optim.Adam(module_list, lr=args.lr) # resume if args.pretrained is not None and args.resume is None: checkpoint = torch.load(args.pretrained) model_waveform.load_state_dict(checkpoint["model_waveform"]) # optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("pretrained from %d-iter checkpoint." % epoch_idx) epoch_idx = 0 elif args.resume is not None: #if args.resume is not None: checkpoint = torch.load(args.resume) model_waveform.load_state_dict(checkpoint["model_waveform"]) optimizer.load_state_dict(checkpoint["optimizer"]) epoch_idx = checkpoint["iterations"] logging.info("restored from %d-iter checkpoint." % epoch_idx) # epoch_idx = 2 else: epoch_idx = 0 def zero_wav_pad(x): return padding(x, args.pad_len * args.upsampling_factor, value=0.0) # noqa: E704 def zero_feat_pad(x): return padding(x, args.pad_len, value=0.0) # noqa: E704 pad_wav_transform = transforms.Compose([zero_wav_pad]) pad_feat_transform = transforms.Compose([zero_feat_pad]) wav_transform = transforms.Compose( [lambda x: encode_mu_law(x, args.n_quantize)]) # define generator training if os.path.isdir(args.waveforms): filenames = sorted( find_files(args.waveforms, "*.wav", use_dir_name=False)) wav_list = [args.waveforms + "/" + filename for filename in filenames] elif os.path.isfile(args.waveforms): wav_list = read_txt(args.waveforms) else: logging.error("--waveforms should be directory or list.") sys.exit(1) if os.path.isdir(args.feats): feat_list = [args.feats + "/" + filename for filename in filenames] elif os.path.isfile(args.feats): feat_list = read_txt(args.feats) else: logging.error("--feats should be directory or list.") sys.exit(1) assert len(wav_list) == len(feat_list) logging.info("number of training data = %d." % len(feat_list)) dataset = FeatureDatasetNeuVoco(wav_list, feat_list, pad_wav_transform, pad_feat_transform, args.upsampling_factor, args.string_path, wav_transform=wav_transform) dataloader = DataLoader(dataset, batch_size=args.batch_size_utt, shuffle=True, num_workers=args.n_workers) #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1) generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None) #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=1, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt)) #generator = train_generator(dataloader, device, args.batch_size, args.upsampling_factor, limit_count=None, resume_c_idx=1426, max_c_idx=(len(feat_list)//args.batch_size_utt)) # define generator evaluation if os.path.isdir(args.waveforms_eval): filenames = sorted( find_files(args.waveforms_eval, "*.wav", use_dir_name=False)) wav_list_eval = [ args.waveforms + "/" + filename for filename in filenames ] elif os.path.isfile(args.waveforms_eval): wav_list_eval = read_txt(args.waveforms_eval) else: logging.error("--waveforms_eval should be directory or list.") sys.exit(1) if os.path.isdir(args.feats_eval): feat_list_eval = [ args.feats_eval + "/" + filename for filename in filenames ] elif os.path.isfile(args.feats): feat_list_eval = read_txt(args.feats_eval) else: logging.error("--feats_eval should be directory or list.") sys.exit(1) assert len(wav_list_eval) == len(feat_list_eval) logging.info("number of evaluation data = %d." % len(feat_list_eval)) dataset_eval = FeatureDatasetNeuVoco(wav_list_eval, feat_list_eval, pad_wav_transform, pad_feat_transform, args.upsampling_factor, args.string_path, wav_transform=wav_transform) dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size_utt_eval, shuffle=False, num_workers=args.n_workers) ##generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1) #generator_eval = eval_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None) #generator_eval = train_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=1) generator_eval = train_generator(dataloader_eval, device, args.batch_size, args.upsampling_factor, limit_count=None) writer = SummaryWriter(args.expdir) total_train_loss = defaultdict(list) total_eval_loss = defaultdict(list) # train logging.info(args.string_path) total = 0 iter_count = 0 loss_ce = [] loss_err = [] min_eval_loss_err = 99999999.99 min_eval_loss_err_std = 99999999.99 min_eval_loss_ce = 99999999.99 min_eval_loss_ce_std = 99999999.99 iter_idx = 0 min_idx = -1 #min_eval_loss_ce = 1.575400 #min_eval_loss_ce_std = 0.645726 #iter_idx = 8098898 #min_idx = 68 #resume70 change_min_flag = False if args.resume is not None: np.random.set_state(checkpoint["numpy_random_state"]) torch.set_rng_state(checkpoint["torch_random_state"]) logging.info("==%d EPOCH==" % (epoch_idx + 1)) logging.info("Training data") while epoch_idx < args.epoch_count: start = time.time() batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator) if args.init: c_idx = -1 if c_idx < 0: # summarize epoch if not args.init: # save current epoch model numpy_random_state = np.random.get_state() torch_random_state = torch.get_rng_state() # report current epoch logging.info("(EPOCH:%d) average optimization loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; "\ "(%.3f min., %.3f sec / batch)" % (epoch_idx + 1, np.mean(loss_ce), np.std(loss_ce), \ np.mean(loss_err), np.std(loss_err), total / 60.0, total / iter_count)) logging.info("estimated time until max. epoch = {0.days:02}:{0.hours:02}:{0.minutes:02}:"\ "{0.seconds:02}".format(relativedelta(seconds=int((args.epoch_count - (epoch_idx + 1)) * total)))) # compute loss in evaluation data total = 0 iter_count = 0 loss_ce = [] loss_err = [] model_waveform.eval() for param in model_waveform.parameters(): param.requires_grad = False pair_exist = False logging.info("Evaluation data") while True: with torch.no_grad(): start = time.time() batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator_eval) if c_idx < 0: break x_es = x_ss + x_bs f_es = f_ss + f_bs logging.info( f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}' ) if x_ss > 0: if x_es <= max_slen: batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:x_es - 1] batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:f_es] batch_x = batch_x[:, x_ss:x_es] else: batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:-1] batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:] batch_x = batch_x[:, x_ss:] # assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all()) else: batch_x_prev = F.pad( batch_x[:, :x_es - 1], (model_waveform.receptive_field + 1, 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:, :f_es] batch_x = batch_x[:, :x_es] # assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all()) if x_ss > 0: batch_x_output = model_waveform( batch_feat, batch_x_prev)[:, shift_rec_field:] else: batch_x_output = model_waveform( batch_feat, batch_x_prev, first=True)[:, model_waveform.receptive_field:] # samples check i = np.random.randint(0, batch_x_output.shape[0]) logging.info("%s" % (os.path.join( os.path.basename(os.path.dirname(featfile[i])), os.path.basename(featfile[i])))) #check_samples = batch_x[i,5:10].long() #logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples)) #logging.info(check_samples) # handle short ending batch_loss = 0 if len(idx_select) > 0: logging.info('len_idx_select: ' + str(len(idx_select))) batch_loss_ce = 0 batch_loss_err = 0 for j in range(len(idx_select)): k = idx_select[j] slens_utt = slens_acc[k] logging.info('%s %d' % (featfile[k], slens_utt)) batch_x_output_k = batch_x_output[k, :slens_utt] batch_x_k = batch_x[k, :slens_utt] batch_loss_ce += torch.mean( criterion_ce(batch_x_output_k, batch_x_k)) batch_loss_err += torch.mean( torch.sum( 100 * criterion_l1( F.softmax(batch_x_output_k, dim=-1), F.one_hot(batch_x_k, num_classes=args.n_quantize). float()), -1)) batch_loss += batch_loss_ce batch_loss_ce /= len(idx_select) batch_loss_err /= len(idx_select) total_eval_loss["eval/loss_ce"].append( batch_loss_ce.item()) total_eval_loss["eval/loss_err"].append( batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) if len(idx_select_full) > 0: logging.info('len_idx_select_full: ' + str(len(idx_select_full))) batch_x = torch.index_select( batch_x, 0, idx_select_full) batch_x_output = torch.index_select( batch_x_output, 0, idx_select_full) else: logging.info( "batch eval loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_count += 1 total += time.time() - start continue batch_loss_ce_ = torch.mean( criterion_ce( batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape( batch_x_output.shape[0], -1), -1) batch_loss_err_ = torch.mean( torch.sum( 100 * criterion_l1( F.softmax(batch_x_output, dim=-1), F.one_hot( batch_x, num_classes=args.n_quantize).float()), -1), -1) batch_loss_ce = batch_loss_ce_.mean() batch_loss_err = batch_loss_err_.mean() total_eval_loss["eval/loss_ce"].append( batch_loss_ce.item()) total_eval_loss["eval/loss_err"].append( batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) logging.info("batch eval loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, \ x_ss, x_bs, f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_count += 1 total += time.time() - start logging.info('sme') for key in total_eval_loss.keys(): total_eval_loss[key] = np.mean(total_eval_loss[key]) logging.info( f"(Steps: {iter_idx}) {key} = {total_eval_loss[key]:.4f}.") write_to_tensorboard(writer, iter_idx, total_eval_loss) total_eval_loss = defaultdict(list) eval_loss_ce = np.mean(loss_ce) eval_loss_ce_std = np.std(loss_ce) eval_loss_err = np.mean(loss_err) eval_loss_err_std = np.std(loss_err) logging.info("(EPOCH:%d) average evaluation loss = %.6f (+- %.6f) %.6f (+- %.6f) %% ;; (%.3f min., "\ "%.3f sec / batch)" % (epoch_idx + 1, eval_loss_ce, eval_loss_ce_std, \ eval_loss_err, eval_loss_err_std, total / 60.0, total / iter_count)) if (eval_loss_ce+eval_loss_ce_std) <= (min_eval_loss_ce+min_eval_loss_ce_std) \ or (eval_loss_ce <= min_eval_loss_ce): min_eval_loss_ce = eval_loss_ce min_eval_loss_ce_std = eval_loss_ce_std min_eval_loss_err = eval_loss_err min_eval_loss_err_std = eval_loss_err_std min_idx = epoch_idx change_min_flag = True #else: # epoch_min_flag = False if change_min_flag: logging.info("min_eval_loss = %.6f (+- %.6f) %.6f (+- %.6f) %% min_idx=%d" % (min_eval_loss_ce, \ min_eval_loss_ce_std, min_eval_loss_err, min_eval_loss_err_std, min_idx+1)) #if ((epoch_idx + 1) % args.save_interval_epoch == 0) or (epoch_min_flag): # logging.info('save epoch:%d' % (epoch_idx+1)) # save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1) if args.init: exit() logging.info('save epoch:%d' % (epoch_idx + 1)) save_checkpoint(args.expdir, model_waveform, optimizer, numpy_random_state, torch_random_state, epoch_idx + 1) total = 0 iter_count = 0 loss_ce = [] loss_err = [] epoch_idx += 1 np.random.set_state(numpy_random_state) torch.set_rng_state(torch_random_state) model_waveform.train() for param in model_waveform.parameters(): param.requires_grad = True for param in model_waveform.scale_in.parameters(): param.requires_grad = False # start next epoch if epoch_idx < args.epoch_count: start = time.time() logging.info("==%d EPOCH==" % (epoch_idx + 1)) logging.info("Training data") batch_x, batch_feat, c_idx, utt_idx, featfile, x_bs, f_bs, x_ss, f_ss, n_batch_utt, \ del_index_utt, max_slen, idx_select, idx_select_full, slens_acc = next(generator) # feedforward and backpropagate current batch if epoch_idx < args.epoch_count: logging.info("%d iteration [%d]" % (iter_idx + 1, epoch_idx + 1)) x_es = x_ss + x_bs f_es = f_ss + f_bs logging.info( f'{x_ss} {x_bs} {x_es} {f_ss} {f_bs} {f_es} {max_slen}') if x_ss > 0: if x_es <= max_slen: batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:x_es - 1] batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:f_es] batch_x = batch_x[:, x_ss:x_es] else: batch_x_prev = batch_x[:, x_ss - shift_rec_field - 1:-1] batch_feat = batch_feat[:, f_ss - shift_rec_field_frm:] batch_x = batch_x[:, x_ss:] # assert((batch_x_prev[:,shift_rec_field+1:] == batch_x[:,:-1]).all()) else: batch_x_prev = F.pad(batch_x[:, :x_es - 1], (model_waveform.receptive_field + 1, 0), "constant", args.n_quantize // 2) batch_feat = batch_feat[:, :f_es] batch_x = batch_x[:, :x_es] # assert((batch_x_prev[:,model_waveform.receptive_field+1:] == batch_x[:,:-1]).all()) if x_ss > 0: batch_x_output = model_waveform(batch_feat, batch_x_prev, do=True)[:, shift_rec_field:] else: batch_x_output = model_waveform( batch_feat, batch_x_prev, first=True, do=True)[:, model_waveform.receptive_field:] # samples check i = np.random.randint(0, batch_x_output.shape[0]) logging.info( "%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])), os.path.basename(featfile[i])))) #with torch.no_grad(): # i = np.random.randint(0, batch_x_output.shape[0]) # logging.info("%s" % (os.path.join(os.path.basename(os.path.dirname(featfile[i])),os.path.basename(featfile[i])))) # check_samples = batch_x[i,5:10].long() # logging.info(torch.index_select(F.softmax(batch_x_output[i,5:10], dim=-1), 1, check_samples)) # logging.info(check_samples) # handle short ending batch_loss = 0 if len(idx_select) > 0: logging.info('len_idx_select: ' + str(len(idx_select))) batch_loss_ce = 0 batch_loss_err = 0 for j in range(len(idx_select)): k = idx_select[j] slens_utt = slens_acc[k] logging.info('%s %d' % (featfile[k], slens_utt)) batch_x_output_k = batch_x_output[k, :slens_utt] batch_x_k = batch_x[k, :slens_utt] batch_loss_ce += torch.mean( criterion_ce(batch_x_output_k, batch_x_k)) batch_loss_err += torch.mean( torch.sum( 100 * criterion_l1( F.softmax(batch_x_output_k, dim=-1), F.one_hot( batch_x_k, num_classes=args.n_quantize).float()), -1)) batch_loss += batch_loss_ce batch_loss_ce /= len(idx_select) batch_loss_err /= len(idx_select) total_train_loss["train/loss_ce"].append(batch_loss_ce.item()) total_train_loss["train/loss_err"].append( batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) if len(idx_select_full) > 0: logging.info('len_idx_select_full: ' + str(len(idx_select_full))) batch_x = torch.index_select(batch_x, 0, idx_select_full) batch_x_output = torch.index_select( batch_x_output, 0, idx_select_full) else: optimizer.zero_grad() batch_loss.backward() torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10) optimizer.step() logging.info("batch loss select %.3f %.3f (%.3f sec)" % (batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_idx += 1 #if iter_idx % args.save_interval_iter == 0: # logging.info('save iter:%d' % (iter_idx)) # save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx) iter_count += 1 if iter_idx % args.log_interval_steps == 0: logging.info('smt') for key in total_train_loss.keys(): total_train_loss[key] = np.mean( total_train_loss[key]) logging.info( f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}." ) write_to_tensorboard(writer, iter_idx, total_train_loss) total_train_loss = defaultdict(list) total += time.time() - start continue # loss batch_loss_ce_ = torch.mean( criterion_ce(batch_x_output.reshape(-1, args.n_quantize), batch_x.reshape(-1)).reshape( batch_x_output.shape[0], -1), -1) batch_loss_err_ = torch.mean( torch.sum( 100 * criterion_l1( F.softmax(batch_x_output, dim=-1), F.one_hot(batch_x, num_classes=args.n_quantize).float()), -1), -1) batch_loss_ce = batch_loss_ce_.mean() batch_loss_err = batch_loss_err_.mean() total_train_loss["train/loss_ce"].append(batch_loss_ce.item()) total_train_loss["train/loss_err"].append(batch_loss_err.item()) loss_ce.append(batch_loss_ce.item()) loss_err.append(batch_loss_err.item()) batch_loss += batch_loss_ce_.sum() optimizer.zero_grad() batch_loss.backward() torch.nn.utils.clip_grad_norm_(model_waveform.parameters(), 10) optimizer.step() logging.info("batch loss [%d] %d %d %d %d %d : %.3f %.3f %% (%.3f sec)" % (c_idx+1, max_slen, x_ss, x_bs, \ f_ss, f_bs, batch_loss_ce.item(), batch_loss_err.item(), time.time() - start)) iter_idx += 1 #if iter_idx % args.save_interval_iter == 0: # logging.info('save iter:%d' % (iter_idx)) # save_checkpoint(args.expdir, model_waveform, optimizer, np.random.get_state(), torch.get_rng_state(), iter_idx) iter_count += 1 if iter_idx % args.log_interval_steps == 0: logging.info('smt') for key in total_train_loss.keys(): total_train_loss[key] = np.mean(total_train_loss[key]) logging.info( f"(Steps: {iter_idx}) {key} = {total_train_loss[key]:.4f}." ) write_to_tensorboard(writer, iter_idx, total_train_loss) total_train_loss = defaultdict(list) total += time.time() - start # save final model model_waveform.cpu() torch.save({"model_waveform": model_waveform.state_dict()}, args.expdir + "/checkpoint-final.pkl") logging.info("final checkpoint created.")
# Type = 'best' save_folder = 'model_result_multi_layer' Type = 'trainable' model_check_point = '%s/model_%s_%d.pk' % (save_folder, Type, version_num) optim_check_point = '%s/optim_%s_%d.pkl' % (save_folder, Type, version_num) loss_check_point = '%s/loss_%s_%d.pkl' % (save_folder, Type, version_num) epoch_check_point = '%s/epoch_%s_%d.pkl' % (save_folder, Type, version_num) bleu_check_point = '%s/bleu_%s_%d.pkl' % (save_folder, Type, version_num) loss_values = [] epoch_values = [] bleu_values = [] if os.path.isfile(model_check_point): print('Loading previous status (ver.%d)...' % version_num) model.load_state_dict(torch.load(model_check_point, map_location='cpu')) model = model.to(device) optimizer.load_state_dict(torch.load(optim_check_point)) lr_scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.4, patience=2, min_lr=1e-7, verbose=True) loss_values = torch.load(loss_check_point) epoch_values = torch.load(epoch_check_point) bleu_values = torch.load(bleu_check_point) print('Load successfully') else: print("ver.%d doesn't exist" % version_num) # evaluateAndShowAttention(['現在', '未來', '夢想', '科學', '文化'], method='beam_search', is_sample=True)
class Trainer: def __init__(self, model, train_loader, test_loader, epochs=200, batch_size=60, run_id=0, logs_dir='logs', device='cpu', saturation_device=None, optimizer='None', plot=True, compute_top_k=False, data_prallel=False, conv_method='channelwise', thresh=.99, half_precision=False, downsampling=None): self.saturation_device = device if saturation_device is None else saturation_device self.device = device self.model = model self.epochs = epochs self.plot = plot self.compute_top_k = compute_top_k if 'cuda' in device: cudnn.benchmark = True self.train_loader = train_loader self.test_loader = test_loader self.criterion = nn.CrossEntropyLoss() print('Checking for optimizer for {}'.format(optimizer)) #optimizer = str(optimizer) if optimizer == "adam": print('Using adam') self.optimizer = optim.Adam(model.parameters()) elif optimizer == "adam_lr": print("Using adam with higher learning rate") self.optimizer = optim.Adam(model.parameters(), lr=0.01) elif optimizer == 'adam_lr2': print('Using adam with to large learning rate') self.optimizer = optim.Adam(model.parameters(), lr=0.0001) elif optimizer == "SGD": print('Using SGD') self.optimizer = optim.SGD(model.parameters(), momentum=0.9, weight_decay=5e-4) elif optimizer == "LRS": print('Using LRS') self.optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) self.lr_scheduler = optim.lr_scheduler.StepLR( self.optimizer, self.epochs // 3) elif optimizer == "radam": print('Using radam') self.optimizer = RAdam(model.parameters()) else: raise ValueError('Unknown optimizer {}'.format(optimizer)) self.opt_name = optimizer save_dir = os.path.join(logs_dir, model.name, train_loader.name) if not os.path.exists(save_dir): os.makedirs(save_dir) self.savepath = os.path.join( save_dir, f'{model.name}_bs{batch_size}_e{epochs}_dspl{downsampling}_t{int(thresh*1000)}_id{run_id}.csv' ) self.experiment_done = False if os.path.exists(self.savepath): trained_epochs = len(pd.read_csv(self.savepath, sep=';')) if trained_epochs >= epochs: self.experiment_done = True print( f'Experiment Logs for the exact same experiment with identical run_id was detected, training will be skipped, consider using another run_id' ) if os.path.exists((self.savepath.replace('.csv', '.pt'))): self.model.load_state_dict( torch.load(self.savepath.replace('.csv', '.pt'))['model_state_dict']) if data_prallel: self.model = nn.DataParallel(self.model) self.model = self.model.to(self.device) if half_precision: self.model = self.model.half() self.optimizer.load_state_dict( torch.load(self.savepath.replace('.csv', '.pt'))['optimizer']) self.start_epoch = torch.load(self.savepath.replace( '.csv', '.pt'))['epoch'] + 1 initial_epoch = self._infer_initial_epoch(self.savepath) print('Resuming existing run, starting at epoch', self.start_epoch, 'from', self.savepath.replace('.csv', '.pt')) else: if half_precision: self.model = self.model.half() self.start_epoch = 0 initial_epoch = 0 self.parallel = data_prallel if data_prallel: self.model = nn.DataParallel(self.model) self.model = self.model.to(self.device) writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''), fontsize=16, primary_metric='test_accuracy') writer2 = NPYWriter(self.savepath.replace('.csv', '')) self.pooling_strat = conv_method print('Settomg Satiraton recording threshold to', thresh) self.half = half_precision self.stats = CheckLayerSat(self.savepath.replace('.csv', ''), [writer], model, ignore_layer_names='convolution', stats=['lsat', 'idim'], sat_threshold=.99, verbose=False, conv_method=conv_method, log_interval=1, device=self.saturation_device, reset_covariance=True, max_samples=None, initial_epoch=initial_epoch, interpolation_strategy='nearest' if downsampling is not None else None, interpolation_downsampling=4) def _infer_initial_epoch(self, savepath): if not os.path.exists(savepath): return 0 else: df = pd.read_csv(savepath, sep=';', index_col=0) print(len(df) + 1) return len(df) def train(self): if self.experiment_done: return for epoch in range(self.start_epoch, self.epochs): #self.test(epoch=epoch) print('Start training epoch', epoch) print( "{} Epoch {}, training loss: {}, training accuracy: {}".format( now(), epoch, *self.train_epoch())) self.test(epoch=epoch) if self.opt_name == "LRS": print('LRS step') self.lr_scheduler.step() self.stats.add_saturations() #self.stats.save() #if self.plot: # plot_saturation_level_from_results(self.savepath, epoch) self.stats.close() return self.savepath + '.csv' def train_epoch(self): self.model.train() correct = 0 total = 0 running_loss = 0 old_time = time() top5_accumulator = 0 for batch, data in enumerate(self.train_loader): if batch % 10 == 0 and batch != 0: print( batch, 'of', len(self.train_loader), 'processing time', time() - old_time, "top5_acc:" if self.compute_top_k else 'acc:', round(top5_accumulator / (batch), 3) if self.compute_top_k else correct / total) old_time = time() inputs, labels = data if self.half: inputs, labels = inputs.to(self.device).half(), labels.to( self.device) else: inputs, labels = inputs.to(self.device), labels.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) if self.compute_top_k: top5_accumulator += accuracy(outputs, labels, (5, ))[0] _, predicted = torch.max(outputs.data, 1) total += labels.size(0) loss = self.criterion(outputs, labels) loss.backward() self.optimizer.step() correct += (predicted == labels.long()).sum().item() running_loss += loss.item() self.stats.add_scalar('training_loss', running_loss / total) if self.compute_top_k: self.stats.add_scalar('training_accuracy', (top5_accumulator / (batch + 1))) else: self.stats.add_scalar('training_accuracy', correct / total) return running_loss / total, correct / total def test(self, epoch, save=True): self.model.eval() correct = 0 total = 0 test_loss = 0 top5_accumulator = 0 with torch.no_grad(): for batch, data in enumerate(self.test_loader): if batch % 10 == 0: print('Processing eval batch', batch, 'of', len(self.test_loader)) inputs, labels = data if self.half: inputs, labels = inputs.to(self.device).half(), labels.to( self.device) else: inputs, labels = inputs.to(self.device), labels.to( self.device) outputs = self.model(inputs) loss = self.criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels.long()).sum().item() if self.compute_top_k: top5_accumulator += accuracy(outputs, labels, (5, ))[0] test_loss += loss.item() self.stats.add_scalar('test_loss', test_loss / total) if self.compute_top_k: self.stats.add_scalar('test_accuracy', top5_accumulator / (batch + 1)) print('{} Test Top5-Accuracy on {} images: {:.4f}'.format( now(), total, top5_accumulator / (batch + 1))) else: self.stats.add_scalar('test_accuracy', correct / total) print('{} Test Accuracy on {} images: {:.4f}'.format( now(), total, correct / total)) if save: torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'epoch': epoch, 'test_loss': test_loss / total }, self.savepath.replace('.csv', '.pt')) return correct / total, test_loss / total