def initialize_encoder(self): # initialize RNN encoder -- self.encoder self.encoder = RNNEncoder( layers_before_gru=self.args.layers_before_aggregator, hidden_size=self.args.aggregator_hidden_size, layers_after_gru=self.args.layers_after_aggregator, task_embedding_size=self.args.task_embedding_size, action_size=self.args.action_dim, action_embed_size=self.args.action_embedding_size, state_size=self.args.obs_dim, state_embed_size=self.args.state_embedding_size, reward_size=1, reward_embed_size=self.args.reward_embedding_size, ).to(ptu.device)
def __init__(self, args, device, load_pretrained_bert = False, bert_config = None): super(Summarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config) if (args.encoder == 'classifier'): self.encoder = Classifier(self.bert.model.config.hidden_size) elif(args.encoder=='transformer'): self.encoder = TransformerInterEncoder(self.bert.model.config.hidden_size, args.ff_size, args.heads, args.dropout, args.inter_layers) elif(args.encoder=='rnn'): self.encoder = RNNEncoder(bidirectional=True, num_layers=1, input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size, dropout=args.dropout) elif (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size, num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size) self.bert.model = BertModel(bert_config) self.encoder = Classifier(self.bert.model.config.hidden_size) if args.param_init != 0.0: for p in self.encoder.parameters(): p.data.uniform_(-args.param_init, args.param_init) if args.param_init_glorot: for p in self.encoder.parameters(): if p.dim() > 1: xavier_uniform_(p) self.to(device)
def __init__(self, args, device, load_pretrained_bert=False, bert_config=None): super(Summarizer, self).__init__() self.args = args self.device = device self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config) if (args.freeze_initial > 0): for param in self.bert.model.encoder.layer[ 0:args.freeze_initial].parameters(): param.requires_grad = False print("*" * 80) print("*" * 80) print("Initial Layers of BERT is frozen, ie first ", args.freeze_initial, "Layers") print(self.bert.model.encoder.layer[0:args.freeze_initial]) print("*" * 80) print("*" * 80) if (args.encoder == 'classifier'): self.encoder = Classifier(self.bert.model.config.hidden_size) elif (args.encoder == 'multi_layer_classifier'): self.encoder = MultiLayerClassifier( self.bert.model.config.hidden_size, 32) elif (args.encoder == 'transformer'): self.encoder = TransformerInterEncoder( self.bert.model.config.hidden_size, args.ff_size, args.heads, args.dropout, args.inter_layers) elif (args.encoder == 'rnn'): self.encoder = RNNEncoder( bidirectional=True, num_layers=1, input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size, dropout=args.dropout) elif (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size, num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size) self.bert.model = BertModel(bert_config) self.encoder = Classifier(self.bert.model.config.hidden_size) if args.param_init != 0.0: for p in self.encoder.parameters(): p.data.uniform_(-args.param_init, args.param_init) if args.param_init_glorot: for p in self.encoder.parameters(): if p.dim() > 1: xavier_uniform_(p) self.to(device)
def initialise_encoder(self): """ Initialises and returns an RNN encoder """ encoder = RNNEncoder( layers_before_gru=self.args.encoder_layers_before_gru, hidden_size=self.args.encoder_gru_hidden_size, layers_after_gru=self.args.encoder_layers_after_gru, latent_dim=self.args.latent_dim, action_dim=self.args.action_dim, action_embed_dim=self.args.action_embedding_size, state_dim=self.args.state_dim, state_embed_dim=self.args.state_embedding_size, reward_size=1, reward_embed_size=self.args.reward_embedding_size, ).to(device) return encoder
class VAE: def __init__(self, args): self.args = args self.initialize_encoder() self.initialize_decoder() self.initialize_optimizer() def initialize_encoder(self): # initialize RNN encoder -- self.encoder self.encoder = RNNEncoder( layers_before_gru=self.args.layers_before_aggregator, hidden_size=self.args.aggregator_hidden_size, layers_after_gru=self.args.layers_after_aggregator, task_embedding_size=self.args.task_embedding_size, action_size=self.args.action_dim, action_embed_size=self.args.action_embedding_size, state_size=self.args.obs_dim, state_embed_size=self.args.state_embedding_size, reward_size=1, reward_embed_size=self.args.reward_embedding_size, ).to(ptu.device) def initialize_decoder(self): task_embedding_size = self.args.task_embedding_size if self.args.disable_stochasticity_in_latent: task_embedding_size *= 2 # initialize model decoders -- self.reward_decoder, self.state_decoder, self.task_decoder if self.args.decode_reward: # initialise reward decoder for VAE self.reward_decoder = RewardDecoder( layers=self.args.reward_decoder_layers, task_embedding_size=task_embedding_size, state_size=self.args.obs_dim, state_embed_size=self.args.state_embedding_size, action_size=self.args.action_dim, action_embed_size=self.args.action_embedding_size, num_states=self.args.num_states, multi_head=self.args.multihead_for_reward, pred_type=self.args.rew_pred_type, input_prev_state=self.args.input_prev_state, input_next_state=self.args.input_next_state, input_action=self.args.input_action, ).to(ptu.device) # set reward function # if self.args.rew_loss_fn == 'BCE': # self.rew_loss_fn = lambda in_, target: F.binary_cross_entropy(in_, target, reduction='none') # elif self.args.rew_loss_fn == 'FL': # self.rew_loss_fn = ptu.FocalLoss() # else: # raise NotImplementedError else: self.reward_decoder = None if self.args.decode_state: # initialise state decoder for VAE self.state_decoder = StateTransitionDecoder( task_embedding_size=task_embedding_size, layers=self.args.state_decoder_layers, action_size=self.args.action_dim, action_embed_size=self.args.action_embedding_size, state_size=self.args.obs_dim, state_embed_size=self.args.state_embedding_size, pred_type=self.args.state_pred_type, ).to(ptu.device) else: self.state_decoder = None if self.args.decode_task: env = gym.make(self.args.env_name) if self.args.task_pred_type == 'task_description': task_dim = env.task_dim elif self.args.task_pred_type == 'task_id': task_dim = env.num_tasks else: raise NotImplementedError self.task_decoder = TaskDecoder( task_embedding_size=task_embedding_size, layers=self.args.task_decoder_layers, task_dim=task_dim, pred_type=self.args.task_pred_type, ).to(ptu.device) else: self.task_decoder = None def initialize_optimizer(self): decoder_params = [] if not self.args.disable_decoder: # initialise optimiser for decoder if self.args.decode_reward: decoder_params.extend(self.reward_decoder.parameters()) if self.args.decode_state: decoder_params.extend(self.state_decoder.parameters()) if self.args.decode_task: decoder_params.extend(self.task_decoder.parameters()) # initialize optimizer self.optimizer = torch.optim.Adam([*self.encoder.parameters(), *decoder_params], lr=self.args.vae_lr) def compute_task_reconstruction_loss(self, dec_embedding, dec_task, return_predictions=False): # make some predictions and compute individual losses task_pred = self.task_decoder(dec_embedding) if self.args.task_pred_type == 'task_id': env = gym.make(self.args.env_name) dec_task = env.task_to_id(dec_task) dec_task = dec_task.expand(task_pred.shape[:-1]).view(-1) # loss for the data we fed into encoder task_pred_shape = task_pred.shape loss_task = F.cross_entropy(task_pred.view(-1, task_pred.shape[-1]), dec_task, reduction='none').reshape( task_pred_shape[:-1]) elif self.args.task_pred_type == 'task_description': loss_task = (task_pred - dec_task).pow(2).mean(dim=1) if return_predictions: return loss_task, task_pred else: return loss_task def compute_state_reconstruction_loss(self, dec_embedding, dec_prev_obs, dec_next_obs, dec_actions, return_predictions=False): # make some predictions and compute individual losses if self.args.state_pred_type == 'deterministic': obs_reconstruction = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions) loss_state = (obs_reconstruction - dec_next_obs[:,:,:2]).pow(2).mean(dim=1) elif self.args.state_pred_type == 'gaussian': state_pred = self.state_decoder(dec_embedding, dec_prev_obs, dec_actions) state_pred_mean = state_pred[:, :state_pred.shape[1] // 2] state_pred_std = torch.exp(0.5 * state_pred[:, state_pred.shape[1] // 2:]) m = torch.distributions.normal.Normal(state_pred_mean, state_pred_std) # TODO: check if this is correctly averaged loss_state = -m.log_prob(dec_next_obs).mean(dim=1) if return_predictions: return loss_state, obs_reconstruction else: return loss_state def compute_rew_reconstruction_loss(self, dec_embedding, dec_prev_obs, dec_next_obs, dec_actions, dec_rewards, return_predictions=False): """ Computed the reward reconstruction loss (no reduction of loss is done here; sum/avg has to be done outside) """ # make some predictions and compute individual losses if self.args.multihead_for_reward: if self.args.rew_pred_type == 'bernoulli' or self.args.rew_pred_type == 'categorical': # loss for the data we fed into encoder p_rew = self.reward_decoder(dec_embedding, None) env = gym.make(self.args.env_name) indices = env.task_to_id(dec_next_obs).to(ptu.device) if indices.dim() < p_rew.dim(): indices = indices.unsqueeze(-1) rew_pred = p_rew.gather(dim=-1, index=indices) rew_target = (dec_rewards == 1).float() loss_rew = F.binary_cross_entropy(rew_pred, rew_target, reduction='none').mean(dim=-1) # loss_rew = self.rew_loss_fn(rew_pred, rew_target).mean(dim=-1) elif self.args.rew_pred_type == 'deterministic': raise NotImplementedError # p_rew = self.reward_decoder(dec_embedding, None) # env = gym.make(self.args.env_name) # indices = env.task_to_id(dec_next_obs) # loss_rew = F.mse_loss(p_rew.gather(1, indices.reshape(-1, 1)), dec_rewards, reduction='none').mean( # dim=1) else: raise NotImplementedError else: if self.args.rew_pred_type == 'bernoulli': rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions) loss_rew = F.binary_cross_entropy(rew_pred, (dec_rewards == 1).float(), reduction='none').mean(dim=1) elif self.args.rew_pred_type == 'deterministic': rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions) loss_rew = (rew_pred - dec_rewards).pow(2).mean(dim=1) elif self.args.rew_pred_type == 'gaussian': rew_pred = self.reward_decoder(dec_embedding, dec_next_obs, dec_prev_obs, dec_actions).mean(dim=1) rew_pred_mean = rew_pred[:, :rew_pred.shape[1] // 2] rew_pred_std = torch.exp(0.5 * rew_pred[:, rew_pred.shape[1] // 2:]) m = torch.distributions.normal.Normal(rew_pred_mean, rew_pred_std) loss_rew = -m.log_prob(dec_rewards) else: raise NotImplementedError if return_predictions: return loss_rew, rew_pred else: return loss_rew def compute_kl_loss(self, latent_mean, latent_logvar, len_encoder): # -- KL divergence if self.args.kl_to_gauss_prior: kl_divergences = (- 0.5 * (1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp()).sum(dim=1)) else: gauss_dim = latent_mean.shape[-1] # add the gaussian prior all_means = torch.cat((torch.zeros(1, latent_mean.shape[1]).to(ptu.device), latent_mean)) all_logvars = torch.cat((torch.zeros(1, latent_logvar.shape[1]).to(ptu.device), latent_logvar)) # https://arxiv.org/pdf/1811.09975.pdf # KL(N(mu,E)||N(m,S)) = 0.5 * (log(|S|/|E|) - K + tr(S^-1 E) + (m-mu)^T S^-1 (m-mu))) mu = all_means[1:] m = all_means[:-1] logE = all_logvars[1:] logS = all_logvars[:-1] kl_divergences = 0.5 * (torch.sum(logS, dim=1) - torch.sum(logE, dim=1) - gauss_dim + torch.sum( 1 / torch.exp(logS) * torch.exp(logE), dim=1) + ((m - mu) / torch.exp(logS) * (m - mu)).sum(dim=1)) if self.args.learn_prior: mask = torch.ones(len(kl_divergences)) mask[0] = 0 kl_divergences = kl_divergences * mask # returns, for each ELBO_t term, one KL (so H+1 kl's) if len_encoder is not None: return kl_divergences[len_encoder] else: return kl_divergences def compute_belief_reward(self, task_means, task_logvars, obs, next_obs, actions): """ compute reward in the BAMDP by averaging over sampled latent embeddings - R+ = E[R(b)] """ # sample multiple latent embeddings from posterior - (n_samples, n_processes, latent_dim) task_samples = self.encoder._sample_gaussian(task_means, task_logvars, self.args.num_belief_samples) if next_obs.dim() > 2: next_obs = next_obs.repeat(self.args.num_belief_samples, 1, 1) obs = obs.repeat(self.args.num_belief_samples, 1, 1) if obs is not None else None actions = actions.repeat(self.args.num_belief_samples, 1, 1) if actions is not None else None else: next_obs = next_obs.repeat(self.args.num_belief_samples, 1) obs = obs.repeat(self.args.num_belief_samples, 1) if obs is not None else None actions = actions.repeat(self.args.num_belief_samples, 1) if actions is not None else None # make some predictions and average if self.args.multihead_for_reward: if self.args.rew_pred_type == 'bernoulli': # or self.args.rew_pred_type == 'categorical': p_rew = self.reward_decoder(task_samples, None).detach() # average over samples dimension to get R+ p_rew = p_rew.mean(dim=0) env = gym.make(self.args.env_name) indices = env.task_to_id(next_obs).to(ptu.device) if indices.dim() < p_rew.dim(): indices = indices.unsqueeze(-1) rew_pred = p_rew.gather(dim=-1, index=indices) else: raise NotImplementedError else: if self.args.rew_pred_type == 'deterministic': rew_pred = self.reward_decoder(task_samples, next_obs, obs, actions) rew_pred = rew_pred.mean(dim=0) else: raise NotImplementedError return rew_pred def load_model(self, device='cpu', **kwargs): if "encoder_path" in kwargs: self.encoder.load_state_dict(torch.load(kwargs["encoder_path"], map_location=device)) if "reward_decoder_path" in kwargs and self.reward_decoder is not None: self.reward_decoder.load_state_dict(torch.load(kwargs["reward_decoder_path"], map_location=device)) if "state_decoder_path" in kwargs and self.state_decoder is not None: self.state_decoder.load_state_dict(torch.load(kwargs["state_decoder_path"], map_location=device)) if "task_decoder_path" in kwargs and self.task_decoder is not None: self.task_decoder.load_state_dict(torch.load(kwargs["task_decoder_path"], map_location=device))
def __init__(self, args, device, vocab, checkpoint=None): super(Model, self).__init__() self.args = args self.device = device self.vocab = vocab self.vocab_size = len(vocab) self.beam_size = args.beam_size self.max_length = args.max_length self.min_length = args.min_length # special tokens self.start_token = vocab['[unused1]'] self.end_token = vocab['[unused2]'] self.pad_token = vocab['[PAD]'] self.mask_token = vocab['[MASK]'] self.seg_token = vocab['[SEP]'] self.cls_token = vocab['[CLS]'] self.agent_token = vocab['[unused3]'] self.customer_token = vocab['[unused4]'] if args.encoder == 'bert': self.encoder = Bert(args.bert_dir, args.finetune_bert) if (args.max_pos > 512): my_pos_embeddings = nn.Embedding( args.max_pos, self.encoder.model.config.hidden_size) my_pos_embeddings.weight.data[: 512] = self.encoder.model.embeddings.position_embeddings.weight.data my_pos_embeddings.weight.data[ 512:] = self.encoder.model.embeddings.position_embeddings.weight.data[ -1][None, :].repeat(args.max_pos - 512, 1) self.encoder.model.embeddings.position_embeddings = my_pos_embeddings self.hidden_size = self.encoder.model.config.hidden_size tgt_embeddings = nn.Embedding( self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) else: self.hidden_size = args.enc_hidden_size self.embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) tgt_embeddings = nn.Embedding(self.vocab_size, self.hidden_size, padding_idx=0) if args.encoder == 'rnn': self.encoder = RNNEncoder('LSTM', bidirectional=True, num_layers=args.enc_layers, hidden_size=self.hidden_size, dropout=args.enc_dropout, embeddings=self.embeddings) elif args.encoder == "transformer": self.encoder = TransformerEncoder(self.hidden_size, args.enc_ff_size, args.enc_heads, args.enc_dropout, args.enc_layers) if args.decoder == "transformer": self.decoder = TransformerDecoder(args.dec_layers, args.dec_hidden_size, heads=args.dec_heads, d_ff=args.dec_ff_size, dropout=args.dec_dropout, embeddings=tgt_embeddings) elif args.decoder == "rnn": self.decoder = RNNDecoder("LSTM", True, args.dec_layers, args.dec_hidden_size, dropout=args.dec_dropout, embeddings=tgt_embeddings, coverage_attn=args.coverage, copy_attn=args.copy_attn) if args.copy_attn: self.generator = CopyGenerator(self.vocab_size, args.dec_hidden_size, self.pad_token) else: self.generator = Generator(self.vocab_size, args.dec_hidden_size, self.pad_token) self.generator.linear.weight = self.decoder.embeddings.weight if checkpoint is not None: self.load_state_dict(checkpoint['model'], strict=True) if args.share_emb: if args.encoder == 'bert': self.embeddings = self.encoder.model.embeddings.word_embeddings self.generator.linear.weight = self.decoder.embeddings.weight else: # initialize params. if args.encoder == "transformer": for module in self.encoder.modules(): self._set_parameter_tf(module) elif args.encoder == "rnn": for p in self.encoder.parameters(): self._set_parameter_linear(p) for module in self.decoder.modules(): self._set_parameter_tf(module) for p in self.generator.parameters(): self._set_parameter_linear(p) if args.share_emb: if args.encoder == 'bert': tgt_embeddings = nn.Embedding( self.vocab_size, self.encoder.model.config.hidden_size, padding_idx=0) tgt_embeddings.weight = copy.deepcopy( self.encoder.model.embeddings.word_embeddings.weight) self.decoder.embeddings = tgt_embeddings self.generator.linear.weight = self.decoder.embeddings.weight self.to(device)
def __init__(self, args, device, load_pretrained_bert=False, bert_config=None, topic_num=10): super(Summarizer, self).__init__() self.tokenizer = BertTokenizer.from_pretrained( 'bert-base-uncased', do_lower_case=True, never_split=('[SEP]', '[CLS]', '[PAD]', '[unused0]', '[unused1]', '[unused2]', '[UNK]'), no_word_piece=True) self.args = args self.device = device self.bert = Bert(args.temp_dir, load_pretrained_bert, bert_config) self.memory = Memory(device, 1, self.bert.model.config.hidden_size) self.key_memory = Key_memory(device, 1, self.bert.model.config.hidden_size, args.dropout) self.topic_predictor = Topic_predictor( self.bert.model.config.hidden_size, device, topic_num, d_ex_type=args.d_ex_type) # self.topic_embedding = nn.Embedding(topic_num, self.bert.model.config.hidden_size) # todo transform to normal weight not embedding self.topic_embedding, self.topic_word, self.topic_word_emb = self.get_embedding( self.bert.model.embeddings) self.topic_embedding.requires_grad = True self.topic_word_emb.requires_grad = True self.topic_embedding = self.topic_embedding.to(device) self.topic_word_emb = self.topic_word_emb.to(device) if (args.encoder == 'classifier'): self.encoder = Classifier(self.bert.model.config.hidden_size) elif (args.encoder == 'transformer'): self.encoder = TransformerInterEncoder( self.bert.model.config.hidden_size, args.ff_size, args.heads, args.dropout, args.inter_layers) elif (args.encoder == 'rnn'): self.encoder = RNNEncoder( bidirectional=True, num_layers=1, input_size=self.bert.model.config.hidden_size, hidden_size=args.rnn_size, dropout=args.dropout) elif (args.encoder == 'baseline'): bert_config = BertConfig(self.bert.model.config.vocab_size, hidden_size=args.hidden_size, num_hidden_layers=6, num_attention_heads=8, intermediate_size=args.ff_size) self.bert.model = BertModel(bert_config) self.encoder = Classifier(self.bert.model.config.hidden_size) if args.param_init != 0.0: for p in self.encoder.parameters(): p.data.uniform_(-args.param_init, args.param_init) if args.param_init_glorot: for p in self.encoder.parameters(): if p.dim() > 1: xavier_uniform_(p) self.to(device)