def create_model(sess, data, args, embed): with tf.variable_scope(args.name): model = VAEModel(data, args, embed) model.print_parameters() latest_dir = '%s/checkpoint_latest' % args.model_dir best_dir = '%s/checkpoint_best' % args.model_dir if tf.train.get_checkpoint_state( latest_dir) and args.restore == "last": print("Reading model parameters from %s" % latest_dir) model.latest_saver.restore(sess, tf.train.latest_checkpoint(latest_dir)) else: if tf.train.get_checkpoint_state( best_dir) and args.restore == "best": print('Reading model parameters from %s' % best_dir) model.best_saver.restore(sess, tf.train.latest_checkpoint(best_dir)) else: print("Created model with fresh parameters.") global_variable = [ gv for gv in tf.global_variables() if args.name in gv.name ] sess.run(tf.variables_initializer(global_variable)) return model
class Trainer(object): def __init__(self, lr=1e-4, **kwargs): self.lr = lr self.vae = VAEModel(**kwargs) self.data_path = "/datasets/Img/img_align_celeba/img_align_celeba" def train(self): step = 0 with self.vae.g_train.as_default(): kl, mse = self.vae.train_graph(self.data_path) loss = kl + mse update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optim = tf.train.AdamOptimizer( learning_rate=self.lr).minimize(loss) tf.summary.scalar('KLLoss', kl) tf.summary.scalar('MSELoss', mse) tf.summary.scalar('TotalLoss', loss) merged = tf.summary.merge_all() writer = tf.summary.FileWriter("/tmp/log", self.vae.g_train) saver = tf.train.Saver() with tf.Session() as s: s.run(tf.global_variables_initializer()) while True: try: _, kl_d, mse_d, loss_d, mgd = s.run( [optim, kl, mse, loss, merged]) step += 1 writer.add_summary(mgd, step) if step % 1000 == 0: print("Saving model...") saver.save(s, "./saved_models/model.ckpt") print( "Iter: {:10d}, KLLoss: {:12.6f}, MSELoss: {:12.6f}, TotalLoss: {:12.6f}" .format(step, kl_d, mse_d, loss_d), end="\n") except tf.errors.OutOfRangeError: print("Training finished...") break
def __init__(self, lr=1e-4, **kwargs): self.lr = lr self.vae = VAEModel(**kwargs) self.data_path = "/datasets/Img/img_align_celeba/img_align_celeba"
if args.task == 'depth': train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'depths') elif args.task == 'segmentation': train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'segs') else: #args.task == 'con2rgb' train_ds, test_ds = import_data.load_dataset_for_training_decoder(args.dset_file, args.num_imgs, args.batch_size, 'rgbs') # prepare boolean vars for decoding layers trainable_layers = [True]*5 frozen_layers = 5 - args.trainable_layers for i in range(frozen_layers): trainable_layers[i] = False # create model, loss and optimizer if args.task == 'depth': model = VAEModel(n_z=args.n_z, trainable_encoder=False, trainable_decoder=trainable_layers, out_channels=1) model.load_weights(args.model_path) elif args.task == 'segmentation': model = VAEModel(n_z=args.n_z, trainable_encoder=False, trainable_decoder=trainable_layers, out_channels=3) model.load_weights(args.model_path) else: # args.task == 'con2rgb' model = VAEModel(n_z=args.n_z, trainable_encoder=True, trainable_decoder=trainable_layers, out_channels=3) model.load_decoder_weights_only(args.model_path) optimizer = tf.keras.optimizers.Adam() # create text log file log_file = open(os.path.join(args.output_dir,"log.txt"),"w") log_file.write("Epoch\tTrainLoss\tTestLoss\n") # define metrics train_loss = tf.keras.metrics.Mean(name='train_loss')
# add special tokens special_tokens_dict = { 'pad_token': '<|startoftext|>', 'cls_token': '<|startofcond|>', 'sep_token': '<|sepofcond|>', 'mask_token': '<|endofcond|>' } num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) print('We have added', num_added_toks, 'special tokens') # Notice: resize_token_embeddings expect to receive the full size of the new vocab gpt2_model.resize_token_embeddings(len(tokenizer)) assert tokenizer.pad_token == '<|startoftext|>' VAE = VAEModel(config, add_input=add_input, add_attn=add_attn, add_softmax=add_softmax) init_para_frompretrained(VAE.transformer, gpt2_model.transformer, share_para=True) init_para_frompretrained(VAE.encoder, gpt2_model.transformer, share_para=False) VAE.lm_head.weight = gpt2_model.lm_head.weight if VAE.add_softmax: VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size()) print('VAE_params:', num_params(VAE)) # 286694400 print('Done.') print('Loading model weights...') state = torch.load(os.path.join(args.load, 'model_latest.pt'), map_location='cpu') if 'module' in list(state.keys(
def main_worker(gpu, ngpus_per_node, args): if args.model_type == 'cvae': args.learn_prior = True else: args.learn_prior = False # GPU args.gpu = gpu print("There are ", torch.cuda.device_count(), " available GPUs!") # print('Setting GPUs {}'.format(args.device)) print('Using GPU devices {}'.format(devices)) device = torch.device('cuda', args.gpu) torch.cuda.set_device(device) print('Current single GPU: {}'.format(torch.cuda.current_device())) # randomness np.random.seed(args.seed) prng = np.random.RandomState() torch.random.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # For multiprocessing distributed training, rank needs to be the global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu print('Setting rank', args.rank) recon_attempt = 1 connected = False if args.rank != 0: # Stall to have rank 0 node go first time.sleep(3) while not connected: try: dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) connected = True print('Established connection. Rank:', args.rank) except Exception as e: # Sometimes the head node launches after the worker, which would cause an issue print('Failed to init process group. Retrying...', recon_attempt, e) recon_attempt += 1 time.sleep(10) # logging if args.rank == 0: save_folder = os.path.join(args.out_dir, args.experiment) os.makedirs(save_folder, exist_ok=True) t_writer = SummaryWriter(os.path.join(save_folder, 'train'), flush_secs=5) v_writer = SummaryWriter(os.path.join(save_folder, 'val'), flush_secs=5) importlib.reload(logging) logging.basicConfig(filename=os.path.join(save_folder, 'train.log'), level=logging.INFO, format='%(asctime)s--- %(message)s') logging.info( '\n*******************************************************************************\n' ) logging.info("the configuration:") logging.info(str(args).replace(',', '\n')) print('Loading models...') cache_dir = os.path.join(args.out_dir, 'model_cache') os.makedirs(cache_dir, exist_ok=True) # Load pre-trained teacher tokenizer (vocabulary) tokenizer = GPT2Tokenizer.from_pretrained('gpt2', cache_dir=cache_dir) # Hack to allow tokenizing longer sequences. tokenizer.max_len = int(1e12) gpt2_model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir=cache_dir) print('gpt2_params:', num_params(gpt2_model)) # gpt2: 124439808 config = GPT2Config() # add special tokens # special_tokens_dict = { # 'pad_token': '<|startoftext|>', # 'cls_token': '<|startofcond|>', # 'sep_token': '<|sepofcond|>', # 'mask_token': '<|endofcond|>' # } # num_added_toks = tokenizer.add_special_tokens(special_tokens_dict) # print('We have added', num_added_toks, 'special tokens') # # Notice: resize_token_embeddings expect to receive the full size of the new vocab # gpt2_model.resize_token_embeddings(len(tokenizer)) # assert tokenizer.pad_token == '<|startoftext|>' VAE = VAEModel(config, add_input=args.add_input, add_attn=args.add_attn, add_softmax=args.add_softmax, attn_proj_vary=args.attn_proj_vary, learn_prior=args.learn_prior) init_para_frompretrained(VAE.transformer, gpt2_model.transformer, share_para=True) init_para_frompretrained(VAE.encoder, gpt2_model.transformer, share_para=False) if args.learn_prior: init_para_frompretrained(VAE.encoder_prior, VAE.encoder, share_para=True) VAE.encoder_prior.averageSelfAttention.attention_weights = VAE.encoder.averageSelfAttention.attention_weights VAE.lm_head.weight = gpt2_model.lm_head.weight if VAE.add_softmax: VAE.lm_head_rep = Conv1D(*gpt2_model.lm_head.weight.size()) # VAE.lm_head_rep = LM_head_rep(*gpt2_model.lm_head.weight.size()[::-1]) print('VAE_params:', num_params(VAE)) # 286694400 if args.load: print('Loading model weights...') state = torch.load(os.path.join( args.load, 'model_latest.pt')) # , map_location='cpu' if 'module' in list(state.keys( ))[0]: # model_path is data parallel model with attr 'module' state_copy = copy.copy(state) keys = state_copy.keys() for k in keys: state[k.replace('module.', '')] = state.pop(k) VAE.load_state_dict(state) gc.collect() print('Done.') # fix pre-trained parameters before certain iterations tuning_all_after_iters = 10000 tuning_all = False for name, parameter in VAE.named_parameters(): # print((name, parameter.requires_grad)) new_pars = [ 'c_z', 'attention_weights', 'mean', 'logvar', 'input_proj', 'attn_proj', 'Nu_fc1', 'Nu_fc2', 'lm_head_rep' ] if not any([True if n in name else False for n in new_pars]): parameter.requires_grad = False print('Setup data...') # Batch and sequence length schedule assert len(args.batch_sizes) == len(args.seq_lens) batch_schedule = list( zip(map(int, args.batch_sizes), map(int, args.seq_lens))) assert len( batch_schedule) <= 2, 'Currently not supporting multiple schedule' cur_b_schedule = len(batch_schedule) - 1 if args.switch_time == 0 else 0 print('Batch schedule', batch_schedule) train_loader, val_loader, test_loader = prepare_dataset( args.data_dir, args.dataset, tokenizer, batch_schedule[cur_b_schedule][0], batch_schedule[cur_b_schedule][1], batch_schedule[-1][0], batch_schedule[-1][1], batch_schedule[-1][0], batch_schedule[-1][1], make_test=True, num_workers=args.workers, data_type=args.data_type) print('Done.') ### val_loader = test_loader ### print('Wrapping models and optimizers...') # Apply linear scaling rule to increase batch size for short sequence training. lr_schedule = switch_schedule( linear_schedule(args), batch_schedule[cur_b_schedule][0] / batch_schedule[-1][0], int(args.iterations * args.switch_time)) VAE = VAE.to(device) VAE = VAE.train() optimizer = AdamW(VAE.parameters(), lr=args.lr, correct_bias=True) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule) VAE, optimizer = amp.initialize(VAE, optimizer, opt_level=args.fp16_opt_level) loss_model = DDP(VAE) # , delay_allreduce=True loss_fn = nn.CrossEntropyLoss(reduction='none') print('Done.') print('Begin training iterations') logging.info("Begin training iterations") max_val_batches = 20000 # max num. of val batches logging.info("Total iteration: %d" % args.iterations) e = 0 # number of epoch num_iters = 0 optimizer.zero_grad() beta = args.beta_0 endoftext = tokenizer.convert_tokens_to_ids("<|endoftext|>") def val_step(val_loader): VAE.eval() n_words_bpe = 0 n_words = 0 logp_sum = 0.0 kl_loss_sum = 0.0 logging.info("Validation loop. Batches: %d" % len(val_loader)) logging.info("Validation loop. max_val_batches: %d" % max_val_batches) # val_iter = iter(val_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(val_iter) with tqdm(total=min(len(val_loader), max_val_batches)) as pbar: for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(val_loader): with torch.no_grad(): if args.model_type == 'cvae': loss, ce_loss, kl_loss = compute_loss( device, VAE, x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask, loss_fn, 1.0) else: loss, ce_loss, kl_loss = compute_loss_ae( device, VAE, x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask, loss_fn, 1.0) if len(target_tokens.size()) == 1: target_tokens = target_tokens.unsqueeze(0) n, l = target_tokens.size() text = target_tokens[0, :].tolist() logprob = ce_loss.tolist() assert len(text) == len(logprob) # only for story idx = text.index(endoftext) text = text[idx + 1:] logprob = logprob[idx + 1:] if endoftext in text: idx = text.index(endoftext) text = text[:idx] logprob = logprob[:idx] logp_sum += sum(logprob) n_words_bpe += len(text) story = [ tokenizer.decode(target_tokens[i, :]) for i in range(n) ] story = [ s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in story ] story = [ s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in story ] words = sum([ len([ t for t in re.split( '("|\'|!|\?|\.|,|:| |\n|’|“|”|;|\(|\)|`)', s) if t != ' ' and t != '' ]) for s in story ]) n_words += words kl_loss_sum += kl_loss.item() if i > max_val_batches: break pbar.update(1) loss_bpe = logp_sum / n_words_bpe ppl_bpe = round(math.exp(min(logp_sum / n_words_bpe, 100)), 3) ppl_word = round(math.exp(min(logp_sum / n_words, 100)), 3) kl = kl_loss_sum / len(val_loader) v_writer.add_scalar('loss', loss_bpe, num_iters) v_writer.add_scalar('ppl_bpe', ppl_bpe, num_iters) v_writer.add_scalar('ppl_word', ppl_word, num_iters) v_writer.add_scalar('kl', kl, num_iters) logging.info('val loss : %.4f' % loss_bpe) logging.info('val ppl_bpe : %.4f' % ppl_bpe) logging.info('val ppl_word: %.4f' % ppl_word) logging.info('val kl : %.4f' % kl) VAE.train() def test_plot(test_loader, num_iters): VAE.eval() # get embedding X_emb = None y = None # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter) with tqdm(total=len(test_loader)) as pbar: for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(test_loader): y_mask = y_mask.to(device) y_tokens = y_tokens.to(device) x_mask = x_mask.to(device) x_tokens = x_tokens.to(device) with torch.no_grad(): if args.model_type == 'cvae': latent_mean, _ = VAE.encoder_prior( input_ids=x_tokens, attention_mask=x_mask)[:2] else: latent_mean, _ = VAE.encoder(input_ids=x_tokens, attention_mask=x_mask)[:2] if args.dataset == 'ax' or args.dataset == 'yp': label = [ tokenizer.decode(l)[:2] for l in x_tokens.tolist() ] elif args.dataset == 'wp': label = [] prompts = [ tokenizer.decode(l)[:6].lower() for l in x_tokens.tolist() ] for prom in prompts: if prom[0] in ['[', '('] and prom[5] in [']', ')']: label.append(prom[2:4]) else: label.append(None) elif args.dataset == 'wi': # 0. TV, play, miniseries, telenovela; 1.film; 2. music; 3. manga, comic, 4. book, novel, story 5. game label = [] prompts = [tokenizer.decode(l) for l in x_tokens.tolist()] for prom in prompts: if 'TV' in prom or 'play' in prom or 'miniseries' in prom or 'telenovela' in prom: label.append(0) elif 'film' in prom: label.append(1) elif 'music' in prom: label.append(2) elif 'manga' in prom or 'comic' in prom: label.append(3) elif 'book' in prom or 'novel' in prom or 'story' in prom: label.append(4) elif 'game' in prom: label.append(5) else: label.append(None) else: raise Exception if i == 0: X_emb = latent_mean.data y = label else: X_emb = torch.cat((X_emb, latent_mean.data), dim=0) y.extend(label) pbar.update(1) X_emb = X_emb.cpu().numpy() try: if args.dataset == 'yp': y = ['0' if l in ['0', '1'] else l for l in y] y = ['4' if l in ['3', '4'] else l for l in y] X_emb = X_emb[[l != '2' for l in y], :] y = [l for l in y if l != '2'] if args.dataset == 'wp': topics = [['wp', 'sp', 'tt'], ['eu'], ['cw'], ['pm'], ['mp', 'ip'], ['pi', 'cc'], ['ot'], ['rf']] match = [[True if l in t else False for t in topics] for l in y] y = [m.index(True) if True in m else None for m in match] X_emb = X_emb[[l is not None for l in y], :] y = [l for l in y if l is not None] if args.dataset == 'wi': X_emb = X_emb[[l is not None for l in y], :] y = [l for l in y if l is not None] # to 2D # X_emb_2d = TSNE(n_components=2, init='pca', verbose=1).fit_transform(X_emb) X_emb_2d = TSNE(n_components=2, verbose=1, perplexity=40).fit_transform(X_emb) def remove_outliers(data, r=2.0): outliers_data = abs( data - np.mean(data, axis=0)) >= r * np.std(data, axis=0) outliers = np.any(outliers_data, axis=1) keep = np.logical_not(outliers) return outliers, keep outliers, keep = remove_outliers(X_emb_2d) X_emb_2d = X_emb_2d[keep, :] y = [l for l, k in zip(y, keep.tolist()) if k] # plot fig = plt.figure(figsize=(4, 4)) ax = fig.add_axes([0, 0, 1, 1]) cc = ['r', 'b', 'g', 'y', 'k', 'c', 'm', 'tab:blue'] for i, l in enumerate(sorted(set(y))): idx = [yl == l for yl in y] plt.scatter(X_emb_2d[idx, 0], X_emb_2d[idx, 1], c=cc[i], s=10, edgecolor='none', alpha=0.5) ax.axis('off') # adding it will get no axis plt.savefig( os.path.join(save_folder, 'tSNE_' + '{:07d}'.format(num_iters) + '.png')) plt.close(fig) except: pass VAE.train() def generate(test_loader, num_iters): VAE.eval() n_samples = 0 bleu4_sum = 0.0 rouge_scores_values_sum = [0.0] * 9 args.nsamples = 1 args.batch_size = 1 args.temperature = 0.95 args.top_k = 100 args.top_p = 0.95 model_type = args.model_type # write samples to file samples_file = open(os.path.join( save_folder, 'generate-' + '%07d' % num_iters + '.txt'), 'w', encoding='utf8') # test_iter = iter(test_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(test_iter) with tqdm(total=len(test_loader)) as pbar: for i_test, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(test_loader): if i_test >= 10: break length = -1 if length == -1: length = VAE.config.n_ctx - x_tokens.size(1) - 1 elif length > VAE.config.n_ctx - x_tokens.size(1) - 1: raise ValueError( "Can't get samples longer than window size: %s" % VAE.config.n_ctx) eff_samples = [] n, l = target_tokens.size() storys = [ tokenizer.decode(target_tokens[i, :]) for i in range(n) ] storys = [ s[s.find("<|endoftext|>") + len("<|endoftext|>"):] for s in storys ] storys_str = [ s[:s.find("<|endoftext|>") + len("<|endoftext|>")] if "<|endoftext|>" in s else s for s in storys ] for _ in range(args.nsamples // args.batch_size): # model, batch_size, temperature, top_k, top_p, eos_token, sample = VAE, args.batch_size, args.temperature, args.top_k, args.top_p, tokenizer.encoder['<|endoftext|>'], True out, _ = sample_sequence( model=VAE, tokenizer=tokenizer, length=length, batch_size=args.batch_size, x_mask=x_mask, x_tokens=x_tokens, y_mask=y_mask, y_tokens=y_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, device=device, eos_token=tokenizer.encoder['<|endoftext|>'], model_type=model_type) out = out.tolist() # extract story, check metrics for i in range(len(out)): text = out[i] text = text[text.index(endoftext) + 1:] if endoftext in text: idx = text.index(endoftext) text = text[:idx] text = tokenizer.decode(text).strip() # score for one long text, higher than 0.075 usually means repetition # rep_score = repeat_score(text.split(), ngram=[3, 4, 5, 6, 7, 8]) # if rep_score > 0.075: # # print(rep_score) # continue try: # check bleu bleu4 = sentence_bleu( [storys_str[i].split()], text, smoothing_function=SmoothingFunction().method7) # check rouge rouge = Rouge() rouge_scores = rouge.get_scores( text, storys_str[i]) rouge_scores_values = [ v for k in rouge_scores[0].keys() for v in rouge_scores[0][k].values() ] bleu4_sum += bleu4 rouge_scores_values_sum = [ v1 + v2 for v1, v2 in zip(rouge_scores_values_sum, rouge_scores_values) ] n_samples += 1 except: bleu4 = 0.0 rouge_scores = [{ 'rouge-1': { 'f': 0.0, 'p': 0.0, 'r': 0.0 }, 'rouge-2': { 'f': 0.0, 'p': 0.0, 'r': 0.0 }, 'rouge-l': { 'f': 0.0, 'p': 0.0, 'r': 0.0 } }] eff_samples.append((text, bleu4, rouge_scores)) pbar.update(1) for i in range(len(eff_samples)): samples_file.write("=" * 50 + " SAMPLE " + str(i_test) + " " + "=" * 50) samples_file.write('\n' * 2) samples_file.write("=" * 40 + " Outlines " + "=" * 40) samples_file.write('\n' * 2) samples_file.write( tokenizer.decode( x_tokens[i, :][x_mask[i, :] == 1].tolist())) samples_file.write('\n' * 2) samples_file.write("=" * 40 + " Story " + "=" * 40) samples_file.write('\n' * 2) samples_file.write(storys_str[i]) samples_file.write('\n' * 2) samples_file.write("=" * 40 + " Generated " + "=" * 40) samples_file.write('\n' * 2) samples_file.write(eff_samples[i][0]) samples_file.write('\n' * 4) samples_file.flush() print('Test complete with %05d samples.' % n_samples) logging.info("Test complete with %05d samples.", n_samples) logging.info("Iteration completed: %d" % num_iters) bleu4 = round(bleu4_sum / n_samples, 3) rouge_scores_values = [ round(r / n_samples, 3) for r in rouge_scores_values_sum ] print(' bleu-4:', bleu4) print(' rouge :', rouge_scores_values) logging.info(' bleu-4: %f', bleu4) logging.info(' rouge : %s', str(rouge_scores_values)) VAE.train() if args.rank == 0: test_plot(test_loader, num_iters) val_step(val_loader) generate(test_loader, num_iters) torch.save( VAE.state_dict(), os.path.join(save_folder, 'model_' + '{:07d}'.format(num_iters) + '.pt')) while num_iters < args.iterations: # Run epoch st = time.time() # Training print('Training loop. Batches:', len(train_loader)) logging.info( '\n----------------------------------------------------------------------' ) logging.info("Training loop. Batches: %d" % len(train_loader)) # train_iter = iter(train_loader); x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask = next(train_iter) with tqdm(total=len(train_loader)) as pbar: for i, (x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask) in enumerate(train_loader): if num_iters % args.cycle >= args.cycle - args.beta_warmup - 25000: beta = min(1.0, beta + (1. - args.beta_0) / args.beta_warmup) if not tuning_all and num_iters >= tuning_all_after_iters: for name, parameter in VAE.named_parameters(): # print((name, parameter.requires_grad)) parameter.requires_grad = True tuning_all = True output = train_step(device, loss_model, optimizer, x_mask, x_tokens, y_mask, y_tokens, input_tokens, target_tokens, mask, loss_fn, beta, args.model_type) if args.rank == 0: loss, ce_loss, kl_loss = output[-1] lr = scheduler.get_last_lr()[0] # Log to Tensorboard t_writer.add_scalar('loss', loss, num_iters) t_writer.add_scalar('ppl', math.exp(min(ce_loss, 10)), num_iters) t_writer.add_scalar('lr', lr, num_iters) t_writer.add_scalar('iter_time', time.time() - st, num_iters) t_writer.add_scalar('kl', kl_loss, num_iters) t_writer.add_scalar('beta', beta, num_iters) if args.model_type == 'ae_vae_fusion': loss, ce_loss, kl_loss = output[0] # Log to Tensorboard t_writer.add_scalar('ae_loss', loss, num_iters) t_writer.add_scalar('ae_kl', kl_loss, num_iters) st = time.time() end = num_iters >= args.iterations if args.warmup != -1: scheduler.step() if end: break num_iters += 1 pbar.update(1) if num_iters % args.cycle == 0: beta = args.beta_0 logging.info('KL annealing restart') if args.rank == 0 and num_iters % 10000 == 0: test_plot(test_loader, num_iters) val_step(val_loader) generate(test_loader, num_iters) if args.rank == 0 and num_iters % 10000 == 0: print('Saving model...') logging.info("Iteration completed: %d, remained %d" % (num_iters, args.iterations - num_iters)) logging.info("Saving model...") logging.info( '\n------------------------------------------------------' ) torch.save( VAE.state_dict(), os.path.join( save_folder, 'model_' + '{:07d}'.format(num_iters) + '.pt')) if args.switch_time > 0 and num_iters == int( args.iterations * args.switch_time): print('Switch to long sequence training') logging.info("Switch to long sequence training") cur_b_schedule += 1 train_loader, val_loader, test_loader = prepare_dataset( args.data_dir, args.dataset, tokenizer, batch_schedule[cur_b_schedule][0], batch_schedule[cur_b_schedule][1], batch_schedule[-1][0], batch_schedule[-1][1], batch_schedule[-1][0], batch_schedule[-1][1], make_test=True, num_workers=args.workers, data_type=args.data_type) if not end: e += 1 logging.info("Training loop. The ith epoch completed: %d" % e) if args.rank == 0: torch.save(VAE.state_dict(), os.path.join(save_folder, 'model_latest.pt')) print('Training complete.') logging.info("Training complete.")
parser.add_argument('--show', '-show', help='choose what to do from [predict, inter]', default='predict', type=str) parser.add_argument('--method', '-method', help='choose what to do from [restoration, depth, con2rgb]', default='restoration', type=str) parser.add_argument('--grayscale', '-grayscale', dest='grayscale', help='choose for training on grayscale images', action='store_true') args = parser.parse_args() # tf function for prediction @tf.function def predict_image(image, inter=None): return model(image, inter) # allow growth is possible using an env var in tf2.0 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Load the model model = VAEModel(n_z=args.n_z) model.load_weights(args.path) # Load image im = Image.open(args.img) if args.res > 0: im = im.resize((args.res, args.res), Image.ANTIALIAS) input_img = (np.expand_dims(np.array(im), axis=0) / 255.0).astype(np.float32) if len(input_img.shape) < 4: input_img = np.expand_dims(np.array(input_img), axis=-1) input_img = np.repeat(input_img, 3, axis=-1) if args.show == 'inter': # interpolate over the latent space latent_array = np.array([-0.02, -0.01, 0.0, 0.01, 0.02])
if y_test[i] == 9: temp = x_test[i, :, :, :] x_test_9.append( temp.reshape( (x_train_shape[1], x_train_shape[2], x_train_shape[3]))) x_test_1 = np.array(x_test_1) x_test_9 = np.array(x_test_9) # network parameters input_shape = (8, 8, 1) batch_size = 128 epochs = 5 # loading model model = VAEModel(input_shape) vae = model.load_model(input_shape) # train the autoencoder vae.fit(x_train_1, epochs=epochs, batch_size=batch_size) #validation_data=(x_test, None)) vae.save('vae_mlp_mnist.h5', include_optimizer=False) #正常/異常のテストデータ idx1 = np.random.randint(len(x_test_1)) idx2 = np.random.randint(len(x_test_9)) test_normal = x_test_1[idx1, :, :, :] test_anomaly = x_test_9[idx2, :, :, :] test_normal = test_normal.reshape(1, test_normal.shape[0],
test_kl_loss(kl_loss) # allow growth is possible using an env var in tf2.0 os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu # get train and test datasets train_ds, test_ds = import_data.load_dataset_for_frame_restoration( args.data_file, args.num_imgs, args.batch_size, random_sample=args.random_sample) # create model, loss and optimizer model = VAEModel(n_z=args.n_z, out_channels=3) optimizer = tf.keras.optimizers.Adam() # define metrics train_rec_loss = tf.keras.metrics.Mean(name='train_rec_loss') train_kl_loss = tf.keras.metrics.Mean(name='train_kl_loss') test_rec_loss = tf.keras.metrics.Mean(name='test_rec_loss') test_kl_loss = tf.keras.metrics.Mean(name='test_kl_loss') metrics_writer = tf.summary.create_file_writer(args.output_dir) # check if output folder exists if not os.path.isdir(args.output_dir): os.makedirs(args.output_dir) # train print('Start training...')