class Solver(utils.BaseSolver): def build(self): self.model = VAE(self.config) if self.config.cuda: self.model.cuda() print(self.model) # Build Optimizer (Training Only) if self.config.mode == 'train': self.optimizer = self.config.optimizer( self.model.parameters(), lr=self.config.learning_rate, betas=(self.config.beta1, 0.999)) self.loss_function = layers.VAELoss(self.config.recon_loss) def train_step(self, images): # Reconstruct Images recon_images, mu, log_variance = self.model(images) # Calculate loss recon_loss, kl_div = self.loss_function(images, recon_images, mu, log_variance) return recon_loss, kl_div
def load_networks(isTraining=False): depth_net = DepthNetModel() color_net = ColorNetModel() d_net = VAE() if param.useGPU: depth_net.cuda() color_net.cuda() d_net.cuda() depth_optimizer = optim.Adam(depth_net.parameters(), lr=param.alpha, betas=(param.beta1, param.beta2), eps=param.eps) color_optimizer = optim.Adam(color_net.parameters(), lr=param.alpha, betas=(param.beta1, param.beta2), eps=param.eps) d_optimizer = optim.Adam(d_net.parameters()) if isTraining: netFolder = param.trainNet # 'TrainingData' netName, _, _ = get_folder_content(netFolder) net = [] for target in sorted(netName): if target[-4:] == '.tar': net.append(target) if param.isContinue and net: tokens = net[0].split('-')[1].split('.')[0] param.startIter = int(tokens) checkpoint = torch.load(netFolder + '/' + net[0]) depth_net.load_state_dict(checkpoint['depth_net']) color_net.load_state_dict(checkpoint['color_net']) d_net.load_state_dict(checkpoint['d_net']) depth_optimizer.load_state_dict(checkpoint['depth_optimizer']) color_optimizer.load_state_dict(checkpoint['color_optimizer']) d_optimizer.load_state_dict(checkpoint['d_optimizer']) else: param.isContinue = False else: netFolder = param.testNet netName, _, _ = get_folder_content(netFolder) net = [] for target in sorted(netName): if target[-4:] == '.tar': net.append(target) checkpoint = torch.load(netFolder + '/' + net[0]) depth_net.load_state_dict(checkpoint['depth_net']) color_net.load_state_dict(checkpoint['color_net']) d_net.load_state_dict(checkpoint['d_net']) depth_optimizer.load_state_dict(checkpoint['depth_optimizer']) color_optimizer.load_state_dict(checkpoint['color_optimizer']) d_optimizer.load_state_dict(checkpoint['d_optimizer']) return depth_net, color_net, d_net, depth_optimizer, color_optimizer, d_optimizer
def main(): args = parse_arguments() hidden_size = 300 embed_size = 50 kld_weight = 0.05 temperature = 0.9 use_cuda = torch.cuda.is_available() print("[!] preparing dataset...") TEXT = data.Field(lower=True, fix_length=30) LABEL = data.Field(sequential=False) train_data, test_data = datasets.IMDB.splits(TEXT, LABEL) TEXT.build_vocab(train_data, max_size=250000) LABEL.build_vocab(train_data) train_iter, test_iter = data.BucketIterator.splits( (train_data, test_data), batch_size=args.batch_size, repeat=False) vocab_size = len(TEXT.vocab) + 2 print("[!] Instantiating models...") encoder = EncoderRNN(vocab_size, hidden_size, embed_size, n_layers=2, dropout=0.5, use_cuda=use_cuda) decoder = DecoderRNN(embed_size, hidden_size, vocab_size, n_layers=2, dropout=0.5, use_cuda=use_cuda) vae = VAE(encoder, decoder) optimizer = optim.Adam(vae.parameters(), lr=args.lr) if use_cuda: print("[!] Using CUDA...") vae.cuda() best_val_loss = None for e in range(1, args.epochs + 1): train(e, vae, optimizer, train_iter, vocab_size, kld_weight, temperature, args.grad_clip, use_cuda, TEXT) val_loss = evaluate(vae, test_iter, vocab_size, kld_weight, use_cuda) print("[Epoch: %d] val_loss:%5.3f | val_pp:%5.2fS" % (e, val_loss, math.exp(val_loss))) # Save the model if the validation loss is the best we've seen so far. if not best_val_loss or val_loss < best_val_loss: print("[!] saving model...") if not os.path.isdir("snapshot"): os.makedirs("snapshot") torch.save(vae.state_dict(), './snapshot/vae_{}.pt'.format(e)) best_val_loss = val_loss
class VAETrainer: def __init__(self, dataset): self.model = VAE() if config.USE_GPU: self.model.cuda() self, dataset = dataset self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) self.train_loader = DataLoader(self.dataset(train=True)) self.test_loader = DataLoader(self.dataset(train=False)) def train(self): pass def test(self): pass
def train(data_loader, model_index, x_eval_train, loaded_model): ### Model Initiation if loaded_model: vae = VAE() vae.cuda() saved_state_dict = tor.load(loaded_model) vae.load_state_dict(saved_state_dict) vae.cuda() else: vae = VAE() vae = vae.cuda() loss_func = tor.nn.MSELoss().cuda() #optim = tor.optim.SGD(fcn.parameters(), lr=LR, momentum=MOMENTUM) optim = tor.optim.Adam(vae.parameters(), lr=LR) lr_step = StepLR(optim, step_size=LR_STEPSIZE, gamma=LR_GAMMA) ### Training for epoch in range(EPOCH): print("|Epoch: {:>4} |".format(epoch + 1)) ### Training for step, (x_batch, y_batch) in enumerate(data_loader): print("Process: {}/{}".format(step, int(AVAILABLE_SIZE[0] / BATCHSIZE)), end="\r") x = Variable(x_batch).cuda() y = Variable(y_batch).cuda() out, KLD = vae(x) recon_loss = loss_func(out.cuda(), y) loss = (recon_loss + KLD_LAMBDA * KLD) loss.backward() optim.step() lr_step.step() optim.zero_grad() if step % RECORD_JSON_PERIOD == 0: save_record(model_index, epoch, optim, recon_loss, KLD) if step % RECORD_PIC_PERIOD == 0: save_pic("output_{}".format(model_index), vaee, 3) if step % RECORD_MODEL_PERIOD == 0: tor.save( vae.state_dict(), os.path.join(MODEL_ROOT, "ave_model_{}.pkl".format(model_index)))
def extract(fs, idx, N): model = VAE() model.load_state_dict( torch.load(cfg.vae_save_ckpt, map_location=lambda storage, loc: storage)['model']) model = model.cuda(idx) for n, f in enumerate(fs): data = np.load(f) imgs = data['sx'].transpose(0, 3, 1, 2) actions = data['ax'] rewards = data['rx'] dones = data['dx'] x = torch.from_numpy(imgs).float().cuda(idx) / 255.0 mu, logvar, _, z = model(x) save_path = "{}/{}".format(cfg.seq_extract_dir, f.split('/')[-1]) np.savez_compressed(save_path, mu=mu.detach().cpu().numpy(), logvar=logvar.detach().cpu().numpy(), dones=dones, rewards=rewards, actions=actions) if n % 10 == 0: print('Process %d: %5d / %5d' % (idx, n, N))
if (epoch % save_epoch == 0) or (epoch == training_epochs - 1): torch.save(scan.state_dict(), '{}/scan_epoch_{}.pth'.format(exp, epoch)) data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE() scan = SCAN() if use_cuda: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth')) vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth')) scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth')) dae, vae, scan = dae.cuda(), vae.cuda(), scan.cuda() else: dae.load_state_dict( torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) vae.load_state_dict( torch.load('save/vae/vae_epoch_2999.pth', map_location=lambda storage, loc: storage)) scan.load_state_dict( torch.load(exp + '/' + opt.load, map_location=lambda storage, loc: storage)) if opt.train: scan_optimizer = optim.Adam(scan.parameters(), lr=1e-4, eps=1e-8) train_scan(dae, vae, scan, data_manager, scan_optimizer)
def train_vae(args): hidden_size = 300 embed_size = 50 kld_start_inc = 2 kld_weight = 0.05 kld_max = 0.1 kld_inc = 0.000002 temperature = 0.9 temperature_min = 0.5 temperature_dec = 0.000002 USE_CUDA = torch.cuda.is_available() print_loss_total = 0 print("[!] preparing dataset...") TEXT = data.Field(lower=True, fix_length=30) LABEL = data.Field(sequential=False) train, test = datasets.IMDB.splits(TEXT, LABEL) TEXT.build_vocab(train, max_size=250000) LABEL.build_vocab(train) train_iter, test_iter = data.BucketIterator.splits( (train, test), batch_size=args.batch_size, repeat=False) vocab_size = len(TEXT.vocab) + 2 print("[!] Instantiating models...") encoder = EncoderRNN(vocab_size, hidden_size, embed_size, n_layers=1, use_cuda=USE_CUDA) decoder = DecoderRNN(embed_size, hidden_size, vocab_size, n_layers=2, use_cuda=USE_CUDA) vae = VAE(encoder, decoder) optimizer = optim.Adam(vae.parameters(), lr=args.lr) vae.train() if USE_CUDA: print("[!] Using CUDA...") vae.cuda() for epoch in range(1, args.epochs + 1): for b, batch in enumerate(train_iter): x, y = batch.text, batch.label if USE_CUDA: x, y = x.cuda(), y.cuda() optimizer.zero_grad() m, l, z, decoded = vae(x, temperature) if temperature > temperature_min: temperature -= temperature_dec recon_loss = F.cross_entropy(decoded.view(-1, vocab_size), x.contiguous().view(-1)) kl_loss = -0.5 * (2 * l - torch.pow(m, 2) - torch.pow(torch.exp(l), 2) + 1) kl_loss = torch.clamp(kl_loss.mean(), min=0.2).squeeze() loss = recon_loss + kl_loss * kld_weight if epoch > 1 and kld_weight < kld_max: kld_weight += kld_inc loss.backward() ec = nn.utils.clip_grad_norm(vae.parameters(), args.grad_clip) optimizer.step() sys.stdout.write( '\r[%d] [loss] %.4f - recon_loss: %.4f - kl_loss: %.4f - kld-weight: %.4f - temp: %4f' % (b, loss.data[0], recon_loss.data[0], kl_loss.data[0], kld_weight, temperature)) print_loss_total += loss.data[0] if b % 200 == 0 and b != 0: print_loss_avg = print_loss_total / 200 print_loss_total = 0 print("\n[avg loss] - ", print_loss_avg) _, sample = decoded.data.cpu()[:, 0, :].topk(1) print("[ORI]: ", " ".join([TEXT.vocab.itos[i] for i in x.data[:, 0]])) print("[GEN]: ", " ".join([TEXT.vocab.itos[i] for i in sample.squeeze()])) torch.save(vae, './snapshot/vae_{}.pt'.format(epoch))
def main(args): conf = None with open(args.config, 'r') as config_file: config = yaml.load(config_file, Loader=yaml.FullLoader) conf = config['combine'] model_params = config['model'] preprocess_params = config['preprocessor'] date_time = time.strftime('%Y_%m_%d_%H_%M_%S', time.localtime()) path = os.path.join(conf['save_path'], date_time) path = conf['save_path'] model = VAE(model_params['roll_dim'], model_params['hidden_dim'], model_params['infor_dim'], model_params['time_step'], 12) model.load_state_dict(torch.load(conf['model_path'])) if torch.cuda.is_available(): print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device())) model.cuda() else: print('CPU mode') model.eval() pitch_path = conf['p_path'] + ".txt" rhythm_path = conf['r_path'] + ".txt" #chord_path = conf['chord_path'] + ".txt" name1 = pitch_path.split("/")[-3] name2 = rhythm_path.split("/")[-3] name = name1 + "+" + name2 + ".mid" name2 = name1 + "+" + name2 + ".txt" pitch = np.loadtxt(pitch_path) print(pitch) rhythm = np.loadtxt(rhythm_path) print(rhythm) print("Importing " + name1 + " pitch and " + name2 + " rhythm") #line_graph(pitch,rhythm) #bar_graph(pitch,rhythm) pitch = torch.from_numpy(pitch).float() rhythm = torch.from_numpy(rhythm).float() recon = model.decoder(pitch, rhythm) recon = torch.squeeze(recon, 0) recon = mf._sampling(recon) recon = np.array(recon.cpu().detach().numpy()) length = torch.sum(rhythm).int() recon = recon[:length] #打印生成的音符分布 note = recon[:, :-2] note = np.nonzero(note)[1] note = np.bincount(note, minlength=34).astype(float) recon = mf.modify_pianoroll_dimentions(recon, preprocess_params['low_crop'], preprocess_params['high_crop'], "add") #bar_graph(pitch,rhythm) mf.numpy_to_midi(recon, 120, path, name, preprocess_params['smallest_note']) #pitch_rhythm(recon,path,name2) # write pitch information print("combine succeed")
data_dir, train=True, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.MNIST( data_dir, train=False, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True) torch.manual_seed(seed) if use_gpu: torch.cuda.manual_seed(seed) model = VAE() if use_gpu: model.cuda() optimizer = optim.Adam(model.parameters(), lr=learning_rate) loss_list = [] test_loss_list = [] for epoch in range(num_epochs + 1): # training train_loss = train(epoch, model, optimizer, train_loader) loss_list.append(train_loss) # test test_loss = test(epoch, model, test_loader) test_loss_list.append(test_loss) print('epoch [{}/{}], loss: {:.4f}, test_loss: {:.4f}'.format(
class TrainingModel(object): def __init__(self, args, config): self.__dict__.update(config) self.config = config random.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) if use_cuda: torch.cuda.manual_seed(self.seed) torch.cuda.manual_seed_all(self.seed) torch.cuda.set_device(args.gpu) #torch.backends.cudnn.benchmark = False #torch.backends.cudnn.deterministic = True self.message = args.m self.data_generator = DataGenerator(self.config) self.vocab_size = self.data_generator.vocab_size self.ent_size = self.data_generator.ent_size self.model_name = 'IERM' if args.m != "": self.saveModeladdr = './trainModel/checkpoint_%s_%s.pkl' % ( self.model_name, args.m) else: self.saveModeladdr = './trainModel/' + args.save self.model = Ranker(self.vocab_size, self.ent_size, self.config) self.VAE_model = VAE(self.vocab_size, self.ent_size, self.model.word_emb, self.model.ent_emb, self.config) if use_cuda: self.model.cuda() self.VAE_model.cuda() vae_lr = self.config[ 'pretrain_lr'] if config['pretrain_step'] > 0 else config['vae_lr'] self.vae_optimizer = getOptimizer(config['vae_optim'], self.VAE_model.parameters(), lr=vae_lr, betas=(0.99, 0.99)) self.ranker_optimizer = getOptimizer( config['ranker_optim'], self.model.parameters(), lr=config['ranker_lr'], weight_decay=config['weight_decay']) vae_model_size = sum(p.numel() for p in self.VAE_model.parameters()) ranker_size = sum(p.numel() for p in self.model.parameters()) #print 'Model size: ', vae_model_size, ranker_size #exit(-1) if args.resume and os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #print checkpoint.keys() self.model.load_state_dict(checkpoint['rank_state_dict']) self.VAE_model.load_state_dict(checkpoint['vae_state_dict']) self.vae_optimizer.load_state_dict(checkpoint['vae_optimizer']) self.ranker_optimizer.load_state_dict(checkpoint['rank_optimizer']) else: print("Creating a new model") self.timings = defaultdict(list) #record the loss iterations self.evaluator = rank_eval() self.epoch = 0 self.step = 0 self.kl_weight = 1 if args.visual: self.config['visual'] = True self.writer = SummaryWriter('runs/' + args.m) else: self.config['visual'] = False self.reconstr_loss = nn.MSELoss() def add_values(self, iter, value_dict): for key in value_dict: self.writer.add_scalar(key, value_dict[key], iter) def adjust_learning_rate(self, optimizer, lr, decay_rate=.5): for param_group in optimizer.param_groups: param_group['lr'] = lr * decay_rate def kl_anneal_function(self, anneal_function, step, k=0.0025, x0=2500): if anneal_function == 'logistic': return float(1 / (1 + np.exp(-k * (step - x0)))) elif anneal_function == 'linear': return min(1, step / x0) def vae_loss(self, input_qw, reconstr_w, input_qe, reconstr_e, prior_mean, prior_var, posterior_mean, posterior_var, posterior_log_var): # Reconstruction term if self.config['reconstruct'] != 'entity': input_qw_bow = to_bow(input_qw, self.vocab_size) input_qw_bow = Tensor2Varible(torch.tensor(input_qw_bow).float()) #reconstr_w = torch.log_softmax(reconstr_w + 1e-10,dim=1) #RL_w = -torch.sum(input_qw_bow * reconstr_w , dim=1) #RL_w = self.reconstr_loss(reconstr_w,input_qw_bow) RL_w = -torch.sum( input_qw_bow * reconstr_w + (1 - input_qw_bow) * torch.log(1 - torch.exp(reconstr_w)), dim=1) else: RL_w = Tensor2Varible(torch.tensor([0]).float()) if self.config['reconstruct'] != 'word': input_qe_bow = to_bow(input_qe, self.ent_size) input_qe_bow = Tensor2Varible(torch.tensor(input_qe_bow).float()) #RL_e = -torch.sum(input_qe_bow * reconstr_e, dim=1) #RL_e = self.reconstr_loss(reconstr_e,input_qe_bow) RL_e = -torch.sum( input_qe_bow * reconstr_e + (1 - input_qe_bow) * torch.log(1 - torch.exp(reconstr_e)), dim=1) else: RL_e = Tensor2Varible(torch.tensor([0]).float()) # KL term # var division term var_division = torch.sum(posterior_var / prior_var, dim=1) # diff means term diff_means = prior_mean - posterior_mean diff_term = torch.sum((diff_means * diff_means) / prior_var, dim=1) # logvar det division term logvar_det_division = \ prior_var.log().sum() - posterior_log_var.sum(dim=1) # combine terms KL = 0.5 * (var_division + diff_term - self.model.intent_num + logvar_det_division) loss = self.kl_weight * KL + RL_w + RL_e #loss = 0.001 * KL + RL_w + RL_e return loss.sum(), KL.sum(), RL_w.sum(), RL_e.sum() def pretraining(self): if self.pretrain_step <= 0: return train_start_time = time.time() data_reader = self.data_generator.pretrain_reader(self.pretrain_bs) total_loss = 0. total_KL_loss = 0. total_RLw_loss = 0. total_RLe_loss = 0. for step in xrange(self.pretrain_step): input_qw, input_qe = next(data_reader) #self.kl_weight = self.kl_anneal_function('logistic', step) topic_e, vae_loss, kl_loss, rl_w_loss, rl_e_loss = self.train_VAE( input_qw, input_qe) vae_loss.backward() torch.nn.utils.clip_grad_value_( self.VAE_model.parameters(), self.clip_grad) # clip_grad_norm(, ) self.vae_optimizer.step() vae_loss = vae_loss.data #print ('VAE loss: %.3f\tKL: %.3f\tRL_w:%.3f\tRL_e:%.3f' % (vae_loss, kl_loss, rl_w_loss, rl_e_loss)) if torch.isnan(vae_loss): print("Got NaN cost .. skipping") exit(-1) continue #if self.config['visual']: # self.add_values(step, {'vae_loss': vae_loss, 'kl_loss': kl_loss, 'rl_w_loss': rl_w_loss, # 'rl_e_loss': rl_e_loss, 'kl_weight': self.kl_weight}) total_loss += vae_loss total_KL_loss += kl_loss total_RLw_loss += rl_w_loss total_RLe_loss += rl_e_loss if step != 0 and step % self.pretrain_freq == 0: total_loss /= self.pretrain_freq total_KL_loss /= self.pretrain_freq total_RLw_loss /= self.pretrain_freq total_RLe_loss /= self.pretrain_freq print('Step: %d\t Elapsed:%.2f' % (step, time.time() - train_start_time)) print( 'Pretrain VAE loss: %.3f\tKL: %.3f\tRL_w:%.3f\tRL_e:%.3f' % (total_loss, total_KL_loss, total_RLw_loss, total_RLe_loss)) if self.config['visual']: self.add_values( step, { 'vae_loss': total_loss, 'kl_loss': total_KL_loss, 'rl_w_loss': total_RLw_loss, 'rl_e_loss': total_RLe_loss, 'kl_weight': self.kl_weight }) total_loss = 0. total_KL_loss = 0. total_RLw_loss = 0. total_RLe_loss = 0. print '==============================================' #self.generate_beta_phi_3(show_topic_limit=5) self.save_checkpoint(message=self.message + '-pretraining') print('Pretraining end') #recovering the learning rate self.adjust_learning_rate(self.vae_optimizer, self.config['vae_lr'], 1) def trainIters(self, ): self.step = 0 train_start_time = time.time() patience = self.patience best_ndcg10 = 0.0 last_ndcg10 = 0.0 data_reader = self.data_generator.pair_reader(self.batch_size) total_loss = 0.0 total_rank_loss = 0. total_vae_loss = 0. total_KL_loss = 0. total_RLw_loss = 0. total_RLe_loss = 0. for step in xrange(self.steps): out = next(data_reader) input_qw, input_qe, input_dw_pos, input_de_pos, input_dw_neg, input_de_neg = out rank_loss, vae_total_loss, KL_loss, RL_w_loss, RL_e_loss \ = self.train(input_qw,input_qe,input_dw_pos,input_de_pos,input_dw_neg,input_de_neg) cur_total_loss = rank_loss + vae_total_loss if torch.isnan(cur_total_loss): print("Got NaN cost .. skipping") continue self.step += 1 total_loss += cur_total_loss total_rank_loss += rank_loss total_vae_loss += vae_total_loss total_KL_loss += KL_loss total_RLw_loss += RL_w_loss total_RLe_loss += RL_e_loss if self.eval_freq != -1 and self.step % self.eval_freq == 0: with torch.no_grad(): valid_performance = self.test( valid_or_test='valid', source=self.config['click_model']) current_ndcg10 = valid_performance['ndcg@10'] if current_ndcg10 > best_ndcg10: print 'Got better result, save to %s' % self.saveModeladdr best_ndcg10 = current_ndcg10 patience = self.patience self.save_checkpoint(message=self.message) #self.generate_beta_phi_3(show_topic_limit=5) elif current_ndcg10 <= last_ndcg10 * self.cost_threshold: patience -= 1 last_ndcg10 = current_ndcg10 if self.step % self.train_freq == 0: total_loss /= self.train_freq total_rank_loss /= self.train_freq total_vae_loss /= self.train_freq total_KL_loss /= self.train_freq total_RLw_loss /= self.train_freq total_RLe_loss /= self.train_freq self.timings['train'].append(total_loss) print('Step: %d\t Elapsed:%.2f' % (step, time.time() - train_start_time)) print( 'Train total loss: %.3f\tRank loss: %.3f\tVAE loss: %.3f' % (total_loss, total_rank_loss, total_vae_loss)) print('KL loss: %.3f\tRL W: %.3f\tRL E: %.3f' % (total_KL_loss, total_RLw_loss, total_RLe_loss)) print('Patience left: %d' % patience) if self.config['visual']: self.add_values( step, { 'Train vae_loss': total_loss, 'Train kl_loss': total_KL_loss, 'Train rl_w_loss': total_RLw_loss, 'Train rl_e_loss': total_RLe_loss, 'Train Rank loss': total_rank_loss }) total_loss = 0 total_rank_loss = 0. total_vae_loss = 0. total_KL_loss = 0. total_RLw_loss = 0. total_RLe_loss = 0. if patience < 0: print 'patience runs out...' break print 'Patience___: ', patience print("All done, exiting...") def test(self, valid_or_test, source): predicted = [] results = defaultdict(list) if valid_or_test == 'valid': is_test = False data_addr = self.valid_rank_addr data_source = self.data_generator.pointwise_reader_evaluation( data_addr, is_test=is_test, label_type=source) elif valid_or_test == 'ntcir13' or valid_or_test == 'ntcir14': is_test = True data_source = self.data_generator.pointwise_ntcir_generator( valid_or_test) source = 'HUMAN' else: is_test = True data_addr = self.test_rank_addr data_source = self.data_generator.pointwise_reader_evaluation( data_addr, is_test=is_test, label_type=source) start = time.clock() count = 0 for out in data_source: (qid, dids, input_qw, input_qe, input_dw, input_de, gt_rels) = out gt_rels = map(lambda t: score2cutoff(source, t), gt_rels) rels_predicted = self.predict(input_qw, input_qe, input_dw, input_de).view(-1).cpu().numpy() result = self.evaluator.eval(gt_rels, rels_predicted) for did, gt, pred in zip(dids, gt_rels, rels_predicted): predicted.append((qid, did, pred, gt)) for k, v in result.items(): results[k].append(v) count += 1 elapsed = (time.clock() - start) print('Elapsed:%.3f\tAvg:%.3f' % (elapsed, elapsed / count)) performances = {} for k, v in results.items(): performances[k] = np.mean(v) print '------Source: %s\tPerformance-------:' % source print 'Validating...' if valid_or_test == 'valid' else 'Testing' print 'Message: %s' % self.message print 'Source: %s' % source print performances if valid_or_test != 'valid': path = './results/' + self.message + '_' + valid_or_test + '_' + source if not os.path.exists(path): os.makedirs(path) out_file = open('%s/%s.predicted.txt' % (path, self.model_name), 'w') for qid, did, pred, gt in predicted: print >> out_file, '\t'.join([qid, did, str(pred), str(gt)]) return performances def get_text(self, input, map_fun): text_list = [] for element in input: if element == 0: break text_list.append(map_fun(element)) return ' '.join(text_list) def generate_beta_phi_3(self, topK=10, show_topic_limit=-1): beta, phi = self.VAE_model.infer_topic_dis(topK) topics = defaultdict(list) topics_ents = defaultdict(list) show_topic_num = self.config[ 'intent_num'] if show_topic_limit == -1 else show_topic_limit for i in range(show_topic_num): idxs = beta[i] eidxs = phi[i] component_words = [ self.data_generator.id2word[idx] for idx in idxs.cpu().numpy() ] component_ents = [ self.data_generator.id2ent[self.data_generator.new2old[idx]] for idx in eidxs.cpu().numpy() ] topics[i] = component_words topics_ents[i] = component_ents print '--------Topic-Word-------' prefix = ('./topic/%s/' % args.m) if not os.path.exists(prefix): os.makedirs(prefix) outfile = open(prefix + 'topic-words.txt', 'w') for k in topics: print >> outfile, (str(k) + ' : ' + ' '.join(topics[k])) print >> outfile, (str(k) + ' : ' + ' '.join(topics_ents[k])) return topics, topics_ents def run_test_topic(self, out_file_name, topK, topicNum): topics_words, topics_ents = self.generate_beta_phi_3(topK) data_addr = self.test_rank_addr data_source = self.data_generator.pointwise_reader_evaluation( data_addr, is_test=True, label_type=self.config['click_model']) out_file = open(out_file_name, 'w') with torch.no_grad(): self.VAE_model.eval() self.model.eval() for i, out in enumerate(data_source): (qid, dids, input_qw, input_qe, input_dw, input_de, gt_rels) = out theta = self.VAE_model.get_theta(input_qw, input_qe) input_qw = input_qw[0] input_qe = input_qe[0] input_w = self.get_text( input_qw, lambda w: self.data_generator.id2word[w]) input_e = self.get_text( input_qe, lambda e: self.data_generator.id2ent[ self.data_generator.new2old[e]]) theta = theta[0].data.cpu().numpy() top_indices = np.argsort(theta)[::-1][:3] #print '=========================' print >> out_file, 'Query: ', input_w print >> out_file, 'Entity: ', input_e for j, k in enumerate(top_indices): ws = topics_words[k] es = topics_ents[k] print >> out_file, '%d Word Topic %d: %s' % (j, k, ' '.join(ws)) print >> out_file, '%d Entity Topic %d: %s' % ( j, k, ' '.join(es)) def generate_topic_word_ent(self, out_file, topK=10): print 'Visualizing ...' data_addr = self.test_rank_addr data_source = self.data_generator.pointwise_reader_evaluation( data_addr, is_test=True, label_type=self.config['click_model']) out_file = open(out_file, 'w') with torch.no_grad(): self.VAE_model.eval() self.model.eval() for i, out in enumerate(data_source): (input_qw, input_qe, input_dw, input_de, gt_rels) = out _, word_indices, ent_indices = self.VAE_model.get_topic_words( input_qw, input_qe, topK=topK) word_indices = word_indices[0].data.cpu().numpy() ent_indices = ent_indices[0].data.cpu().numpy() #print 'ent_indices: ', ent_indices #print 'word_indices: ', word_indices input_qw = input_qw[0] input_qe = input_qe[0] input_w = self.get_text( input_qw, lambda w: self.data_generator.id2word[w]) input_e = self.get_text( input_qe, lambda e: self.data_generator.id2ent[ self.data_generator.new2old[e]]) reconstuct_w = self.get_text( word_indices, lambda w: self.data_generator.id2word[w]) reconstuct_e = self.get_text( ent_indices, lambda e: self.data_generator.id2ent[ self.data_generator.new2old[e]]) print >> out_file, ('%d: Word: %s\tRecons: %s' % (i + 1, input_w, reconstuct_w)) print >> out_file, ('%d: Ent: %s\tRecons: %s' % (i + 1, input_e, reconstuct_e)) def train_VAE(self, input_qw, input_qe): self.VAE_model.train() self.VAE_model.zero_grad() self.vae_optimizer.zero_grad() topic_embeddings, logPw, logPe, prior_mean, prior_variance,\ poster_mu, poster_sigma, poster_log_sigma = self.VAE_model(input_qw,input_qe) vae_total_loss, KL, RL_w, RL_e = self.vae_loss( input_qw, logPw, input_qe, logPe, prior_mean, prior_variance, poster_mu, poster_sigma, poster_log_sigma) #vae_total_loss.backward(retain_graph=True) # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. #torch.nn.utils.clip_grad_value_(self.VAE_model.parameters(), self.clip_grad) # clip_grad_norm(, ) #self.vae_optimizer.step() return topic_embeddings, vae_total_loss, KL.data, RL_w.data, RL_e.data def train(self, input_qw, input_qe, input_dw_pos, input_de_pos, input_dw_neg, input_de_neg): # Turn on training mode which enables dropout. self.model.train() self.model.zero_grad() self.ranker_optimizer.zero_grad() topic_embeddings, vae_total_loss, KL_loss, RL_w_loss, RL_e_loss = self.train_VAE( input_qw, input_qe) score_pos, orth_loss_1 = self.model(input_qw, input_qe, input_dw_pos, input_de_pos, topic_embeddings) score_neg, orth_loss_2 = self.model(input_qw, input_qe, input_dw_neg, input_de_neg, topic_embeddings) rank_loss = torch.sum(torch.clamp(1.0 - score_pos + score_neg, min=0)) vae_weight = self.config['intent_lambda'] orth_loss = (orth_loss_1 + orth_loss_2) / 2 total_loss = rank_loss + vae_weight * vae_total_loss + orth_loss total_loss.backward() ## update parameters # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_value_(self.VAE_model.parameters(), self.clip_grad) # clip_grad_norm(, ) torch.nn.utils.clip_grad_value_(self.model.parameters(), self.clip_grad) #clip_grad_norm(, ) self.ranker_optimizer.step() self.vae_optimizer.step() return rank_loss.data, vae_total_loss.data, KL_loss, RL_w_loss, RL_e_loss def predict(self, input_qw, input_qe, input_dw, input_de): # Turn on evaluation mode which disables dropout. with torch.no_grad(): self.VAE_model.eval() self.model.eval() topic_embeddings = self.VAE_model(input_qw, input_qe) rels_predicted, _ = self.model(input_qw, input_qe, input_dw, input_de, topic_embeddings) return rels_predicted def save_checkpoint(self, message): filePath = os.path.join(self.saveModeladdr) #if not os.path.exists(filePath): # os.makedirs(filePath) torch.save( { 'vae_state_dict': self.VAE_model.state_dict(), 'rank_state_dict': self.model.state_dict(), 'vae_optimizer': self.vae_optimizer.state_dict(), 'rank_optimizer': self.ranker_optimizer.state_dict() }, filePath) def get_embeddings(self): word_embeddings = self.model.word_emb.weight.detach().cpu().numpy() ent_embeddings = self.model.ent_emb.weight.detach().cpu().numpy() topic_embeddings = self.model.topic_embedding.detach().cpu().numpy() print 'Topic size: ', topic_embeddings.shape[0] cPickle.dump((word_embeddings, ent_embeddings, topic_embeddings), open('./topic_analysis/w_e_t_embedding.pkl', 'w')) print 'saved' return
class Runner(object): def __init__(self, hparams, train_size: int, class_weight: Optional[Tensor] = None): # model, criterion self.model = VAE() # optimizer and scheduler self.optimizer = torch.optim.Adam(self.model.parameters(), lr=hparams.learning_rate, eps=hparams.eps, weight_decay=hparams.weight_decay) self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, **hparams.scheduler) self.bce = nn.BCEWithLogitsLoss(reduction='none') # self.kld = nn.KLDivLoss(reduction='sum') # device device_for_summary = self.__init_device(hparams.device, hparams.out_device) # summary self.writer = SummaryWriter(logdir=hparams.logdir) # TODO: fill in ~~DUMMY~~INPUT~~SIZE~~ path_summary = Path(self.writer.logdir, 'summary.txt') if not path_summary.exists(): print_to_file(path_summary, summary, (self.model, (40, 11)), dict(device=device_for_summary)) # save hyperparameters path_hparam = Path(self.writer.logdir, 'hparams.txt') if not path_hparam.exists(): print_to_file(path_hparam, hparams.print_params) def __init_device(self, device, out_device): if device == 'cpu': self.in_device = torch.device('cpu') self.out_device = torch.device('cpu') self.str_device = 'cpu' return 'cpu' # device type: List[int] if type(device) == int: device = [device] elif type(device) == str: device = [int(device[-1])] else: # sequence of devices if type(device[0]) != int: device = [int(d[-1]) for d in device] self.in_device = torch.device(f'cuda:{device[0]}') if len(device) > 1: if type(out_device) == int: self.out_device = torch.device(f'cuda:{out_device}') else: self.out_device = torch.device(out_device) self.str_device = ', '.join([f'cuda:{d}' for d in device]) self.model = nn.DataParallel(self.model, device_ids=device, output_device=self.out_device) else: self.out_device = self.in_device self.str_device = str(self.in_device) self.model.cuda(self.in_device) self.bce.cuda(self.out_device) ## torch.cuda.set_device(self.in_device) return 'cuda' # Running model for train, test and validation. def run(self, dataloader, mode: str, epoch: int): self.model.train() if mode == 'train' else self.model.eval() if mode == 'test': state_dict = torch.load(Path(self.writer.logdir, f'{epoch}.pt'), map_location='cpu') if isinstance(self.model, nn.DataParallel): self.model.module.load_state_dict(state_dict) else: self.model.load_state_dict(state_dict) path_test_result = Path(self.writer.logdir, f'test_{epoch}') os.makedirs(path_test_result, exist_ok=True) else: path_test_result = None avg_loss = 0. y = [] y_est = [] pred_prob = [] pbar = tqdm(dataloader, desc=f'{mode} {epoch:3d}', postfix='-', dynamic_ncols=True) for i_batch, batch in enumerate(pbar): # data x = batch['batch_x'] x = x.to(self.in_device) # B, F, T # forward reconstruct_x, mu, logvar = self.model(x) # loss BCE = self.bce(reconstruct_x, x.view(-1, 440)).mean(dim=1) # (B,) if mode != 'test': loss = torch.mean( BCE - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)) else: loss = 0. if mode == 'train': # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() loss = loss.item() elif mode == 'valid': loss = loss.item() else: y += batch['batch_y'] y_est += (BCE < 0.5).int().tolist() pred_prob += BCE.tolist() pbar.set_postfix_str('') avg_loss += loss avg_loss = avg_loss / len(dataloader.dataset) y = np.array(y) y_est = np.array(y_est) pred_prob = np.array(pred_prob, dtype=np.float32) return avg_loss, (y, y_est, pred_prob) def step(self, valid_loss: float, epoch: int): """ :param valid_loss: :param epoch: :return: test epoch or 0 """ # self.scheduler.step() self.scheduler.step(valid_loss) # print learning rate for param_group in self.optimizer.param_groups: self.writer.add_scalar('learning rate', param_group['lr'], epoch) if epoch % 5 == 0: torch.save((self.model.module.state_dict() if isinstance( self.model, nn.DataParallel) else self.model.state_dict(), ), Path(hparams.logdir) / f'VAE_{epoch}.pt') return 0
#if epoch % 100 == 99: #disentangle_check(session, vae, data_manager) data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE() if use_cuda: dae.load_state_dict('save/dae/dae_epoch_2999.pth') else: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) if opt.load != '': print('loading {}'.format(opt.load)) if use_cuda: vae.load_state_dict(torch.load(exp+'/'+opt.load)) else: vae.load_state_dict(torch.load(exp+'/'+opt.load, map_location=lambda storage, loc: storage)) if use_cuda: dae, vae = dae.cuda(), vae.cuda() if opt.train: vae_optimizer = optim.Adam(vae.parameters(), lr=1e-4, eps=1e-8) train_vae(dae, vae, data_manager, vae_optimizer)
z_mean2[ri][i] = z_m[i] z_mean2 = Variable(z_mean2) if use_cuda: z_mean2 = z_mean2.cuda() generated_xs_v = vae.decode(z_mean2) generated_xs = dae(generated_xs_v) file_name = "disentangle_img/check_z{0}.png".format(target_z_index) generated_xs = torch.transpose(generated_xs,2,3) generated_xs = torch.transpose(generated_xs,1,3) if use_cuda: hsv_image = generated_xs.data.cpu().numpy() else: hsv_image = generated_xs.data.numpy() print(hsv_image[0].shape) save_10_images(hsv_image, file_name) data_manager = DataManager() data_manager.prepare() vae = VAE() dae = DAE() if use_cuda: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth')) vae = vae.cuda() dae = dae.cuda() vae.load_state_dict(torch.load('save/vae/vae_epoch_2900.pth')) else: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) vae.load_state_dict(torch.load('save/vae/vae_epoch_2900.pth', map_location=lambda storage, loc: storage)) disentangle_check(dae, vae, data_manager)
def main(): parser = argparse.ArgumentParser() parser.add_argument("--hidden", '-hid', type=int, default=768, help="hidden state dimension") parser.add_argument('--epochs', '-e', type=int, default=5, help="number of epochs") parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4, help="learning rate") parser.add_argument('--grudim', '-gd', type=int, default=1024, help='dimension for gru layer') parser.add_argument('--batch_size', '-b', type=int, default=64, help='input batch size for training') parser.add_argument('--name', '-n', type=str, default='embedded', help='tensorboard visual name') parser.add_argument('--decay', '-d', type=float, default=-1, help='learning rate decay: Gamma') parser.add_argument('--beta', type=float, default=0.1, help='beta for kld') parser.add_argument('--data', type=int, default=1000, help='how many pieces of music to use') args = parser.parse_args() hidden_dim = args.hidden epochs = args.epochs gru_dim = args.grudim learning_rate = args.learning_rate batch_size = args.batch_size decay = args.decay beta = args.beta data_num = args.data folder_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d" % ( hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay, beta, data_num) writer = SummaryWriter('../logs/{}'.format(folder_name)) # load data file_list = find('*.npy', data_dir) f = np.load(data_dir + file_list[0]) note_dim = f.shape[1] model = VAE(note_dim, gru_dim, hidden_dim, batch_size) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if decay > 0: scheduler = MinExponentialLR(optimizer, gamma=decay, minimum=1e-5) step = 0 if torch.cuda.is_available(): print('Using: ', torch.cuda.get_device_name(torch.cuda.current_device())) model.cuda() else: print('CPU mode') for epoch in range(1, epochs): print("#" * 5, epoch, "#" * 5) batch_data = [] batch_num = 0 max_len = 0 for i in range(len(file_list)): if i != 0 and i % batch_size == 0 or i == len( file_list) - 1 or i == data_num: # create a batch by zero padding print("#" * 5, "batch", batch_num) if (i == len(file_list) - 1): batch_size = len(file_list) % batch_size seq_lengths = LongTensor(list(map(len, batch_data))) print(seq_lengths.size()) max_len = torch.max(seq_lengths).item() print("max_len:", max_len) batch = np.zeros((max_len, batch_size, note_dim)) for j in range(len(batch_data)): batch[:batch_data[j].shape[0], j, :] = batch_data[j] batch = torch.from_numpy(batch) seq_lengths, perm_idx = seq_lengths.sort(0, descending=True) batch = batch[:, perm_idx, :] step = train(model, batch, seq_lengths, step, optimizer, beta, writer) # reset max_len = 0 batch_data = [] if decay > 0: scheduler.step() batch_num += 1 data = np.load(data_dir + file_list[i]) batch_data.append(data) if i == data_num: break print("# saving params") param_name = "hid%d_e%d_gru%d_lr%.4f_batch%d_decay%.4f_beta%.2f_data%d_epoch%d" % ( hidden_dim, epochs, gru_dim, learning_rate, batch_size, decay, beta, data_num, epoch) save_path = '../params/{}.pt'.format(param_name) if not os.path.exists('params') or not os.path.isdir('params'): os.mkdir('params') if torch.cuda.is_available(): torch.save(model.cpu().state_dict(), save_path) model.cuda() else: torch.save(model.state_dict(), save_path) print('# Model saved!') writer.close()
if args.resume is not None: checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage) print("checkpoint loaded!") print("val loss: {}\tepoch: {}\t".format(checkpoint['val_loss'], checkpoint['epoch'])) # model model = VAE(args.image_size) if args.resume is not None: model.load_state_dict(checkpoint['state_dict']) # criterion criterion = VAELoss(size_average=True, kl_weight=args.kl_weight) if args.cuda is True: model = model.cuda() criterion = criterion.cuda() # load data train_loader, val_loader = load_vae_train_datasets(input_size=args.image_size, data=args.data, batch_size=args.batch_size) # load optimizer and scheduler opt = torch.optim.Adam(params=model.parameters(), lr=args.lr, betas=(0.9, 0.999)) if args.resume is not None and not args.reset_opt: opt.load_state_dict(checkpoint['optimizer']) scheduler = torch.optim.lr_scheduler.MultiStepLR(opt,
'{}/recomb_epoch_{}.pth'.format(exp, epoch)) data_manager = DataManager() data_manager.prepare() dae = DAE() vae = VAE() scan = SCAN() recomb = Recombinator() if use_cuda: dae.load_state_dict(torch.load('save/dae/dae_epoch_2999.pth')) vae.load_state_dict(torch.load('save/vae/vae_epoch_2999.pth')) scan.load_state_dict(torch.load('save/scan/scan_epoch_1499.pth')) dae, vae, scan, recomb = dae.cuda(), vae.cuda(), scan.cuda(), recomb.cuda() else: dae.load_state_dict( torch.load('save/dae/dae_epoch_2999.pth', map_location=lambda storage, loc: storage)) vae.load_state_dict( torch.load('save/vae/vae_epoch_2999.pth', map_location=lambda storage, loc: storage)) scan.load_state_dict( torch.load('save/scan/scan_epoch_1499.pth', map_location=lambda storage, loc: storage)) recomb.load_state_dict( torch.load(exp + '/' + opt.load, map_location=lambda storage, loc: storage)) if opt.train:
netImage = VAE(latent_variable_size=args.latent_dims, batchnorm=True) netImage.load_state_dict(torch.load(args.pretrained_file)) print("Pre-trained model loaded from %s" % args.pretrained_file) if args.conditional_adv: netClf = FC_Classifier(nz=args.latent_dims + 10) assert not args.conditional else: netClf = FC_Classifier(nz=args.latent_dims) if args.conditional: netCondClf = Simple_Classifier(nz=args.latent_dims) if args.use_gpu: netRNA.cuda() netImage.cuda() netClf.cuda() if args.conditional: netCondClf.cuda() # load data genomics_dataset = RNA_Dataset(datadir="data/nCD4_gene_exp_matrices/") image_dataset = NucleiDataset(datadir="data/nuclear_crops_all_experiments", mode="test") image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=args.batch_size, drop_last=True, shuffle=True) genomics_loader = torch.utils.data.DataLoader(genomics_dataset, batch_size=args.batch_size,
class Trainer(object): def __init__(self, args): self.args = args torch.manual_seed(self.args.seed) if self.args.cuda: torch.cuda.manual_seed(self.args.seed) kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor()), batch_size=self.args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('./data', train=False, transform=transforms.ToTensor()), batch_size=self.args.batch_size, shuffle=True, **kwargs) self.train_loader = train_loader self.test_loader = test_loader self.model = VAE() if self.args.cuda: self.model.cuda() self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) def loss_function(self, recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784)) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) KLD /= self.args.batch_size * 784 return BCE + KLD def train_one_epoch(self, epoch): train_loader = self.train_loader args = self.args self.model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(train_loader): data = Variable(data) if args.cuda: data = data.cuda() self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(data) loss = self.loss_function(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.data[0] self.optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0] / len(data))) print('=====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader.dataset))) def test(self, epoch): test_loader = self.test_loader args = self.args self.model.eval() test_loss = 0 for i, (data, _) in enumerate(test_loader): if args.cuda: data = data.cuda() data = Variable(data, volatile=True) recon_batch, mu, logvar = self.model(data) test_loss += self.loss_function(recon_batch, data, mu, logvar).data[0] if i == 0: n = min(data.size(0), 8) comparison = torch.cat([ data[:n], recon_batch.view(args.batch_size, 1, 28, 28)[:n] ]) fname = 'results/reconstruction_' + str(epoch) + '.png' save_image(comparison.data.cpu(), fname, nrow=n) test_loss /= len(test_loader.dataset) print('=====> Test set loss: {:.4f}'.format(test_loss)) def train(self): args = self.args for epoch in range(1, args.epochs + 1): self.train_one_epoch(epoch) self.test(epoch) sample = Variable(torch.randn(64, 20)) if args.cuda: sample = sample.cuda() sample = self.model.decode(sample).cpu() save_image(sample.data.view(64, 1, 28, 28), './results/sample_' + str(epoch) + '.png')
train_loader = torch.utils.data.DataLoader(datasets.MNIST( '../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True) vae = VAE(n_latents=args.n_latents) if args.cuda: vae.cuda() optimizer = optim.Adam(vae.parameters(), lr=args.lr) def train(epoch): vae.train() loss_meter = AverageMeter() for batch_idx, (data, _) in enumerate(train_loader): data = Variable(data) if args.cuda: data = data.cuda() optimizer.zero_grad() recon_batch, mu, logvar = vae(data) loss = loss_function(recon_batch, data, mu, logvar) loss.backward()
def main(): time_str = time.strftime("%Y%m%d-%H%M%S") print('time_str: ', time_str) exp_count = 0 if args.experiment == 'a|s': direc_name_ = '_'.join([args.env, args.experiment]) else: direc_name_ = '_'.join( [args.env, args.experiment, 'bp2VAE', str(args.bp2VAE)]) direc_name_exist = True while direc_name_exist: exp_count += 1 direc_name = '/'.join([direc_name_, str(exp_count)]) direc_name_exist = os.path.exists(direc_name) try: os.makedirs(direc_name) except OSError as e: if e.errno != errno.EEXIST: raise if args.tensorboard_dir is None: logger = Logger('/'.join([direc_name, time_str])) else: logger = Logger(args.tensorboard_dir) env = gym.make(args.env) if args.wrapper: if args.video_dir is None: args.video_dir = '/'.join([direc_name, 'videos']) env = gym.wrappers.Monitor(env, args.video_dir, force=True) print('observation_space: ', env.observation_space) print('action_space: ', env.action_space) env.seed(args.seed) torch.manual_seed(args.seed) if args.experiment == 'a|s': dim_x = env.observation_space.shape[0] elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' or \ args.experiment == 'a|z(a_prev, s, s_next)': dim_x = args.z_dim policy = ActorCritic(input_size=dim_x, hidden1_size=3 * dim_x, hidden2_size=6 * dim_x, action_size=env.action_space.n) if args.use_cuda: Tensor = torch.cuda.FloatTensor torch.cuda.manual_seed_all(args.seed) policy.cuda() else: Tensor = torch.FloatTensor policy_optimizer = optim.Adam(policy.parameters(), lr=args.policy_lr) if args.experiment != 'a|s': from util import ReplayBuffer, vae_loss_function dim_s = env.observation_space.shape[0] if args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)': from model import VAE vae = VAE(input_size=dim_s, hidden1_size=3 * args.z_dim, hidden2_size=args.z_dim) elif args.experiment == 'a|z(a_prev, s, s_next)': from model import CVAE vae = CVAE(input_size=dim_s, class_size=1, hidden1_size=3 * args.z_dim, hidden2_size=args.z_dim) if args.use_cuda: vae.cuda() vae_optimizer = optim.Adam(vae.parameters(), lr=args.vae_lr) if args.experiment == 'a|z(s)': from util import Transition_S2S as Transition elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)': from util import Transition_S2SNext as Transition buffer = ReplayBuffer(args.buffer_capacity, Transition) update_vae = True if args.experiment == 'a|s': from util import Record_S elif args.experiment == 'a|z(s)': from util import Record_S2S elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)': from util import Record_S2SNext def train_actor_critic(n): saved_info = policy.saved_info R = 0 cum_returns_ = [] for r in policy.rewards[::-1]: R = r + args.gamma * R cum_returns_.insert(0, R) cum_returns = Tensor(cum_returns_) cum_returns = (cum_returns - cum_returns.mean()) \ / (cum_returns.std() + np.finfo(np.float32).eps) cum_returns = Variable(cum_returns, requires_grad=False).unsqueeze(1) batch_info = SavedInfo(*zip(*saved_info)) batch_log_prob = torch.cat(batch_info.log_prob) batch_value = torch.cat(batch_info.value) batch_adv = cum_returns - batch_value policy_loss = -torch.sum(batch_log_prob * batch_adv) value_loss = F.smooth_l1_loss(batch_value, cum_returns, size_average=False) policy_optimizer.zero_grad() total_loss = policy_loss + value_loss total_loss.backward() policy_optimizer.step() if args.use_cuda: logger.scalar_summary('value_loss', value_loss.data.cpu()[0], n) logger.scalar_summary('policy_loss', policy_loss.data.cpu()[0], n) all_value_loss.append(value_loss.data.cpu()[0]) all_policy_loss.append(policy_loss.data.cpu()[0]) else: logger.scalar_summary('value_loss', value_loss.data[0], n) logger.scalar_summary('policy_loss', policy_loss.data[0], n) all_value_loss.append(value_loss.data[0]) all_policy_loss.append(policy_loss.data[0]) del policy.rewards[:] del policy.saved_info[:] if args.experiment != 'a|s': def train_vae(n): train_times = (n // args.vae_update_frequency - 1) * args.vae_update_times for i in range(args.vae_update_times): train_times += 1 sample = buffer.sample(args.batch_size) batch = Transition(*zip(*sample)) state_batch = torch.cat(batch.state) if args.experiment == 'a|z(s)': recon_batch, mu, log_var = vae.forward(state_batch) mse_loss, kl_loss = vae_loss_function( recon_batch, state_batch, mu, log_var, logger, train_times, kl_discount=args.kl_weight, mode=args.experiment) elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)': next_state_batch = Variable(torch.cat(batch.next_state), requires_grad=False) predicted_batch, mu, log_var = vae.forward(state_batch) mse_loss, kl_loss = vae_loss_function( predicted_batch, next_state_batch, mu, log_var, logger, train_times, kl_discount=args.kl_weight, mode=args.experiment) vae_loss = mse_loss + kl_loss vae_optimizer.zero_grad() vae_loss.backward() vae_optimizer.step() logger.scalar_summary('vae_loss', vae_loss.data[0], train_times) all_vae_loss.append(vae_loss.data[0]) all_mse_loss.append(mse_loss.data[0]) all_kl_loss.append(kl_loss.data[0]) # To store cum_reward, value_loss and policy_loss from each episode all_cum_reward = [] all_last_hundred_average = [] all_value_loss = [] all_policy_loss = [] if args.experiment != 'a|s': # Store each vae_loss calculated all_vae_loss = [] all_mse_loss = [] all_kl_loss = [] for episode in count(1): done = False state_ = torch.Tensor([env.reset()]) cum_reward = 0 if args.experiment == 'a|z(a_prev, s, s_next)': action = random.randint(0, 2) state_, reward, done, info = env.step(action) cum_reward += reward state_ = torch.Tensor([np.append(state_, action)]) while not done: if args.experiment == 'a|s': state = Variable(state_, requires_grad=False) elif args.experiment == 'a|z(s)' or args.experiment == 'a|z(s, s_next)' \ or args.experiment == 'a|z(a_prev, s, s_next)': state_ = Variable(state_, requires_grad=False) mu, log_var = vae.encode(state_) if args.bp2VAE and update_vae: state = vae.reparametrize(mu, log_var) else: state = vae.reparametrize(mu, log_var).detach() action_ = policy.select_action(state) if args.use_cuda: action = action_.cpu()[0, 0] else: action = action_[0, 0] next_state_, reward, done, info = env.step(action) next_state_ = torch.Tensor([next_state_]) cum_reward += reward if args.render: env.render() policy.rewards.append(reward) if args.experiment == 'a|z(s)': buffer.push(state_) elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)': if not done: buffer.push(state_, next_state_) if args.experiment == 'a|z(a_prev, s, s_next)': next_state_ = torch.cat( [next_state_, torch.Tensor([action])], 1) state_ = next_state_ train_actor_critic(episode) last_hundred_average = sum(all_cum_reward[-100:]) / 100 logger.scalar_summary('cum_reward', cum_reward, episode) logger.scalar_summary('last_hundred_average', last_hundred_average, episode) all_cum_reward.append(cum_reward) all_last_hundred_average.append(last_hundred_average) if update_vae: if args.experiment != 'a|s' and episode % args.vae_update_frequency == 0: assert len(buffer) >= args.batch_size train_vae(episode) if len(all_vae_loss) > 1000: if abs( sum(all_vae_loss[-500:]) / 500 - sum(all_vae_loss[-1000:-500]) / 500) < args.vae_update_threshold: update_vae = False if episode % args.log_interval == 0: print( 'Episode {}\tLast cum return: {:5f}\t100-episodes average cum return: {:.2f}' .format(episode, cum_reward, last_hundred_average)) if episode > args.num_episodes: print("100-episodes average cum return is now {} and " "the last episode runs to {} time steps!".format( last_hundred_average, cum_reward)) env.close() torch.save(policy, '/'.join([direc_name, 'model'])) if args.experiment == 'a|s': record = Record_S( policy_loss=all_policy_loss, value_loss=all_value_loss, cum_reward=all_cum_reward, last_hundred_average=all_last_hundred_average) elif args.experiment == 'a|z(s)': record = Record_S2S( policy_loss=all_policy_loss, value_loss=all_value_loss, cum_reward=all_cum_reward, last_hundred_average=all_last_hundred_average, mse_recon_loss=all_mse_loss, kl_loss=all_kl_loss, vae_loss=all_vae_loss) elif args.experiment == 'a|z(s, s_next)' or args.experiment == 'a|z(a_prev, s, s_next)': record = Record_S2SNext( policy_loss=all_policy_loss, value_loss=all_value_loss, cum_reward=all_cum_reward, last_hundred_average=all_last_hundred_average, mse_pred_loss=all_mse_loss, kl_loss=all_kl_loss, vae_loss=all_vae_loss) pickle.dump(record, open('/'.join([direc_name, 'record']), 'wb')) break