class WorldModel: def __init__(self, params): cuda = params["cuda"] and torch.cuda.is_available() torch.manual_seed(params["seed"]) device = torch.device("cuda" if cuda else "cpu") vae_file = os.path.join(params['logdir'], 'vae', 'best.tar') assert os.path.exists(vae_file), "VAE Checkpoint does not exist." state = torch.load(vae_file, map_location=torch.device('cpu')) print("Loading VAE at epoch {} " "with test error {}".format( state['epoch'], state['precision'])) self.vae_model = VAE(nc=3, ngf=params["img_size"], ndf=params["img_size"], latent_variable_size=params["latent_size"], cuda=cuda).to(device) self.vae_model.load_state_dict(state['state_dict']) rnn_dir = os.path.join(params['logdir'], 'mdrnn') rnn_file = os.path.join(rnn_dir, 'best.tar') assert os.path.exists(rnn_file), "MD-RNN Checkpoint does not exist." state_rnn = torch.load(rnn_file) print("Loading MD-RNN at epoch {} " "with test error {}".format( state_rnn['epoch'], state_rnn['precision'])) self.mdrnn = MDRNNCell(params['latent_size'], params['action_size'], params['hidden_size'], params['num_gmm']).to(device) rnn_state_dict = {k.strip('_l0'): v for k, v in state_rnn['state_dict'].items()} self.mdrnn.load_state_dict(rnn_state_dict) self.latent = torch.randn(1, params['latent_size']) self.hidden = 2 * [torch.zeros(1, params['hidden_size'])] self.monitor = None self.figure = None # obs self.start_obs = None self.start_obs_recon = None self._obs = None self._visual_obs = None self.red_size = params["red_size"] self.params = params def forw(self): ovo=self.vae_model.forward(self.start_obs) return ovo[0] def compute_z(self, obs): self.latent = self.vae_model.get_latent_var(obs) return self.latent def decode(self): with torch.no_grad(): self.start_obs_recon = self.vae_model.decode(self.latent) np_obs = self.start_obs_recon.numpy() np_obs = np.clip(np_obs, 0, 1) * 255 np_obs = np.transpose(np_obs, (0, 2, 3, 1)) np_obs = np_obs.squeeze() np_obs = np_obs.astype(np.uint8) self.start_obs_recon_np = np_obs def reset(self): """ Resetting """ self.latent = torch.randn(1, self.params['latent_size']) self.hidden = 2 * [torch.zeros(1, self.params['hidden_size'])] def step(self, action): """ One step forward """ with torch.no_grad(): #action = torch.Tensor(action).unsqueeze(0) mu, sigma, pi, n_h = self.mdrnn(action, self.latent, self.hidden) pi = pi.squeeze() mixt = Categorical(torch.exp(pi)).sample().item() self.latent = mu[:, mixt, :] # + sigma[:, mixt, :] * torch.randn_like(mu[:, mixt, :]) self.hidden = n_h self._obs = self.vae_model.decode(self.latent) return self._obs
class MFEC: def __init__(self, env, args, device='cpu'): """ Instantiate an MFEC Agent ---------- env: gym.Env gym environment to train on args: args class from argparser args are from from train.py: see train.py for help with each arg device: string 'cpu' or 'cuda:0' depending on use_cuda flag from train.py """ self.environment_type = args.environment_type self.env = env self.actions = range(self.env.action_space.n) self.frames_to_stack = args.frames_to_stack self.Q_train_algo = args.Q_train_algo self.use_Q_max = args.use_Q_max self.force_knn = args.force_knn self.weight_neighbors = args.weight_neighbors self.delta = args.delta self.device = device self.rs = np.random.RandomState(args.seed) # Hyperparameters self.epsilon = args.initial_epsilon self.final_epsilon = args.final_epsilon self.epsilon_decay = args.epsilon_decay self.gamma = args.gamma self.lr = args.lr self.q_lr = args.q_lr # Autoencoder for state embedding network self.vae_batch_size = args.vae_batch_size # batch size for training VAE self.vae_epochs = args.vae_epochs # number of epochs to run VAE self.embedding_type = args.embedding_type self.SR_embedding_type = args.SR_embedding_type self.embedding_size = args.embedding_size self.in_height = args.in_height self.in_width = args.in_width if self.embedding_type == 'VAE': self.vae_train_frames = args.vae_train_frames self.vae_loss = VAELoss() self.vae_print_every = args.vae_print_every self.load_vae_from = args.load_vae_from self.vae_weights_file = args.vae_weights_file self.vae = VAE(self.frames_to_stack, self.embedding_size, self.in_height, self.in_width) self.vae = self.vae.to(self.device) self.optimizer = get_optimizer(args.optimizer, self.vae.parameters(), self.lr) elif self.embedding_type == 'random': self.projection = self.rs.randn( self.embedding_size, self.in_height * self.in_width * self.frames_to_stack).astype(np.float32) elif self.embedding_type == 'SR': self.SR_train_algo = args.SR_train_algo self.SR_gamma = args.SR_gamma self.SR_epochs = args.SR_epochs self.SR_batch_size = args.SR_batch_size self.n_hidden = args.n_hidden self.SR_train_frames = args.SR_train_frames self.SR_filename = args.SR_filename if self.SR_embedding_type == 'random': self.projection = np.random.randn( self.embedding_size, self.in_height * self.in_width).astype(np.float32) if self.SR_train_algo == 'TD': self.mlp = MLP(self.embedding_size, self.n_hidden) self.mlp = self.mlp.to(self.device) self.loss_fn = nn.MSELoss(reduction='mean') params = self.mlp.parameters() self.optimizer = get_optimizer(args.optimizer, params, self.lr) # QEC self.max_memory = args.max_memory self.num_neighbors = args.num_neighbors self.qec = QEC(self.actions, self.max_memory, self.num_neighbors, self.use_Q_max, self.force_knn, self.weight_neighbors, self.delta, self.q_lr) #self.state = np.empty(self.embedding_size, self.projection.dtype) #self.action = int self.memory = [] self.print_every = args.print_every self.episodes = 0 def choose_action(self, values): """ Choose epsilon-greedy policy according to Q-estimates """ # Exploration if self.rs.random_sample() < self.epsilon: self.action = self.rs.choice(self.actions) # Exploitation else: best_actions = np.argwhere(values == np.max(values)).flatten() self.action = self.rs.choice(best_actions) return self.action def TD_update(self, prev_embedding, prev_action, reward, values, time): # On-policy value estimate of current state (epsiloln-greedy) # Expected Sarsa v_t = (1 - self.epsilon) * np.max(values) + self.epsilon * np.mean(values) value = reward + self.gamma * v_t self.qec.update(prev_embedding, prev_action, value, time - 1) def MC_update(self): value = 0.0 for _ in range(len(self.memory)): experience = self.memory.pop() value = value * self.gamma + experience["reward"] self.qec.update( experience["state"], experience["action"], value, experience["time"], ) def add_to_memory(self, state_embedding, action, reward, time): self.memory.append({ "state": state_embedding, "action": action, "reward": reward, "time": time, }) def run_episode(self): """ Train an MFEC agent for a single episode: Interact with environment Perform update """ self.episodes += 1 RENDER_SPEED = 0.04 RENDER = False episode_frames = 0 total_reward = 0 total_steps = 0 # Update epsilon if self.epsilon > self.final_epsilon: self.epsilon = self.epsilon * self.epsilon_decay #self.env.seed(random.randint(0, 1000000)) state = self.env.reset() if self.environment_type == 'fourrooms': fewest_steps = self.env.shortest_path_length(self.env.state) done = False time = 0 while not done: time += 1 if self.embedding_type == 'random': state = np.array(state).flatten() state_embedding = np.dot(self.projection, state) elif self.embedding_type == 'VAE': state = torch.tensor(state).permute(2, 0, 1) #(H,W,C)->(C,H,W) state = state.unsqueeze(0).to(self.device) with torch.no_grad(): mu, logvar = self.vae.encoder(state) state_embedding = torch.cat([mu, logvar], 1) state_embedding = state_embedding.squeeze() state_embedding = state_embedding.cpu().numpy() elif self.embedding_type == 'SR': if self.SR_train_algo == 'TD': state = np.array(state).flatten() state_embedding = np.dot(self.projection, state) with torch.no_grad(): state_embedding = self.mlp( torch.tensor(state_embedding)).cpu().numpy() elif self.SR_train_algo == 'DP': s = self.env.state state_embedding = self.true_SR_dict[s] state_embedding = state_embedding / np.linalg.norm(state_embedding) if RENDER: self.env.render() time.sleep(RENDER_SPEED) # Get estimated value of each action values = [ self.qec.estimate(state_embedding, action) for action in self.actions ] action = self.choose_action(values) state, reward, done, _ = self.env.step(action) if self.Q_train_algo == 'MC': self.add_to_memory(state_embedding, action, reward, time) elif self.Q_train_algo == 'TD': if time > 1: self.TD_update(prev_embedding, prev_action, prev_reward, values, time) prev_reward = reward prev_embedding = state_embedding prev_action = action total_reward += reward total_steps += 1 episode_frames += self.env.skip if self.Q_train_algo == 'MC': self.MC_update() if self.episodes % self.print_every == 0: print("KNN usage:", np.mean(self.qec.knn_usage)) self.qec.knn_usage = [] print("Proportion of replace:", np.mean(self.qec.replace_usage)) self.qec.replace_usage = [] if self.environment_type == 'fourrooms': n_extra_steps = total_steps - fewest_steps return n_extra_steps, episode_frames, total_reward else: return episode_frames, total_reward def warmup(self): """ Collect 1 million frames from random policy and train VAE """ if self.embedding_type == 'VAE': if self.load_vae_from is not None: self.vae.load_state_dict(torch.load(self.load_vae_from)) self.vae = self.vae.to(self.device) else: # Collect 1 million frames from random policy print("Generating dataset to train VAE from random policy") vae_data = [] state = self.env.reset() total_frames = 0 while total_frames < self.vae_train_frames: action = random.randint(0, self.env.action_space.n - 1) state, reward, done, _ = self.env.step(action) vae_data.append(state) total_frames += self.env.skip if done: state = self.env.reset() # Dataset, Dataloader for 1 million frames vae_data = torch.tensor( vae_data ) # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack) vae_data = vae_data.permute(0, 3, 1, 2) # (N x C x H x W) vae_dataset = TensorDataset(vae_data) vae_dataloader = DataLoader(vae_dataset, batch_size=self.vae_batch_size, shuffle=True) # Training loop print("Training VAE") self.vae.train() for epoch in range(self.vae_epochs): train_loss = 0 for batch_idx, batch in enumerate(vae_dataloader): batch = batch[0].to(self.device) self.optimizer.zero_grad() recon_batch, mu, logvar = self.vae(batch) loss = self.vae_loss(recon_batch, batch, mu, logvar) train_loss += loss.item() loss.backward() self.optimizer.step() if batch_idx % self.vae_print_every == 0: msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format( epoch, batch_idx * len(batch), len(vae_dataloader.dataset), loss.item() / len(batch)) print(msg) print('====> Epoch {} Average loss: {:.4f}'.format( epoch, train_loss / len(vae_dataloader.dataset))) if self.vae_weights_file is not None: torch.save(self.vae.state_dict(), self.vae_weights_file) self.vae.eval() elif self.embedding_type == 'SR': if self.SR_embedding_type == 'random': if self.SR_train_algo == 'TD': total_frames = 0 transitions = [] while total_frames < self.SR_train_frames: observation = self.env.reset() s_t = self.env.state # will not work on Atari done = False while not done: action = np.random.randint(self.env.action_space.n) observation, reward, done, _ = self.env.step( action) s_tp1 = self.env.state # will not work on Atari transitions.append((s_t, s_tp1)) total_frames += self.env.skip s_t = s_tp1 # Dataset, Dataloader dataset = SRDataset(self.env, self.projection, transitions) dataloader = DataLoader(dataset, batch_size=self.SR_batch_size, shuffle=True) train_losses = [] #Training loop for epoch in range(self.SR_epochs): for batch_idx, batch in enumerate(dataloader): self.optimizer.zero_grad() e_t, e_tp1 = batch e_t = e_t.to(self.device) e_tp1 = e_tp1.to(self.device) mhat_t = self.mlp(e_t) mhat_tp1 = self.mlp(e_tp1) target = e_t + self.gamma * mhat_tp1.detach() loss = self.loss_fn(mhat_t, target) loss.backward() self.optimizer.step() train_losses.append(loss.item()) print("Epoch:", epoch, "Average loss", np.mean(train_losses)) emb_reps = np.zeros( [self.env.n_states, self.embedding_size]) SR_reps = np.zeros( [self.env.n_states, self.embedding_size]) labels = [] room_size = self.env.room_size for i, (state, obs) in enumerate(self.env.state_dict.items()): emb = np.dot(self.projection, obs.flatten()) emb_reps[i, :] = emb with torch.no_grad(): emb = torch.tensor(emb).to(self.device) SR = self.mlp(emb).cpu().numpy() SR_reps[i, :] = SR if state[0] < room_size + 1 and state[ 1] < room_size + 1: label = 0 elif state[0] > room_size + 1 and state[ 1] < room_size + 1: label = 1 elif state[0] < room_size + 1 and state[ 1] > room_size + 1: label = 2 elif state[0] > room_size + 1 and state[ 1] > room_size + 1: label = 3 else: label = 4 labels.append(label) np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps) np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps) np.save('%s_labels.npy' % (self.SR_filename), labels) elif self.SR_train_algo == 'MC': pass elif self.SR_train_algo == 'DP': # Use this to ensure same order every time idx_to_state = { i: state for i, state in enumerate(self.env.state_dict.keys()) } state_to_idx = {v: k for k, v in idx_to_state.items()} T = np.zeros([self.env.n_states, self.env.n_states]) for i, s in idx_to_state.items(): for a in range(4): self.env.state = s _, _, _, _ = self.env.step(a) s_tp1 = self.env.state T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25 true_SR = np.eye(self.env.n_states) done = False t = 0 while not done: t += 1 new_SR = true_SR + (self.SR_gamma**t) * (np.matmul( true_SR, T)) done = np.max(np.abs(true_SR - new_SR)) < 1e-10 true_SR = new_SR self.true_SR_dict = {} for s, obs in self.env.state_dict.items(): idx = state_to_idx[s] self.true_SR_dict[s] = true_SR[idx, :] else: pass # random projection doesn't require warmup
def validate(vae_name, params, val_loader, vae_param): cuda = params["cuda"] and torch.cuda.is_available() torch.manual_seed(params["seed"]) device = torch.device("cuda" if cuda else "cpu") model_path = os.path.join(params["vae_dir"], vae_name) vae_file = os.path.join(model_path, 'vae', 'best.tar') assert os.path.exists(vae_file), "VAE Checkpoint does not exist." state = torch.load(vae_file) print("Loading VAE at epoch {} " "with test error {}".format(state['epoch'], state['precision'])) print(str(vae_name)) model = VAE(nc=3, ngf=params["img_size"], ndf=params["img_size"], latent_variable_size=vae_param["latent_size"], cuda=cuda).to(device) model.load_state_dict(state['state_dict']) model.eval() avg_psnr = 0 avg_ms_ssim = 0 index = 0 with torch.no_grad(): val_loader.dataset.load_next_buffer() for i, data in enumerate(val_loader): data = data.to(device) recon_batch, mu, logvar = model(data) mse = F.mse_loss(recon_batch, data) psnr = 10 * log10(1 / mse.item()) avg_psnr += psnr #print("PSNR:", str(psnr)) ms_ssim = msssim(recon_batch, data).item() avg_ms_ssim += ms_ssim #print("MS-SSIM:", str(ms_ssim)) if index < 10: n = min(data.size(0), 10) comparison = torch.cat([ data[:n], recon_batch.view(params["batch_size"], 3, params["img_size"], params["img_size"])[:n] ]) save_image(comparison.cpu(), os.path.join( params["report_dir"], str(vae_name) + "_" + str(index) + '.png'), nrow=n) index += 1 step = len(val_loader.dataset) / params["batch_size"] avg_ms_ssim /= step avg_psnr /= step print("AVG PSNR", str(avg_psnr)) print("AVG MS-SSIM", str(avg_ms_ssim)) print("index", str(index)) print("step", str(step)) return [avg_psnr, avg_ms_ssim]
class VAE_TRAINER(): def __init__(self, params): self.params = params self.loss_function = { 'ms-ssim': ms_ssim_loss, 'mse': mse_loss, 'mix': mix_loss }[params["loss"]] # Choose device self.cuda = params["cuda"] and torch.cuda.is_available() torch.manual_seed(params["seed"]) # Fix numeric divergence due to bug in Cudnn torch.backends.cudnn.benchmark = True self.device = torch.device("cuda" if self.cuda else "cpu") # Prepare data transformations red_size = params["img_size"] transform_train = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((red_size, red_size)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_val = transforms.Compose([ transforms.ToPILImage(), transforms.Resize((red_size, red_size)), transforms.ToTensor(), ]) # Initialize Data loaders op_dataset = RolloutObservationDataset(params["path_data"], transform_train, train=True) val_dataset = RolloutObservationDataset(params["path_data"], transform_val, train=False) self.train_loader = torch.utils.data.DataLoader( op_dataset, batch_size=params["batch_size"], shuffle=True, num_workers=0) self.eval_loader = torch.utils.data.DataLoader( val_dataset, batch_size=params["batch_size"], shuffle=False, num_workers=0) # Initialize model and hyperparams self.model = VAE(nc=3, ngf=64, ndf=64, latent_variable_size=params["latent_size"], cuda=self.cuda).to(self.device) self.optimizer = optim.Adam(self.model.parameters()) self.init_vae_model() self.visualize = params["visualize"] if self.visualize: self.plotter = VisdomLinePlotter(env_name=params['env']) self.img_plotter = VisdomImagePlotter(env_name=params['env']) self.alpha = params["alpha"] if params["alpha"] else 1.0 def train(self, epoch): self.model.train() # dataset_train.load_next_buffer() mse_loss = 0 ssim_loss = 0 train_loss = 0 # Train step for batch_idx, data in enumerate(self.train_loader): data = data.to(self.device) self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(data) loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar, self.alpha) loss.backward() train_loss += loss.item() ssim_loss += ssim mse_loss += mse self.optimizer.step() if batch_idx % params["log_interval"] == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(self.train_loader.dataset), 100. * batch_idx / len(self.train_loader), loss.item())) print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim)) step = len(self.train_loader.dataset) / float( self.params["batch_size"]) mean_train_loss = train_loss / step mean_ssim_loss = ssim_loss / step mean_mse_loss = mse_loss / step print('-- Epoch: {} Average loss: {:.4f}'.format( epoch, mean_train_loss)) print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format( mean_mse_loss, mean_ssim_loss)) if self.visualize: self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch, mean_train_loss) return def eval(self): self.model.eval() # dataset_test.load_next_buffer() eval_loss = 0 mse_loss = 0 ssim_loss = 0 vis = True with torch.no_grad(): # Eval step for data in self.eval_loader: data = data.to(self.device) recon_batch, mu, logvar = self.model(data) loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar, self.alpha) eval_loss += loss.item() ssim_loss += ssim mse_loss += mse if vis: org_title = "Epoch: " + str(epoch) comparison1 = torch.cat([ data[:4], recon_batch.view(params["batch_size"], 3, params["img_size"], params["img_size"])[:4] ]) if self.visualize: self.img_plotter.plot(comparison1, org_title) vis = False step = len(self.eval_loader.dataset) / float(params["batch_size"]) mean_eval_loss = eval_loss / step mean_ssim_loss = ssim_loss / step mean_mse_loss = mse_loss / step print('-- Eval set loss: {:.4f}'.format(mean_eval_loss)) print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format( mean_mse_loss, mean_ssim_loss)) if self.visualize: self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch, mean_eval_loss) self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch, mean_mse_loss) self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch, mean_ssim_loss) return mean_eval_loss def init_vae_model(self): self.vae_dir = os.path.join(self.params["logdir"], 'vae') check_dir(self.vae_dir, 'samples') if not self.params["noreload"]: # and os.path.exists(reload_file): reload_file = os.path.join(self.params["vae_location"], 'best.tar') state = torch.load(reload_file) print("Reloading model at epoch {}" ", with eval error {}".format(state['epoch'], state['precision'])) self.model.load_state_dict(state['state_dict']) self.optimizer.load_state_dict(state['optimizer']) def checkpoint(self, cur_best, eval_loss): # Save the best and last checkpoint best_filename = os.path.join(self.vae_dir, 'best.tar') filename = os.path.join(self.vae_dir, 'checkpoint.tar') is_best = not cur_best or eval_loss < cur_best if is_best: cur_best = eval_loss save_checkpoint( { 'epoch': epoch, 'state_dict': self.model.state_dict(), 'precision': eval_loss, 'optimizer': self.optimizer.state_dict() }, is_best, filename, best_filename) return cur_best def plot(self, train, eval, epochs): plt.plot(epochs, train, label="train loss") plt.plot(epochs, eval, label="eval loss") plt.legend() plt.grid() plt.savefig(self.params["logdir"] + "/vae_training_curve.png") plt.close()
print('accuracy {}'.format(acc.item())) if not os.path.exists('joint_models/'): os.mkdir('joint_models/') torch.save( classifier.state_dict(), 'joint_models/joint_classifier_' + arguments.dataset_name + 'accuracy_{}'.format(acc) + '.t') pdb.set_trace() # ### generator model = VAE(arguments) if arguments.cuda: model = model.cuda() if 0 & os.path.exists(model_path): print('loading model...') model.load_state_dict(torch.load(model_path)) model = model.cuda() else: print('training model...') optimizer = AdamNormGrad(model.parameters(), lr=arguments.lr) tr.experiment_vae(arguments, train_loader, val_loader, test_loader, model, optimizer, dr, arguments.model_name) results = ev.evaluate_vae(arguments, model, train_loader, test_loader, 0, results_path, 'test') pickle.dump(results, open(results_path + results_name + '.pk', 'wb'))
batch_size=params['batch_size'], num_workers=1, shuffle=True) test_loader = DataLoader( RolloutSequenceDataset(params["path_data"], params["seq_len"], transform, train=False, buffer_size=params["test_buffer_size"]), batch_size=params['batch_size'], num_workers=1) vae_file = os.path.join(params['logdir'], 'vae', 'best.tar') assert os.path.exists(vae_file), "VAE Checkpoint does not exist." state = torch.load(vae_file) print("Loading VAE at epoch {} " "with test error {}".format( state['epoch'], state['precision'])) vae_model = VAE(nc=3, ngf=params["img_size"], ndf=params["img_size"], latent_variable_size=params["latent_size"], cuda=cuda).to(device) vae_model.load_state_dict(state['state_dict']) rnn_dir = os.path.join(params['logdir'], 'mdrnn') rnn_file = os.path.join(rnn_dir, 'best.tar') if os.path.exists(rnn_file): state_rnn = torch.load(rnn_file) print("Loading MD-RNN at epoch {} " "with test error {}".format( state_rnn['epoch'], state_rnn['precision'])) mdrnn = MDRNN(params['latent_size'], params['action_size'], params['hidden_size'], params['num_gmm']).to(device) rnn_state_dict = {k: v for k, v in state_rnn['state_dict'].items()} mdrnn.load_state_dict(rnn_state_dict) else: mdrnn = MDRNN(params['latent_size'], params['action_size'], params['hidden_size'], params['num_gmm']) mdrnn.to(device)
def main(): parser = argparse.ArgumentParser(description='Testing') parser.add_argument('--obj', type=str, default='.') parser.add_argument('--data_type', type=str, default='mvtec') parser.add_argument('--data_path', type=str, default='.') parser.add_argument('--checkpoint_dir', type=str, default='.') parser.add_argument("--grayscale", action='store_true', help='color or grayscale input image') parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--img_resize', type=int, default=128) parser.add_argument('--crop_size', type=int, default=128) parser.add_argument('--seed', type=int, default=None) args = parser.parse_args() args.save_dir = './' + args.data_type + '/' + args.obj + '/vgg_feature' + '/seed_{}/'.format( args.seed) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) # load model and dataset args.input_channel = 1 if args.grayscale else 3 model = VAE(input_channel=args.input_channel, z_dim=100).to(device) checkpoint = torch.load(args.checkpoint_dir) model.load_state_dict(checkpoint['model']) teacher = models.vgg16(pretrained=True).to(device) for param in teacher.parameters(): param.requires_grad = False img_size = args.crop_size if args.img_resize != args.crop_size else args.img_resize kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {} test_dataset = MVTecDataset(args.data_path, class_name=args.obj, is_train=False, resize=img_size) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) scores, test_imgs, recon_imgs, gt_list, gt_mask_list = test( model, teacher, test_loader) scores = np.asarray(scores) max_anomaly_score = scores.max() min_anomaly_score = scores.min() scores = (scores - min_anomaly_score) / (max_anomaly_score - min_anomaly_score) gt_mask = np.asarray(gt_mask_list) precision, recall, thresholds = precision_recall_curve( gt_mask.flatten(), scores.flatten()) a = 2 * precision * recall b = precision + recall f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0) threshold = thresholds[np.argmax(f1)] fpr, tpr, _ = roc_curve(gt_mask.flatten(), scores.flatten()) per_pixel_rocauc = roc_auc_score(gt_mask.flatten(), scores.flatten()) print('pixel ROCAUC: %.3f' % (per_pixel_rocauc)) plt.plot(fpr, tpr, label='%s ROCAUC: %.3f' % (args.obj, per_pixel_rocauc)) plt.legend(loc="lower right") save_dir = args.save_dir + '/' + f'seed_{args.seed}' + '/' + 'pictures_{:.4f}'.format( threshold) os.makedirs(save_dir, exist_ok=True) plt.savefig(os.path.join(save_dir, args.obj + '_roc_curve.png'), dpi=100) plot_fig(args, test_imgs, recon_imgs, scores, gt_mask_list, threshold, save_dir)
writer.add_scalar(tag='loss/test', scalar_value=loss, global_step=i) likelihood_x[i * batch_size:(i + 1) * batch_size] = logsumexp( losses, axis=1) - np.log(number_samples) return np.mean(likelihood_x) if __name__ == '__main__': from loaders.load_funtions import load_MNIST from models.VAE import VAE import pathlib _, loader, _, dataset_type = load_MNIST('../datasets/') output_dit = pathlib.Path('../outputs/') input_shape = (1, 28, 28) model = VAE(dimension_latent_space=50, input_shape=input_shape, dataset_type=dataset_type) model.load_state_dict( torch.load('../outputs/trained/mnist_bin_standard_50/model', map_location='cpu')) model.eval() print(model.calculate_likelihood(loader, 100))