예제 #1
0
 def initialize_encoder(self):
     # initialize RNN encoder -- self.encoder
     self.encoder = RNNEncoder(
         layers_before_gru=self.args.layers_before_aggregator,
         hidden_size=self.args.aggregator_hidden_size,
         layers_after_gru=self.args.layers_after_aggregator,
         task_embedding_size=self.args.task_embedding_size,
         action_size=self.args.action_dim,
         action_embed_size=self.args.action_embedding_size,
         state_size=self.args.obs_dim,
         state_embed_size=self.args.state_embedding_size,
         reward_size=1,
         reward_embed_size=self.args.reward_embedding_size,
     ).to(ptu.device)
    def __init__(self, args, device, load_pretrained_bert = False, bert_config = None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif(args.encoder=='transformer'):
            self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, args.ff_size, args.heads,
                                                   args.dropout, args.inter_layers)
        elif(args.encoder=='rnn'):
            self.encoder = RNNEncoder(bidirectional=True, num_layers=1,
                                      input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size,
                                      dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size,
                                     num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
예제 #3
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None):
        super(Summarizer, self).__init__()
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        if (args.freeze_initial > 0):
            for param in self.bert.model.encoder.layer[
                    0:args.freeze_initial].parameters():
                param.requires_grad = False
            print("*" * 80)
            print("*" * 80)
            print("Initial Layers of BERT is frozen, ie first ",
                  args.freeze_initial, "Layers")
            print(self.bert.model.encoder.layer[0:args.freeze_initial])
            print("*" * 80)
            print("*" * 80)

        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'multi_layer_classifier'):
            self.encoder = MultiLayerClassifier(
                self.bert.model.config.hidden_size, 32)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)
예제 #4
0
파일: vae.py 프로젝트: orilinial/RSPP
 def initialise_encoder(self):
     """ Initialises and returns an RNN encoder """
     encoder = RNNEncoder(
         layers_before_gru=self.args.encoder_layers_before_gru,
         hidden_size=self.args.encoder_gru_hidden_size,
         layers_after_gru=self.args.encoder_layers_after_gru,
         latent_dim=self.args.latent_dim,
         action_dim=self.args.action_dim,
         action_embed_dim=self.args.action_embedding_size,
         state_dim=self.args.state_dim,
         state_embed_dim=self.args.state_embedding_size,
         reward_size=1,
         reward_embed_size=self.args.reward_embedding_size,
     ).to(device)
     return encoder
