def main(): start_epoch = 1 model = Loop(args) model.cuda() if args.checkpoint != '': checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) start_epoch = checkpoint_args[3] model.load_state_dict(torch.load(args.checkpoint)) criterion = MaskedMSE().cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) # Keep track of losses train_losses = [] eval_losses = [] best_eval = float('inf') # Begin! for epoch in range(start_epoch, start_epoch + args.epochs): train(model, criterion, optimizer, epoch, train_losses) eval_loss = evaluate(model, criterion, epoch, eval_losses) if eval_loss < best_eval: torch.save(model.state_dict(), '%s/bestmodel.pth' % (args.expName)) best_eval = eval_loss torch.save(model.state_dict(), '%s/lastmodel.pth' % (args.expName)) torch.save([args, train_losses, eval_losses, epoch], '%s/args.pth' % (args.expName))
def test_model(): hp = Hparams() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Loop(hp, device) optim = torch.optim.Adam(model.parameters(), lr=1e-4) print("model has {} million parameters".format(model.count_parameters())) dataset = VCTKDataSet("data/vctk/numpy_features_valid/") loader = DataLoader(dataset, shuffle=False, batch_size=10, drop_last=False, collate_fn = my_collate_fn) for data in tqdm(loader): text, text_list, target, target_list, spkr = data loss = model.compute_loss_batch((text, text_list), spkr, (target, target_list)) print(loss.detach().cpu().numpy())
def main(): args = init() checkpoint = args.checkpoint checkpoint_args_path = os.path.dirname(checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) opt = torch.load(os.path.dirname(checkpoint) + '/args.pth') train_args = opt[0] train_args.noise = 0 train_args.checkpoint = checkpoint args_to_use = args args_to_use = train_args print args_to_use model = Loop(args_to_use) model.cuda() model.load_state_dict( torch.load(args_to_use.checkpoint, map_location=lambda storage, loc: storage)) criterion = MaskedMSE().cuda() loader = get_loader(args.data, args.max_seq_len, args.batch_size, args.nspk) eval_loss = evaluate(model, loader, criterion) print eval_loss
def model_def(checkpoint, gpu=-1, valid_loader=None): weights = torch.load(checkpoint, map_location=lambda storage, loc: storage) opt = torch.load(os.path.dirname(checkpoint) + '/args.pth') train_args = opt[0] train_args.noise = 0 #norm = opt[5] #dict = {v: k for k, v in enumerate(code2phone)} norm = np.load(valid_loader.dataset.npzs[0])['audio_norminfo'] model = Loop(train_args) model.load_state_dict(weights) if gpu >= 0: model.cuda() model.eval() return model, norm
def eval_loss(checkpoint='models/vctk/bestmodel.pth', data='data/vctk', max_seq_len=1000, nspk=22, gpu=0, batch_size=64, seed=1): #args = init() torch.cuda.set_device(gpu) torch.manual_seed(seed) torch.cuda.manual_seed(seed) print checkpoint print os.getcwd() checkpoint_args_path = os.path.dirname(checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) opt = torch.load(os.path.dirname(checkpoint) + '/args.pth') train_args = opt[0] train_args.noise = 0 train_args.checkpoint = checkpoint #args_to_use = args args_to_use = train_args print args_to_use model = Loop(args_to_use) model.cuda() model.load_state_dict( torch.load(args_to_use.checkpoint, map_location=lambda storage, loc: storage)) criterion = MaskedMSE().cuda() loader = get_loader(data, max_seq_len, batch_size, nspk) eval_loss, my_eval_loss, loss_workings = evaluate(model, loader, criterion) print eval_loss print my_eval_loss return eval_loss, loss_workings
def main(): weights = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) opt = torch.load(os.path.dirname(args.checkpoint) + '/args.pth') train_args = opt[0] char2code = {'aa': 0, 'ae': 1, 'ah': 2, 'ao': 3, 'aw': 4, 'ax': 5, 'ay': 6, 'b': 7, 'ch': 8, 'd': 9, 'dh': 10, 'eh': 11, 'er': 12, 'ey': 13, 'f': 14, 'g': 15, 'hh': 16, 'i': 17, 'ih': 18, 'iy': 19, 'jh': 20, 'k': 21, 'l': 22, 'm': 23, 'n': 24, 'ng': 25, 'ow': 26, 'oy': 27, 'p': 28, 'pau': 29, 'r': 30, 's': 31, 'sh': 32, 'ssil': 33, 't': 34, 'th': 35, 'uh': 36, 'uw': 37, 'v': 38, 'w': 39, 'y': 40, 'z': 41} nspkr = train_args.nspk norm_path = None if os.path.exists(train_args.data + '/norm_info/norm.dat'): norm_path = train_args.data + '/norm_info/norm.dat' elif os.path.exists(os.path.dirname(args.checkpoint) + '/norm.dat'): norm_path = os.path.dirname(args.checkpoint) + '/norm.dat' else: print('ERROR: Failed to find norm file.') return train_args.noise = 0 model = Loop(train_args) model.load_state_dict(weights) if args.gpu >= 0: model.cuda() model.eval() if args.spkr not in range(nspkr): print('ERROR: Unknown speaker id: %d.' % args.spkr) return txt, feat, spkr, output_fname = None, None, None, None if args.npz is not '': txt, text, feat = npy_loader_phonemes(args.npz) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(torch.LongTensor([args.spkr]), volatile=True) fname = os.path.basename(args.npz)[:-4] output_fname = fname + '.gen_' + str(args.spkr) words = np.char.split(text).tolist() words = [word.encode('utf-8') for word in words] action = 'none' number = 'none' objectt = 'none' location = 'none' # Remove extra word for special cases if len(words) == 7: words = words[1:] action = words[0] if len(words) == 2: objectt = words[1] elif len(words) > 3: number = words[1] objectt = words[2] location = words[-1] #print(words[0], words[1], words[2], words[-1]) #print(text) # Read dataframe frames = {} if os.path.exists(args.dataset_file): df = pd.read_csv(args.dataset_file) for row in zip(*[df[col].values.tolist() for col in ['path', 'speakerId', 'transcription', 'action', 'number', 'object', 'location']]): frames[row[0]] = {'path': row[0], 'speakerId': row[1], 'transcription': row[2], 'action': row[3], 'number': row[4], 'object': row[5], 'location': row[6]} # Add new data path = os.path.join('wavs/synthetic', output_fname.strip("/") + '.wav') frames[path] = {'path': path, 'speakerId': args.spkr, 'transcription': text, 'action': action, 'number': number, 'object': objectt, 'location': location} paths = [] speakerIds = [] transcriptions = [] actions = [] numbers = [] objects = [] locations = [] for key, frame in frames.items(): paths.append(frame['path']) speakerIds.append(frame['speakerId']) transcriptions.append(frame['transcription']) actions.append(frame['action']) numbers.append(frame['number']) objects.append(frame['object']) locations.append(frame['location']) df = pd.DataFrame(OrderedDict([('path', paths), ('speakerId', speakerIds), ('transcription', transcriptions), ('action', actions), ('number', numbers), ('object', objects), ('location', locations)])) df.to_csv(args.dataset_file) else: print('ERROR: Must supply npz file path or text as source.') return ### key_list = list(char2code.keys()) val_list = list(char2code.values()) phrase = [key_list[val_list.index(letter)] for letter in txt.data.numpy()] #print(phrase) ### if args.gpu >= 0: txt = txt.cuda() feat = feat.cuda() spkr = spkr.cuda() out, attn = model([txt, spkr], feat) out, attn = trim_pred(out, attn) output_dir = os.path.join(os.path.dirname(args.checkpoint), 'results') if not os.path.exists(output_dir): os.makedirs(output_dir) generate_merlin_wav(out.data.cpu().numpy(), output_dir, output_fname, norm_path)
def generate_sample_with_loop( npz='', text='', spkr_id=1, gender=1, checkpoint='models/vctk-16khz-cmu-no-boundaries-all/bestmodel.pth', output_dir='./', npz_path='/home/ubuntu/loop/data/vctk-16khz-cmu-no-boundaries-all/numpy_features', output_file_override=None, ident_override=None): # npz = '' # text = 'Your tickets for the social issues' # text = 'see that girl watch that scene' # npz = '/home/ubuntu/loop/data/vctk/numpy_features/p294_011.npz' # spkr_id = 12 # checkpoint = 'checkpoints/vctk/lastmodel.pth' # checkpoint = 'models/vctk/bestmodel.pth' gender = np.array(gender).reshape(-1) out_dict = dict() if not os.path.exists(output_dir): os.makedirs(output_dir) gpu = 0 # load loop weights & params from checkpoint weights = torch.load(checkpoint, map_location=lambda storage, loc: storage) opt = torch.load(os.path.dirname(checkpoint) + '/args.pth') train_args = opt[0] train_dataset = NpzFolder( '/home/ubuntu/loop/data/vctk-16khz-cmu-no-boundaries-all/numpy_features' ) char2code = train_dataset.dict spkr2code = train_dataset.speakers # print spkr2code.cpu().data norm_path = train_args.data + '/norm_info/norm.dat' norm_path = '/home/ubuntu/loop/data/vctk-16khz-cmu-no-boundaries-all/norm_info/norm.dat' train_args.noise = 0 valid_dataset_path = npz_path + '_valid' # prepare loop model if ident_override: #model = Loop_Ident(train_args) pass else: model = Loop_Base(train_args) model.load_state_dict(weights) if gpu >= 0: model.cuda() model.eval() # check speaker id is valid if spkr_id not in range(len(spkr2code)): print('ERROR: Unknown speaker id: %d.' % spkr_id) # get phone sequence txt, feat, spkr, output_fname = None, None, None, None if npz is not '': # use pre-calculated phonemes etc. txt, feat, pre_calc_feat = npy_loader_phonemes( os.path.join(npz_path, npz)) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(torch.LongTensor([spkr_id]), volatile=True) output_file = os.path.basename(npz)[:-4] + '_' + str(spkr_id) out_dict['pre_calc_feat'] = pre_calc_feat elif text is not '': # use specified text string # extract phonemes from the text txt = text2phone(text, char2code) feat = torch.FloatTensor(txt.size(0) * 20, 63) spkr = torch.LongTensor([spkr_id]) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(spkr, volatile=True) output_file = text.replace(' ', '_') else: print('ERROR: Must supply npz file path or text as source.') raise Exception('Need source') if output_file_override: output_file = output_file_override # use gpu if gpu >= 0: txt = txt.cuda() feat = feat.cuda() spkr = spkr.cuda() # run loop model to generate output features # print(ident_override) if ident_override: loop_feat, attn = model([txt, spkr, gender], feat, ident_override=ident_override) else: loop_feat, attn = model([txt, spkr, gender], feat) loop_feat, attn = trim_pred(loop_feat, attn) # add to output dictionary out_dict['txt'] = txt[:, 0].squeeze().data.tolist() out_dict['spkr'] = spkr out_dict['feat'] = feat.data.cpu().numpy() out_dict['loop_feat'] = loop_feat.data.cpu().numpy() out_dict['attn'] = attn.squeeze().data.cpu().numpy() out_dict['output_file'] = output_file out_dict['valid_dataset_path'] = valid_dataset_path # print output_dir # generate .wav file from loop output features #print(output_dir) #print(output_file) #print(norm_path) generate_merlin_wav(loop_feat.data.cpu().numpy(), output_dir, output_file, norm_path) # generate .wav file from original features for reference if npz is not '': output_orig_fname = os.path.basename(npz)[:-4] + '.orig' generate_merlin_wav(feat[:, 0, :].data.cpu().numpy(), output_dir, output_orig_fname, norm_path) out_dict['output_orig_fname'] = output_orig_fname return out_dict
def main(): weights = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) opt = torch.load(os.path.dirname(args.checkpoint) + '/args.pth') train_args = opt[0] char2code = {'aa': 0, 'ae': 1, 'ah': 2, 'ao': 3, 'aw': 4, 'ax': 5, 'ay': 6, 'b': 7, 'ch': 8, 'd': 9, 'dh': 10, 'eh': 11, 'er': 12, 'ey': 13, 'f': 14, 'g': 15, 'hh': 16, 'i': 17, 'ih': 18, 'iy': 19, 'jh': 20, 'k': 21, 'l': 22, 'm': 23, 'n': 24, 'ng': 25, 'ow': 26, 'oy': 27, 'p': 28, 'pau': 29, 'r': 30, 's': 31, 'sh': 32, 'ssil': 33, 't': 34, 'th': 35, 'uh': 36, 'uw': 37, 'v': 38, 'w': 39, 'y': 40, 'z': 41} nspkr = train_args.nspk norm_path = None if os.path.exists(train_args.data + '/norm_info/norm.dat'): norm_path = train_args.data + '/norm_info/norm.dat' elif os.path.exists(os.path.dirname(args.checkpoint) + '/norm.dat'): norm_path = os.path.dirname(args.checkpoint) + '/norm.dat' else: print('ERROR: Failed to find norm file.') return train_args.noise = 0 model = Loop(train_args) model.load_state_dict(weights) if args.gpu >= 0: model.cuda() model.eval() if args.spkr not in range(nspkr): print('ERROR: Unknown speaker id: %d.' % args.spkr) return txt, feat, spkr, output_fname = None, None, None, None if args.npz is not '': txt, feat = npy_loader_phonemes(args.npz) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(torch.LongTensor([args.spkr]), volatile=True) fname = os.path.basename(args.npz)[:-4] output_fname = fname + '.gen_' + str(args.spkr) elif args.text is not '': txt = text2phone(args.text, char2code) feat = torch.FloatTensor(txt.size(0)*20, 63) spkr = torch.LongTensor([args.spkr]) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(spkr, volatile=True) # slugify input string to file name fname = args.text.replace(' ', '_') valid_chars = "-_.() %s%s" % (string.ascii_letters, string.digits) fname = ''.join(c for c in fname if c in valid_chars) output_fname = fname + '.gen_' + str(args.spkr) else: print('ERROR: Must supply npz file path or text as source.') return if args.gpu >= 0: txt = txt.cuda() feat = feat.cuda() spkr = spkr.cuda() out, attn = model([txt, spkr], feat) out, attn = trim_pred(out, attn) output_dir = os.path.join(os.path.dirname(args.checkpoint), 'results') if not os.path.exists(output_dir): os.makedirs(output_dir) generate_merlin_wav(out.data.cpu().numpy(), output_dir, output_fname, norm_path) if args.npz is not '': output_orig_fname = os.path.basename(args.npz)[:-4] + '.orig' generate_merlin_wav(feat[:, 0, :].data.cpu().numpy(), output_dir, output_orig_fname, norm_path)
def main(): weights = torch.load(args.checkpoint, map_location=lambda storage, loc: storage) opt = torch.load(os.path.dirname(args.checkpoint) + '/args.pth') train_args = opt[0] train_dataset = NpzFolder(train_args.data + '/numpy_features') char2code = train_dataset.dict spkr2code = train_dataset.speakers norm_path = train_args.data + '/norm_info/norm.dat.npy' train_args.noise = 0 model = Loop(train_args) model.load_state_dict(weights) if args.gpu >= 0: model.cuda() model.eval() if args.spkr not in range(len(spkr2code)): print('ERROR: Unknown speaker id: %d.' % args.spkr) return txt, feat, spkr, output_fname = None, None, None, None if args.npz is not '': txt, feat = npy_loader_phonemes(args.npz) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(torch.LongTensor([args.spkr]), volatile=True) fname = os.path.basename(args.npz)[:-4] output_fname = fname + '.gen_' + str(args.spkr) elif args.text is not '': txt = text2phone(args.text, char2code) #feat = torch.FloatTensor(500, 67) feat = torch.FloatTensor(1500, 67) spkr = torch.LongTensor([args.spkr]) txt = Variable(txt.unsqueeze(1), volatile=True) feat = Variable(feat.unsqueeze(1), volatile=True) spkr = Variable(spkr, volatile=True) fname = args.text.replace(' ', '_') output_fname = fname + '.gen_' + str(args.spkr) else: print('ERROR: Must supply npz file path or text as source.') return if args.gpu >= 0: txt = txt.cuda() feat = feat.cuda() spkr = spkr.cuda() out, attn = model([txt, spkr], feat) out, attn = trim_pred(out, attn) output_dir = os.path.join(os.path.dirname(args.checkpoint), 'results') if not os.path.exists(output_dir): os.makedirs(output_dir) #''' generate_merlin_wav(out.data.cpu().numpy(), output_dir, output_fname, norm_path) #''' #out.data.cpu().numpy().tofile(output_fname) if args.npz is not '': output_orig_fname = os.path.basename(args.npz)[:-4] + '.orig' generate_merlin_wav(feat[:, 0, :].data.cpu().numpy(), output_dir, output_orig_fname, norm_path)
def main(): # load datasets train_dataset_path = os.path.join(args.data, 'numpy_features') train = NpzFolder(train_dataset_path) train.remove_too_long_seq(args.max_seq_len) train_loader = Dataset_Iter(train, batch_size=args.batch_size) train_loader.shuffle() valid_dataset_path = os.path.join(args.data, 'numpy_features_valid') valid = NpzFolder(valid_dataset_path) valid_loader = Dataset_Iter(valid, batch_size=args.batch_size) valid_loader.shuffle() # train_loader = Dataset_Iter(valid, batch_size=args.batch_size) # initiate tensorflow model input0 = tf.placeholder(tf.int64, [None, None]) input1 = tf.placeholder(tf.float32, [None]) # contains length of sentence speaker = tf.placeholder(tf.int32, [None, 1]) # speaker identity target0 = tf.placeholder(tf.float32, [None, None, 63]) target1 = tf.placeholder(tf.float32, [None]) # apparently speaker identity # idente = tf.placeholder(tf.float32, [None,256]) # s_t = tf.placeholder(tf.float32, [64,319,20]) # mu_t = tf.placeholder(tf.float32, [64,10]) # context = tf.placeholder(tf.float32, [64,64,256]) start = tf.placeholder(tf.bool, shape=(), name='start_new_batch') train_flag = tf.placeholder(tf.bool, shape=(), name='train_flag') # out_seq = tf.placeholder(tf.float32, [None, None, 63]) # attns_seq = tf.placeholder(tf.float32, [None, None, 63]) model = Loop(args) # Define loss and optimizer output, attns = model.forward(input0, speaker, target0, start, train_flag) loss_op = MaskedMSE(output, target0, target1) optimizer = tf.train.AdamOptimizer(learning_rate=args.lr) train_op, clip_flag = gradient_check_and_clip(loss_op, optimizer, args.clip_grad, args.ignore_grad) merged = tf.summary.merge_all() # Initialize the variables (i.e. assign their default value) init = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver(global_variable_list) load_model = not args.checkpoint == '' save_model = True best_eval = float('inf') sess_idx = 0 train_losses = [] valid_losses = [] with tf.Session() as sess: # Run the initializer train_writer = tf.summary.FileWriter( "%s/%s/train" % (args.outpath, expName), sess.graph) valid_writer = tf.summary.FileWriter( "%s/%s/valid" % (args.outpath, expName), sess.graph) # Restore variables from disk. sess.run(init) if load_model: saver.restore(sess, args.checkpoint) print("Model restored from file: %s" % args.checkpoint) for epoch in range(args.epochs): train_enum = tqdm(train_loader, desc='Train epoch %d' % epoch, total=ceil_on_division(len(train_loader), args.batch_size)) # Train data for batch_ind in train_enum: batch_loss_list = [] (srcBatch, srcLengths), (tgtBatch, tgtLengths), full_spkr = \ make_a_batch(train_loader.dataset, batch_ind) batch_iter = TBPTTIter((srcBatch, srcLengths), (tgtBatch, tgtLengths), full_spkr, args.seq_len) for (srcBatch, srcLenths), (tgtBatch, tgtLengths), spkr, start2 in batch_iter: loss, _, clip_flag1, summary = sess.run( [loss_op, train_op, clip_flag, merged], feed_dict={ input0: srcBatch, speaker: spkr, target0: tgtBatch, target1: tgtLengths, start: start2, train_flag: True }) train_writer.add_summary(summary, sess_idx) sess_idx += 1 if not clip_flag1: batch_loss_list.append(loss) else: print( '-' ) # if too many - appear, there are exploding gradients train_losses.append(batch_loss_list) if len(batch_loss_list) != 0: batch_loss = sum(batch_loss_list) / len(batch_loss_list) batch_loss_list.append(batch_loss) else: batch_loss = -1. train_enum.set_description('Train (loss %.2f) epoch %d' % (batch_loss, epoch)) train_enum.update(srcBatch.shape[0]) # Validate data valid_enum = tqdm(valid_loader, desc='Validating epoch %d' % epoch, total=ceil_on_division(len(valid_loader), args.batch_size)) batch_loss_list = [] for batch_ind in valid_enum: (srcBatch, srcLengths), (tgtBatch, tgtLengths), full_spkr = \ make_a_batch(valid_loader.dataset, batch_ind) loss, summary = sess.run( [loss_op, merged], feed_dict={ input0: srcBatch, speaker: full_spkr, target0: tgtBatch, target1: tgtLengths, start: True, train_flag: False }) batch_loss_list.append(loss) train_enum.set_description('Train (loss %.2f) epoch %d' % (loss, epoch)) valid_writer.add_summary(summary, sess_idx) sess_idx += 1 valid_enum.set_description('Validating (loss %.2f) epoch %d' % (loss, epoch)) if len(batch_loss_list) != 0: valid_losses.append(batch_loss_list) valid_loss = sum(batch_loss_list) / len(batch_loss_list) else: valid_loss = 99999. if valid_loss < best_eval and save_model: best_eval = valid_loss save_path = saver.save(sess, "%s/bestmodel.ckpt" % args.expName) print("NEW BEST MODEL!, model saved in file: %s" % save_path) print('Final validation loss for epoch %d is: %.2f' % (epoch, valid_loss)) train_loader.shuffle() valid_loader.shuffle() if save_model: save_path = saver.save(sess, "%s/model.ckpt" % args.expName) print("Model saved in file: %s" % save_path) train_writer.close() valid_writer.close()
def train(): hp = Hparams() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = Loop(hp, device) # check if we have checkpoint checkpoint_path = "checkpoints/last_model.pwf" if os.path.isfile(checkpoint_path): print("checkpoint found! loading checkpoint model...") model = load_from_checkpoint(model, checkpoint_path) else: print("no checkpoint found, training from scratch...") print("model has {} million parameters...".format( model.count_parameters())) # hyper-parameters optim = torch.optim.Adam(model.parameters(), lr=1e-4) epochs = 100 batch_size = 25 grad_norm = 0.5 valid_epoch = 2 # training parameters print('loading data...') train_data = VCTKDataSet("data/vctk/numpy_features/") val_data = VCTKDataSet("data/vctk/numpy_features_valid/") val_loader = DataLoader(val_data, batch_size=10, shuffle=False, drop_last=False, collate_fn=my_collate_fn) print('initial validation...') validate(model, val_loader) # actual training loop: for ep in tqdm(range(epochs)): # initialze loss and dataset total_loss = 0 loader = DataLoader(train_data, shuffle=True, drop_last=False, batch_size=batch_size, collate_fn=my_collate_fn) for data in tqdm(loader): text, text_list, target, target_list, spkr = data loss = model.compute_loss_batch((text, text_list), spkr, (target, target_list), teacher_forcing=True) # update optim.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm) optim.step() # save loss total_loss += float(loss.detach().cpu().numpy()) # if total loss is nan if math.isnan(total_loss): print('total loss is nan! loading from last checkpoint') model = load_from_checkpoint(model, checkpoint_path) optim = torch.optim.Adam(model.parameters(), lr=1e-4) else: print("loss is good, saving model...") torch.save(model.state_dict(), checkpoint_path) # print loss after every epoch print("epoch: {}, total loss: {}".format(ep, total_loss)) if ep != 0 and ep % valid_epoch == 0: print("validating model... ") validate(model, val_loader) # save model after every validation torch.save( model.state_dict(), "checkpoints/saved_models/val_model_{0:03d}.pwf".format(ep))
def main(): start_epoch = 1 model = Loop(args) model.cuda() if args.checkpoint != '': checkpoint_args_path = os.path.dirname(args.checkpoint) + '/args.pth' checkpoint_args = torch.load(checkpoint_args_path) start_epoch = checkpoint_args[3] model.load_state_dict( torch.load(args.checkpoint, map_location=lambda storage, loc: storage)) criterion = MaskedMSE().cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) # Keep track of losses train_losses = [] eval_losses = [] best_eval = float('inf') training_monitor = TrainingMonitor(file=args.expNameRaw, exp_name=args.expNameRaw, b_append=True, path='training_logs') # Begin! for epoch in range(start_epoch, start_epoch + args.epochs): # train model train(model, criterion, optimizer, epoch, train_losses) # evaluate on validation set eval_loss = evaluate(model, criterion, epoch, eval_losses) #chk, _, _, _ = ec.evaluate(model=model, # criterion=criterion, # epoch=epoch, # loader=valid_loader, # metrics=('loss') # ) # save checkpoint for this epoch # I'm saving every epoch so I can compute evaluation metrics across the training curve later on torch.save(model.state_dict(), '%s/epoch_%d.pth' % (args.expName, epoch)) torch.save([args, train_losses, eval_losses, epoch], '%s/args.pth' % (args.expName)) if eval_loss < best_eval: # if this is the best model yet, save it as 'bestmodel' torch.save(model.state_dict(), '%s/bestmodel.pth' % (args.expName)) best_eval = eval_loss # also keep a running copy of 'lastmodel' torch.save(model.state_dict(), '%s/lastmodel.pth' % (args.expName)) torch.save([args, train_losses, eval_losses, epoch], '%s/args.pth' % (args.expName)) # evaluate on a randomised subset of the training set if epoch % args.eval_epochs == 0: train_eval_loader = ec.get_training_data_for_eval( data=args.data, len_valid=len(valid_loader.dataset)) train_loss, _, _, _ = ec.evaluate(model=model, criterion=criterion, epoch=epoch, loader=train_eval_loader, metrics=('loss')) else: train_loss = None # store loss metrics training_monitor.insert(epoch=epoch, valid_loss=eval_loss, train_loss=train_loss) training_monitor.write()