if 'vae' in model_path:
        model = VAE(encoder_hidden_size=config.vae_encoder_hidden_size,
                    decoder_hidden_size=config.vae_decoder_hidden_size,
                    latent_size=config.vae_latent_size,
                    vocab_size=config.vocab_size,
                    param_wdropout_k=-1,
                    embedding_size=config.embedding_size).to(config.device)

        loss_fn = make_elbo_criterion(config.vocab_size,
                                      config.vae_latent_size, -1, 0)

        prior = torch.distributions.Normal(torch.zeros(model.latent_size),
                                           torch.ones(model.latent_size))

        path_to_model = f'results/saved_models/{model_path}'
        model, _, _ = utils.load_model(path_to_model, model, config.device)
        model = model.to(config.device)

        vae_results_writer: SummaryWriter = SummaryWriter(
            comment=f"EVAL_{config.run_label}--{model_path}")
        (test_total_loss, test_total_kl_loss, test_total_nlll,
         test_total_mu_loss), test_perp = evaluate_vae(model, test_loader, -1,
                                                       config.device, loss_fn,
                                                       0, vae_results_writer,
                                                       'test')

        print(f'For model {model_path}: \n')
        print(
            f'Test Results || Elbo loss: {test_total_loss} || KL loss: {test_total_kl_loss} || NLLL {test_total_nlll} || Perp: {test_perp} ||MU loss {test_total_mu_loss}'
        )
