예제 #1
0
# 	generate_data(game='SonicTheHedgehog-Genesis', state=level, extension_name=VAE_TRAINING_EXT, frame_jump=4, save_actions=False)
'''
	===============================================
	2 - VAE training
	===============================================
'''

vae = VAE()
# /!\ If the memory of your graphic card is too low, you can choose a smaller batch_size
# print('\nVAE training.')
# Do one train with Dropout
# vae.train(filepath=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5', batch_size=16, epochs=200)
# Then load the trained network but without dropout to push the training a bit further (you need to comment the dropout lines in the VAE.py file)
vae.load_weights(file_path=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5')
vae.train(filepath=SAVED_MODELS_DIR + '/VAE_GreenHillZone.h5',
          batch_size=16,
          epochs=200)
'''
	===============================================
	3 - Visualization of the VAE
	===============================================
'''

# print('VAE visualization')
# images_array_path = IMG_DIR + '/GreenHillZone.Act2.vae_train1.npy'
# vae.generate_render(data_path=images_array_path)
'''
	===============================================
	4 - Generation of the LSTM's training dataset
	This training makes the LSTM learns the logic and physic of the game in order to predict the next (latent) frame.
	Make sure to record a lot of different situations (Sonic stuck in front of a wall, jump and move in the air...)
class MFEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an MFEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.actions = range(self.env.action_space.n)
        self.frames_to_stack = args.frames_to_stack
        self.Q_train_algo = args.Q_train_algo
        self.use_Q_max = args.use_Q_max
        self.force_knn = args.force_knn
        self.weight_neighbors = args.weight_neighbors
        self.delta = args.delta
        self.device = device
        self.rs = np.random.RandomState(args.seed)

        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.lr = args.lr
        self.q_lr = args.q_lr

        # Autoencoder for state embedding network
        self.vae_batch_size = args.vae_batch_size  # batch size for training VAE
        self.vae_epochs = args.vae_epochs  # number of epochs to run VAE
        self.embedding_type = args.embedding_type
        self.SR_embedding_type = args.SR_embedding_type
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width

        if self.embedding_type == 'VAE':
            self.vae_train_frames = args.vae_train_frames
            self.vae_loss = VAELoss()
            self.vae_print_every = args.vae_print_every
            self.load_vae_from = args.load_vae_from
            self.vae_weights_file = args.vae_weights_file
            self.vae = VAE(self.frames_to_stack, self.embedding_size,
                           self.in_height, self.in_width)
            self.vae = self.vae.to(self.device)
            self.optimizer = get_optimizer(args.optimizer,
                                           self.vae.parameters(), self.lr)
        elif self.embedding_type == 'random':
            self.projection = self.rs.randn(
                self.embedding_size, self.in_height * self.in_width *
                self.frames_to_stack).astype(np.float32)
        elif self.embedding_type == 'SR':
            self.SR_train_algo = args.SR_train_algo
            self.SR_gamma = args.SR_gamma
            self.SR_epochs = args.SR_epochs
            self.SR_batch_size = args.SR_batch_size
            self.n_hidden = args.n_hidden
            self.SR_train_frames = args.SR_train_frames
            self.SR_filename = args.SR_filename
            if self.SR_embedding_type == 'random':
                self.projection = np.random.randn(
                    self.embedding_size,
                    self.in_height * self.in_width).astype(np.float32)
                if self.SR_train_algo == 'TD':
                    self.mlp = MLP(self.embedding_size, self.n_hidden)
                    self.mlp = self.mlp.to(self.device)
                    self.loss_fn = nn.MSELoss(reduction='mean')
                    params = self.mlp.parameters()
                    self.optimizer = get_optimizer(args.optimizer, params,
                                                   self.lr)

        # QEC
        self.max_memory = args.max_memory
        self.num_neighbors = args.num_neighbors
        self.qec = QEC(self.actions, self.max_memory, self.num_neighbors,
                       self.use_Q_max, self.force_knn, self.weight_neighbors,
                       self.delta, self.q_lr)

        #self.state = np.empty(self.embedding_size, self.projection.dtype)
        #self.action = int
        self.memory = []
        self.print_every = args.print_every
        self.episodes = 0

    def choose_action(self, values):
        """
        Choose epsilon-greedy policy according to Q-estimates
        """
        # Exploration
        if self.rs.random_sample() < self.epsilon:
            self.action = self.rs.choice(self.actions)

        # Exploitation
        else:
            best_actions = np.argwhere(values == np.max(values)).flatten()
            self.action = self.rs.choice(best_actions)

        return self.action

    def TD_update(self, prev_embedding, prev_action, reward, values, time):
        # On-policy value estimate of current state (epsiloln-greedy)
        # Expected Sarsa
        v_t = (1 -
               self.epsilon) * np.max(values) + self.epsilon * np.mean(values)
        value = reward + self.gamma * v_t
        self.qec.update(prev_embedding, prev_action, value, time - 1)

    def MC_update(self):
        value = 0.0
        for _ in range(len(self.memory)):
            experience = self.memory.pop()
            value = value * self.gamma + experience["reward"]
            self.qec.update(
                experience["state"],
                experience["action"],
                value,
                experience["time"],
            )

    def add_to_memory(self, state_embedding, action, reward, time):
        self.memory.append({
            "state": state_embedding,
            "action": action,
            "reward": reward,
            "time": time,
        })

    def run_episode(self):
        """
        Train an MFEC agent for a single episode:
            Interact with environment
            Perform update
        """
        self.episodes += 1
        RENDER_SPEED = 0.04
        RENDER = False

        episode_frames = 0
        total_reward = 0
        total_steps = 0

        # Update epsilon
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay

        #self.env.seed(random.randint(0, 1000000))
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        done = False
        time = 0
        while not done:
            time += 1
            if self.embedding_type == 'random':
                state = np.array(state).flatten()
                state_embedding = np.dot(self.projection, state)
            elif self.embedding_type == 'VAE':
                state = torch.tensor(state).permute(2, 0, 1)  #(H,W,C)->(C,H,W)
                state = state.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    mu, logvar = self.vae.encoder(state)
                    state_embedding = torch.cat([mu, logvar], 1)
                    state_embedding = state_embedding.squeeze()
                    state_embedding = state_embedding.cpu().numpy()
            elif self.embedding_type == 'SR':
                if self.SR_train_algo == 'TD':
                    state = np.array(state).flatten()
                    state_embedding = np.dot(self.projection, state)
                    with torch.no_grad():
                        state_embedding = self.mlp(
                            torch.tensor(state_embedding)).cpu().numpy()
                elif self.SR_train_algo == 'DP':
                    s = self.env.state
                    state_embedding = self.true_SR_dict[s]
            state_embedding = state_embedding / np.linalg.norm(state_embedding)
            if RENDER:
                self.env.render()
                time.sleep(RENDER_SPEED)

            # Get estimated value of each action
            values = [
                self.qec.estimate(state_embedding, action)
                for action in self.actions
            ]

            action = self.choose_action(values)
            state, reward, done, _ = self.env.step(action)
            if self.Q_train_algo == 'MC':
                self.add_to_memory(state_embedding, action, reward, time)
            elif self.Q_train_algo == 'TD':
                if time > 1:
                    self.TD_update(prev_embedding, prev_action, prev_reward,
                                   values, time)
            prev_reward = reward
            prev_embedding = state_embedding
            prev_action = action
            total_reward += reward
            total_steps += 1
            episode_frames += self.env.skip

        if self.Q_train_algo == 'MC':
            self.MC_update()
        if self.episodes % self.print_every == 0:
            print("KNN usage:", np.mean(self.qec.knn_usage))
            self.qec.knn_usage = []
            print("Proportion of replace:", np.mean(self.qec.replace_usage))
            self.qec.replace_usage = []
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, episode_frames, total_reward
        else:
            return episode_frames, total_reward

    def warmup(self):
        """
        Collect 1 million frames from random policy and train VAE
        """
        if self.embedding_type == 'VAE':
            if self.load_vae_from is not None:
                self.vae.load_state_dict(torch.load(self.load_vae_from))
                self.vae = self.vae.to(self.device)
            else:
                # Collect 1 million frames from random policy
                print("Generating dataset to train VAE from random policy")
                vae_data = []
                state = self.env.reset()
                total_frames = 0
                while total_frames < self.vae_train_frames:
                    action = random.randint(0, self.env.action_space.n - 1)
                    state, reward, done, _ = self.env.step(action)
                    vae_data.append(state)
                    total_frames += self.env.skip
                    if done:
                        state = self.env.reset()
                # Dataset, Dataloader for 1 million frames
                vae_data = torch.tensor(
                    vae_data
                )  # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack)
                vae_data = vae_data.permute(0, 3, 1, 2)  # (N x C x H x W)
                vae_dataset = TensorDataset(vae_data)
                vae_dataloader = DataLoader(vae_dataset,
                                            batch_size=self.vae_batch_size,
                                            shuffle=True)
                # Training loop
                print("Training VAE")
                self.vae.train()
                for epoch in range(self.vae_epochs):
                    train_loss = 0
                    for batch_idx, batch in enumerate(vae_dataloader):
                        batch = batch[0].to(self.device)
                        self.optimizer.zero_grad()
                        recon_batch, mu, logvar = self.vae(batch)
                        loss = self.vae_loss(recon_batch, batch, mu, logvar)
                        train_loss += loss.item()
                        loss.backward()
                        self.optimizer.step()
                        if batch_idx % self.vae_print_every == 0:
                            msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * len(batch),
                                len(vae_dataloader.dataset),
                                loss.item() / len(batch))
                            print(msg)
                    print('====> Epoch {} Average loss: {:.4f}'.format(
                        epoch, train_loss / len(vae_dataloader.dataset)))
                    if self.vae_weights_file is not None:
                        torch.save(self.vae.state_dict(),
                                   self.vae_weights_file)
            self.vae.eval()
        elif self.embedding_type == 'SR':
            if self.SR_embedding_type == 'random':
                if self.SR_train_algo == 'TD':
                    total_frames = 0
                    transitions = []
                    while total_frames < self.SR_train_frames:
                        observation = self.env.reset()
                        s_t = self.env.state  # will not work on Atari
                        done = False
                        while not done:
                            action = np.random.randint(self.env.action_space.n)
                            observation, reward, done, _ = self.env.step(
                                action)
                            s_tp1 = self.env.state  # will not work on Atari
                            transitions.append((s_t, s_tp1))
                            total_frames += self.env.skip
                            s_t = s_tp1
                    # Dataset, Dataloader
                    dataset = SRDataset(self.env, self.projection, transitions)
                    dataloader = DataLoader(dataset,
                                            batch_size=self.SR_batch_size,
                                            shuffle=True)
                    train_losses = []
                    #Training loop
                    for epoch in range(self.SR_epochs):
                        for batch_idx, batch in enumerate(dataloader):
                            self.optimizer.zero_grad()
                            e_t, e_tp1 = batch
                            e_t = e_t.to(self.device)
                            e_tp1 = e_tp1.to(self.device)
                            mhat_t = self.mlp(e_t)
                            mhat_tp1 = self.mlp(e_tp1)
                            target = e_t + self.gamma * mhat_tp1.detach()
                            loss = self.loss_fn(mhat_t, target)
                            loss.backward()
                            self.optimizer.step()
                            train_losses.append(loss.item())
                        print("Epoch:", epoch, "Average loss",
                              np.mean(train_losses))

                    emb_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    SR_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    labels = []
                    room_size = self.env.room_size
                    for i, (state,
                            obs) in enumerate(self.env.state_dict.items()):
                        emb = np.dot(self.projection, obs.flatten())
                        emb_reps[i, :] = emb
                        with torch.no_grad():
                            emb = torch.tensor(emb).to(self.device)
                            SR = self.mlp(emb).cpu().numpy()
                        SR_reps[i, :] = SR
                        if state[0] < room_size + 1 and state[
                                1] < room_size + 1:
                            label = 0
                        elif state[0] > room_size + 1 and state[
                                1] < room_size + 1:
                            label = 1
                        elif state[0] < room_size + 1 and state[
                                1] > room_size + 1:
                            label = 2
                        elif state[0] > room_size + 1 and state[
                                1] > room_size + 1:
                            label = 3
                        else:
                            label = 4
                        labels.append(label)
                    np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps)
                    np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps)
                    np.save('%s_labels.npy' % (self.SR_filename), labels)
                elif self.SR_train_algo == 'MC':
                    pass
                elif self.SR_train_algo == 'DP':
                    # Use this to ensure same order every time
                    idx_to_state = {
                        i: state
                        for i, state in enumerate(self.env.state_dict.keys())
                    }
                    state_to_idx = {v: k for k, v in idx_to_state.items()}
                    T = np.zeros([self.env.n_states, self.env.n_states])
                    for i, s in idx_to_state.items():
                        for a in range(4):
                            self.env.state = s
                            _, _, _, _ = self.env.step(a)
                            s_tp1 = self.env.state
                            T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25
                    true_SR = np.eye(self.env.n_states)
                    done = False
                    t = 0
                    while not done:
                        t += 1
                        new_SR = true_SR + (self.SR_gamma**t) * (np.matmul(
                            true_SR, T))
                        done = np.max(np.abs(true_SR - new_SR)) < 1e-10
                        true_SR = new_SR
                    self.true_SR_dict = {}
                    for s, obs in self.env.state_dict.items():
                        idx = state_to_idx[s]
                        self.true_SR_dict[s] = true_SR[idx, :]
        else:
            pass  # random projection doesn't require warmup
예제 #3
0
#TestFile.py


# Below two lines are for f*****g graphviz
#import os
#os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'

from models.VAE import VAE
from utils.loader import load_celeb_generator,load_mnist
vae = VAE((28,28,1),[32,64,64]
                , [3,3,3]
                , [2,2,1]
                , [64,64,1]
                , [3,3,3]
                , [2,2,1]
                , 200)
(a,b),(c,d) = load_mnist()
vae.compile(0.005,1000)
vae.train(a,100,1024)
예제 #4
0
from tensorflow.python.keras.datasets import mnist
from models.VAE import VAE
import tensorflow as tf
import numpy as np

(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

vae = VAE()
vae.build()
vae.train(x_train, x_test)
예제 #5
0
파일: train_VAE.py 프로젝트: jiaj15/SAIL
def train_vae(args, dtype=torch.float32):
    torch.set_default_dtype(dtype)
    state_dim = args.state_dim
    output_path = args.output_path
    # generate state pairs
    expert_traj_raw = list(pickle.load(open(args.expert_traj_path, "rb")))
    state_pairs = generate_pairs(expert_traj_raw,
                                 state_dim,
                                 args.size_per_traj,
                                 max_step=10,
                                 min_step=5)  # tune the step size if needed.
    # shuffle and split
    idx = np.arange(state_pairs.shape[0])
    np.random.shuffle(idx)
    state_pairs = state_pairs[idx, :]
    split = (state_pairs.shape[0] * 19) // 20
    state_tuples = state_pairs[:split, :]
    test_state_tuples = state_pairs[split:, :]
    print(state_tuples.shape)
    print(test_state_tuples.shape)

    goal_model = VAE(state_dim, latent_dim=128)
    optimizer_vae = torch.optim.Adam(goal_model.parameters(), lr=args.model_lr)
    save_path = '{}_softbc_{}_{}'.format(datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name, \
                                                  args.beta)
    writer = SummaryWriter(log_dir=os.path.join(output_path, 'runs/' +
                                                save_path))

    if args.weight:
        state_dim = state_dim + 1

        state_tuples = torch.from_numpy(state_pairs).to(dtype)
        s, t = state_tuples[:, :state_dim - 1], state_tuples[:, state_dim:2 *
                                                             state_dim]

        state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype)
        s_test, t_test = state_tuples_test[:, :state_dim -
                                           1], state_tuples_test[:,
                                                                 state_dim:2 *
                                                                 state_dim]
    else:
        state_tuples = torch.from_numpy(state_pairs).to(dtype)
        s, t = state_tuples[:, :state_dim], state_tuples[:, state_dim:2 *
                                                         state_dim]

        state_tuples_test = torch.from_numpy(test_state_tuples).to(dtype)
        s_test, t_test = state_tuples_test[:, :
                                           state_dim], state_tuples_test[:,
                                                                         state_dim:
                                                                         2 *
                                                                         state_dim]

    for i in range(1, args.iter + 1):
        loss = goal_model.train(s, t, epoch=args.epoch, optimizer=optimizer_vae, \
                                        batch_size=args.optim_batch_size, beta=args.beta, use_weight=args.weight)
        next_states = goal_model.get_next_states(s_test)
        if args.weight:
            val_error = (t_test[:, -1].unsqueeze(1) *
                         (t_test[:, :-1] - next_states)**2).mean()
        else:
            val_error = ((t_test[:, :-1] - next_states)**2).mean()
        writer.add_scalar('loss/vae', loss, i)
        writer.add_scalar('valid/vae', val_error, i)
        if i % args.lr_decay_rate == 0:
            adjust_lr(optimizer_vae, 2.)
        torch.save(
            goal_model.state_dict(),
            os.path.join(output_path,
                         '{}_{}_vae.pt'.format(args.env_name, str(args.beta))))
예제 #6
0
    dataLoader_train = dataLoader_val
    trainLoader = valLoader

    epoch = 0
    gen_iterations = 1
    display_iters = 5
    recon_loss_sum, kl_loss_sum = 0, 0
    Validation_loss = []
    loss_L1 = lossL1()

    while epoch < niter:
        epoch += 1

        # training phase
        vae3d.train()
        for idx, (rdfs, masks, weights, qsms) in enumerate(trainLoader):
            if gen_iterations%display_iters == 0:
                print('epochs: [%d/%d], batchs: [%d/%d], time: %ds, case_validation: %f'
                    % (epoch, niter, idx, dataLoader_train.num_samples//batch_size+1, time.time()-t0, opt['case_validation']))
                print('Recon loss: %f, kl_loss: %f' % (recon_loss_sum/display_iters, kl_loss_sum/display_iters))
                if epoch > 1:
                    print('Validation loss of last epoch: %f' % (Validation_loss[-1]))

                recon_loss_sum, kl_loss_sum = 0, 0

            qsms = (qsms.to(device, dtype=torch.float) + trans) * scale
            masks = masks.to(device, dtype=torch.float)
            qsms = qsms * masks

            recon_loss, kl_loss = vae_train(model=vae3d, optimizer=optimizer, x=qsms, mask=masks)
예제 #7
0
class VAE_TRAINER():
    def __init__(self, params):

        self.params = params
        self.loss_function = {
            'ms-ssim': ms_ssim_loss,
            'mse': mse_loss,
            'mix': mix_loss
        }[params["loss"]]

        # Choose device
        self.cuda = params["cuda"] and torch.cuda.is_available()
        torch.manual_seed(params["seed"])
        # Fix numeric divergence due to bug in Cudnn
        torch.backends.cudnn.benchmark = True
        self.device = torch.device("cuda" if self.cuda else "cpu")

        # Prepare data transformations
        red_size = params["img_size"]
        transform_train = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

        transform_val = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((red_size, red_size)),
            transforms.ToTensor(),
        ])

        # Initialize Data loaders
        op_dataset = RolloutObservationDataset(params["path_data"],
                                               transform_train,
                                               train=True)
        val_dataset = RolloutObservationDataset(params["path_data"],
                                                transform_val,
                                                train=False)

        self.train_loader = torch.utils.data.DataLoader(
            op_dataset,
            batch_size=params["batch_size"],
            shuffle=True,
            num_workers=0)
        self.eval_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=params["batch_size"],
            shuffle=False,
            num_workers=0)

        # Initialize model and hyperparams
        self.model = VAE(nc=3,
                         ngf=64,
                         ndf=64,
                         latent_variable_size=params["latent_size"],
                         cuda=self.cuda).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())
        self.init_vae_model()
        self.visualize = params["visualize"]
        if self.visualize:
            self.plotter = VisdomLinePlotter(env_name=params['env'])
            self.img_plotter = VisdomImagePlotter(env_name=params['env'])
        self.alpha = params["alpha"] if params["alpha"] else 1.0

    def train(self, epoch):
        self.model.train()
        # dataset_train.load_next_buffer()
        mse_loss = 0
        ssim_loss = 0
        train_loss = 0
        # Train step
        for batch_idx, data in enumerate(self.train_loader):
            data = data.to(self.device)
            self.optimizer.zero_grad()
            recon_batch, mu, logvar = self.model(data)
            loss, mse, ssim = self.loss_function(recon_batch, data, mu, logvar,
                                                 self.alpha)
            loss.backward()

            train_loss += loss.item()
            ssim_loss += ssim
            mse_loss += mse
            self.optimizer.step()

            if batch_idx % params["log_interval"] == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data),
                    len(self.train_loader.dataset),
                    100. * batch_idx / len(self.train_loader), loss.item()))
                print('MSE: {} , SSIM: {:.4f}'.format(mse, ssim))

        step = len(self.train_loader.dataset) / float(
            self.params["batch_size"])
        mean_train_loss = train_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Epoch: {} Average loss: {:.4f}'.format(
            epoch, mean_train_loss))
        print('-- Average MSE: {:.5f} Average SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'train', 'VAE Train Loss', epoch,
                              mean_train_loss)
        return

    def eval(self):
        self.model.eval()
        # dataset_test.load_next_buffer()
        eval_loss = 0
        mse_loss = 0
        ssim_loss = 0
        vis = True
        with torch.no_grad():
            # Eval step
            for data in self.eval_loader:
                data = data.to(self.device)
                recon_batch, mu, logvar = self.model(data)

                loss, mse, ssim = self.loss_function(recon_batch, data, mu,
                                                     logvar, self.alpha)
                eval_loss += loss.item()
                ssim_loss += ssim
                mse_loss += mse
                if vis:
                    org_title = "Epoch: " + str(epoch)
                    comparison1 = torch.cat([
                        data[:4],
                        recon_batch.view(params["batch_size"], 3,
                                         params["img_size"],
                                         params["img_size"])[:4]
                    ])
                    if self.visualize:
                        self.img_plotter.plot(comparison1, org_title)
                    vis = False

        step = len(self.eval_loader.dataset) / float(params["batch_size"])
        mean_eval_loss = eval_loss / step
        mean_ssim_loss = ssim_loss / step
        mean_mse_loss = mse_loss / step
        print('-- Eval set loss: {:.4f}'.format(mean_eval_loss))
        print('-- Eval MSE: {:.5f} Eval SSIM: {:.4f}'.format(
            mean_mse_loss, mean_ssim_loss))
        if self.visualize:
            self.plotter.plot('loss', 'eval', 'VAE Eval Loss', epoch,
                              mean_eval_loss)
            self.plotter.plot('loss', 'mse train', 'VAE MSE Loss', epoch,
                              mean_mse_loss)
            self.plotter.plot('loss', 'ssim train', 'VAE MSE Loss', epoch,
                              mean_ssim_loss)

        return mean_eval_loss

    def init_vae_model(self):
        self.vae_dir = os.path.join(self.params["logdir"], 'vae')
        check_dir(self.vae_dir, 'samples')
        if not self.params["noreload"]:  # and os.path.exists(reload_file):
            reload_file = os.path.join(self.params["vae_location"], 'best.tar')
            state = torch.load(reload_file)
            print("Reloading model at epoch {}"
                  ", with eval error {}".format(state['epoch'],
                                                state['precision']))
            self.model.load_state_dict(state['state_dict'])
            self.optimizer.load_state_dict(state['optimizer'])

    def checkpoint(self, cur_best, eval_loss):
        # Save the best and last checkpoint
        best_filename = os.path.join(self.vae_dir, 'best.tar')
        filename = os.path.join(self.vae_dir, 'checkpoint.tar')
        is_best = not cur_best or eval_loss < cur_best
        if is_best:
            cur_best = eval_loss

        save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': self.model.state_dict(),
                'precision': eval_loss,
                'optimizer': self.optimizer.state_dict()
            }, is_best, filename, best_filename)
        return cur_best

    def plot(self, train, eval, epochs):
        plt.plot(epochs, train, label="train loss")
        plt.plot(epochs, eval, label="eval loss")
        plt.legend()
        plt.grid()
        plt.savefig(self.params["logdir"] + "/vae_training_curve.png")
        plt.close()
예제 #8
0
def train_vae(
    model: VAE,
    optimizer,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    nr_epochs: int,
    device: str,
    results_writer: ResultsWriter,
    config: Config,
    decoder,
):
    """
    Trains VAE, bases on a config file
    """

    # Define highest values
    best_valid_loss = np.inf
    previous_valid_loss = np.inf

    # Create elbo loss function
    loss_fn = make_elbo_criterion(
        vocab_size=model.vocab_size,
        latent_size=model.latent_size,
        freebits_param=config.freebits_param,
        mu_force_beta_param=config.mu_force_beta_param)

    for epoch in range(nr_epochs):
        for idx, (train_batch, batch_sent_lengths) in enumerate(train_loader):
            it = epoch * len(train_loader) + idx

            batch_loss, perp, preds = train_batch_vae(
                model, optimizer, loss_fn, train_batch, device,
                config.mu_force_beta_param, batch_sent_lengths, results_writer,
                it)
            elbo_loss, kl_loss, nlll, mu_loss = batch_loss

            if it % config.print_every == 0:
                print(
                    f'Iteration: {it} || NLLL: {nlll} || Perp: {perp} || KL Loss: {kl_loss} || MuLoss: {mu_loss} || Total: {elbo_loss}'
                )

                # Store in the table
                train_vae_results = make_vae_results_dict(
                    batch_loss, perp, model, config, epoch, it)
                results_writer.add_train_batch_results(train_vae_results)

            if it % config.train_text_gen_every == 0:
                with torch.no_grad():
                    decoded_first_pred = decoder(preds.detach())
                    decoded_first_true = decoder(train_batch[:, 1:])
                    results_writer.add_sentence_predictions(
                        decoded_first_pred, decoded_first_true, it)

                    print(f'VAE is generating sentences on {it}: \n')
                    print(
                        f'\t The true sentence is: "{decoded_first_true}" \n')
                    print(
                        f'\t The predicted sentence is: "{decoded_first_pred}" \n'
                    )

            if idx % config.validate_every == 0 and it != 0:
                print('Validating model')
                valid_losses, valid_perp = evaluate_vae(
                    model,
                    valid_loader,
                    epoch,
                    device,
                    loss_fn,
                    config.mu_force_beta_param,
                    iteration=it)

                # Store validation results
                valid_vae_results = make_vae_results_dict(
                    valid_losses, valid_perp, model, config, epoch, it)
                results_writer.add_valid_results(valid_vae_results)

                valid_elbo_loss, valid_kl_loss, valid_nll_loss, valid_mu_loss = valid_losses
                print(
                    f'Validation Results || Elbo loss: {valid_elbo_loss} || KL loss: {valid_kl_loss} || NLLL {valid_nll_loss} || Perp: {valid_perp} ||MU loss {valid_mu_loss}'
                )

                # Check if the model is better and save
                previous_valid_loss = valid_elbo_loss
                if previous_valid_loss < best_valid_loss:
                    print(
                        f'New Best Validation score of {previous_valid_loss}!')
                    best_valid_loss = previous_valid_loss
                    save_model(
                        f'vae_best_mu{config.mu_force_beta_param}_wd{model.param_wdropout_k}_fb{config.freebits_param}',
                        model, optimizer, it)

                model.train()

    results_writer.save_train_results()
    results_writer.save_valid_results()

    print('Done training the VAE')