def main(network=1): if network == 1: model = Text2Mel().to(DEVICE) elif network == 2: model = SSRN().to(DEVICE) print('Model {} is working...'.format(model.name)) print('{} threads are used...'.format(torch.get_num_threads())) ckpt_dir = os.path.join(args.logdir, model.name) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = StepLR(optimizer, step_size=args.lr_decay_step//10, gamma=0.933) # around 1/2 per decay step if not os.path.exists(ckpt_dir): os.makedirs(os.path.join(ckpt_dir, 'A', 'train')) else: print('Already exists. Retrain the model.') ckpt = pd.read_csv(os.path.join(ckpt_dir, 'ckpt.csv'), sep=',', header=None) ckpt.columns = ['models', 'loss'] ckpt = ckpt.sort_values(by='loss', ascending=True) state = torch.load(os.path.join(ckpt_dir, ckpt.models.loc[0])) model.load_state_dict(state['model']) args.global_step = state['global_step'] optimizer.load_state_dict(state['optimizer']) scheduler.load_state_dict(state['scheduler']) # model = torch.nn.DataParallel(model, device_ids=list(range(args.no_gpu))).to(DEVICE) if model.name == 'Text2Mel': if args.ga_mode: cfn_train, cfn_eval = t2m_ga_collate_fn, t2m_collate_fn else: cfn_train, cfn_eval = t2m_collate_fn, t2m_collate_fn else: cfn_train, cfn_eval = collate_fn, collate_fn dataset = SpeechDataset(args.data_path, args.meta_train, model.name, mem_mode=args.mem_mode, ga_mode=args.ga_mode) validset = SpeechDataset(args.data_path, args.meta_eval, model.name, mem_mode=args.mem_mode) data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, collate_fn=cfn_train, drop_last=True, pin_memory=True) valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch, shuffle=False, collate_fn=cfn_eval, pin_memory=True) writer = SummaryWriter(ckpt_dir) train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer) return None
def main(): testset = TextDataset(args.testset) test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False, shuffle=False, collate_fn=synth_collate_fn, pin_memory=True) t2m = Text2Mel().to(DEVICE) ssrn = SSRN().to(DEVICE) mname = type(t2m).__name__ ckpt = sorted( glob.glob(os.path.join(args.logdir, mname, '{}-*k.pth'.format(mname)))) state = torch.load(ckpt[-1]) t2m.load_state_dict(state['model']) args.global_step = state['global_step'] mname = type(ssrn).__name__ ckpt = sorted( glob.glob(os.path.join(args.logdir, mname, '{}-*k.pth'.format(mname)))) state = torch.load(ckpt[-1]) ssrn.load_state_dict(state['model']) print('All of models are loaded.') t2m.eval() ssrn.eval() if not os.path.exists(os.path.join(args.sampledir, 'A')): os.makedirs(os.path.join(args.sampledir, 'A')) synthesize(t2m, ssrn, test_loader, args.test_batch)
def main(): testset = TextDataset(args.testset) test_loader = DataLoader(dataset=testset, batch_size=args.test_batch, drop_last=False, shuffle=False, collate_fn=synth_collate_fn, pin_memory=True) t2m = Text2Mel().to(DEVICE) ssrn = SSRN().to(DEVICE) ckpt = pd.read_csv(os.path.join(args.logdir, t2m.name, 'ckpt.csv'), sep=',', header=None) ckpt.columns = ['models', 'loss'] ckpt = ckpt.sort_values(by='loss', ascending=True) state = torch.load(os.path.join(args.logdir, t2m.name, ckpt.models.loc[0])) t2m.load_state_dict(state['model']) args.global_step = state['global_step'] ckpt = pd.read_csv(os.path.join(args.logdir, ssrn.name, 'ckpt.csv'), sep=',', header=None) ckpt.columns = ['models', 'loss'] ckpt = ckpt.sort_values(by='loss', ascending=True) state = torch.load(os.path.join(args.logdir, ssrn.name, ckpt.models.loc[0])) ssrn.load_state_dict(state['model']) print('All of models are loaded.') t2m.eval() ssrn.eval() if not os.path.exists(os.path.join(args.sampledir, 'A')): os.makedirs(os.path.join(args.sampledir, 'A')) return synthesize(t2m=t2m, ssrn=ssrn, data_loader=test_loader, batch_size=args.test_batch)
def main(mode): t2m = Text2Mel().to(DEVICE) ssrn = SSRN().to(DEVICE) if mode == "train": dataset = SpeechDataset(args.data_path, args.meta_train, "Text2Mel", mem_mode=args.mem_mode) elif mode=="test": dataset = SpeechDataset(args.data_path, args.meta_test, "Text2Mel", mem_mode=args.mem_mode) elif mode=="eval": dataset = SpeechDataset(args.data_path, args.meta_eval, "Text2Mel", mem_mode=args.mem_mode) else: print('[ERROR] Please set correct type: TRAIN or TEST!' ) exit(0) data_loader = DataLoader(dataset=dataset, batch_size=args.mse_batch, shuffle=False, collate_fn=t2m_collate_fn, pin_memory=True) ckpt = pd.read_csv(os.path.join(args.logdir, t2m.name, 'ckpt.csv'), sep=',', header=None) ckpt.columns = ['models', 'loss'] ckpt = ckpt.sort_values(by='loss', ascending=True) state = torch.load(os.path.join(args.logdir, t2m.name, ckpt.models.loc[0])) t2m.load_state_dict(state['model']) args.global_step = state['global_step'] ckpt = pd.read_csv(os.path.join(args.logdir, ssrn.name, 'ckpt.csv'), sep=',', header=None) ckpt.columns = ['models', 'loss'] ckpt = ckpt.sort_values(by='loss', ascending=True) state = torch.load(os.path.join(args.logdir, ssrn.name, ckpt.models.loc[0])) ssrn.load_state_dict(state['model']) print('All of models are loaded.') t2m.eval() ssrn.eval() if not os.path.exists(os.path.join(args.sampledir, 'A')): os.makedirs(os.path.join(args.sampledir, 'A')) return calculate_MSE(t2m=t2m, ssrn=ssrn, data_loader=data_loader, batch_size=args.mse_batch)