class MFEC:
    def __init__(self, env, args, device='cpu'):
        """
        Instantiate an MFEC Agent
        ----------
        env: gym.Env
            gym environment to train on
        args: args class from argparser
            args are from from train.py: see train.py for help with each arg
        device: string
            'cpu' or 'cuda:0' depending on use_cuda flag from train.py
        """
        self.environment_type = args.environment_type
        self.env = env
        self.actions = range(self.env.action_space.n)
        self.frames_to_stack = args.frames_to_stack
        self.Q_train_algo = args.Q_train_algo
        self.use_Q_max = args.use_Q_max
        self.force_knn = args.force_knn
        self.weight_neighbors = args.weight_neighbors
        self.delta = args.delta
        self.device = device
        self.rs = np.random.RandomState(args.seed)

        # Hyperparameters
        self.epsilon = args.initial_epsilon
        self.final_epsilon = args.final_epsilon
        self.epsilon_decay = args.epsilon_decay
        self.gamma = args.gamma
        self.lr = args.lr
        self.q_lr = args.q_lr

        # Autoencoder for state embedding network
        self.vae_batch_size = args.vae_batch_size  # batch size for training VAE
        self.vae_epochs = args.vae_epochs  # number of epochs to run VAE
        self.embedding_type = args.embedding_type
        self.SR_embedding_type = args.SR_embedding_type
        self.embedding_size = args.embedding_size
        self.in_height = args.in_height
        self.in_width = args.in_width

        if self.embedding_type == 'VAE':
            self.vae_train_frames = args.vae_train_frames
            self.vae_loss = VAELoss()
            self.vae_print_every = args.vae_print_every
            self.load_vae_from = args.load_vae_from
            self.vae_weights_file = args.vae_weights_file
            self.vae = VAE(self.frames_to_stack, self.embedding_size,
                           self.in_height, self.in_width)
            self.vae = self.vae.to(self.device)
            self.optimizer = get_optimizer(args.optimizer,
                                           self.vae.parameters(), self.lr)
        elif self.embedding_type == 'random':
            self.projection = self.rs.randn(
                self.embedding_size, self.in_height * self.in_width *
                self.frames_to_stack).astype(np.float32)
        elif self.embedding_type == 'SR':
            self.SR_train_algo = args.SR_train_algo
            self.SR_gamma = args.SR_gamma
            self.SR_epochs = args.SR_epochs
            self.SR_batch_size = args.SR_batch_size
            self.n_hidden = args.n_hidden
            self.SR_train_frames = args.SR_train_frames
            self.SR_filename = args.SR_filename
            if self.SR_embedding_type == 'random':
                self.projection = np.random.randn(
                    self.embedding_size,
                    self.in_height * self.in_width).astype(np.float32)
                if self.SR_train_algo == 'TD':
                    self.mlp = MLP(self.embedding_size, self.n_hidden)
                    self.mlp = self.mlp.to(self.device)
                    self.loss_fn = nn.MSELoss(reduction='mean')
                    params = self.mlp.parameters()
                    self.optimizer = get_optimizer(args.optimizer, params,
                                                   self.lr)

        # QEC
        self.max_memory = args.max_memory
        self.num_neighbors = args.num_neighbors
        self.qec = QEC(self.actions, self.max_memory, self.num_neighbors,
                       self.use_Q_max, self.force_knn, self.weight_neighbors,
                       self.delta, self.q_lr)

        #self.state = np.empty(self.embedding_size, self.projection.dtype)
        #self.action = int
        self.memory = []
        self.print_every = args.print_every
        self.episodes = 0

    def choose_action(self, values):
        """
        Choose epsilon-greedy policy according to Q-estimates
        """
        # Exploration
        if self.rs.random_sample() < self.epsilon:
            self.action = self.rs.choice(self.actions)

        # Exploitation
        else:
            best_actions = np.argwhere(values == np.max(values)).flatten()
            self.action = self.rs.choice(best_actions)

        return self.action

    def TD_update(self, prev_embedding, prev_action, reward, values, time):
        # On-policy value estimate of current state (epsiloln-greedy)
        # Expected Sarsa
        v_t = (1 -
               self.epsilon) * np.max(values) + self.epsilon * np.mean(values)
        value = reward + self.gamma * v_t
        self.qec.update(prev_embedding, prev_action, value, time - 1)

    def MC_update(self):
        value = 0.0
        for _ in range(len(self.memory)):
            experience = self.memory.pop()
            value = value * self.gamma + experience["reward"]
            self.qec.update(
                experience["state"],
                experience["action"],
                value,
                experience["time"],
            )

    def add_to_memory(self, state_embedding, action, reward, time):
        self.memory.append({
            "state": state_embedding,
            "action": action,
            "reward": reward,
            "time": time,
        })

    def run_episode(self):
        """
        Train an MFEC agent for a single episode:
            Interact with environment
            Perform update
        """
        self.episodes += 1
        RENDER_SPEED = 0.04
        RENDER = False

        episode_frames = 0
        total_reward = 0
        total_steps = 0

        # Update epsilon
        if self.epsilon > self.final_epsilon:
            self.epsilon = self.epsilon * self.epsilon_decay

        #self.env.seed(random.randint(0, 1000000))
        state = self.env.reset()
        if self.environment_type == 'fourrooms':
            fewest_steps = self.env.shortest_path_length(self.env.state)
        done = False
        time = 0
        while not done:
            time += 1
            if self.embedding_type == 'random':
                state = np.array(state).flatten()
                state_embedding = np.dot(self.projection, state)
            elif self.embedding_type == 'VAE':
                state = torch.tensor(state).permute(2, 0, 1)  #(H,W,C)->(C,H,W)
                state = state.unsqueeze(0).to(self.device)
                with torch.no_grad():
                    mu, logvar = self.vae.encoder(state)
                    state_embedding = torch.cat([mu, logvar], 1)
                    state_embedding = state_embedding.squeeze()
                    state_embedding = state_embedding.cpu().numpy()
            elif self.embedding_type == 'SR':
                if self.SR_train_algo == 'TD':
                    state = np.array(state).flatten()
                    state_embedding = np.dot(self.projection, state)
                    with torch.no_grad():
                        state_embedding = self.mlp(
                            torch.tensor(state_embedding)).cpu().numpy()
                elif self.SR_train_algo == 'DP':
                    s = self.env.state
                    state_embedding = self.true_SR_dict[s]
            state_embedding = state_embedding / np.linalg.norm(state_embedding)
            if RENDER:
                self.env.render()
                time.sleep(RENDER_SPEED)

            # Get estimated value of each action
            values = [
                self.qec.estimate(state_embedding, action)
                for action in self.actions
            ]

            action = self.choose_action(values)
            state, reward, done, _ = self.env.step(action)
            if self.Q_train_algo == 'MC':
                self.add_to_memory(state_embedding, action, reward, time)
            elif self.Q_train_algo == 'TD':
                if time > 1:
                    self.TD_update(prev_embedding, prev_action, prev_reward,
                                   values, time)
            prev_reward = reward
            prev_embedding = state_embedding
            prev_action = action
            total_reward += reward
            total_steps += 1
            episode_frames += self.env.skip

        if self.Q_train_algo == 'MC':
            self.MC_update()
        if self.episodes % self.print_every == 0:
            print("KNN usage:", np.mean(self.qec.knn_usage))
            self.qec.knn_usage = []
            print("Proportion of replace:", np.mean(self.qec.replace_usage))
            self.qec.replace_usage = []
        if self.environment_type == 'fourrooms':
            n_extra_steps = total_steps - fewest_steps
            return n_extra_steps, episode_frames, total_reward
        else:
            return episode_frames, total_reward

    def warmup(self):
        """
        Collect 1 million frames from random policy and train VAE
        """
        if self.embedding_type == 'VAE':
            if self.load_vae_from is not None:
                self.vae.load_state_dict(torch.load(self.load_vae_from))
                self.vae = self.vae.to(self.device)
            else:
                # Collect 1 million frames from random policy
                print("Generating dataset to train VAE from random policy")
                vae_data = []
                state = self.env.reset()
                total_frames = 0
                while total_frames < self.vae_train_frames:
                    action = random.randint(0, self.env.action_space.n - 1)
                    state, reward, done, _ = self.env.step(action)
                    vae_data.append(state)
                    total_frames += self.env.skip
                    if done:
                        state = self.env.reset()
                # Dataset, Dataloader for 1 million frames
                vae_data = torch.tensor(
                    vae_data
                )  # (N x H x W x C) - (1mill/skip X 84 X 84 X frames_to_stack)
                vae_data = vae_data.permute(0, 3, 1, 2)  # (N x C x H x W)
                vae_dataset = TensorDataset(vae_data)
                vae_dataloader = DataLoader(vae_dataset,
                                            batch_size=self.vae_batch_size,
                                            shuffle=True)
                # Training loop
                print("Training VAE")
                self.vae.train()
                for epoch in range(self.vae_epochs):
                    train_loss = 0
                    for batch_idx, batch in enumerate(vae_dataloader):
                        batch = batch[0].to(self.device)
                        self.optimizer.zero_grad()
                        recon_batch, mu, logvar = self.vae(batch)
                        loss = self.vae_loss(recon_batch, batch, mu, logvar)
                        train_loss += loss.item()
                        loss.backward()
                        self.optimizer.step()
                        if batch_idx % self.vae_print_every == 0:
                            msg = 'VAE Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                                epoch, batch_idx * len(batch),
                                len(vae_dataloader.dataset),
                                loss.item() / len(batch))
                            print(msg)
                    print('====> Epoch {} Average loss: {:.4f}'.format(
                        epoch, train_loss / len(vae_dataloader.dataset)))
                    if self.vae_weights_file is not None:
                        torch.save(self.vae.state_dict(),
                                   self.vae_weights_file)
            self.vae.eval()
        elif self.embedding_type == 'SR':
            if self.SR_embedding_type == 'random':
                if self.SR_train_algo == 'TD':
                    total_frames = 0
                    transitions = []
                    while total_frames < self.SR_train_frames:
                        observation = self.env.reset()
                        s_t = self.env.state  # will not work on Atari
                        done = False
                        while not done:
                            action = np.random.randint(self.env.action_space.n)
                            observation, reward, done, _ = self.env.step(
                                action)
                            s_tp1 = self.env.state  # will not work on Atari
                            transitions.append((s_t, s_tp1))
                            total_frames += self.env.skip
                            s_t = s_tp1
                    # Dataset, Dataloader
                    dataset = SRDataset(self.env, self.projection, transitions)
                    dataloader = DataLoader(dataset,
                                            batch_size=self.SR_batch_size,
                                            shuffle=True)
                    train_losses = []
                    #Training loop
                    for epoch in range(self.SR_epochs):
                        for batch_idx, batch in enumerate(dataloader):
                            self.optimizer.zero_grad()
                            e_t, e_tp1 = batch
                            e_t = e_t.to(self.device)
                            e_tp1 = e_tp1.to(self.device)
                            mhat_t = self.mlp(e_t)
                            mhat_tp1 = self.mlp(e_tp1)
                            target = e_t + self.gamma * mhat_tp1.detach()
                            loss = self.loss_fn(mhat_t, target)
                            loss.backward()
                            self.optimizer.step()
                            train_losses.append(loss.item())
                        print("Epoch:", epoch, "Average loss",
                              np.mean(train_losses))

                    emb_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    SR_reps = np.zeros(
                        [self.env.n_states, self.embedding_size])
                    labels = []
                    room_size = self.env.room_size
                    for i, (state,
                            obs) in enumerate(self.env.state_dict.items()):
                        emb = np.dot(self.projection, obs.flatten())
                        emb_reps[i, :] = emb
                        with torch.no_grad():
                            emb = torch.tensor(emb).to(self.device)
                            SR = self.mlp(emb).cpu().numpy()
                        SR_reps[i, :] = SR
                        if state[0] < room_size + 1 and state[
                                1] < room_size + 1:
                            label = 0
                        elif state[0] > room_size + 1 and state[
                                1] < room_size + 1:
                            label = 1
                        elif state[0] < room_size + 1 and state[
                                1] > room_size + 1:
                            label = 2
                        elif state[0] > room_size + 1 and state[
                                1] > room_size + 1:
                            label = 3
                        else:
                            label = 4
                        labels.append(label)
                    np.save('%s_SR_reps.npy' % (self.SR_filename), SR_reps)
                    np.save('%s_emb_reps.npy' % (self.SR_filename), emb_reps)
                    np.save('%s_labels.npy' % (self.SR_filename), labels)
                elif self.SR_train_algo == 'MC':
                    pass
                elif self.SR_train_algo == 'DP':
                    # Use this to ensure same order every time
                    idx_to_state = {
                        i: state
                        for i, state in enumerate(self.env.state_dict.keys())
                    }
                    state_to_idx = {v: k for k, v in idx_to_state.items()}
                    T = np.zeros([self.env.n_states, self.env.n_states])
                    for i, s in idx_to_state.items():
                        for a in range(4):
                            self.env.state = s
                            _, _, _, _ = self.env.step(a)
                            s_tp1 = self.env.state
                            T[state_to_idx[s], state_to_idx[s_tp1]] += 0.25
                    true_SR = np.eye(self.env.n_states)
                    done = False
                    t = 0
                    while not done:
                        t += 1
                        new_SR = true_SR + (self.SR_gamma**t) * (np.matmul(
                            true_SR, T))
                        done = np.max(np.abs(true_SR - new_SR)) < 1e-10
                        true_SR = new_SR
                    self.true_SR_dict = {}
                    for s, obs in self.env.state_dict.items():
                        idx = state_to_idx[s]
                        self.true_SR_dict[s] = true_SR[idx, :]
        else:
            pass  # random projection doesn't require warmup
