def generate_from_synth_model( model_path="data/model_files/handwriting_cond_best.pt", sentence_list=[ "hello world !!", "this text is generated using an RNN model", "Welcome to Lyrebird!", ], bias=3.0, device=torch.device("cpu"), ): model = HandWritingSynthRNN() if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=device)) oh_encoder = pickle.load(open("data/one_hot_encoder.pkl", "rb")) sentences = [s.to(device) for s in oh_encoder.one_hot(sentence_list)] generated_samples, attn_vars = model.generate(sentences=sentences, bias=bias, device=device, use_stopping=True) model_name = model_path.split("/")[-1].replace(".pt", "") for i in range(len(sentence_list)): plot_stroke( generated_samples[:, i, :].cpu().numpy(), save_name="samples/{}_{}.png".format(model_name, i), ) print(f"generated strokes for: {sentence_list[i]}")
def train_all_random_batch(rnn: GeneratorRNN, optimizer: torch.optim.Optimizer, data, output_directory='./output', tail=True): batch_size = 100 i = 0 model_dir = os.path.join(output_directory, 'models_batch_uncond') sample_dir = os.path.join(output_directory, 'batch_uncond') if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) pbar = tqdm() while True: index = np.random.choice(range(len(data) - batch_size), size=1)[0] batched_data = data[index:(index + batch_size)] if i % 50 == 0: for b in [0, 0.1, 1, 5]: generated_strokes = generate_sequence(rnn, 700, bias=b) file_name = os.path.join(sample_dir, '%d-%s.png' % (i, b)) plot_stroke(generated_strokes, file_name) tqdm.write('Writing file: %s' % file_name) model_file_name = os.path.join(model_dir, '%d.pt' % i) torch.save(rnn.state_dict(), model_file_name) i += 1 if tail: train_full_batch(rnn, optimizer, batched_data) else: train_truncated_batch(rnn, optimizer, batched_data) pbar.update(1) return
def generate(self, text, timesteps=1000, seed=None, filepath='samples/conditional.png'): self.build_model(seq_length=1, max_sentence_length=len(text)) sample = np.zeros((1, timesteps + 1, 3), dtype='float32') char_index, _ = char_to_index() one_hot_text = tf.expand_dims(one_hot_encode(text, self.num_characters, char_index), 0) text_length = tf.expand_dims(tf.constant(len(text)), 0) input_states = self.initial_states(1) for i in range(timesteps): outputs, input_states, phi = self.model([sample[:,i:i+1,:], input_states, one_hot_text, text_length]) input_states[-1] = tf.reshape(input_states[-1], (1, self.num_characters)) sample[0,i+1] = self.sample(outputs, seed) # stopping heuristic finished = True phi_last = phi[0,0,-1] for phi_u in phi[0,0,:-1]: if phi_u.numpy() > phi_last.numpy(): finished = False break # prevent early stopping if i < 100: finished = False if finished: break # remove first zeros and discard unused timesteps sample = sample[0,1:i] plot_stroke(sample, save_name=filepath) return sample
def train_all(rnn: GeneratorRNN, optimizer: torch.optim.Optimizer, data): i = 0 while True: for strokes in tqdm(data): if i % 50 == 0: for b in [0, 0.1, 1, 5]: generated_strokes = generate_sequence(rnn, 700, bias=b) file_name = 'output/uncond/%d-%s.png' % (i, b) plot_stroke(generated_strokes, file_name) tqdm.write('Writing file: %s' % file_name) torch.save(rnn.state_dict(), "output/models/%d.pt" % i) i += 1 train(rnn, optimizer, strokes) return
def generate(self, timesteps=400, seed=None, filepath='samples/unconditional.jpeg'): self.build_model(seq_length=1) sample = np.zeros((1, timesteps + 1, 3), dtype='float32') input_states = [tf.zeros((1, self.num_cells))] * 2 * self.num_layers for i in range(timesteps): outputs, input_states = self.model( [sample[:, i:i + 1, :], input_states]) sample[0, i + 1] = self.sample(outputs, seed) # remove first zeros sample = sample[0, 1:] plot_stroke(sample, save_name=filepath) return sample
def path_to_stroke(path_data, k=1, save_path="./mobile/"): """ Convert svg path data into stroke data with offset coordinates args: path_data: list of svg path points k: downsample factor, default 1 means no downsampling save_path: directory path to save stroke.npy file """ save_path = Path(save_path) stroke = np.zeros((len(path_data), 3)) i = 0 while i < len(path_data): command = path_data[i][0] coord = path_data[i][1:].split(',') if command == 'M': stroke[i, 0] = 1.0 elif command == 'L': stroke[i, 0] = 0.0 stroke[i, 1] = float(coord[0]) stroke[i, 2] = -float(coord[1]) i += 1 stroke[0, 0] = 0.0 stroke[-1, 0] = 1. print("initial shape of data: ", stroke.shape) cuts = np.where(stroke[:, 0] == 1.)[0] print("EOS index:", cuts) start = 0 down_sample_data = [] for eos in cuts: down_sample_data.append(stroke[start:eos:k]) down_sample_data.append(stroke[eos]) start = eos + 1 down_sample_stroke = np.vstack(down_sample_data) # convert absolute coordinates into offset down_sample_stroke[ 1:, 1:] = down_sample_stroke[1:, 1:] - down_sample_stroke[:-1, 1:] print("final shape of data: ", down_sample_stroke.shape) plot_stroke(down_sample_stroke, "img.png") np.save(save_path, down_sample_stroke, allow_pickle=True)
def main(config): logger = config.get_logger('experiments') # setup the device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # setup data_loader instances data_loader = config.init_obj('data_loader', module_data) # build model architecture model = config.init_obj('arch', module_arch, char2idx=data_loader.dataset.char2idx, device=device) logger.info(model) # Loading the weights of the model logger.info('Loading checkpoint: {} ...'.format(config.resume)) checkpoint = torch.load(config.resume, map_location=device) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) # prepare model for testing model = model.to(device) model.eval() with torch.no_grad(): if str(model).startswith('Unconditional'): sampled_stroke = model.generate_unconditional_sample() plot_stroke(sampled_stroke) elif str(model).startswith('Conditional'): sampled_stroke = model.generate_conditional_sample('hello world') plot_stroke(sampled_stroke) elif str(model).startswith('Seq2Seq'): sent, stroke = data_loader.dataset[21] predicted_seq = model.recognize_sample(stroke) print('real text: ', data_loader.dataset.tensor2sentence(sent)) print( 'predicted text: ', data_loader.dataset.tensor2sentence( torch.tensor(predicted_seq)))
def submit_style_data(): data = request.get_json() path = data["path"] text = data["text"] if path == "": return jsonify( dict({"redirect": url_for("draw"), "message": "Please enter some style"}) ) id = str(uuid.uuid4()) session["id"] = id tmp_dir = os.path.join(flask_app.root_path, "static", "uploads", id) if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) os.chmod(tmp_dir, 0o777) print(tmp_dir) # user agent info user_agent = request.user_agent print(user_agent.string) print(user_agent.platform) phones = ["android", "iphone"] down_sample = True if user_agent.platform in phones: down_sample = False text_path = os.path.join(tmp_dir, "inpText.txt") print(text_path) with open(text_path, "w") as f: f.write(text) f.close() stroke = path_string_to_stroke( path, str_len=len(list(text)), down_sample=down_sample ) save_path = os.path.join(tmp_dir, "style.npy") np.save(save_path, stroke, allow_pickle=True) print(save_path) # plot the sequence plot_stroke(stroke.astype(np.float32), os.path.join(tmp_dir, "original.png")) return jsonify(dict({"redirect": url_for("generate"), "message": ""}))
def generate_from_model( model_path="data/model_files/handwriting_uncond_best.pt", sample_length=600, num_sample=2, bias=0.5, device=torch.device("cpu"), ): """ Generate num_sample number of samples each of length sample_length using a pretrained model """ handWritingRNN = HandWritingRNN() handWritingRNN.load_state_dict(torch.load(model_path, map_location=device)) generated_samples = handWritingRNN.generate(device=device, length=sample_length, batch=num_sample, bias=bias) model_name = model_path.split("/")[-1].replace(".pt", "") for i in range(num_sample): plot_stroke( generated_samples[:, i, :].cpu().numpy(), save_name="samples/{}_{}.png".format(model_name, i), )
def test_no_errors(): rnn = GeneratorRNN(1) strokes = generate_sequence(rnn, 10) plot_stroke(strokes, 'strokes.png') rnn = GeneratorRNN(20) strokes = generate_sequence(rnn, 10) plot_stroke(strokes, 'strokes.png') rnn = GeneratorRNN(20) strokes = generate_sequence(rnn, 10, bias=10) plot_stroke(strokes, 'strokes.png')
def test_conditioned_no_errors(): d = {'a': 0, 'b': 1, 'c': 2} sentence_vec = sentence_to_vectors('abcabc', d) sentence_vec = Variable(torch.from_numpy(sentence_vec).float(), requires_grad=False) rnn = ConditionedRNN(1, num_chars=3) strokes = generate_conditioned_sequence(rnn, 10, sentence_vec) plot_stroke(strokes, 'strokes.png') rnn = ConditionedRNN(20, num_chars=3) strokes = generate_conditioned_sequence(rnn, 10, sentence_vec) plot_stroke(strokes, 'strokes.png') rnn = ConditionedRNN(20, num_chars=3) strokes = generate_conditioned_sequence(rnn, 10, sentence_vec, bias=10) plot_stroke(strokes, 'strokes.png')
def generate_handwriting( char_seq="hello world", real_text="", style_path="../app/static/mobile/style.npy", save_path="", app_path="", n_samples=1, bias=10.0, ): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_path = os.path.join(app_path, "../data/") model_path = os.path.join(app_path, "../results/synthesis/best_model_synthesis.pt") # seed = 194 # print("seed:",seed) # torch.manual_seed(seed) # np.random.seed(seed) # print(np.random.get_state()) train_dataset = HandwritingDataset(data_path, split="train", text_req=True) prime = True is_map = False style = np.load(style_path, allow_pickle=True, encoding="bytes").astype(np.float32) # plot the sequence # plot_stroke(style, os.path.join(save_path, "original.png")) print("Priming text: ", real_text) mean, std, style = data_normalization(style) style = torch.from_numpy(style).unsqueeze(0).to(device) # style = valid_offset_normalization(Global.train_mean, Global.train_std, style[None,:,:]) # style = torch.from_numpy(style).to(device) print("Priming sequence size: ", style.shape) ytext = real_text + " " + char_seq + " " for i in range(n_samples): gen_seq, phi = generate_conditional_sequence( model_path, char_seq, device, train_dataset.char_to_id, train_dataset.idx_to_char, bias, prime, style, real_text, is_map, ) if is_map: plt.imshow(phi, cmap="viridis", aspect="auto") plt.colorbar() plt.xlabel("time steps") plt.yticks(np.arange(phi.shape[0]), list(ytext), rotation="horizontal") plt.margins(0.2) plt.subplots_adjust(bottom=0.15) plt.show() # denormalize the generated offsets using train set mean and std print("data denormalization...") end = style.shape[1] # gen_seq[:,:end] = data_denormalization(mean, std, gen_seq[:, :end]) # gen_seq[:,end:] = data_denormalization(Global.train_mean, Global.train_std, gen_seq[:,end:]) gen_seq = data_denormalization(Global.train_mean, Global.train_std, gen_seq) # plot the sequence print(gen_seq.shape) # plot_stroke(gen_seq[0, :end]) plot_stroke(gen_seq[0], os.path.join(save_path, "gen_stroke_" + str(i) + ".png")) print(save_path)
gen_seq = generate_conditional_sequence(model_path, args.char_seq, device, train_dataset.char_to_id, train_dataset.idx_to_char, args.bias, prime, style, real_text) gen_seq = data_denormalization( Global.train_mean, Global.train_std, gen_seq) gen_seq = np.squeeze(gen_seq) # plot the sequence if args.save_img: img_path = os.path.join(str(args.save_path), "gen_img"+str(time.time())+".png") plot_stroke(gen_seq, save_name=img_path) print(f"Image saved as: {img_path}") if args.save_gif: gif_path = os.path.join(str(args.save_path), "gen_img"+str(time.time())+".gif") plot_stroke_gif(gen_seq, save_name=gif_path) print(f"GIF saved as: {gif_path}") # Export generated sequence as json seq_list = gen_seq.tolist() json_file_path = os.path.join(str(args.save_path), "generated_seq.json") json.dump(seq_list, codecs.open(json_file_path, 'w', encoding='utf-8'), separators=(',', ':'), sort_keys=True, indent=4) print(f"Sequence saved to json: {json_file_path}")
def train(device, args, data_path="data/"): """ """ random_seed = 42 writer = SummaryWriter(log_dir=args.logdir, comment="") model_path = args.logdir + ("/unconditional_models/" if args.uncond else "/conditional_models/") os.makedirs(model_path, exist_ok=True) strokes = np.load(data_path + "strokes.npy", encoding="latin1") sentences = "" with open(data_path + "sentences.txt") as f: sentences = f.readlines() sentences = [snt.replace("\n", "") for snt in sentences] # Instead of removing the newline symbols, should it be used instead? MAX_STROKE_LEN = 800 strokes, sentences, MAX_SENTENCE_LEN = filter_long_strokes( strokes, sentences, MAX_STROKE_LEN, max_index=args.n_data) # print("Max sentence len after filter is: {}".format(MAX_SENTENCE_LEN)) # dimension of one-hot representation N_CHAR = 57 oh_encoder = OneHotEncoder(sentences, n_char=N_CHAR) pickle.dump(oh_encoder, open("data/one_hot_encoder.pkl", "wb")) sentences_oh = [s.to(device) for s in oh_encoder.one_hot(sentences)] # normalize strokes data and convert to pytorch tensors strokes = normalize_data(strokes) # plot_stroke(strokes[1]) tstrokes = [torch.from_numpy(stroke).to(device) for stroke in strokes] # pytorch dataset dataset = HandWritingData(sentences_oh, tstrokes) # validating the padding lengths assert dataset.strokes_padded_len <= MAX_STROKE_LEN assert dataset.sentences_padded_len == MAX_SENTENCE_LEN # train - validation split train_split = 0.95 train_size = int(train_split * len(dataset)) validn_size = len(dataset) - train_size dataset_train, dataset_validn = torch.utils.data.random_split( dataset, [train_size, validn_size]) dataloader_train = DataLoader( dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=False) # last batch may be smaller than batch_size dataloader_validn = DataLoader(dataset_validn, batch_size=args.batch_size, shuffle=True, drop_last=False) common_model_structure = { "memory_cells": 400, "n_gaussians": 20, "num_layers": 2 } model = (HandWritingRNN(**common_model_structure).to(device) if args.uncond else HandWritingSynthRNN( n_char=N_CHAR, n_gaussians_window=10, kappa_factor=0.05, **common_model_structure, ).to(device)) print(model) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0) # optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-2, # weight_decay=0, momentum=0) if args.resume is None: model.init_params() else: model.load_state_dict(torch.load(args.resume, map_location=device)) print("Resuming trainig on {}".format(args.resume)) # resume_optim_file = args.resume.split(".pt")[0] + "_optim.pt" # if os.path.exists(resume_optim_file): # optimizer = torch.load(resume_optim_file, map_location=device) scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.1**0.5, patience=10, verbose=True) best_batch_loss = 1e7 for epoch in range(200): train_losses = [] validation_iters = [] validation_losses = [] for i, (c_seq, x, masks, c_masks) in enumerate(dataloader_train): # make batch_first = false x = x.permute(1, 0, 2) masks = masks.permute(1, 0) # remove last point (prepending a dummy point (zeros) already done in data) inp_x = x[:-1] # shape : (T, B, 3) masks = masks[:-1] # shape: (B, T) # c_seq.shape: (B, MAX_SENTENCE_LEN, n_char), c_masks.shape: (B, MAX_SENTENCE_LEN) inputs = (inp_x, c_seq, c_masks) if args.uncond: inputs = (inp_x, ) e, log_pi, mu, sigma, rho, *_ = model(*inputs) # remove first point from x to make it y loss = criterion(x[1:], e, log_pi, mu, sigma, rho, masks) train_losses.append(loss.detach().cpu().numpy()) optimizer.zero_grad() loss.backward() # --- this may not be needed torch.nn.utils.clip_grad_norm_(model.parameters(), 5) optimizer.step() # do logging print("{},\t".format(loss)) if i % 10 == 0: writer.add_scalar("Every_10th_batch_loss", loss, epoch * len(dataloader_train) + i) # save as best model if loss is better than previous best if loss < best_batch_loss: best_batch_loss = loss model_file = ( model_path + f"handwriting_{('un' if args.uncond else '')}cond_best.pt") torch.save(model.state_dict(), model_file) optim_file = model_file.split(".pt")[0] + "_optim.pt" torch.save(optimizer, optim_file) epoch_avg_loss = np.array(train_losses).mean() scheduler.step(epoch_avg_loss) # ======================== do the per-epoch logging ======================== writer.add_scalar("Avg_loss_for_epoch", epoch_avg_loss, epoch) print(f"Average training-loss for epoch {epoch} is: {epoch_avg_loss}") model_file = ( model_path + f"handwriting_{('un' if args.uncond else '')}cond_ep{epoch}.pt") torch.save(model.state_dict(), model_file) optim_file = model_file.split(".pt")[0] + "_optim.pt" torch.save(optimizer, optim_file) # generate samples from model sample_count = 3 sentences = ["welcome to lyrebird" ] + ["abcd efgh vicki"] * (sample_count - 1) sentences = [s.to(device) for s in oh_encoder.one_hot(sentences)] if args.uncond: generated_samples = model.generate(600, batch=sample_count, device=device) else: generated_samples, attn_vars = model.generate(sentences, device=device) figs = [] # save png files of the generated models for i in range(sample_count): f = plot_stroke( generated_samples[:, i, :].cpu().numpy(), save_name=args.logdir + "/training_imgs/{}cond_ep{}_{}.png".format( ("un" if args.uncond else ""), epoch, i), ) figs.append(f) for i, f in enumerate(figs): writer.add_figure(f"samples/image_{i}", f, epoch) if not args.uncond: figs_phi = plot_phi(attn_vars["phi_list"]) figs_kappa = plot_attn_scalar(attn_vars["kappa_list"]) for i, (f_phi, f_kappa) in enumerate(zip(figs_phi, figs_kappa)): writer.add_figure(f"attention/phi_{i}", f_phi, epoch) writer.add_figure(f"attention/kappa_{i}", f_kappa, epoch)
def save_stroke(new_stroke, name): print() new_stroke = new_stroke.squeeze().data new_stroke = np.array(new_stroke) lb_uts.plot_stroke(new_stroke, name)
nan = True print('exiting train @epoch : {}'.format(e)) break mean_loss = np.mean(train_loss) print("Epoch {:03d}: Loss: {:.3f}".format(e, mean_loss)) if e % 1 == 0: hws.model.save_weights(EPOCH_MODEL_PATH.format(e)) if nan: break except KeyboardInterrupt: pass if not nan: hws.model.save_weights(MODEL_PATH) # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''EVALUATE''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' verbose_sentence = "".join(D.encoder.inverse_transform(sentence)[0]) strokes1, _, _, _ = hws.infer(sentence, inf_type='max', verbose=verbose_sentence) plot_stroke(strokes1) import ipdb ipdb.set_trace()
def main(): parser = argparse.ArgumentParser() parser.add_argument("--strokes_path", type=str, default="../data/strokes.npy") parser.add_argument("--texts_path", type=str, default="../data/sentences.txt") parser.add_argument("--batch_size", type=int, default=128) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--hidden_size", type=int, default=200) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--model_dir", type=str, required=True) parser.add_argument("--grad_norm", type=float, default=10.0) parser.add_argument("--model_type", choices=["prediction", "synthesis"], required=True) parser.add_argument('--disable-cuda', action='store_true', help='Disable CUDA') args = parser.parse_args() args.device = None if not args.disable_cuda and torch.cuda.is_available(): args.device = torch.device('cuda') else: args.device = torch.device('cpu') print("Device: {}".format(args.device)) strokes = np.load(args.strokes_path, encoding="latin1", allow_pickle=True) with open(args.texts_path) as fin: texts = list(map(lambda x: x.strip(), fin)) attention_scale = 1. / np.mean([ len(stroke) / len(sentence) for stroke, sentence in zip(strokes, texts) ]) print("Attention scale: {}".format(attention_scale)) alphabet = {} for t in texts: for c in t: if c not in alphabet: alphabet[c] = len(alphabet) inv_alphabet = {y: x for x, y in alphabet.items()} test_text = "Welcome to lyrebird!" train_strokes, valid_strokes, train_texts, valid_texts = train_test_split( strokes, texts, test_size=0.1) no_text = args.model_type == "prediction" train_dataset = StrokesDataset(train_strokes, train_texts, alphabet, no_text) valid_dataset = StrokesDataset(valid_strokes, valid_texts, alphabet, no_text) train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=16) if args.model_type == "prediction": model = StrokesPrediction(hidden_size=args.hidden_size) model.to(args.device) model.sample(device=args.device) elif args.model_type == "synthesis": model = StrokesSynthesis(num_letters=len(alphabet), attention_scale=attention_scale, hidden_size=args.hidden_size) model.to(args.device) model.sample(torch.LongTensor([alphabet[x] for x in test_text ])[None, :].to(args.device), device=args.device) else: raise ValueError("unknown model type") optmizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-6) writer = SummaryWriter(args.model_dir) print("Starting to train...") global_step = 0 for i in range(args.num_epochs): print("Epoch {} / {}".format(i + 1, args.num_epochs)) for batch_X in tqdm.tqdm(train_dataloader): copy_to_device(batch_X, args.device) batch_X["strokes_inputs"] = batch_X["strokes_inputs"].permute( (1, 0, 2)) batch_X["strokes_targets"] = batch_X["strokes_targets"].permute( (1, 0, 2)) mixture_loss, end_loss = model.loss(**batch_X, device=args.device) loss = mixture_loss + end_loss loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) optmizer.step() optmizer.zero_grad() writer.add_scalar("Mixture loss/train", mixture_loss.cpu().detach().numpy(), global_step=global_step) writer.add_scalar("End loss/train", end_loss.cpu().detach().numpy(), global_step=global_step) loss = loss.cpu().detach().numpy() print("Loss: {}".format(loss)) writer.add_scalar("Loss/train", loss, global_step=global_step) global_step += 1 model_path = os.path.join(args.model_dir, "epoch_{}.pt".format(i)) save_model(model, model_path) valid_losses = [] valid_end_losses = [] valid_mixture_losses = [] for batch_X in tqdm.tqdm(valid_dataloader): copy_to_device(batch_X, args.device) batch_X["strokes_inputs"] = batch_X["strokes_inputs"].permute( (1, 0, 2)) batch_X["strokes_targets"] = batch_X["strokes_targets"].permute( (1, 0, 2)) mixture_loss, end_loss = model.loss(**batch_X, device=args.device) loss = mixture_loss + end_loss valid_losses.append(loss.cpu().detach().numpy()) valid_end_losses.append(end_loss.cpu().detach().numpy()) valid_mixture_losses.append(mixture_loss.cpu().detach().numpy()) writer.add_scalar("Mixture loss/valid", np.mean(valid_mixture_losses), global_step=i) writer.add_scalar("End loss/valid", np.mean(valid_end_losses), global_step=i) writer.add_scalar("Loss/valid", np.mean(valid_losses), global_step=i) if args.model_type == "synthesis": h_0, h_1, h_2, prev_w_t, prev_k = model.init_states( 1, device=args.device) strokes_inputs = torch.FloatTensor( valid_dataset[0]["strokes_inputs"])[:, None, :].to(args.device) text_inputs = torch.LongTensor( valid_dataset[0]["text"])[None, :].to(args.device) text_lengths = torch.LongTensor([valid_dataset[0]["text_lengths"] ]).to(args.device) mixture_params, end_of_stroke_logits, h_0, h_1, h_2, prev_w_t, prev_k, alignments = model( strokes_inputs, text_inputs, text_lengths, h_0, h_1, h_2, prev_w_t, prev_k) first_alignment = alignments[:valid_dataset[0]["strokes_lengths"], 0, :text_lengths[0]] first_alignment = first_alignment.detach().cpu().numpy().T writer.add_image("alignment", first_alignment[None, :]) text = "".join([ inv_alphabet[int(x)] for x in valid_dataset[0]["text"] ]).strip() sample = model.sample(text_inputs, device=args.device) img_dir = os.path.join(args.model_dir, "imgs") os.makedirs(img_dir, exist_ok=True) print(text) plot_stroke(sample, save_name=os.path.join(img_dir, "epoch_{}".format(i))) else: sample = model.sample(device=args.device) img_dir = os.path.join(args.model_dir, "imgs") os.makedirs(img_dir, exist_ok=True) plot_stroke(sample, save_name=os.path.join(img_dir, "epoch_{}".format(i)))
import numpy as np import sys sys.path.append("../") from utils import plot_stroke style = np.load( "./static/uploads/default_style.npy", allow_pickle=True, encoding="bytes" ).astype(np.float32) # plot the sequence plot_stroke(style, "default.png")
def train(args): dataset = Dataset(args.data_path, args.batch_size) data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=_collate_fn) outputlayer_name = ['e', 'pi', 'mu1', 'mu2', 'sig1', 'sig2', 'ro'] # for gradient cliping pkl_file = open('char2int.pkl', 'rb') char2int = pickle.load(pkl_file) pkl_file.close() test_char = "welcome to lyrebird" char2array = torch.from_numpy(np.array([char2int[x] for x in test_char])).long().cuda() char2array = char2array.unsqueeze(0) epochs = args.num_epoch lr = args.lr max_len = 600 use_cuda = torch.cuda.is_available() print_freq = 10 if use_cuda: Model = ConditionalModel(use_cuda=use_cuda).cuda() else: Model = ConditionalModel() #train optimizer = torch.optim.Adam(Model.parameters(), lr=lr) plot_loss = [] for epoch in range(epochs): #batch_loss = 0 total_loss = 0 start = time.time() for i, (data) in enumerate(data_loader): ys, char, ys_mask, char_mask = data if use_cuda: ys = ys.cuda() char = char.long().cuda() ys_mask = torch.from_numpy(ys_mask).long().cuda() char_mask = torch.from_numpy(char_mask).long().cuda() #print(ys) prev_state, prev_offset, prev_w = None, None, None e, pi, mu1, mu2, sig1, sig2, ro, _, _, _, _= Model(ys.permute(1,0,2)[:-1], char, char_mask, prev_state, prev_offset, prev_w) loss = Model.prediction_loss(e, pi, mu1, mu2, sig1, sig2, ro, ys.permute(1,0,2)[1:]) optimizer.zero_grad() loss.backward() for name, param in Model.named_parameters(): if 'lstm' in name: param.grad.data.clamp_(-10, 10) else: param.grad.data.clamp_(-100, 100) optimizer.step() plot_loss.append(loss.item()) total_loss += loss.item() #print(data) if i % print_freq == 0: stroke = torch.tensor([1, 0, 0]) print('Epoch {0} | Iter {1} | Average Loss {2:.3f} | ' 'Current Loss {3:.6f} | {4:.1f} ms/batch'.format( epoch + 1, i + 1, total_loss / (i + 1), loss.item(), 1000 * (time.time() - start) / (i + 1)), flush=True) prediction= Model.generate_samples(stroke, char2array, max_len) prediction = prediction.squeeze(0).cpu().numpy() print(prediction.shape) plot_stroke(prediction, 'generation_fix.png') torch.save(Model.state_dict(), "./checkpoint/synthesis_model")
# Train/Generate if args.train: optimizer = create_optimizer(rnn) if args.unconditioned: if args.batch: unconditioned.train_all_random_batch(rnn, optimizer, normalized_data) else: unconditioned.train_all(rnn, optimizer, normalized_data) elif args.conditioned: if args.batch: conditioned.train_all_random_batch(rnn, optimizer, normalized_data) else: conditioned.train_all(rnn, optimizer, normalized_data) elif args.generate: if args.unconditioned: print("Generating a random handwriting sample.") strokes = generator.generate_sequence(rnn, 700, 1) elif args.conditioned: target_sentence = args.output_text print("Generating handwriting for text: %s" % target_sentence) target_sentence = Variable(torch.from_numpy( sentence_to_vectors(target_sentence, alphabet_dict)).float().cuda(), requires_grad=False) strokes = generator.generate_conditioned_sequence( rnn, 2000, target_sentence, 3) plot_stroke(strokes, "output.png")
def train(model, train_loader, valid_loader, batch_size, n_epochs, lr, patience, step_size, device, model_type, save_path): model_path = save_path + "model_" + model_type + ".pt" model = model.to(device) if os.path.isfile(model_path): model.load_state_dict(torch.load(model_path)) print(f"[ACTION] Loaded model weights from '{model_path}'") else: print("[INFO] No saved weights found, training from scratch.") optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = StepLR(optimizer, step_size=step_size, gamma=0.1) train_losses = [] valid_losses = [] best_loss = math.inf best_epoch = 0 k = 0 for epoch in range(n_epochs): start_time = time.time() print(f"[Epoch {epoch + 1}/{n_epochs}]") print("[INFO] Training Model.....") train_loss = train_epoch(model, optimizer, epoch, train_loader, device, model_type) print("[INFO] Validating Model....") valid_loss = validation(model, valid_loader, device, epoch, model_type) train_losses.append(train_loss) valid_losses.append(valid_loss) print(f"[RESULT] Epoch {epoch + 1}/{n_epochs}" f"\tTrain loss: {train_loss:.3f}\tVal loss: {valid_loss:.3f}") if step_size != -1: scheduler.step() if valid_loss < best_loss: best_loss = valid_loss best_epoch = epoch + 1 print("[SAVE] Saving weights at epoch: {}".format(epoch + 1)) torch.save(model.state_dict(), model_path) if model_type == "prediction": gen_seq = generate_unconditional_seq(model_path, 700, device, bias=10.0, style=None, prime=False) else: gen_seq = generate_conditional_sequence( model_path, "Hello world!", device, train_loader.dataset.char_to_id, train_loader.dataset.idx_to_char, bias=10.0, prime=False, prime_seq=None, real_text=None) # denormalize the generated offsets using train set mean and std gen_seq = data_denormalization(Global.train_mean, Global.train_std, gen_seq) # plot the sequence plot_stroke( gen_seq[0], save_name=save_path + model_type + "_seq_" + str(best_epoch) + ".png", ) k = 0 elif k > patience: print("Best model was saved at epoch: {}".format(best_epoch)) print("Early stopping at epoch {}".format(epoch)) break else: k += 1 total_time_taken = time.time() - start_time print('Time taken per epoch: {:.2f}s\n'.format(total_time_taken))
from argparse import ArgumentParser if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('text') parser.add_argument('-m', '--model', dest='model', default='conditional') parser.add_argument('-e', '--epoch', dest='epoch', type=int, required=True) parser.add_argument('-b', '--sample-bias', dest='sample_bias', type=float, default=0.0) parser.add_argument('-l', '--stroke-length', dest='stroke_length', type=int, default=None) parser.add_argument('-o', '--output', dest='output', default=None) args = parser.parse_args() stroke = generate_conditionally(text=args.text, model=args.model, epoch=args.epoch, sample_bias=args.sample_bias, stroke_length=args.stroke_length or len(args.text.replace(' ', '')) * 30) print('Stroke length: {}'.format(len(stroke))) plot_stroke(stroke, save_name=args.output)
def train(model, train_loader, valid_loader, batch_size, n_epochs, lr, patience, step_size, device, model_type, save_path): model_path = save_path + "best_model_" + model_type + ".pt" model = model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = StepLR(optimizer, step_size=step_size, gamma=0.1) train_losses = [] valid_losses = [] best_loss = math.inf best_epoch = 0 k = 0 for epoch in range(n_epochs): print("training.....") train_loss = train_epoch(model, optimizer, epoch, train_loader, device, model_type) print("validation....") valid_loss = validation(model, valid_loader, device, epoch, model_type) train_losses.append(train_loss) valid_losses.append(valid_loss) print('Epoch {}: Train: avg. loss: {:.3f}'.format( epoch + 1, train_loss)) print('Epoch {}: Valid: avg. loss: {:.3f}'.format( epoch + 1, valid_loss)) if step_size != -1: scheduler.step() if valid_loss < best_loss: best_loss = valid_loss best_epoch = epoch + 1 print('Saving best model at epoch {}'.format(epoch + 1)) torch.save(model.state_dict(), model_path) if model_type == "prediction": gen_seq = generate_unconditional_seq(model_path, 700, device, bias=10.0, style=None, prime=False) else: gen_seq, phi = generate_conditional_sequence( model_path, "Hello world!", device, train_loader.dataset.char_to_id, train_loader.dataset.idx_to_char, bias=10., prime=False, prime_seq=None, real_text=None) plt.imshow(phi, cmap='viridis', aspect='auto') plt.colorbar() plt.xlabel("time steps") plt.yticks(np.arange(phi.shape[1]), list("Hello world! "), rotation='horizontal') plt.margins(0.2) plt.subplots_adjust(bottom=0.15) plt.savefig(save_path + "heat_map" + str(best_epoch) + ".png") plt.close() # denormalize the generated offsets using train set mean and std gen_seq = data_denormalization(Global.train_mean, Global.train_std, gen_seq) # plot the sequence plot_stroke(gen_seq[0], save_name=save_path + model_type + "_seq_" + str(best_epoch) + ".png") k = 0 elif k > patience: print("Best model was saved at epoch: {}".format(best_epoch)) print("Early stopping at epoch {}".format(epoch)) break else: k += 1
model_path = args.model_path model = args.model train_dataset = HandwritingDataset(args.data_path, split="train", text_req=args.text_req) if args.prime and args.file_path: style = np.load(args.file_path + "style.npy", allow_pickle=True, encoding="bytes").astype(np.float32) with open(args.file_path + "inpText.txt") as file: texts = file.read().splitlines() real_text = texts[0] # plot the sequence plot_stroke(style, save_name=args.save_path / "style.png") print(real_text) mean, std, _ = data_normalization(style) style = torch.from_numpy(style).unsqueeze(0).to(device) print(style.shape) ytext = real_text + " " + args.char_seq + " " elif args.prime: strokes = np.load(args.data_path + "strokes.npy", allow_pickle=True, encoding="bytes") with open(args.data_path + "sentences.txt") as file: texts = file.read().splitlines() idx = 3949 # np.random.randint(0, len(strokes)) print("Prime style index: ", idx) real_text = texts[idx] style = strokes[idx]
num_workers=1, collate_fn=pad_collate): self.data_dir = data_dir self.dataset = HandWritingDataset(data_dir) super().__init__(self.dataset, batch_size, shuffle, validation_split, num_workers, collate_fn) if __name__ == '__main__': data_directory = '../data' dataset = HandWritingDataset(data_directory) dataloader = HandWritingDataLoader(data_directory, batch_size=32, shuffle=False) # Test dataset print('Test dataset') print('size of the dataset: ', len(dataset)) sent, stroke = dataset[0] print('sentence 1: ', dataset.tensor2sentence(sent)) plot_stroke(stroke.numpy()) # Test dataloader print('\nTest dataloader') batch = next(iter(dataloader)) (sentences, sentences_mask, strokes, strokes_mask) = batch print('shape sentences: ', sentences.shape) print('shape sentences_mask: ', sentences_mask.shape) print('shape strokes: ', strokes.shape) print('shape strokes_mask: ', strokes_mask.shape)
# -*- coding: utf-8 -*- """ Created on Mon Apr 26 23:01:13 2021 @author: prajw """ import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from data import DataSynthesis from models import HandWritingPrediction, HandWritingSynthesis from utils import plot_stroke d = DataSynthesis() generator = d.batch_generator(shuffle=False) _, sentence, target = next(generator) hws = HandWritingSynthesis() hws.make_model(load_weights='models/trained/model_synthesis_overfit.h5') sentence = tf.dtypes.cast(d.prepare_text('independent expert body'), float) strokes, windows, phis, kappas = hws.infer(sentence, seed=18) plot_stroke(strokes)
model.eval() point, hidden = model(point, hidden) point[:, 0] = torch.sigmoid(point[:, 0]) # point[:, 0] = torch.Tensor([1 if x > 0.5 else 0 for x in point[:, 0]]) point[:, 0] = torch.Tensor( np.random.binomial(1, point[:, 0].data.cpu().numpy())).to(device) stroke[k] = point point = point.unsqueeze(0) return np.array(stroke.squeeze().data.cpu()) if __name__ == "__main__": import argparse arg_parser = argparse.ArgumentParser( description="Generate a random stroke") arg_parser.add_argument("--random_seed", "-r", dest="random_seed", required=False, help="The random seed") args = arg_parser.parse_args() random_seed = args.random_seed if args.random_seed else 1 stroke = generate_unconditionally(random_seed) print('Random seed:', random_seed) plot_stroke(stroke)