예제 #1
0
def validate(vae_name, params, val_loader, vae_param):
    cuda = params["cuda"] and torch.cuda.is_available()
    torch.manual_seed(params["seed"])
    device = torch.device("cuda" if cuda else "cpu")
    model_path = os.path.join(params["vae_dir"], vae_name)
    vae_file = os.path.join(model_path, 'vae', 'best.tar')
    assert os.path.exists(vae_file), "VAE Checkpoint does not exist."
    state = torch.load(vae_file)
    print("Loading VAE at epoch {} "
          "with test error {}".format(state['epoch'], state['precision']))
    print(str(vae_name))
    model = VAE(nc=3,
                ngf=params["img_size"],
                ndf=params["img_size"],
                latent_variable_size=vae_param["latent_size"],
                cuda=cuda).to(device)
    model.load_state_dict(state['state_dict'])
    model.eval()

    avg_psnr = 0
    avg_ms_ssim = 0
    index = 0
    with torch.no_grad():
        val_loader.dataset.load_next_buffer()

        for i, data in enumerate(val_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            mse = F.mse_loss(recon_batch, data)
            psnr = 10 * log10(1 / mse.item())
            avg_psnr += psnr
            #print("PSNR:", str(psnr))
            ms_ssim = msssim(recon_batch, data).item()
            avg_ms_ssim += ms_ssim
            #print("MS-SSIM:", str(ms_ssim))

            if index < 10:
                n = min(data.size(0), 10)
                comparison = torch.cat([
                    data[:n],
                    recon_batch.view(params["batch_size"], 3,
                                     params["img_size"],
                                     params["img_size"])[:n]
                ])
                save_image(comparison.cpu(),
                           os.path.join(
                               params["report_dir"],
                               str(vae_name) + "_" + str(index) + '.png'),
                           nrow=n)
            index += 1

    step = len(val_loader.dataset) / params["batch_size"]
    avg_ms_ssim /= step
    avg_psnr /= step

    print("AVG PSNR", str(avg_psnr))
    print("AVG MS-SSIM", str(avg_ms_ssim))
    print("index", str(index))
    print("step", str(step))

    return [avg_psnr, avg_ms_ssim]
class MFEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an MFEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.actions = range(self.env.action_space.n)
        self.frames_to_stack = args.frames_to_stack
        self.Q_train_algo = args.Q_train_algo
        self.use_Q_max = args.use_Q_max
        self.force_knn = args.force_knn
        self.weight_neighbors = args.weight_neighbors
        self.delta = args.delta
        self.device = device
        self.rs = np.random.RandomState(args.seed)

        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.lr = args.lr
        self.q_lr = args.q_lr

        # Autoencoder for state embedding network
        self.vae_batch_size = args.vae_batch_size  # batch size for training VAE
        self.vae_epochs = args.vae_epochs  # number of epochs to run VAE
        self.embedding_type = args.embedding_type
        self.SR_embedding_type = args.SR_embedding_type
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width

        if self.embedding_type == 'VAE':
            self.vae_train_frames = args.vae_train_frames
            self.vae_loss = VAELoss()
            self.vae_print_every = args.vae_print_every
            self.load_vae_from = args.load_vae_from
            self.vae_weights_file = args.vae_weights_file
            self.vae = VAE(self.frames_to_stack, self.embedding_size,
                           self.in_height, self.in_width)
            self.vae = self.vae.to(self.device)
            self.optimizer = get_optimizer(args.optimizer,
                                           self.vae.parameters(), self.lr)
        elif self.embedding_type == 'random':
            self.projection = self.rs.randn(
                self.embedding_size, self.in_height * self.in_width *
                self.frames_to_stack).astype(np.float32)
        elif self.embedding_type == 'SR':
            self.SR_train_algo = args.SR_train_algo
            self.SR_gamma = args.SR_gamma
            self.SR_epochs = args.SR_epochs
            self.SR_batch_size = args.SR_batch_size
            self.n_hidden = args.n_hidden
            self.SR_train_frames = args.SR_train_frames
            self.SR_filename = args.SR_filename
            if self.SR_embedding_type == 'random':
                self.projection = np.random.randn(
                    self.embedding_size,
                    self.in_height * self.in_width).astype(np.float32)
                if self.SR_train_algo == 'TD':
                    self.mlp = MLP(self.embedding_size, self.n_hidden)
                    self.mlp = self.mlp.to(self.device)
                    self.loss_fn = nn.MSELoss(reduction='mean')
                    params = self.mlp.parameters()
                    self.optimizer = get_optimizer(args.optimizer, params,
                                                   self.lr)

        # QEC
        self.max_memory = args.max_memory
        self.num_neighbors = args.num_neighbors
        self.qec = QEC(self.actions, self.max_memory, self.num_neighbors,
                       self.use_Q_max, self.force_knn, self.weight_neighbors,
                       self.delta, self.q_lr)

        #self.state = np.empty(self.embedding_size, self.projection.dtype)
        #self.action = int
        self.memory = []
        self.print_every = args.print_every
        self.episodes = 0

    def choose_action(self, values):
        """
        Choose epsilon-greedy policy according to Q-estimates
        """
        # Exploration
        if self.rs.random_sample() < self.epsilon:
            self.action = self.rs.choice(self.actions)

        # Exploitation
        else:
            best_actions = np.argwhere(values == np.max(values)).flatten()
            self.action = self.rs.choice(best_actions)

        return self.action

    def TD_update(self, prev_embedding, prev_action, reward, values, time):
        # On-policy value estimate of current state (epsiloln-greedy)
        # Expected Sarsa
        v_t = (1 -
               self.epsilon) * np.max(values) + self.epsilon * np.mean(values)
        value = reward + self.gamma * v_t
        self.qec.update(prev_embedding, prev_action, value, time - 1)

    def MC_update(self):
        value = 0.0
        for _ in range(len(self.memory)):
            experience = self.memory.pop()
            value = value * self.gamma + experience["reward"]
            self.qec.update(
                experience["state"],
                experience["action"],
                value,
                experience["time"],
            )

    def add_to_memory(self, state_embedding, action, reward, time):
        self.memory.append({
            "state": state_embedding,
            "action": action,
            "reward": reward,
            "time": time,
        })

    def run_episode(self):
        """
        Train an MFEC agent for a single episode:
            Interact with environment
            Perform update
        """
        self.episodes += 1
        RENDER_SPEED = 0.04
        RENDER = False

        episode_frames = 0
        total_reward = 0
        total_steps = 0

        # Update epsilon
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay

        #self.env.seed(random.randint(0, 1000000))
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        done = False
        time = 0
        while not done:
            time += 1
            if self.embedding_type == 'random':
                state = np.array(state).flatten()
                state_embedding = np.dot(self.projection, state)
            elif self.embedding_type == 'VAE':
                state = torch.tensor(state).permute(2, 0, 1)  #(H,W,C)->(C,H,W)
                state = state.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    mu, logvar = self.vae.encoder(state)
                    state_embedding = torch.cat([mu, logvar], 1)
                    state_embedding = state_embedding.squeeze()
                    state_embedding = state_embedding.cpu().numpy()
            elif self.embedding_type == 'SR':
                if self.SR_train_algo == 'TD':
                    state = np.array(state).flatten()
                    state_embedding = np.dot(self.projection, state)
                    with torch.no_grad():
                        state_embedding = self.mlp(
                            torch.tensor(state_embedding)).cpu().numpy()
                elif self.SR_train_algo == 'DP':
                    s = self.env.state
                    state_embedding = self.true_SR_dict[s]
            state_embedding = state_embedding / np.linalg.norm(state_embedding)
            if RENDER:
                self.env.render()
                time.sleep(RENDER_SPEED)

            # Get estimated value of each action
            values = [
                self.qec.estimate(state_embedding, action)
                for action in self.actions
            ]

            action = self.choose_action(values)
            state, reward, done, _ = self.env.step(action)
            if self.Q_train_algo == 'MC':
                self.add_to_memory(state_embedding, action, reward, time)
            elif self.Q_train_algo == 'TD':
                if time > 1:
                    self.TD_update(prev_embedding, prev_action, prev_reward,
                                   values, time)
            prev_reward = reward
            prev_embedding = state_embedding
            prev_action = action
            total_reward += reward
            total_steps += 1
            episode_frames += self.env.skip

        if self.Q_train_algo == 'MC':
            self.MC_update()
        if self.episodes % self.print_every == 0:
            print("KNN usage:", np.mean(self.qec.knn_usage))
            self.qec.knn_usage = []
            print("Proportion of replace:", np.mean(self.qec.replace_usage))
            self.qec.replace_usage = []
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, episode_frames, total_reward
        else:
            return episode_frames, total_reward

    def warmup(self):
        """
        Collect 1 million frames from random policy and train VAE
        """
        if self.embedding_type == 'VAE':
            if self.load_vae_from is not None:
                self.vae.load_state_dict(torch.load(self.load_vae_from))
                self.vae = self.vae.to(self.device)
            else:
                # Collect 1 million frames from random policy
                print("Generating dataset to train VAE from random policy")
                vae_data = []
                state = self.env.reset()
                total_frames = 0
                while total_frames < self.vae_train_frames:
                    action = random.randint(0, self.env.action_space.n - 1)
                    state, reward, done, _ = self.env.step(action)
                    vae_data.append(state)
                    total_frames += self.env.skip
                    if done:
                        state = self.env.reset()
                # Dataset, Dataloader for 1 million frames
                vae_data = torch.tensor(
                    vae_data
                )  # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack)
                vae_data = vae_data.permute(0, 3, 1, 2)  # (N x C x H x W)
                vae_dataset = TensorDataset(vae_data)
                vae_dataloader = DataLoader(vae_dataset,
                                            batch_size=self.vae_batch_size,
                                            shuffle=True)
                # Training loop
                print("Training VAE")
                self.vae.train()
                for epoch in range(self.vae_epochs):
                    train_loss = 0
                    for batch_idx, batch in enumerate(vae_dataloader):
                        batch = batch[0].to(self.device)
                        self.optimizer.zero_grad()
                        recon_batch, mu, logvar = self.vae(batch)
                        loss = self.vae_loss(recon_batch, batch, mu, logvar)
                        train_loss += loss.item()
                        loss.backward()
                        self.optimizer.step()
                        if batch_idx % self.vae_print_every == 0:
                            msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * len(batch),
                                len(vae_dataloader.dataset),
                                loss.item() / len(batch))
                            print(msg)
                    print('====> Epoch {} Average loss: {:.4f}'.format(
                        epoch, train_loss / len(vae_dataloader.dataset)))
                    if self.vae_weights_file is not None:
                        torch.save(self.vae.state_dict(),
                                   self.vae_weights_file)
            self.vae.eval()
        elif self.embedding_type == 'SR':
            if self.SR_embedding_type == 'random':
                if self.SR_train_algo == 'TD':
                    total_frames = 0
                    transitions = []
                    while total_frames < self.SR_train_frames:
                        observation = self.env.reset()
                        s_t = self.env.state  # will not work on Atari
                        done = False
                        while not done:
                            action = np.random.randint(self.env.action_space.n)
                            observation, reward, done, _ = self.env.step(
                                action)
                            s_tp1 = self.env.state  # will not work on Atari
                            transitions.append((s_t, s_tp1))
                            total_frames += self.env.skip
                            s_t = s_tp1
                    # Dataset, Dataloader
                    dataset = SRDataset(self.env, self.projection, transitions)
                    dataloader = DataLoader(dataset,
                                            batch_size=self.SR_batch_size,
                                            shuffle=True)
                    train_losses = []
                    #Training loop
                    for epoch in range(self.SR_epochs):
                        for batch_idx, batch in enumerate(dataloader):
                            self.optimizer.zero_grad()
                            e_t, e_tp1 = batch
                            e_t = e_t.to(self.device)
                            e_tp1 = e_tp1.to(self.device)
                            mhat_t = self.mlp(e_t)
                            mhat_tp1 = self.mlp(e_tp1)
                            target = e_t + self.gamma * mhat_tp1.detach()
                            loss = self.loss_fn(mhat_t, target)
                            loss.backward()
                            self.optimizer.step()
                            train_losses.append(loss.item())
                        print("Epoch:", epoch, "Average loss",
                              np.mean(train_losses))

                    emb_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    SR_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    labels = []
                    room_size = self.env.room_size
                    for i, (state,
                            obs) in enumerate(self.env.state_dict.items()):
                        emb = np.dot(self.projection, obs.flatten())
                        emb_reps[i, :] = emb
                        with torch.no_grad():
                            emb = torch.tensor(emb).to(self.device)
                            SR = self.mlp(emb).cpu().numpy()
                        SR_reps[i, :] = SR
                        if state[0] < room_size + 1 and state[
                                1] < room_size + 1:
                            label = 0
                        elif state[0] > room_size + 1 and state[
                                1] < room_size + 1:
                            label = 1
                        elif state[0] < room_size + 1 and state[
                                1] > room_size + 1:
                            label = 2
                        elif state[0] > room_size + 1 and state[
                                1] > room_size + 1:
                            label = 3
                        else:
                            label = 4
                        labels.append(label)
                    np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps)
                    np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps)
                    np.save('%s_labels.npy' % (self.SR_filename), labels)
                elif self.SR_train_algo == 'MC':
                    pass
                elif self.SR_train_algo == 'DP':
                    # Use this to ensure same order every time
                    idx_to_state = {
                        i: state
                        for i, state in enumerate(self.env.state_dict.keys())
                    }
                    state_to_idx = {v: k for k, v in idx_to_state.items()}
                    T = np.zeros([self.env.n_states, self.env.n_states])
                    for i, s in idx_to_state.items():
                        for a in range(4):
                            self.env.state = s
                            _, _, _, _ = self.env.step(a)
                            s_tp1 = self.env.state
                            T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25
                    true_SR = np.eye(self.env.n_states)
                    done = False
                    t = 0
                    while not done:
                        t += 1
                        new_SR = true_SR + (self.SR_gamma**t) * (np.matmul(
                            true_SR, T))
                        done = np.max(np.abs(true_SR - new_SR)) < 1e-10
                        true_SR = new_SR
                    self.true_SR_dict = {}
                    for s, obs in self.env.state_dict.items():
                        idx = state_to_idx[s]
                        self.true_SR_dict[s] = true_SR[idx, :]
        else:
            pass  # random projection doesn't require warmup
