def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size): super(GraphProp, self).__init__() self.num_prop_rounds = num_prop_rounds # Setting from the paper self.node_activation_hidden_size = 2 * node_hidden_size message_funcs = [] self.reduce_funcs = [] node_update_funcs = [] for t in range(num_prop_rounds): # input being [hv, hu, xuv] message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size, self.node_activation_hidden_size)) self.reduce_funcs.append(partial(self.dgmg_reduce, round=t)) node_update_funcs.append( nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)) self.message_funcs = nn.ModuleList(message_funcs) self.node_update_funcs = nn.ModuleList(node_update_funcs)
def __init__(self): super(Net, self).__init__() self.time_sequence = 8 self.conv1 = nn.Conv2d(1, 4, kernel_size=1, padding=0) self.conv1_bn = nn.BatchNorm2d(5) self.conv2 = nn.Conv2d(5, 3, kernel_size=3, padding=1) self.conv2_bn = nn.BatchNorm2d(8) self.gru1 = nn.GRUCell(1000 + 96, 512) self.gru2 = nn.GRUCell(512, 512) self.gru3 = nn.GRUCell(512, 512) self.gru4 = nn.GRUCell(512, 512) self.gru5 = nn.GRUCell(512, 512) self.gru6 = nn.GRUCell(512, 512) self.nalu2 = nalu.NALU(512, 125) torch.nn.init.xavier_uniform_(self.conv1.weight) torch.nn.init.xavier_uniform_(self.conv2.weight)
def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False, context_attention='simple', D_a=100, dropout=0.5, emo_gru=True): super(CommonsenseRNNCell, self).__init__() self.D_m = D_m self.D_s = D_s self.D_g = D_g self.D_p = D_p self.D_r = D_r self.D_i = D_i self.D_e = D_e # print ('dmsg', D_m, D_s, D_g) self.g_cell = nn.GRUCell(D_m+D_p+D_r, D_g) self.p_cell = nn.GRUCell(D_s+D_g, D_p) self.r_cell = nn.GRUCell(D_m+D_s+D_g, D_r) self.i_cell = nn.GRUCell(D_s+D_p, D_i) self.e_cell = nn.GRUCell(D_m+D_p+D_r+D_i, D_e) self.emo_gru = emo_gru self.listener_state = listener_state if listener_state: self.pl_cell = nn.GRUCell(D_s+D_g, D_p) self.rl_cell = nn.GRUCell(D_m+D_s+D_g, D_r) self.dropout = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.dropout4 = nn.Dropout(dropout) self.dropout5 = nn.Dropout(dropout) if context_attention=='simple': self.attention = SimpleAttention(D_g) else: self.attention = MatchingAttention(D_g, D_m, D_a, context_attention)
def __init__(self, vocab_size, embedding_size=768, gru_hidden_size=768): super().__init__() self.gru_cell = nn.GRUCell(embedding_size, gru_hidden_size) self.l_out = nn.Linear(gru_hidden_size, vocab_size) self.softmax = nn.Softmax(dim=1)
def __init__(self, input_size, args): super().__init__() self.args = args self.input_size = input_size output_size = args.rnn_agg_size self.rnn = nn.GRUCell(input_size, output_size)
def _load_rnn(self): self.rnn = nn.GRUCell(620, 2400) parameters = self._load_rnn_params() state_dict = self._make_rnn_state_dict(parameters) self.rnn.load_state_dict(state_dict) return self.rnn
def __init__(self, ins=2, es=8, hs=16): super(EncoderRNN, self).__init__() self.hs = hs self.linear1 = nn.Linear(ins, es) self.lstm1 = nn.LSTMCell(es, hs) self.gru1 = nn.GRUCell(es, hs)
def __init__(self, input_vocabulary, target_vocabulary, hidden_size=512, embedding_size=128, cell_type="LSTM"): """ :param list input_vocabulary: list of possible inputs :param list target_vocabulary: list of possible targets """ super(Network, self).__init__() self.h_input_encoder_size = hidden_size self.h_output_encoder_size = hidden_size self.h_decoder_size = hidden_size self.embedding_size = embedding_size self.input_vocabulary = input_vocabulary self.target_vocabulary = target_vocabulary # Number of tokens in input vocabulary self.v_input = len(input_vocabulary) # Number of tokens in target vocabulary self.v_target = len(target_vocabulary) self.cell_type = cell_type if cell_type == 'GRU': self.input_encoder_cell = nn.GRUCell( input_size=self.v_input + 1, hidden_size=self.h_input_encoder_size, bias=True) self.input_encoder_init = Parameter( torch.rand(1, self.h_input_encoder_size)) self.output_encoder_cell = nn.GRUCell( input_size=self.v_input + 1 + self.h_input_encoder_size, hidden_size=self.h_output_encoder_size, bias=True) self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1, hidden_size=self.h_decoder_size, bias=True) if cell_type == 'LSTM': self.input_encoder_cell = nn.LSTMCell( input_size=self.v_input + 1, hidden_size=self.h_input_encoder_size, bias=True) self.input_encoder_init = nn.ParameterList([ Parameter(torch.rand(1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size)) ]) self.output_encoder_cell = nn.LSTMCell( input_size=self.v_input + 1 + self.h_input_encoder_size, hidden_size=self.h_output_encoder_size, bias=True) self.output_encoder_init_c = Parameter( torch.rand(1, self.h_output_encoder_size)) self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1, hidden_size=self.h_decoder_size, bias=True) self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size)) self.W = nn.Linear(self.h_output_encoder_size + self.h_decoder_size, self.embedding_size) self.V = nn.Linear(self.embedding_size, self.v_target + 1) self.input_A = nn.Bilinear(self.h_input_encoder_size, self.h_output_encoder_size, 1, bias=False) self.output_A = nn.Bilinear(self.h_output_encoder_size, self.h_decoder_size, 1, bias=False) self.input_EOS = torch.zeros(1, self.v_input + 1) self.input_EOS[:, -1] = 1 self.input_EOS = Parameter(self.input_EOS) self.output_EOS = torch.zeros(1, self.v_input + 1) self.output_EOS[:, -1] = 1 self.output_EOS = Parameter(self.output_EOS) self.target_EOS = torch.zeros(1, self.v_target + 1) self.target_EOS[:, -1] = 1 self.target_EOS = Parameter(self.target_EOS)
def __init__(self, word_dict, item_dict, context_dict, output_length, args, device_id): super(DialogModel, self).__init__(device_id) domain = get_domain(args.domain) self.word_dict = word_dict self.item_dict = item_dict self.context_dict = context_dict self.args = args # embedding for words self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word) # context encoder ctx_encoder_ty = modules.RnnContextEncoder if args.rnn_ctx_encoder \ else modules.MlpContextEncoder self.ctx_encoder = ctx_encoder_ty(len(self.context_dict), domain.input_length(), args.nembed_ctx, args.nhid_ctx, args.init_range, device_id) # a reader RNN, to encode words self.reader = nn.GRU(input_size=args.nhid_ctx + args.nembed_word, hidden_size=args.nhid_lang, bias=True) self.decoder = nn.Linear(args.nhid_lang, args.nembed_word) # a writer, a RNNCell that will be used to generate utterances self.writer = nn.GRUCell(input_size=args.nhid_ctx + args.nembed_word, hidden_size=args.nhid_lang, bias=True) # tie the weights of reader and writer self.writer.weight_ih = self.reader.weight_ih_l0 self.writer.weight_hh = self.reader.weight_hh_l0 self.writer.bias_ih = self.reader.bias_ih_l0 self.writer.bias_hh = self.reader.bias_hh_l0 self.dropout = nn.Dropout(args.dropout) # a bidirectional selection RNN # it will go through input words and generate by the reader hidden states # to produce a hidden representation self.sel_rnn = nn.GRU(input_size=args.nhid_lang + args.nembed_word, hidden_size=args.nhid_attn, bias=True, bidirectional=True) # mask for disabling special tokens when generating sentences self.special_token_mask = torch.FloatTensor(len(self.word_dict)) # attention to combine selection hidden states self.attn = nn.Sequential( torch.nn.Linear(2 * args.nhid_attn, args.nhid_attn), nn.Tanh(), torch.nn.Linear(args.nhid_attn, 1)) # selection encoder, takes attention output and context hidden and combines them self.sel_encoder = nn.Sequential( torch.nn.Linear(2 * args.nhid_attn + args.nhid_ctx, args.nhid_sel), nn.Tanh()) # selection decoders, one per each item self.sel_decoders = nn.ModuleList() for i in range(output_length): self.sel_decoders.append( nn.Linear(args.nhid_sel, len(self.item_dict))) self.init_weights() # fill in the mask for i in range(len(self.word_dict)): w = self.word_dict.get_word(i) special = domain.item_pattern.match(w) or w in ('<unk>', 'YOU:', 'THEM:', '<pad>') self.special_token_mask[i] = -999 if special else 0.0 self.special_token_mask = self.to_device(self.special_token_mask)
def __init__(self, input_shape, args): super(LatentRNNAgent, self).__init__() self.args = args self.input_shape = input_shape self.n_agents = args.n_agents self.n_actions = args.n_actions self.latent_dim = args.latent_dim self.hidden_dim = args.rnn_hidden_dim self.bs = 0 # pi_param = th.rand(args.n_agents) # pi_param = pi_param / pi_param.sum() # self.pi_param = nn.Parameter(pi_param) # mu_param = th.randn(args.n_agents, args.latent_dim) # mu_param = mu_param / mu_param.norm(dim=0) # self.mu_param = nn.Parameter(mu_param) #self.embed_fc_input_size = args.own_feature_size self.embed_fc_input_size = input_shape #self.embed_fc_input_size=args.n_agents self.embed_fc1 = nn.Linear(self.embed_fc_input_size, args.latent_dim * 4) self.embed_fc2 = nn.Linear(args.latent_dim * 4, args.latent_dim * 2) self.inference_fc1 = nn.Linear(args.rnn_hidden_dim + input_shape, args.latent_dim * 4) self.inference_fc2 = nn.Linear(args.latent_dim * 4, args.latent_dim * 2) #snail_T = args.snail_max_t #snail_input_dim=input_shape #if args.obs_agent_id: # snail_input_dim-=args.n_agents #snail_key_size = args.snail_key_size #snail_value_size = args.snail_value_size #snail_filters = args.snail_filters #layer_count = math.ceil(math.log(snail_T) / math.log(2)) #self.infer_mod0=snail.AttentionBlock(snail_input_dim, snail_key_size, snail_value_size) # input_dims, key_size, value_size #self.infer_mod1=snail.TCBlock(snail_input_dim+snail_value_size, snail_T, snail_filters) # in_channels, seq_len, filters #self.infer_mod2=snail.AttentionBlock(snail_input_dim+snail_value_size+snail_filters*layer_count, snail_key_size, snail_value_size) # snail_input_dim+2*snail_value_size+snail_filters*layer_count #self.infer_mod3=nn.Conv1d(snail_input_dim+2*snail_value_size+snail_filters*layer_count,self.latent_dim*2,1) # in_channels, out_channels, kernel_size self.latent = th.rand(args.n_agents, args.latent_dim * 2) # (n,mu+var) #self.latent0 = th.rand(args.n_agents, args.latent_dim * 2) self.latent_fc1 = nn.Linear(args.latent_dim, args.latent_dim * 4) self.latent_fc2 = nn.Linear(args.latent_dim * 4, args.latent_dim * 4) self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) # self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) # self.fc1_w_nn=nn.Linear(args.latent_dim,input_shape*args.rnn_hidden_dim) # self.fc1_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim) # self.rnn_ih_w_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim*args.rnn_hidden_dim) # self.rnn_ih_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim) # self.rnn_hh_w_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim*args.rnn_hidden_dim) # self.rnn_hh_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim) self.fc2_w_nn = nn.Linear(args.latent_dim * 4, args.rnn_hidden_dim * args.n_actions) self.fc2_b_nn = nn.Linear(args.latent_dim * 4, args.n_actions)
def __init__( self, n_vocab, n_layers, n_units, n_embed=None, typ="lstm", dropout_rate=0.5, emb_dropout_rate=0.0, tie_weights=False, ): """Initialize class. :param int n_vocab: The size of the vocabulary :param int n_layers: The number of layers to create :param int n_units: The number of units per layer :param str typ: The RNN type """ super(RNNLM, self).__init__() if n_embed is None: n_embed = n_units self.embed = nn.Embedding(n_vocab, n_embed) if emb_dropout_rate == 0.0: self.embed_drop = None else: self.embed_drop = nn.Dropout(emb_dropout_rate) if typ == "lstm": self.rnn = nn.ModuleList( [nn.LSTMCell(n_embed, n_units)] + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)] ) else: self.rnn = nn.ModuleList( [nn.GRUCell(n_embed, n_units)] + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)] ) self.dropout = nn.ModuleList( [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)] ) self.lo = nn.Linear(n_units, n_vocab) self.n_layers = n_layers self.n_units = n_units self.typ = typ logging.info("Tie weights set to {}".format(tie_weights)) logging.info("Dropout set to {}".format(dropout_rate)) logging.info("Emb Dropout set to {}".format(emb_dropout_rate)) if tie_weights: assert ( n_embed == n_units ), "Tie Weights: True need embedding and final dimensions to match" self.lo.weight = self.embed.weight # initialize parameters from uniform distribution for param in self.parameters(): param.data.uniform_(-0.1, 0.1)
def __init__(self, target_vocabulary, hidden_size=512, embedding_size=128, cell_type="LSTM", input_size=(3, 256, 256)): """ :param: input_vocabularies: List containing a vocabulary list for each input. E.g. if learning a function f:A->B from (a,b) pairs, input_vocabularies has length 2 :param: target_vocabulary: Vocabulary list for output """ super(Image_RobustFill, self).__init__() self.n_encoders = 1 self.hidden_size = hidden_size self.embedding_size = embedding_size self.input_vocabularies = [None] #input_vocabularies self.target_vocabulary = target_vocabulary self._refreshVocabularyIndex() self.v_inputs = None #[len(x) for x in input_vocabularies] # Number of tokens in input vocabularies self.v_target = len( target_vocabulary) # Number of tokens in target vocabulary self.no_inputs = len(self.input_vocabularies) == 0 self.cell_type = cell_type if cell_type == 'GRU': self.encoder_init_h = Parameter(torch.rand(1, self.hidden_size)) # self.encoder_cells = nn.ModuleList( # [nn.GRUCell(input_size=self.v_inputs[0]+1, hidden_size=self.hidden_size, bias=True)] + # [nn.GRUCell(input_size=self.v_inputs[i]+1+self.hidden_size, hidden_size=self.hidden_size, bias=True) for i in range(1, self.n_encoders)] # ) self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1, hidden_size=self.hidden_size, bias=True) if cell_type == 'LSTM': self.encoder_init_h = Parameter( torch.rand(1, self.hidden_size )) #Also used for decoder if self.no_inputs=True # self.encoder_init_cs = nn.ParameterList( # [Parameter(torch.rand(1, self.hidden_size)) for i in range(len(self.v_inputs))] # ) # self.encoder_cells = nn.ModuleList() # for i in range(self.n_encoders): # input_size = self.v_inputs[i] + 1 + (self.hidden_size if i>0 else 0) # self.encoder_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=self.hidden_size, bias=True)) self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1, hidden_size=self.hidden_size, bias=True) self.decoder_init_c = Parameter(torch.rand(1, self.hidden_size)) self.W = nn.Linear( self.hidden_size if self.no_inputs else 2 * self.hidden_size, self.embedding_size) self.V = nn.Linear(self.embedding_size, self.v_target + 1) #self.As = nn.ModuleList([nn.Bilinear(self.hidden_size, self.hidden_size, 1, bias=False) for i in range(self.n_encoders)]) #image encoder: self.conv1 = nn.Conv2d(3, 8, kernel_size=(3, 3), padding=(3, 3), stride=(1, 1)) self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)) self.conv3 = nn.Conv2d(16, 16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)) self.conv4 = nn.Conv2d(16, 16, kernel_size=(3, 3), padding=(1, 1), stride=(1, 1)) #self.conv4 = nn.Conv2d(256, 512, kernel_size=(3, 3), # padding=(1, 1), stride=(1, 1)) self.batch_norm1 = nn.BatchNorm2d(8) self.batch_norm2 = nn.BatchNorm2d(16) self.img_feat_to_embedding = nn.Sequential( nn.Linear(16 * 16 * 16, 64), nn.ReLU(), nn.Linear(64, 64), nn.ReLU(), nn.Linear(64, self.hidden_size)) #attention params: self.h_to_32_linear = nn.Linear(self.hidden_size, 32) self.img_to_32 = nn.Linear(16 * 16 * 16, 32) self.fc_loc = nn.Linear(32 + 32, 3 * 2) self.fc_loc.weight.data.zero_() self.fc_loc.bias.data.copy_( torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)) self.img_feat_to_context = nn.Sequential( nn.Linear(16 * 16 * 16, 128), nn.ReLU(), nn.Linear(128, 128), nn.ReLU(), nn.Linear(128, self.hidden_size))
def __init__(self, extra_layers_encode): super(pruning_model, self).__init__() self.extra_layers_encode = extra_layers_encode self.rnncell = nn.GRUCell(256, 4, bias=True) self._initialize_weights() self.relu = nn.ReLU()
def __init__(self, n_atom_feature, n_edge_feature): super(IntraNet, self).__init__() self.C = nn.GRUCell(n_atom_feature, n_atom_feature) self.cal_message = nn.Sequential( nn.Linear(n_atom_feature * 2 + n_edge_feature, n_atom_feature), )
def test_gru_cell(self): model = nn.GRUCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE) input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE) h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE) self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=(input, h0), use_gpu=False)
def __init__(self, insize, hidden_size, out_size, drop): super(BahdanauDecoder, self).__init__() self.cell = nn.GRUCell(embedding_size + hidden_size, hidden_size) self.concat = nn.Linear(hidden_size * 2, out_size) self.drop = nn.Dropout(drop)
def __init__(self, node_features, edge_features, message_passes, out_features, edge_embedding_size, edge_emb_depth=3, edge_emb_hidden_dim=150, edge_emb_dropout_p=0.0, att_depth=3, att_hidden_dim=80, att_dropout_p=0.0, msg_depth=3, msg_hidden_dim=80, msg_dropout_p=0.0, gather_width=100, gather_att_depth=3, gather_att_hidden_dim=100, gather_att_dropout_p=0.0, gather_emb_depth=3, gather_emb_hidden_dim=100, gather_emb_dropout_p=0.0, out_depth=2, out_hidden_dim=100, out_dropout_p=0, out_layer_shrinkage=1.0): super(EMNImplementation, self).__init__(edge_features, edge_embedding_size, message_passes, out_features) self.embedding_nn = FeedForwardNetwork( node_features * 2 + edge_features, [edge_emb_hidden_dim] * edge_emb_depth, edge_embedding_size, dropout_p=edge_emb_dropout_p) self.emb_msg_nn = FeedForwardNetwork(edge_embedding_size, [msg_hidden_dim] * msg_depth, edge_embedding_size, dropout_p=msg_dropout_p) self.att_msg_nn = FeedForwardNetwork(edge_embedding_size, [att_hidden_dim] * att_depth, edge_embedding_size, dropout_p=att_dropout_p) #self.extra_gru_layer = nn.Linear(edge_embedding_size, edge_embedding_size, bias=False) self.gru = nn.GRUCell(edge_embedding_size, edge_embedding_size, bias=False) self.gather = GraphGather(edge_embedding_size, gather_width, gather_att_depth, gather_att_hidden_dim, gather_att_dropout_p, gather_emb_depth, gather_emb_hidden_dim, gather_emb_dropout_p) out_layer_sizes = [ # example: depth 5, dim 50, shrinkage 0.5 => out_layer_sizes [50, 42, 35, 30, 25] round(out_hidden_dim * (out_layer_shrinkage**(i / (out_depth - 1 + 1e-9)))) for i in range(out_depth) ] self.out_nn = FeedForwardNetwork(gather_width, out_layer_sizes, out_features, dropout_p=out_dropout_p)
def experiment(exp_specs): ptu.set_gpu_mode(exp_specs['use_gpu']) # Set up logging ---------------------------------------------------------- exp_id = exp_specs['exp_id'] exp_prefix = exp_specs['exp_name'] seed = exp_specs['seed'] set_seed(seed) setup_logger(exp_prefix=exp_prefix, exp_id=exp_id, variant=exp_specs) # Prep the data ----------------------------------------------------------- replay_dict = joblib.load(exp_specs['replay_dict_path']) next_obs_array = replay_dict['next_observations'] acts_array = replay_dict['actions'] data_loader = BasicDataLoader( next_obs_array[:40000], acts_array[:40000], exp_specs['episode_length'], exp_specs['batch_size'], use_gpu=ptu.gpu_enabled()) val_data_loader = BasicDataLoader( next_obs_array[40000:], acts_array[40000:], exp_specs['episode_length'], exp_specs['batch_size'], use_gpu=ptu.gpu_enabled()) # Model Definition -------------------------------------------------------- conv_channels = 64 conv_encoder = nn.Sequential( nn.Conv2d(3, conv_channels, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU() ) ae_dim = 128 gru_dim = 512 img_h = 5 flat_inter_img_dim = img_h * img_h * conv_channels fc_encoder = nn.Sequential( nn.Linear(flat_inter_img_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU() ) gru = nn.GRUCell( ae_dim, gru_dim, bias=True ) fc_decoder = nn.Sequential( nn.Linear(gru_dim + 4, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), nn.Linear(ae_dim, flat_inter_img_dim, bias=False), nn.BatchNorm1d(flat_inter_img_dim), nn.ReLU(), ) conv_decoder = nn.Sequential( nn.ConvTranspose2d(conv_channels, conv_channels, 1, stride=1, padding=0, output_padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.ConvTranspose2d(conv_channels, conv_channels, 1, stride=1, padding=0, output_padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True), nn.Sigmoid() ) if ptu.gpu_enabled(): conv_encoder.cuda() fc_encoder.cuda() gru.cuda() fc_decoder.cuda() conv_decoder.cuda() # Optimizer --------------------------------------------------------------- model_optim = Adam( [ item for sublist in map( lambda x: list(x.parameters()), [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder] ) for item in sublist ], lr=float(exp_specs['model_lr']), weight_decay=float(exp_specs['model_wd']) ) # ------------------------------------------------------------------------- freq_bptt = exp_specs['freq_bptt'] episode_length = exp_specs['episode_length'] losses = [] for iter_num in range(int(float(exp_specs['max_iters']))): if iter_num % freq_bptt == 0: if iter_num > 0: # loss = loss / freq_bptt loss.backward() model_optim.step() prev_h_batch = prev_h_batch.detach() loss = 0 if iter_num % episode_length == 0: prev_h_batch = Variable(torch.zeros(exp_specs['batch_size'], gru_dim)) if ptu.gpu_enabled(): prev_h_batch = prev_h_batch.cuda() if iter_num % exp_specs['freq_val'] == 0: train_loss_print = '\t'.join(losses) losses = [] obs_batch, act_batch = data_loader.get_next_batch() recon = fc_decoder(torch.cat([prev_h_batch, act_batch], 1)).view(obs_batch.size(0), conv_channels, img_h, img_h) recon = conv_decoder(recon) enc = conv_encoder(obs_batch).view(obs_batch.size(0), -1) enc = fc_encoder(enc) prev_h_batch = gru(enc, prev_h_batch) losses.append('%.4f' % ((obs_batch - recon)**2).mean()) if iter_num % episode_length != 0: loss = loss + ((obs_batch - recon)**2).sum()/float(exp_specs['batch_size']) if iter_num % (50*episode_length) in range(2*episode_length): save_pytorch_tensor_as_img(recon[0].data.cpu(), 'junk_vis/fixed_colors_simple_maze_5_h/rnn_recon_%d.png' % iter_num) save_pytorch_tensor_as_img(obs_batch[0].data.cpu(), 'junk_vis/fixed_colors_simple_maze_5_h/rnn_obs_%d.png' % iter_num) if iter_num % exp_specs['freq_val'] == 0: print('\nValidating Iter %d...' % iter_num) list(map(lambda x: x.eval(), [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder])) val_prev_h_batch = Variable(torch.zeros(exp_specs['batch_size'], gru_dim)) if ptu.gpu_enabled(): val_prev_h_batch = val_prev_h_batch.cuda() losses = [] for i in range(episode_length): obs_batch, act_batch = data_loader.get_next_batch() recon = fc_decoder(torch.cat([val_prev_h_batch, act_batch], 1)).view(obs_batch.size(0), conv_channels, img_h, img_h) recon = conv_decoder(recon) enc = conv_encoder(obs_batch).view(obs_batch.size(0), -1) enc = fc_encoder(enc) val_prev_h_batch = gru(enc, val_prev_h_batch) losses.append('%.4f' % ((obs_batch - recon)**2).mean()) loss_print = '\t'.join(losses) print('Val MSE:\t' + loss_print) print('Train MSE:\t' + train_loss_print) list(map(lambda x: x.train(), [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder]))
def __init__(self, msg_dim, node_state_dim, edge_feat_dim, num_prop=1, num_layer=1, has_attention=True, att_hidden_dim=128, has_residual=False, has_graph_output=False, output_hidden_dim=128, graph_output_dim=None): super(GNN, self).__init__() self.msg_dim = msg_dim self.node_state_dim = node_state_dim self.edge_feat_dim = edge_feat_dim self.num_prop = num_prop self.num_layer = num_layer self.has_attention = has_attention self.has_residual = has_residual self.att_hidden_dim = att_hidden_dim self.has_graph_output = has_graph_output self.output_hidden_dim = output_hidden_dim self.graph_output_dim = graph_output_dim self.update_func = nn.ModuleList([ nn.GRUCell(input_size=self.msg_dim, hidden_size=self.node_state_dim) for _ in range(self.num_layer) ]) self.msg_func = nn.ModuleList([ nn.Sequential(*[ nn.Linear(self.node_state_dim + self.edge_feat_dim, self.msg_dim), nn.ReLU(), nn.Linear(self.msg_dim, self.msg_dim) ]) for _ in range(self.num_layer) ]) if self.has_attention: self.att_head = nn.ModuleList([ nn.Sequential(*[ nn.Linear(self.node_state_dim + self.edge_feat_dim, self.att_hidden_dim), nn.ReLU(), nn.Linear(self.att_hidden_dim, self.msg_dim), nn.Sigmoid() ]) for _ in range(self.num_layer) ]) if self.has_graph_output: self.graph_output_head_att = nn.Sequential(*[ nn.Linear(self.node_state_dim, self.output_hidden_dim), nn.ReLU(), nn.Linear(self.output_hidden_dim, 1), nn.Sigmoid() ]) self.graph_output_head = nn.Sequential( *[nn.Linear(self.node_state_dim, self.graph_output_dim)])
def __init__(self, dim): super().__init__() self.gru = nn.GRUCell(dim, dim)
def __init__(self, memory, message_dimension, memory_dimension, device): super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device) self.memory_updater = nn.GRUCell(input_size=message_dimension, hidden_size=memory_dimension)
def __init__(self, input_size, hidden_size, bias=True): cell = nn.GRUCell(input_size, hidden_size, bias=bias) super().__init__(cell)
def GRUCell(input_size, hidden_size, **kwargs): m = nn.GRUCell(input_size, hidden_size) for name, param in m.named_parameters(): if "weight" in name or "bias" in name: param.data.uniform_(-0.1, 0.1) return m
def __init__(self, feat_dim: int, n_edge_types: int): super(Propogator, self).__init__() self.feat_dim = feat_dim self.n_edge_types = n_edge_types self.gru = nn.GRUCell(feat_dim * n_edge_types, feat_dim)
def __init__(self, **kwargs): super(MessageModule, self).__init__() h_dims = kwargs['h_dims'] g_dims = kwargs['g_dims'] self.to_state = nn.GRUCell(h_dims + g_dims, h_dims) INIT.orthogonal_(self.to_state.weight_hh)
def __init__(self): super().__init__() conv_channels = 32 self.conv_encoder = nn.Sequential( nn.Conv2d(3, conv_channels, 4, stride=2, padding=1, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.Conv2d(conv_channels, conv_channels, 4, stride=2, padding=1, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU()) ae_dim = 512 lstm_dim = 512 self.lstm_dim = lstm_dim img_h = 5 flat_inter_img_dim = img_h * img_h * conv_channels act_dim = 64 self.conv_channels = conv_channels self.img_h = img_h self.lstm_act_proc_fc = nn.Linear(4, act_dim, bias=True) self.recon_act_proc_fc = nn.Linear(4, act_dim, bias=True) self.mask_act_proc_fc = nn.Linear(4, act_dim, bias=True) self.attention_seq = nn.Sequential( nn.Linear(lstm_dim + act_dim, lstm_dim, bias=False), nn.BatchNorm1d(lstm_dim), nn.ReLU(), nn.Linear(lstm_dim, lstm_dim), # nn.Sigmoid() # nn.Softmax() ) self.fc_encoder = nn.Sequential( nn.Linear(flat_inter_img_dim + act_dim, ae_dim, bias=False), nn.BatchNorm1d(ae_dim), nn.ReLU(), # nn.Linear(ae_dim, ae_dim, bias=False), # nn.BatchNorm1d(ae_dim), # nn.ReLU(), # nn.Linear(ae_dim, ae_dim, bias=False), # nn.BatchNorm1d(ae_dim), # nn.ReLU(), # nn.Linear(ae_dim, ae_dim, bias=False), # nn.BatchNorm1d(ae_dim), # nn.ReLU() ) self.lstm = nn.GRUCell(ae_dim, lstm_dim, bias=True) self.fc_decoder = nn.Sequential( nn.Linear(lstm_dim + act_dim, flat_inter_img_dim, bias=False), nn.BatchNorm1d(flat_inter_img_dim), nn.ReLU(), # nn.Linear(ae_dim, ae_dim, bias=False), # nn.BatchNorm1d(ae_dim), # nn.ReLU(), # # nn.Linear(ae_dim, ae_dim, bias=False), # # nn.BatchNorm1d(ae_dim), # # nn.ReLU(), # # nn.Linear(ae_dim, ae_dim, bias=False), # # nn.BatchNorm1d(ae_dim), # # nn.ReLU(), # nn.Linear(ae_dim, flat_inter_img_dim, bias=False), # nn.BatchNorm1d(flat_inter_img_dim), # nn.ReLU(), ) self.conv_decoder = nn.Sequential( nn.ConvTranspose2d(conv_channels, conv_channels, 4, stride=2, padding=1, output_padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), # nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False), # nn.BatchNorm2d(conv_channels), # nn.ReLU(), nn.ConvTranspose2d(conv_channels, conv_channels, 4, stride=2, padding=1, output_padding=0, bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(), # nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False), # nn.BatchNorm2d(conv_channels), # nn.ReLU(), ) self.mean_decoder = nn.Sequential( nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True), nn.Sigmoid()) self.log_cov_decoder = nn.Sequential( nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True), )
def __init__(self, opt, padding_idx=0, embedding=None): super().__init__(opt, padding_idx, embedding) # RNN document encoder self.doc_rnn = layers.StackedBRNN( input_size=self.doc_input_size, hidden_size=opt['hidden_size'], num_layers=2, dropout_rate=opt['dropout_rnn'], # dropout_output=opt['dropout_rnn_output'], variational_dropout=opt['variational_dropout'], concat_layers=True, rnn_type=opt['rnn_type'], padding=opt['rnn_padding'], residual=opt['residual'], squeeze_excitation=opt['squeeze_excitation'], ) # RNN question encoder self.question_rnn = layers.StackedBRNN( input_size=self.question_input_size, hidden_size=opt['hidden_size'], num_layers=2, dropout_rate=opt['dropout_rnn'], # dropout_output=opt['dropout_rnn_output'], variational_dropout=opt['variational_dropout'], concat_layers=True, rnn_type=opt['rnn_type'], padding=opt['rnn_padding'], residual=opt['residual'], squeeze_excitation=opt['squeeze_excitation'], ) # Output sizes of rnn encoders doc_hidden_size = 2 * 2 * opt['hidden_size'] question_hidden_size = doc_hidden_size self.question_urnn = layers.StackedBRNN( input_size=question_hidden_size, hidden_size=opt['hidden_size'], num_layers=opt['fusion_understanding_layers'], dropout_rate=opt['dropout_rnn'], variational_dropout=opt['variational_dropout'], rnn_type=opt['rnn_type'], padding=opt['rnn_padding'], residual=opt['residual'], squeeze_excitation=opt['squeeze_excitation'], concat_layers=False, ) self.multi_level_fusion = layers.FullAttention( full_size=self.paired_input_size + doc_hidden_size, hidden_size=2 * 3 * opt['hidden_size'], num_level=3, dropout=opt['dropout_rnn'], variational_dropout=opt['variational_dropout'], ) self.doc_urnn = layers.StackedBRNN( input_size=2 * 5 * opt['hidden_size'], hidden_size=opt['hidden_size'], num_layers=opt['fusion_understanding_layers'], dropout_rate=opt['dropout_rnn'], variational_dropout=opt['variational_dropout'], rnn_type=opt['rnn_type'], padding=opt['rnn_padding'], residual=opt['residual'], squeeze_excitation=opt['squeeze_excitation'], concat_layers=False, ) self.self_boost_fusions = nn.ModuleList() self.doc_final_rnns = nn.ModuleList() full_size = self.paired_input_size + 4 * 3 * opt['hidden_size'] for i in range(self.opt['fusion_self_boost_times']): self.self_boost_fusions.append( layers.FullAttention( full_size=full_size, hidden_size=2 * opt['hidden_size'], num_level=1, dropout=opt['dropout_rnn'], variational_dropout=opt['variational_dropout'], )) self.doc_final_rnns.append( layers.StackedBRNN( input_size=4 * opt['hidden_size'], hidden_size=opt['hidden_size'], num_layers=opt['fusion_final_layers'], dropout_rate=opt['dropout_rnn'], variational_dropout=opt['variational_dropout'], rnn_type=opt['rnn_type'], padding=opt['rnn_padding'], residual=opt['residual'], squeeze_excitation=opt['squeeze_excitation'], concat_layers=False, )) full_size += 2 * opt['hidden_size'] # Question merging if opt['question_merge'] not in ['avg', 'self_attn']: raise NotImplementedError('question_merge = %s' % opt['question_merge']) if opt['question_merge'] == 'self_attn': self.quesiton_merge_attns = nn.ModuleList() # Question merging if opt['question_merge'] not in ['avg', 'self_attn']: raise NotImplementedError('question_merge = %s' % opt['question_merge']) if opt['question_merge'] == 'self_attn': self.self_attn = layers.LinearSeqAttn(2 * opt['hidden_size']) # Bilinear attention for span start/end self.start_attn = layers.BilinearSeqAttn( 2 * opt['hidden_size'], 2 * opt['hidden_size'], ) if opt['end_gru']: self.end_gru = nn.GRUCell(2 * opt['hidden_size'], 2 * opt['hidden_size']) self.end_attn = layers.BilinearSeqAttn( 2 * opt['hidden_size'], 2 * opt['hidden_size'], )
def __init__(self, ctx_size_dict, z_size, z_len=10, z_transform=None, z_in_size=256, z_merge='sum', z_init='mean_ctx', att_type='mlp', att_activ='tanh', att_bottleneck='ctx', att_temp=1.0, att_transform_ctx=False, mlp_bias=False, hiero_mid_dim=128): super().__init__() self.ctx_size_dict = ctx_size_dict self.z_size = z_size self.z_len = z_len self.z_transform = z_transform.lower() if z_transform else None self.z_in_size = z_in_size self.z_merge = z_merge # Other arguments self.att_type = att_type self.att_bottleneck = att_bottleneck self.att_activ = att_activ self.att_temp = att_temp self.att_transform_ctx = att_transform_ctx self.mlp_bias = mlp_bias self.z_init = z_init self.hiero_mid_dim = hiero_mid_dim # Safety check self._sanity_check() # Create FF layers to manage different context size... # Each layer maps ctx_size_dict[k] to z_in_size (==ctx_size) # z_transform tells the kind of (non-)linearity to use if self.z_transform: self.z_transforms = nn.ModuleDict() for k in self.ctx_size_dict: self.z_transforms[k] = FF(self.ctx_size_dict[k], self.z_in_size, activ=z_transform) self.ctx_size_dict[k] = self.z_in_size self.ctx_size = self.z_in_size else: s = set([size for size in self.ctx_size_dict.values()]) assert len(set(s)) == 1, \ "Incompatible encoding sizes, consider using z_transform:tanh in config." self.ctx_size = next(iter(s)) # Create an attention layer for each modality # TODO: sharing weights between att. mechanisms is possible self.att = nn.ModuleDict() # Fetch correct attention class Attention = get_attention(self.att_type) for k in self.ctx_size_dict: att_in_size = self.ctx_size if self.z_transform else self.ctx_size_dict[ k] self.att[k] = Attention(att_in_size, self.z_size, transform_ctx=self.att_transform_ctx, mlp_bias=self.mlp_bias, att_activ=self.att_activ, att_bottleneck=self.att_bottleneck, temp=self.att_temp, ctx2hid=False) # Fusion operation if self.z_merge == 'hierarchical': self.hiero_att = HierarchicalAttention( [self.ctx_size_dict[k] for k in self.ctx_size_dict.keys()], self.z_size, self.hiero_mid_dim) self.merge_op = self._merge_hierarchical else: self.merge_op = self._merge_sum # Create decoder layer necessary for attention self.dec = nn.GRUCell(self.ctx_size, self.z_size) # Several strategies to initialize the decoder can be considered # Set decoder initializer self._init_func = getattr(self, '_rnn_init_{}'.format(self.z_init)) # if init is not zero, then create FF layer if self.z_init != 'zero': self.ff_z_init = FF(self.ctx_size, self.z_size, activ='tanh')
def __init__(self, input_size, hidden_size): super().__init__() self.input_size, self.hidden_size = input_size, hidden_size self.dynamics = nn.GRUCell(input_size=input_size, hidden_size=hidden_size)
def initGRUCell(input_size, hidden_size, **kwargs): model = nn.GRUCell(input_size, hidden_size, **kwargs) for name, param in model.named_parameters(): if 'weight' in name or 'bias' in name: param.data.uniform_(-0.1, 0.1) return model # Attention: attention module class Attention(nn.Module): def __init__(self, hidden_size, attn_size): super(Attention, self).__init__() self.hidden_size = hidden_size self.attn_size = attn_size self.linear_layer1 = nn.Linear(self.hidden_size, self.attn_size) self.linear_layer2 = nn.Linear(self.hidden_size + self.attn_size, self.attn_size) def forward(self, hidden, encoder_outs, source_lengths): # hidden_size -> attn_size attn_hidden = self.linear_layer1(hidden) # get scores attn_score = torch.sum((encoder_outs.transpose(0,1) * attn_hidden.unsqueeze(0)),2) attn_mask = torch.transpose(seq_mask(source_lengths, max_len = max(source_lengths).item()), 0,1) masked_attn = attn_mask*attn_score masked_attn[masked_attn==0] = -1e10 # softmax over attention to get weights attn_scores = F.softmax(masked_attn, dim=0) # compute weighted sum according to attention scores attn_hidden = torch.sum(attn_scores.unsqueeze(2)*encoder_outs.transpose(0,1), 0) attn_hidden = self.linear_layer2(torch.cat((attn_hidden, hidden), dim=1)) attn_hidden = torch.tanh(attn_hidden) return attn_hidden, attn_scores # AttnDecoderRNN class AttnDecoderRNN(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, num_rnn_layers = 1, attention = True, dropout_percent=0.1): super(AttnDecoderRNN, self).__init__() self.hidden_size = hidden_size self.embed_size = embed_size encoder_output_size = self.hidden_size self.embedding = nn.Embedding(vocab_size, embed_size, PAD_IDX) self.dropout_f = nn.Dropout(p=dropout_percent) self.num_layers = num_rnn_layers if attention: self.attention = Attention(self.hidden_size, encoder_output_size) else: self.attention = None self.layers = nn.ModuleList([initLSTMCell(input_size=self.hidden_size+self.embed_size if ((layer == 0) and attention) \ else self.embed_size if layer == 0 else self.hidden_size, hidden_size=self.hidden_size,)for layer in range(self.num_layers)]) self.linear_layer = nn.Linear(self.hidden_size, vocab_size) self.log_softmax = nn.LogSoftmax(dim=1) def forward(self, decoder_inputs, context, prev_hiddens, prev_context, encoder_outputs, source_lengths): batch_size = decoder_inputs.size(0) # embed embed_target = self.embedding(decoder_inputs) out = self.dropout_f(embed_target) if self.attention is not None: input_ = torch.cat([out.squeeze(1), context], dim = 1) else: input_ = out.squeeze(1) context_ = [] decoder_hiddens_ = [] for layer, rnn in enumerate(self.layers): hidden, con = rnn(input_, (prev_hiddens[layer], prev_context[layer])) input_ = self.dropout_f(hidden) decoder_hiddens_.append(hidden.unsqueeze(0)) context_.append(con.unsqueeze(0)) decoder_hiddens_ = torch.cat(decoder_hiddens_, dim = 0) context_ = torch.cat(context_, dim = 0) if self.attention is not None: out, attn_score = self.attention(hidden, encoder_outputs, source_lengths) else: out = hidden attn_score = None context_vec = out out = self.dropout_f(out) # linear: hidden_size -> vocab_size deco_out = self.linear_layer(out) deco_out = self.log_softmax(deco_out) return out_vocab, context_vec, decoder_hiddens_, context_, attn_score # CNNencoder class CNNencoder(nn.Module): def __init__(self, vocab_size, embed_size, hidden_size, kernel_size, num_layers, percent_dropout=0.3): super(CNNencoder, self).__init__() self.num_layers = num_layers self.hidden_size = hidden_size self.kernel_size = kernel_size self.embed_size = embed_size self.vocab_size = vocab_size self.embedding = nn.Embedding(self.vocab_size, self.embed_size, padding_idx=0) self.dropout_f = nn.Dropout(percent_dropout) in_channels = self.embed_size self.conv = nn.Conv1d(in_channels, self.hidden_size, kernel_size, padding=kernel_size//2) # todo self.conv2 = nn.Conv1d(60, self.hidden_size, kernel_size, padding=kernel_size//2) self.ReLU = nn.ReLU() def forward(self, source_sentence): batch_size, seq_len = source_sentence.size() embeds_source = self.embedding(source_sentence) out = self.conv(embeds_source.transpose(1, 2)).transpose(1,2) out = self.ReLU(out) out = F.max_pool1d(out, kernel_size=5, stride=5) out = self.conv2(out.transpose(1, 2)).transpose(1,2) out = self.ReLU(out) out = torch.mean(out, dim=1).view(1, batch_size, self.hidden_size) return out class LSTMencoder(nn.Module): def __init__(self, input_size, embed_size, hidden_size, num_lstm_layers): super(LSTMencoder, self).__init__() self.hidden_size = hidden_size self.embed_size = embed_size self.embedding = Embedding(input_size, self.embed_size, padding_idx=0) self.dropout_ = nn.Dropout(p = 0.1) self.num_layers = num_lstm_layers self.lstm = LSTM(self.embed_size, self.hidden_size, batch_first=True, bidirectional=True, num_layers = self.num_layers, dropout = 0.15) def initHidden(self, batch_size): return torch.zeros(self.num_layers*2, batch_size, self.hidden_size).to(device),\ torch.zeros(self.num_layers*2, batch_size, self.hidden_size).to(device) def forward(self, encoder_inputs, source_lengths): sort_original_source = torch.sort(source_lengths, descending=True)[1] unsort_to_original_source = torch.sort(sort_original_source)[1] embeds_source = self.embedding(encoder_inputs) lstm_out = self.dropout_(embeds_source) batch_size, seq_len = embeds_source.size() hidden, context = self.initHidden(batch_size) sorted_output = lstm_out[sort_original_source] sorted_len = source_lengths[sort_original_source] packed_output = nn.utils.rnn.pack_padded_sequence(sorted_output, sorted_lengths, batch_first = True) packed_outs, (hiddden, context) = self.lstm(packed_output,(hidden, context)) hidden = hidden[:,unsort_to_original_source,:] context = context[:,unsort_to_original_source,:] lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=PAD_IDX, batch_first = True) # UNSORT OUTPUT lstm_out = lstm_out[unsort_to_original_source] hidden = hidden.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous().view(self.num_layers, batch_size, -1) context = context.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous().view(self.num_layers, batch_size, -1) return output, hidden, context