示例#3
0
def get_model(args, input_shape, dataset_type, output_dir=None):
    """
    Selects and initialise a pytorch model based on the selected arguments.
    Raises a NotImplementedError if the required model is not available.

    Parameters
    ----------
    args: SimpleNamespace
        Namespace containing the arguments of the experiment
    input_shape: tuple of int
        Shape of the input WITHOUT batch size
    dataset_type: str
        Which type of dataset are we using, e.i. binary or continuous
    output_dir: Path
        path to the output directory

    Returns
    -------
    nn.Module:
        The selected model correctly initialised

    """

    model = None

    if args.type_vae == 'standard':

        if args.architecture == 'shallow':
            model = VAE(dimension_latent_space=args.latent_size,
                        input_shape=input_shape,
                        dataset_type=dataset_type,
                        device=args.device)

        elif args.architecture == 'deep':
            model = DeepVAE(dimension_latent_space=args.latent_size,
                            input_shape=input_shape,
                            dataset_type=dataset_type,
                            device=args.device)

    elif args.type_vae == 'copula':

        if args.architecture == 'shallow':

            if args.marginals == "gaussian":
                model = GaussianCopulaVAE(
                    dimension_latent_space=args.latent_size,
                    input_shape=input_shape,
                    dataset_type=dataset_type,
                    device=args.device)

            elif args.marginals == "laplace":
                model = LaplaceCopulaVAE(
                    dimension_latent_space=args.latent_size,
                    input_shape=input_shape,
                    dataset_type=dataset_type,
                    device=args.device)

            elif args.marginals == "log_norm":
                model = LogNormalCopulaVAE(
                    dimension_latent_space=args.latent_size,
                    input_shape=input_shape,
                    dataset_type=dataset_type,
                    device=args.device)

            elif args.marginals == "exp":
                model = ExponentialCopulaVAE(
                    dimension_latent_space=args.latent_size,
                    input_shape=input_shape,
                    dataset_type=dataset_type,
                    device=args.device)

            elif args.marginals == "mix":
                model = MixCopulaVAE(dimension_latent_space=args.latent_size,
                                     input_shape=input_shape,
                                     dataset_type=dataset_type,
                                     device=args.device)

    if model is None:
        error = f'We do not support {args.type_vae} with {args.architecture} as architecture'
        raise NotImplementedError(error)

    if output_dir is not None:
        with open(output_dir / 'model.txt', 'w') as f:
            print(model, file=f)

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)

    model = model.to(args.device)

    return model
示例#4
0
        patchSize_padding = patchSize
        extraction_step = (21, 21, 21)
        voxel_size = (1, 1, 1)
        D = dipole_kernel(patchSize_padding, voxel_size, B0_dir)

    # network
    vae3d = VAE(
        input_channels=1, 
        output_channels=2,
        latent_dim=latent_dim,
        use_deconv=use_deconv,
        renorm=renorm,
        flag_r_train=0
    )

    vae3d.to(device)
    print(vae3d)

    # optimizer
    optimizer = optim.Adam(vae3d.parameters(), lr = lr, betas=(0.5, 0.999))
    ms = [0.3, 0.5, 0.7, 0.9]
    ms = [np.floor(m * niter).astype(int) for m in ms]
    scheduler = MultiStepLR(optimizer, milestones = ms, gamma = 0.2)

    # logger
    logger = Logger('logs', rootDir, opt['flag_rsa'], opt['case_validation'], opt['case_test'])

    # dataloader
    # dataLoader_train = COSMOS_data_loader(
    #     split='Train',
    #     patchSize=patchSize,