예제 #3
0
class VAE_TRAINER():
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0

    def train(self, epoch):
        self.model.train()
        # dataset_train.load_next_buffer()
        mse_loss = 0
        ssim_loss = 0
        train_loss = 0
        # Train step
        for batch_idx, data in enumerate(self.train_loader):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar,
                                                 self.alpha)
            loss.backward()

            train_loss += loss.item()
            ssim_loss += ssim
            mse_loss += mse
            self.optimizer.step()

            if batch_idx % params["log_interval"] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim))

        step = len(self.train_loader.dataset) / float(
            self.params["batch_size"])
        mean_train_loss = train_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Epoch: {} Average loss: {:.4f}'.format(
            epoch, mean_train_loss))
        print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch,
                              mean_train_loss)
        return

    def eval(self):
        self.model.eval()
        # dataset_test.load_next_buffer()
        eval_loss = 0
        mse_loss = 0
        ssim_loss = 0
        vis = True
        with torch.no_grad():
            # Eval step
            for data in self.eval_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar = self.model(data)

                loss, mse, ssim = self.loss_function(recon_batch, data, mu,
                                                     logvar, self.alpha)
                eval_loss += loss.item()
                ssim_loss += ssim
                mse_loss += mse
                if vis:
                    org_title = "Epoch: " + str(epoch)
                    comparison1 = torch.cat([
                        data[:4],
                        recon_batch.view(params["batch_size"], 3,
                                         params["img_size"],
                                         params["img_size"])[:4]
                    ])
                    if self.visualize:
                        self.img_plotter.plot(comparison1, org_title)
                    vis = False

        step = len(self.eval_loader.dataset) / float(params["batch_size"])
        mean_eval_loss = eval_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Eval set loss: {:.4f}'.format(mean_eval_loss))
        print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch,
                              mean_eval_loss)
            self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch,
                              mean_mse_loss)
            self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch,
                              mean_ssim_loss)

        return mean_eval_loss

    def init_vae_model(self):
        self.vae_dir = os.path.join(self.params["logdir"], 'vae')
        check_dir(self.vae_dir, 'samples')
        if not self.params["noreload"]:  # and os.path.exists(reload_file):
            reload_file = os.path.join(self.params["vae_location"], 'best.tar')
            state = torch.load(reload_file)
            print("Reloading model at epoch {}"
                  ", with eval error {}".format(state['epoch'],
                                                state['precision']))
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])

    def checkpoint(self, cur_best, eval_loss):
        # Save the best and last checkpoint
        best_filename = os.path.join(self.vae_dir, 'best.tar')
        filename = os.path.join(self.vae_dir, 'checkpoint.tar')
        is_best = not cur_best or eval_loss < cur_best
        if is_best:
            cur_best = eval_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'precision': eval_loss,
                'optimizer': self.optimizer.state_dict()
            }, is_best, filename, best_filename)
        return cur_best

    def plot(self, train, eval, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, eval, label="eval loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/vae_training_curve.png")
        plt.close()