예제 #5
0
class VAE:
    def __init__(self, args):
        self.args = args

        self.initialize_encoder()
        self.initialize_decoder()
        self.initialize_optimizer()

    def initialize_encoder(self):
        # initialize RNN encoder -- self.encoder
        self.encoder = RNNEncoder(
            layers_before_gru=self.args.layers_before_aggregator,
            hidden_size=self.args.aggregator_hidden_size,
            layers_after_gru=self.args.layers_after_aggregator,
            task_embedding_size=self.args.task_embedding_size,
            action_size=self.args.action_dim,
            action_embed_size=self.args.action_embedding_size,
            state_size=self.args.obs_dim,
            state_embed_size=self.args.state_embedding_size,
            reward_size=1,
            reward_embed_size=self.args.reward_embedding_size,
        ).to(ptu.device)

    def initialize_decoder(self):
        task_embedding_size = self.args.task_embedding_size
        if self.args.disable_stochasticity_in_latent:
            task_embedding_size *= 2

        # initialize model decoders -- self.reward_decoder, self.state_decoder, self.task_decoder
        if self.args.decode_reward:
            # initialise reward decoder for VAE
            self.reward_decoder = RewardDecoder(
                layers=self.args.reward_decoder_layers,
                task_embedding_size=task_embedding_size,
                state_size=self.args.obs_dim,
                state_embed_size=self.args.state_embedding_size,
                action_size=self.args.action_dim,
                action_embed_size=self.args.action_embedding_size,
                num_states=self.args.num_states,
                multi_head=self.args.multihead_for_reward,
                pred_type=self.args.rew_pred_type,
                input_prev_state=self.args.input_prev_state,
                input_next_state=self.args.input_next_state,
                input_action=self.args.input_action,
            ).to(ptu.device)
            # set reward function
            # if self.args.rew_loss_fn == 'BCE':
            #     self.rew_loss_fn = lambda in_, target: F.binary_cross_entropy(in_, target, reduction='none')
            # elif self.args.rew_loss_fn == 'FL':
            #     self.rew_loss_fn = ptu.FocalLoss()
            # else:
            #     raise NotImplementedError
        else:
            self.reward_decoder = None

        if self.args.decode_state:
            # initialise state decoder for VAE
            self.state_decoder = StateTransitionDecoder(
                task_embedding_size=task_embedding_size,
                layers=self.args.state_decoder_layers,
                action_size=self.args.action_dim,
                action_embed_size=self.args.action_embedding_size,
                state_size=self.args.obs_dim,
                state_embed_size=self.args.state_embedding_size,
                pred_type=self.args.state_pred_type,
            ).to(ptu.device)
        else:
            self.state_decoder = None

        if self.args.decode_task:
            env = gym.make(self.args.env_name)
            if self.args.task_pred_type == 'task_description':
                task_dim = env.task_dim
            elif self.args.task_pred_type == 'task_id':
                task_dim = env.num_tasks
            else:
                raise NotImplementedError
            self.task_decoder = TaskDecoder(
                task_embedding_size=task_embedding_size,
                layers=self.args.task_decoder_layers,
                task_dim=task_dim,
                pred_type=self.args.task_pred_type,
            ).to(ptu.device)
        else:
            self.task_decoder = None

    def initialize_optimizer(self):
        decoder_params = []
        if not self.args.disable_decoder:
            # initialise optimiser for decoder
            if self.args.decode_reward:
                decoder_params.extend(self.reward_decoder.parameters())
            if self.args.decode_state:
                decoder_params.extend(self.state_decoder.parameters())
            if self.args.decode_task:
                decoder_params.extend(self.task_decoder.parameters())
        # initialize optimizer
        self.optimizer = torch.optim.Adam([*self.encoder.parameters(), *decoder_params], lr=self.args.vae_lr)

    def compute_task_reconstruction_loss(self, dec_embedding, dec_task, return_predictions=False):
        # make some predictions and compute individual losses
        task_pred = self.task_decoder(dec_embedding)

        if self.args.task_pred_type == 'task_id':
            env = gym.make(self.args.env_name)
            dec_task = env.task_to_id(dec_task)
            dec_task = dec_task.expand(task_pred.shape[:-1]).view(-1)
            # loss for the data we fed into encoder
            task_pred_shape = task_pred.shape
            loss_task = F.cross_entropy(task_pred.view(-1, task_pred.shape[-1]), dec_task, reduction='none').reshape(
                task_pred_shape[:-1])
        elif self.args.task_pred_type == 'task_description':
            loss_task = (task_pred - dec_task).pow(2).mean(dim=1)

        if return_predictions:
            return loss_task, task_pred
        else:
            return loss_task

    def compute_state_reconstruction_loss(self, dec_embedding, dec_prev_obs, dec_next_obs, dec_actions,
                                          return_predictions=False):
        # make some predictions and compute individual losses
        if self.args.state_pred_type == 'deterministic':
            obs_reconstruction = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions)
            loss_state = (obs_reconstruction - dec_next_obs[:,:,:2]).pow(2).mean(dim=1)
        elif self.args.state_pred_type == 'gaussian':
            state_pred = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions)
            state_pred_mean = state_pred[:, :state_pred.shape[1] // 2]
            state_pred_std = torch.exp(0.5 * state_pred[:, state_pred.shape[1] // 2:])
            m = torch.distributions.normal.Normal(state_pred_mean, state_pred_std)
            # TODO: check if this is correctly averaged
            loss_state = -m.log_prob(dec_next_obs).mean(dim=1)

        if return_predictions:
            return loss_state, obs_reconstruction
        else:
            return loss_state

    def compute_rew_reconstruction_loss(self, dec_embedding, dec_prev_obs, dec_next_obs, dec_actions,
                                        dec_rewards, return_predictions=False):
        """
        Computed the reward reconstruction loss
        (no reduction of loss is done here; sum/avg has to be done outside)
        """
        # make some predictions and compute individual losses
        if self.args.multihead_for_reward:
            if self.args.rew_pred_type == 'bernoulli' or self.args.rew_pred_type == 'categorical':
                # loss for the data we fed into encoder
                p_rew = self.reward_decoder(dec_embedding, None)
                env = gym.make(self.args.env_name)
                indices = env.task_to_id(dec_next_obs).to(ptu.device)
                if indices.dim() < p_rew.dim():
                    indices = indices.unsqueeze(-1)
                rew_pred = p_rew.gather(dim=-1, index=indices)
                rew_target = (dec_rewards == 1).float()
                loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1)
                # loss_rew = self.rew_loss_fn(rew_pred, rew_target).mean(dim=-1)
            elif self.args.rew_pred_type == 'deterministic':
                raise NotImplementedError
                # p_rew = self.reward_decoder(dec_embedding, None)
                # env = gym.make(self.args.env_name)
                # indices = env.task_to_id(dec_next_obs)
                # loss_rew = F.mse_loss(p_rew.gather(1, indices.reshape(-1, 1)), dec_rewards, reduction='none').mean(
                #     dim=1)
            else:
                raise NotImplementedError
        else:
            if self.args.rew_pred_type == 'bernoulli':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions)
                loss_rew = F.binary_cross_entropy(rew_pred, (dec_rewards == 1).float(), reduction='none').mean(dim=1)
            elif self.args.rew_pred_type == 'deterministic':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions)
                loss_rew = (rew_pred - dec_rewards).pow(2).mean(dim=1)
            elif self.args.rew_pred_type == 'gaussian':
                rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions).mean(dim=1)
                rew_pred_mean = rew_pred[:, :rew_pred.shape[1] // 2]
                rew_pred_std = torch.exp(0.5 * rew_pred[:, rew_pred.shape[1] // 2:])
                m = torch.distributions.normal.Normal(rew_pred_mean, rew_pred_std)
                loss_rew = -m.log_prob(dec_rewards)
            else:
                raise NotImplementedError

        if return_predictions:
            return loss_rew, rew_pred
        else:
            return loss_rew

    def compute_kl_loss(self, latent_mean, latent_logvar, len_encoder):

        # -- KL divergence
        if self.args.kl_to_gauss_prior:
            kl_divergences = (- 0.5 * (1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp()).sum(dim=1))
        else:
            gauss_dim = latent_mean.shape[-1]
            # add the gaussian prior
            all_means = torch.cat((torch.zeros(1, latent_mean.shape[1]).to(ptu.device), latent_mean))
            all_logvars = torch.cat((torch.zeros(1, latent_logvar.shape[1]).to(ptu.device), latent_logvar))
            # https://arxiv.org/pdf/1811.09975.pdf
            # KL(N(mu,E)||N(m,S)) = 0.5 * (log(|S|/|E|) - K + tr(S^-1 E) + (m-mu)^T S^-1 (m-mu)))
            mu = all_means[1:]
            m = all_means[:-1]
            logE = all_logvars[1:]
            logS = all_logvars[:-1]
            kl_divergences = 0.5 * (torch.sum(logS, dim=1) - torch.sum(logE, dim=1) - gauss_dim + torch.sum(
                1 / torch.exp(logS) * torch.exp(logE), dim=1) + ((m - mu) / torch.exp(logS) * (m - mu)).sum(dim=1))

        if self.args.learn_prior:
            mask = torch.ones(len(kl_divergences))
            mask[0] = 0
            kl_divergences = kl_divergences * mask

        # returns, for each ELBO_t term, one KL (so H+1 kl's)
        if len_encoder is not None:
            return kl_divergences[len_encoder]
        else:
            return kl_divergences

    def compute_belief_reward(self, task_means, task_logvars, obs, next_obs, actions):
        """
        compute reward in the BAMDP by averaging over sampled latent embeddings - R+ = E[R(b)]
        """
        # sample multiple latent embeddings from posterior - (n_samples, n_processes, latent_dim)
        task_samples = self.encoder._sample_gaussian(task_means, task_logvars, self.args.num_belief_samples)
        if next_obs.dim() > 2:
            next_obs = next_obs.repeat(self.args.num_belief_samples, 1, 1)
            obs = obs.repeat(self.args.num_belief_samples, 1, 1) if obs is not None else None
            actions = actions.repeat(self.args.num_belief_samples, 1, 1) if actions is not None else None
        else:
            next_obs = next_obs.repeat(self.args.num_belief_samples, 1)
            obs = obs.repeat(self.args.num_belief_samples, 1) if obs is not None else None
            actions = actions.repeat(self.args.num_belief_samples, 1) if actions is not None else None
        # make some predictions and average
        if self.args.multihead_for_reward:
            if self.args.rew_pred_type == 'bernoulli':  # or self.args.rew_pred_type == 'categorical':
                p_rew = self.reward_decoder(task_samples, None).detach()
                # average over samples dimension to get R+
                p_rew = p_rew.mean(dim=0)
                env = gym.make(self.args.env_name)
                indices = env.task_to_id(next_obs).to(ptu.device)
                if indices.dim() < p_rew.dim():
                    indices = indices.unsqueeze(-1)
                rew_pred = p_rew.gather(dim=-1, index=indices)
            else:
                raise NotImplementedError
        else:
            if self.args.rew_pred_type == 'deterministic':
                rew_pred = self.reward_decoder(task_samples,
                                               next_obs,
                                               obs,
                                               actions)
                rew_pred = rew_pred.mean(dim=0)
            else:
                raise NotImplementedError
        return rew_pred

    def load_model(self, device='cpu', **kwargs):
        if "encoder_path" in kwargs:
            self.encoder.load_state_dict(torch.load(kwargs["encoder_path"], map_location=device))
        if "reward_decoder_path" in kwargs and self.reward_decoder is not None:
            self.reward_decoder.load_state_dict(torch.load(kwargs["reward_decoder_path"], map_location=device))
        if "state_decoder_path" in kwargs and self.state_decoder is not None:
            self.state_decoder.load_state_dict(torch.load(kwargs["state_decoder_path"], map_location=device))
        if "task_decoder_path" in kwargs and self.task_decoder is not None:
            self.task_decoder.load_state_dict(torch.load(kwargs["task_decoder_path"], map_location=device))
예제 #6
0
    def __init__(self, args, device, vocab, checkpoint=None):
        super(Model, self).__init__()
        self.args = args
        self.device = device
        self.vocab = vocab
        self.vocab_size = len(vocab)
        self.beam_size = args.beam_size
        self.max_length = args.max_length
        self.min_length = args.min_length

        # special tokens
        self.start_token = vocab['[unused1]']
        self.end_token = vocab['[unused2]']
        self.pad_token = vocab['[PAD]']
        self.mask_token = vocab['[MASK]']
        self.seg_token = vocab['[SEP]']
        self.cls_token = vocab['[CLS]']
        self.agent_token = vocab['[unused3]']
        self.customer_token = vocab['[unused4]']

        if args.encoder == 'bert':
            self.encoder = Bert(args.bert_dir, args.finetune_bert)
            if (args.max_pos > 512):
                my_pos_embeddings = nn.Embedding(
                    args.max_pos, self.encoder.model.config.hidden_size)
                my_pos_embeddings.weight.data[:
                                              512] = self.encoder.model.embeddings.position_embeddings.weight.data
                my_pos_embeddings.weight.data[
                    512:] = self.encoder.model.embeddings.position_embeddings.weight.data[
                        -1][None, :].repeat(args.max_pos - 512, 1)
                self.encoder.model.embeddings.position_embeddings = my_pos_embeddings
            self.hidden_size = self.encoder.model.config.hidden_size
            tgt_embeddings = nn.Embedding(
                self.vocab_size,
                self.encoder.model.config.hidden_size,
                padding_idx=0)
        else:
            self.hidden_size = args.enc_hidden_size
            self.embeddings = nn.Embedding(self.vocab_size,
                                           self.hidden_size,
                                           padding_idx=0)
            tgt_embeddings = nn.Embedding(self.vocab_size,
                                          self.hidden_size,
                                          padding_idx=0)
            if args.encoder == 'rnn':
                self.encoder = RNNEncoder('LSTM',
                                          bidirectional=True,
                                          num_layers=args.enc_layers,
                                          hidden_size=self.hidden_size,
                                          dropout=args.enc_dropout,
                                          embeddings=self.embeddings)
            elif args.encoder == "transformer":
                self.encoder = TransformerEncoder(self.hidden_size,
                                                  args.enc_ff_size,
                                                  args.enc_heads,
                                                  args.enc_dropout,
                                                  args.enc_layers)

        if args.decoder == "transformer":
            self.decoder = TransformerDecoder(args.dec_layers,
                                              args.dec_hidden_size,
                                              heads=args.dec_heads,
                                              d_ff=args.dec_ff_size,
                                              dropout=args.dec_dropout,
                                              embeddings=tgt_embeddings)
        elif args.decoder == "rnn":
            self.decoder = RNNDecoder("LSTM",
                                      True,
                                      args.dec_layers,
                                      args.dec_hidden_size,
                                      dropout=args.dec_dropout,
                                      embeddings=tgt_embeddings,
                                      coverage_attn=args.coverage,
                                      copy_attn=args.copy_attn)

        if args.copy_attn:
            self.generator = CopyGenerator(self.vocab_size,
                                           args.dec_hidden_size,
                                           self.pad_token)
        else:
            self.generator = Generator(self.vocab_size, args.dec_hidden_size,
                                       self.pad_token)

        self.generator.linear.weight = self.decoder.embeddings.weight

        if checkpoint is not None:
            self.load_state_dict(checkpoint['model'], strict=True)
            if args.share_emb:
                if args.encoder == 'bert':
                    self.embeddings = self.encoder.model.embeddings.word_embeddings
                self.generator.linear.weight = self.decoder.embeddings.weight
        else:
            # initialize params.
            if args.encoder == "transformer":
                for module in self.encoder.modules():
                    self._set_parameter_tf(module)
            elif args.encoder == "rnn":
                for p in self.encoder.parameters():
                    self._set_parameter_linear(p)
            for module in self.decoder.modules():
                self._set_parameter_tf(module)
            for p in self.generator.parameters():
                self._set_parameter_linear(p)
            if args.share_emb:
                if args.encoder == 'bert':
                    tgt_embeddings = nn.Embedding(
                        self.vocab_size,
                        self.encoder.model.config.hidden_size,
                        padding_idx=0)
                    tgt_embeddings.weight = copy.deepcopy(
                        self.encoder.model.embeddings.word_embeddings.weight)
                    self.decoder.embeddings = tgt_embeddings
                self.generator.linear.weight = self.decoder.embeddings.weight

        self.to(device)
예제 #7
0
    def __init__(self,
                 args,
                 device,
                 load_pretrained_bert=False,
                 bert_config=None,
                 topic_num=10):
        super(Summarizer, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(
            'bert-base-uncased',
            do_lower_case=True,
            never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]',
                         '[unused2]', '[UNK]'),
            no_word_piece=True)
        self.args = args
        self.device = device
        self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config)
        self.memory = Memory(device, 1, self.bert.model.config.hidden_size)
        self.key_memory = Key_memory(device, 1,
                                     self.bert.model.config.hidden_size,
                                     args.dropout)
        self.topic_predictor = Topic_predictor(
            self.bert.model.config.hidden_size,
            device,
            topic_num,
            d_ex_type=args.d_ex_type)
        # self.topic_embedding = nn.Embedding(topic_num, self.bert.model.config.hidden_size)
        # todo transform to normal weight not embedding
        self.topic_embedding, self.topic_word, self.topic_word_emb = self.get_embedding(
            self.bert.model.embeddings)
        self.topic_embedding.requires_grad = True
        self.topic_word_emb.requires_grad = True
        self.topic_embedding = self.topic_embedding.to(device)
        self.topic_word_emb = self.topic_word_emb.to(device)
        if (args.encoder == 'classifier'):
            self.encoder = Classifier(self.bert.model.config.hidden_size)
        elif (args.encoder == 'transformer'):
            self.encoder = TransformerInterEncoder(
                self.bert.model.config.hidden_size, args.ff_size, args.heads,
                args.dropout, args.inter_layers)
        elif (args.encoder == 'rnn'):
            self.encoder = RNNEncoder(
                bidirectional=True,
                num_layers=1,
                input_size=self.bert.model.config.hidden_size,
                hidden_size=args.rnn_size,
                dropout=args.dropout)
        elif (args.encoder == 'baseline'):
            bert_config = BertConfig(self.bert.model.config.vocab_size,
                                     hidden_size=args.hidden_size,
                                     num_hidden_layers=6,
                                     num_attention_heads=8,
                                     intermediate_size=args.ff_size)
            self.bert.model = BertModel(bert_config)
            self.encoder = Classifier(self.bert.model.config.hidden_size)

        if args.param_init != 0.0:
            for p in self.encoder.parameters():
                p.data.uniform_(-args.param_init, args.param_init)
        if args.param_init_glorot:
            for p in self.encoder.parameters():
                if p.dim() > 1:
                    xavier_uniform_(p)

        self.to(device)