예제 #4
0
            qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
            masks = masks.to(device, dtype=torch.float)
            qsms = qsms * masks

            recon_loss, kl_loss = vae_train(model=vae3d, optimizer=optimizer, x=qsms, mask=masks)
            recon_loss_sum += recon_loss
            kl_loss_sum += kl_loss
            gen_iterations += 1

            time.sleep(1)

        scheduler.step(epoch)

        # validation phase
        vae3d.eval()
        loss_total = 0
        idx = 0
        with torch.no_grad():  # to solve memory exploration issue
            for idx, (rdfs, masks, weights, qsms) in enumerate(valLoader):
                idx += 1
                qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
                masks = masks.to(device, dtype=torch.float)
                qsms = qsms * masks
                
                x_mu, x_var, z_mu, z_logvar = vae3d(qsms)
                x_factor = torch.prod(torch.tensor(x_mu.size()))
                z_factor = torch.prod(torch.tensor(z_mu.size()))    
                # recon_loss = 0.5 * torch.sum((x_mu*masks - qsms*masks)**2 / x_var + torch.log(x_var)*masks)
                recon_loss = 0.5 * torch.sum((x_mu - qsms)**2 / x_var + torch.log(x_var)) / x_factor
                # recon_loss = torch.sum((x_mu - qsms)**2) / x_factor
예제 #5
0
                writer.add_scalar(tag='loss/test',
                                  scalar_value=loss,
                                  global_step=i)

            likelihood_x[i * batch_size:(i + 1) * batch_size] = logsumexp(
                losses, axis=1) - np.log(number_samples)

        return np.mean(likelihood_x)


if __name__ == '__main__':
    from loaders.load_funtions import load_MNIST
    from models.VAE import VAE

    import pathlib

    _, loader, _, dataset_type = load_MNIST('../datasets/')

    output_dit = pathlib.Path('../outputs/')
    input_shape = (1, 28, 28)

    model = VAE(dimension_latent_space=50,
                input_shape=input_shape,
                dataset_type=dataset_type)
    model.load_state_dict(
        torch.load('../outputs/trained/mnist_bin_standard_50/model',
                   map_location='cpu'))
    model.eval()

    print(model.calculate_likelihood(loader, 100))