コード例 #1
0
    def __init__(self, num_prop_rounds, node_hidden_size, edge_hidden_size):
        super(GraphProp, self).__init__()

        self.num_prop_rounds = num_prop_rounds

        # Setting from the paper
        self.node_activation_hidden_size = 2 * node_hidden_size

        message_funcs = []
        self.reduce_funcs = []
        node_update_funcs = []

        for t in range(num_prop_rounds):
            # input being [hv, hu, xuv]
            message_funcs.append(nn.Linear(2 * node_hidden_size + edge_hidden_size,
                                           self.node_activation_hidden_size))

            self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
            node_update_funcs.append(
                nn.GRUCell(self.node_activation_hidden_size,
                           node_hidden_size))

        self.message_funcs = nn.ModuleList(message_funcs)
        self.node_update_funcs = nn.ModuleList(node_update_funcs)
コード例 #2
0
    def __init__(self):
        super(Net, self).__init__()
        self.time_sequence = 8
        self.conv1 = nn.Conv2d(1, 4, kernel_size=1, padding=0)
        self.conv1_bn = nn.BatchNorm2d(5)
        self.conv2 = nn.Conv2d(5, 3, kernel_size=3, padding=1)
        self.conv2_bn = nn.BatchNorm2d(8)

        self.gru1 = nn.GRUCell(1000 + 96, 512)
        self.gru2 = nn.GRUCell(512, 512)
        self.gru3 = nn.GRUCell(512, 512)
        self.gru4 = nn.GRUCell(512, 512)
        self.gru5 = nn.GRUCell(512, 512)
        self.gru6 = nn.GRUCell(512, 512)

        self.nalu2 = nalu.NALU(512, 125)

        torch.nn.init.xavier_uniform_(self.conv1.weight)
        torch.nn.init.xavier_uniform_(self.conv2.weight)
コード例 #3
0
    def __init__(self, D_m, D_s, D_g, D_p, D_r, D_i, D_e, listener_state=False,
                            context_attention='simple', D_a=100, dropout=0.5, emo_gru=True):
        super(CommonsenseRNNCell, self).__init__()

        self.D_m = D_m
        self.D_s = D_s
        self.D_g = D_g
        self.D_p = D_p
        self.D_r = D_r
        self.D_i = D_i
        self.D_e = D_e

        # print ('dmsg', D_m, D_s, D_g)
        self.g_cell = nn.GRUCell(D_m+D_p+D_r, D_g)
        self.p_cell = nn.GRUCell(D_s+D_g, D_p)
        self.r_cell = nn.GRUCell(D_m+D_s+D_g, D_r)
        self.i_cell = nn.GRUCell(D_s+D_p, D_i)
        self.e_cell = nn.GRUCell(D_m+D_p+D_r+D_i, D_e)
        
        
        self.emo_gru = emo_gru
        self.listener_state = listener_state
        if listener_state:
            self.pl_cell = nn.GRUCell(D_s+D_g, D_p)
            self.rl_cell = nn.GRUCell(D_m+D_s+D_g, D_r)

        self.dropout = nn.Dropout(dropout)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.dropout4 = nn.Dropout(dropout)
        self.dropout5 = nn.Dropout(dropout)

        if context_attention=='simple':
            self.attention = SimpleAttention(D_g)
        else:
            self.attention = MatchingAttention(D_g, D_m, D_a, context_attention)
コード例 #4
0
    def __init__(self, vocab_size, embedding_size=768, gru_hidden_size=768):
        super().__init__()

        self.gru_cell = nn.GRUCell(embedding_size, gru_hidden_size)
        self.l_out = nn.Linear(gru_hidden_size, vocab_size)
        self.softmax = nn.Softmax(dim=1)
コード例 #5
0
ファイル: maven_module.py プロジェクト: yuanleirl/pymarl2
 def __init__(self, input_size, args):
     super().__init__()
     self.args = args
     self.input_size = input_size
     output_size = args.rnn_agg_size
     self.rnn = nn.GRUCell(input_size, output_size)
コード例 #6
0
 def _load_rnn(self):
     self.rnn = nn.GRUCell(620, 2400)
     parameters = self._load_rnn_params()
     state_dict = self._make_rnn_state_dict(parameters)
     self.rnn.load_state_dict(state_dict)
     return self.rnn
コード例 #7
0
 def __init__(self, ins=2, es=8, hs=16):
     super(EncoderRNN, self).__init__()
     self.hs = hs
     self.linear1 = nn.Linear(ins, es)
     self.lstm1 = nn.LSTMCell(es, hs)
     self.gru1 = nn.GRUCell(es, hs)
コード例 #8
0
ファイル: network.py プロジェクト: zlapp/ec
    def __init__(self,
                 input_vocabulary,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM"):
        """
        :param list input_vocabulary: list of possible inputs
        :param list target_vocabulary: list of possible targets
        """
        super(Network, self).__init__()
        self.h_input_encoder_size = hidden_size
        self.h_output_encoder_size = hidden_size
        self.h_decoder_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabulary = input_vocabulary
        self.target_vocabulary = target_vocabulary
        # Number of tokens in input vocabulary
        self.v_input = len(input_vocabulary)
        # Number of tokens in target vocabulary
        self.v_target = len(target_vocabulary)

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.input_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = Parameter(
                torch.rand(1, self.h_input_encoder_size))
            self.output_encoder_cell = nn.GRUCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.h_decoder_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.input_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1,
                hidden_size=self.h_input_encoder_size,
                bias=True)
            self.input_encoder_init = nn.ParameterList([
                Parameter(torch.rand(1, self.h_input_encoder_size)),
                Parameter(torch.rand(1, self.h_input_encoder_size))
            ])
            self.output_encoder_cell = nn.LSTMCell(
                input_size=self.v_input + 1 + self.h_input_encoder_size,
                hidden_size=self.h_output_encoder_size,
                bias=True)
            self.output_encoder_init_c = Parameter(
                torch.rand(1, self.h_output_encoder_size))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.h_decoder_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size))

        self.W = nn.Linear(self.h_output_encoder_size + self.h_decoder_size,
                           self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)
        self.input_A = nn.Bilinear(self.h_input_encoder_size,
                                   self.h_output_encoder_size,
                                   1,
                                   bias=False)
        self.output_A = nn.Bilinear(self.h_output_encoder_size,
                                    self.h_decoder_size,
                                    1,
                                    bias=False)
        self.input_EOS = torch.zeros(1, self.v_input + 1)
        self.input_EOS[:, -1] = 1
        self.input_EOS = Parameter(self.input_EOS)
        self.output_EOS = torch.zeros(1, self.v_input + 1)
        self.output_EOS[:, -1] = 1
        self.output_EOS = Parameter(self.output_EOS)
        self.target_EOS = torch.zeros(1, self.v_target + 1)
        self.target_EOS[:, -1] = 1
        self.target_EOS = Parameter(self.target_EOS)
コード例 #9
0
    def __init__(self, word_dict, item_dict, context_dict, output_length, args,
                 device_id):
        super(DialogModel, self).__init__(device_id)

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.args = args

        # embedding for words
        self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word)

        # context encoder
        ctx_encoder_ty = modules.RnnContextEncoder if args.rnn_ctx_encoder \
            else modules.MlpContextEncoder
        self.ctx_encoder = ctx_encoder_ty(len(self.context_dict),
                                          domain.input_length(),
                                          args.nembed_ctx, args.nhid_ctx,
                                          args.init_range, device_id)

        # a reader RNN, to encode words
        self.reader = nn.GRU(input_size=args.nhid_ctx + args.nembed_word,
                             hidden_size=args.nhid_lang,
                             bias=True)
        self.decoder = nn.Linear(args.nhid_lang, args.nembed_word)
        # a writer, a RNNCell that will be used to generate utterances
        self.writer = nn.GRUCell(input_size=args.nhid_ctx + args.nembed_word,
                                 hidden_size=args.nhid_lang,
                                 bias=True)

        # tie the weights of reader and writer
        self.writer.weight_ih = self.reader.weight_ih_l0
        self.writer.weight_hh = self.reader.weight_hh_l0
        self.writer.bias_ih = self.reader.bias_ih_l0
        self.writer.bias_hh = self.reader.bias_hh_l0

        self.dropout = nn.Dropout(args.dropout)

        # a bidirectional selection RNN
        # it will go through input words and generate by the reader hidden states
        # to produce a hidden representation
        self.sel_rnn = nn.GRU(input_size=args.nhid_lang + args.nembed_word,
                              hidden_size=args.nhid_attn,
                              bias=True,
                              bidirectional=True)

        # mask for disabling special tokens when generating sentences
        self.special_token_mask = torch.FloatTensor(len(self.word_dict))

        # attention to combine selection hidden states
        self.attn = nn.Sequential(
            torch.nn.Linear(2 * args.nhid_attn, args.nhid_attn), nn.Tanh(),
            torch.nn.Linear(args.nhid_attn, 1))

        # selection encoder, takes attention output and context hidden and combines them
        self.sel_encoder = nn.Sequential(
            torch.nn.Linear(2 * args.nhid_attn + args.nhid_ctx, args.nhid_sel),
            nn.Tanh())
        # selection decoders, one per each item
        self.sel_decoders = nn.ModuleList()
        for i in range(output_length):
            self.sel_decoders.append(
                nn.Linear(args.nhid_sel, len(self.item_dict)))

        self.init_weights()

        # fill in the mask
        for i in range(len(self.word_dict)):
            w = self.word_dict.get_word(i)
            special = domain.item_pattern.match(w) or w in ('<unk>', 'YOU:',
                                                            'THEM:', '<pad>')
            self.special_token_mask[i] = -999 if special else 0.0

        self.special_token_mask = self.to_device(self.special_token_mask)
コード例 #10
0
    def __init__(self, input_shape, args):
        super(LatentRNNAgent, self).__init__()
        self.args = args
        self.input_shape = input_shape
        self.n_agents = args.n_agents
        self.n_actions = args.n_actions
        self.latent_dim = args.latent_dim
        self.hidden_dim = args.rnn_hidden_dim
        self.bs = 0

        # pi_param = th.rand(args.n_agents)
        # pi_param = pi_param / pi_param.sum()
        # self.pi_param = nn.Parameter(pi_param)

        # mu_param = th.randn(args.n_agents, args.latent_dim)
        # mu_param = mu_param / mu_param.norm(dim=0)
        # self.mu_param = nn.Parameter(mu_param)

        #self.embed_fc_input_size = args.own_feature_size
        self.embed_fc_input_size = input_shape
        #self.embed_fc_input_size=args.n_agents

        self.embed_fc1 = nn.Linear(self.embed_fc_input_size,
                                   args.latent_dim * 4)
        self.embed_fc2 = nn.Linear(args.latent_dim * 4, args.latent_dim * 2)
        self.inference_fc1 = nn.Linear(args.rnn_hidden_dim + input_shape,
                                       args.latent_dim * 4)
        self.inference_fc2 = nn.Linear(args.latent_dim * 4,
                                       args.latent_dim * 2)

        #snail_T = args.snail_max_t
        #snail_input_dim=input_shape
        #if args.obs_agent_id:
        #    snail_input_dim-=args.n_agents
        #snail_key_size = args.snail_key_size
        #snail_value_size = args.snail_value_size
        #snail_filters = args.snail_filters
        #layer_count = math.ceil(math.log(snail_T) / math.log(2))
        #self.infer_mod0=snail.AttentionBlock(snail_input_dim, snail_key_size, snail_value_size) # input_dims, key_size, value_size
        #self.infer_mod1=snail.TCBlock(snail_input_dim+snail_value_size, snail_T, snail_filters) # in_channels, seq_len, filters
        #self.infer_mod2=snail.AttentionBlock(snail_input_dim+snail_value_size+snail_filters*layer_count, snail_key_size, snail_value_size)
        # snail_input_dim+2*snail_value_size+snail_filters*layer_count
        #self.infer_mod3=nn.Conv1d(snail_input_dim+2*snail_value_size+snail_filters*layer_count,self.latent_dim*2,1) # in_channels, out_channels, kernel_size

        self.latent = th.rand(args.n_agents, args.latent_dim * 2)  # (n,mu+var)
        #self.latent0 =  th.rand(args.n_agents, args.latent_dim * 2)

        self.latent_fc1 = nn.Linear(args.latent_dim, args.latent_dim * 4)
        self.latent_fc2 = nn.Linear(args.latent_dim * 4, args.latent_dim * 4)

        self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim)
        self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)
        # self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)

        # self.fc1_w_nn=nn.Linear(args.latent_dim,input_shape*args.rnn_hidden_dim)
        # self.fc1_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim)

        # self.rnn_ih_w_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim*args.rnn_hidden_dim)
        # self.rnn_ih_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim)
        # self.rnn_hh_w_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim*args.rnn_hidden_dim)
        # self.rnn_hh_b_nn=nn.Linear(args.latent_dim,args.rnn_hidden_dim)

        self.fc2_w_nn = nn.Linear(args.latent_dim * 4,
                                  args.rnn_hidden_dim * args.n_actions)
        self.fc2_b_nn = nn.Linear(args.latent_dim * 4, args.n_actions)
コード例 #11
0
    def __init__(
        self,
        n_vocab,
        n_layers,
        n_units,
        n_embed=None,
        typ="lstm",
        dropout_rate=0.5,
        emb_dropout_rate=0.0,
        tie_weights=False,
    ):
        """Initialize class.

        :param int n_vocab: The size of the vocabulary
        :param int n_layers: The number of layers to create
        :param int n_units: The number of units per layer
        :param str typ: The RNN type
        """
        super(RNNLM, self).__init__()
        if n_embed is None:
            n_embed = n_units

        self.embed = nn.Embedding(n_vocab, n_embed)

        if emb_dropout_rate == 0.0:
            self.embed_drop = None
        else:
            self.embed_drop = nn.Dropout(emb_dropout_rate)

        if typ == "lstm":
            self.rnn = nn.ModuleList(
                [nn.LSTMCell(n_embed, n_units)]
                + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)]
            )
        else:
            self.rnn = nn.ModuleList(
                [nn.GRUCell(n_embed, n_units)]
                + [nn.GRUCell(n_units, n_units) for _ in range(n_layers - 1)]
            )

        self.dropout = nn.ModuleList(
            [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)]
        )
        self.lo = nn.Linear(n_units, n_vocab)
        self.n_layers = n_layers
        self.n_units = n_units
        self.typ = typ

        logging.info("Tie weights set to {}".format(tie_weights))
        logging.info("Dropout set to {}".format(dropout_rate))
        logging.info("Emb Dropout set to {}".format(emb_dropout_rate))

        if tie_weights:
            assert (
                n_embed == n_units
            ), "Tie Weights: True need embedding and final dimensions to match"
            self.lo.weight = self.embed.weight

        # initialize parameters from uniform distribution
        for param in self.parameters():
            param.data.uniform_(-0.1, 0.1)
コード例 #12
0
ファイル: image_robustfill.py プロジェクト: zlapp/ec
    def __init__(self,
                 target_vocabulary,
                 hidden_size=512,
                 embedding_size=128,
                 cell_type="LSTM",
                 input_size=(3, 256, 256)):
        """
        :param: input_vocabularies: List containing a vocabulary list for each input. E.g. if learning a function f:A->B from (a,b) pairs, input_vocabularies has length 2
        :param: target_vocabulary: Vocabulary list for output
        """
        super(Image_RobustFill, self).__init__()
        self.n_encoders = 1

        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.input_vocabularies = [None]  #input_vocabularies
        self.target_vocabulary = target_vocabulary
        self._refreshVocabularyIndex()
        self.v_inputs = None  #[len(x) for x in input_vocabularies] # Number of tokens in input vocabularies
        self.v_target = len(
            target_vocabulary)  # Number of tokens in target vocabulary

        self.no_inputs = len(self.input_vocabularies) == 0

        self.cell_type = cell_type
        if cell_type == 'GRU':
            self.encoder_init_h = Parameter(torch.rand(1, self.hidden_size))
            # self.encoder_cells = nn.ModuleList(
            #     [nn.GRUCell(input_size=self.v_inputs[0]+1, hidden_size=self.hidden_size, bias=True)] +
            #     [nn.GRUCell(input_size=self.v_inputs[i]+1+self.hidden_size, hidden_size=self.hidden_size, bias=True) for i in range(1, self.n_encoders)]
            # )
            self.decoder_cell = nn.GRUCell(input_size=self.v_target + 1,
                                           hidden_size=self.hidden_size,
                                           bias=True)
        if cell_type == 'LSTM':
            self.encoder_init_h = Parameter(
                torch.rand(1, self.hidden_size
                           ))  #Also used for decoder if self.no_inputs=True
            # self.encoder_init_cs = nn.ParameterList(
            #     [Parameter(torch.rand(1, self.hidden_size)) for i in range(len(self.v_inputs))]
            # )

            # self.encoder_cells = nn.ModuleList()
            # for i in range(self.n_encoders):
            #     input_size = self.v_inputs[i] + 1 + (self.hidden_size if i>0 else 0)
            #     self.encoder_cells.append(nn.LSTMCell(input_size=input_size, hidden_size=self.hidden_size, bias=True))
            self.decoder_cell = nn.LSTMCell(input_size=self.v_target + 1,
                                            hidden_size=self.hidden_size,
                                            bias=True)
            self.decoder_init_c = Parameter(torch.rand(1, self.hidden_size))

        self.W = nn.Linear(
            self.hidden_size if self.no_inputs else 2 * self.hidden_size,
            self.embedding_size)
        self.V = nn.Linear(self.embedding_size, self.v_target + 1)

        #self.As = nn.ModuleList([nn.Bilinear(self.hidden_size, self.hidden_size, 1, bias=False) for i in range(self.n_encoders)])

        #image encoder:
        self.conv1 = nn.Conv2d(3,
                               8,
                               kernel_size=(3, 3),
                               padding=(3, 3),
                               stride=(1, 1))
        self.conv2 = nn.Conv2d(8,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        self.conv3 = nn.Conv2d(16,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        self.conv4 = nn.Conv2d(16,
                               16,
                               kernel_size=(3, 3),
                               padding=(1, 1),
                               stride=(1, 1))
        #self.conv4 = nn.Conv2d(256, 512, kernel_size=(3, 3),
        #                        padding=(1, 1), stride=(1, 1))
        self.batch_norm1 = nn.BatchNorm2d(8)
        self.batch_norm2 = nn.BatchNorm2d(16)

        self.img_feat_to_embedding = nn.Sequential(
            nn.Linear(16 * 16 * 16, 64), nn.ReLU(), nn.Linear(64, 64),
            nn.ReLU(), nn.Linear(64, self.hidden_size))

        #attention params:
        self.h_to_32_linear = nn.Linear(self.hidden_size, 32)
        self.img_to_32 = nn.Linear(16 * 16 * 16, 32)

        self.fc_loc = nn.Linear(32 + 32, 3 * 2)
        self.fc_loc.weight.data.zero_()
        self.fc_loc.bias.data.copy_(
            torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        self.img_feat_to_context = nn.Sequential(
            nn.Linear(16 * 16 * 16, 128), nn.ReLU(), nn.Linear(128, 128),
            nn.ReLU(), nn.Linear(128, self.hidden_size))
コード例 #13
0
 def __init__(self, extra_layers_encode):
     super(pruning_model, self).__init__()
     self.extra_layers_encode = extra_layers_encode
     self.rnncell = nn.GRUCell(256, 4, bias=True)
     self._initialize_weights()
     self.relu = nn.ReLU()
コード例 #14
0
ファイル: models.py プロジェクト: qkqkfldis1/SeqGan_study
    def __init__(self, n_atom_feature, n_edge_feature):
        super(IntraNet, self).__init__()

        self.C = nn.GRUCell(n_atom_feature, n_atom_feature)
        self.cal_message = nn.Sequential(
            nn.Linear(n_atom_feature * 2 + n_edge_feature, n_atom_feature), )
コード例 #15
0
 def test_gru_cell(self):
     model = nn.GRUCell(RNN_INPUT_SIZE, RNN_HIDDEN_SIZE)
     input = torch.randn(BATCH_SIZE, RNN_INPUT_SIZE)
     h0 = torch.randn(BATCH_SIZE, RNN_HIDDEN_SIZE)
     self.run_model_test(model, train=False, batch_size=BATCH_SIZE, input=(input, h0), use_gpu=False)
コード例 #16
0
 def __init__(self, insize, hidden_size, out_size, drop):
     super(BahdanauDecoder, self).__init__()
     self.cell = nn.GRUCell(embedding_size + hidden_size, hidden_size)
     self.concat = nn.Linear(hidden_size * 2, out_size)
     self.drop = nn.Dropout(drop)
コード例 #17
0
    def __init__(self,
                 node_features,
                 edge_features,
                 message_passes,
                 out_features,
                 edge_embedding_size,
                 edge_emb_depth=3,
                 edge_emb_hidden_dim=150,
                 edge_emb_dropout_p=0.0,
                 att_depth=3,
                 att_hidden_dim=80,
                 att_dropout_p=0.0,
                 msg_depth=3,
                 msg_hidden_dim=80,
                 msg_dropout_p=0.0,
                 gather_width=100,
                 gather_att_depth=3,
                 gather_att_hidden_dim=100,
                 gather_att_dropout_p=0.0,
                 gather_emb_depth=3,
                 gather_emb_hidden_dim=100,
                 gather_emb_dropout_p=0.0,
                 out_depth=2,
                 out_hidden_dim=100,
                 out_dropout_p=0,
                 out_layer_shrinkage=1.0):
        super(EMNImplementation,
              self).__init__(edge_features, edge_embedding_size,
                             message_passes, out_features)
        self.embedding_nn = FeedForwardNetwork(
            node_features * 2 + edge_features,
            [edge_emb_hidden_dim] * edge_emb_depth,
            edge_embedding_size,
            dropout_p=edge_emb_dropout_p)

        self.emb_msg_nn = FeedForwardNetwork(edge_embedding_size,
                                             [msg_hidden_dim] * msg_depth,
                                             edge_embedding_size,
                                             dropout_p=msg_dropout_p)
        self.att_msg_nn = FeedForwardNetwork(edge_embedding_size,
                                             [att_hidden_dim] * att_depth,
                                             edge_embedding_size,
                                             dropout_p=att_dropout_p)

        #self.extra_gru_layer = nn.Linear(edge_embedding_size, edge_embedding_size, bias=False)
        self.gru = nn.GRUCell(edge_embedding_size,
                              edge_embedding_size,
                              bias=False)
        self.gather = GraphGather(edge_embedding_size, gather_width,
                                  gather_att_depth, gather_att_hidden_dim,
                                  gather_att_dropout_p, gather_emb_depth,
                                  gather_emb_hidden_dim, gather_emb_dropout_p)
        out_layer_sizes = [  # example: depth 5, dim 50, shrinkage 0.5 => out_layer_sizes [50, 42, 35, 30, 25]
            round(out_hidden_dim *
                  (out_layer_shrinkage**(i / (out_depth - 1 + 1e-9))))
            for i in range(out_depth)
        ]
        self.out_nn = FeedForwardNetwork(gather_width,
                                         out_layer_sizes,
                                         out_features,
                                         dropout_p=out_dropout_p)
コード例 #18
0
def experiment(exp_specs):
    ptu.set_gpu_mode(exp_specs['use_gpu'])
    # Set up logging ----------------------------------------------------------
    exp_id = exp_specs['exp_id']
    exp_prefix = exp_specs['exp_name']
    seed = exp_specs['seed']
    set_seed(seed)
    setup_logger(exp_prefix=exp_prefix, exp_id=exp_id, variant=exp_specs)

    # Prep the data -----------------------------------------------------------
    replay_dict = joblib.load(exp_specs['replay_dict_path'])
    next_obs_array = replay_dict['next_observations']
    acts_array = replay_dict['actions']
    data_loader = BasicDataLoader(
        next_obs_array[:40000], acts_array[:40000], exp_specs['episode_length'], exp_specs['batch_size'], use_gpu=ptu.gpu_enabled())
    val_data_loader = BasicDataLoader(
        next_obs_array[40000:], acts_array[40000:], exp_specs['episode_length'], exp_specs['batch_size'], use_gpu=ptu.gpu_enabled())

    # Model Definition --------------------------------------------------------
    conv_channels = 64
    conv_encoder = nn.Sequential(
        nn.Conv2d(3, conv_channels, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(conv_channels),
        nn.ReLU(),
        nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False),
        nn.BatchNorm2d(conv_channels),
        nn.ReLU()
    )
    ae_dim = 128
    gru_dim = 512
    img_h = 5
    flat_inter_img_dim = img_h * img_h * conv_channels
    fc_encoder = nn.Sequential(
        nn.Linear(flat_inter_img_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU()
    )
    gru = nn.GRUCell(
        ae_dim, gru_dim, bias=True
    )
    fc_decoder = nn.Sequential(
        nn.Linear(gru_dim + 4, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, ae_dim, bias=False),
        nn.BatchNorm1d(ae_dim),
        nn.ReLU(),
        nn.Linear(ae_dim, flat_inter_img_dim, bias=False),
        nn.BatchNorm1d(flat_inter_img_dim),
        nn.ReLU(),
    )
    conv_decoder = nn.Sequential(
        nn.ConvTranspose2d(conv_channels, conv_channels, 1, stride=1, padding=0, output_padding=0, bias=False),
        nn.BatchNorm2d(conv_channels),
        nn.ReLU(),
        nn.ConvTranspose2d(conv_channels, conv_channels, 1, stride=1, padding=0, output_padding=0, bias=False),
        nn.BatchNorm2d(conv_channels),
        nn.ReLU(),
        nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True),
        nn.Sigmoid()
    )
    if ptu.gpu_enabled():
        conv_encoder.cuda()
        fc_encoder.cuda()
        gru.cuda()
        fc_decoder.cuda()
        conv_decoder.cuda()

    # Optimizer ---------------------------------------------------------------
    model_optim = Adam(
        [
            item for sublist in
            map(
                lambda x: list(x.parameters()),
                [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder]
            )
            for item in sublist
        ],
        lr=float(exp_specs['model_lr']), weight_decay=float(exp_specs['model_wd'])
    )

    # -------------------------------------------------------------------------
    freq_bptt = exp_specs['freq_bptt']
    episode_length = exp_specs['episode_length']
    losses = []
    for iter_num in range(int(float(exp_specs['max_iters']))):
        if iter_num % freq_bptt == 0:
            if iter_num > 0:
                # loss = loss / freq_bptt
                loss.backward()
                model_optim.step()
                prev_h_batch = prev_h_batch.detach()
            loss = 0
        if iter_num % episode_length == 0:
            prev_h_batch = Variable(torch.zeros(exp_specs['batch_size'], gru_dim))
            if ptu.gpu_enabled():
                prev_h_batch = prev_h_batch.cuda()
            
            if iter_num % exp_specs['freq_val'] == 0:
                train_loss_print = '\t'.join(losses)
            losses = []

        obs_batch, act_batch = data_loader.get_next_batch()
        recon = fc_decoder(torch.cat([prev_h_batch, act_batch], 1)).view(obs_batch.size(0), conv_channels, img_h, img_h)
        recon = conv_decoder(recon)

        enc = conv_encoder(obs_batch).view(obs_batch.size(0), -1)
        enc = fc_encoder(enc)
        prev_h_batch = gru(enc, prev_h_batch)

        losses.append('%.4f' % ((obs_batch - recon)**2).mean())
        if iter_num % episode_length != 0:
            loss = loss + ((obs_batch - recon)**2).sum()/float(exp_specs['batch_size'])

        if iter_num % (50*episode_length) in range(2*episode_length):
            save_pytorch_tensor_as_img(recon[0].data.cpu(), 'junk_vis/fixed_colors_simple_maze_5_h/rnn_recon_%d.png' % iter_num)
            save_pytorch_tensor_as_img(obs_batch[0].data.cpu(), 'junk_vis/fixed_colors_simple_maze_5_h/rnn_obs_%d.png' % iter_num)

        if iter_num % exp_specs['freq_val'] == 0:
            print('\nValidating Iter %d...' % iter_num)
            list(map(lambda x: x.eval(), [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder]))

            val_prev_h_batch = Variable(torch.zeros(exp_specs['batch_size'], gru_dim))
            if ptu.gpu_enabled():
                val_prev_h_batch = val_prev_h_batch.cuda()

            losses = []            
            for i in range(episode_length):
                obs_batch, act_batch = data_loader.get_next_batch()
                
                recon = fc_decoder(torch.cat([val_prev_h_batch, act_batch], 1)).view(obs_batch.size(0), conv_channels, img_h, img_h)
                recon = conv_decoder(recon)

                enc = conv_encoder(obs_batch).view(obs_batch.size(0), -1)
                enc = fc_encoder(enc)
                val_prev_h_batch = gru(enc, val_prev_h_batch)

                losses.append('%.4f' % ((obs_batch - recon)**2).mean())

            loss_print = '\t'.join(losses)
            print('Val MSE:\t' + loss_print)
            print('Train MSE:\t' + train_loss_print)

            list(map(lambda x: x.train(), [fc_encoder, conv_encoder, gru, fc_decoder, conv_decoder]))            
コード例 #19
0
    def __init__(self,
                 msg_dim,
                 node_state_dim,
                 edge_feat_dim,
                 num_prop=1,
                 num_layer=1,
                 has_attention=True,
                 att_hidden_dim=128,
                 has_residual=False,
                 has_graph_output=False,
                 output_hidden_dim=128,
                 graph_output_dim=None):
        super(GNN, self).__init__()
        self.msg_dim = msg_dim
        self.node_state_dim = node_state_dim
        self.edge_feat_dim = edge_feat_dim
        self.num_prop = num_prop
        self.num_layer = num_layer
        self.has_attention = has_attention
        self.has_residual = has_residual
        self.att_hidden_dim = att_hidden_dim
        self.has_graph_output = has_graph_output
        self.output_hidden_dim = output_hidden_dim
        self.graph_output_dim = graph_output_dim

        self.update_func = nn.ModuleList([
            nn.GRUCell(input_size=self.msg_dim,
                       hidden_size=self.node_state_dim)
            for _ in range(self.num_layer)
        ])

        self.msg_func = nn.ModuleList([
            nn.Sequential(*[
                nn.Linear(self.node_state_dim +
                          self.edge_feat_dim, self.msg_dim),
                nn.ReLU(),
                nn.Linear(self.msg_dim, self.msg_dim)
            ]) for _ in range(self.num_layer)
        ])

        if self.has_attention:
            self.att_head = nn.ModuleList([
                nn.Sequential(*[
                    nn.Linear(self.node_state_dim +
                              self.edge_feat_dim, self.att_hidden_dim),
                    nn.ReLU(),
                    nn.Linear(self.att_hidden_dim, self.msg_dim),
                    nn.Sigmoid()
                ]) for _ in range(self.num_layer)
            ])

        if self.has_graph_output:
            self.graph_output_head_att = nn.Sequential(*[
                nn.Linear(self.node_state_dim, self.output_hidden_dim),
                nn.ReLU(),
                nn.Linear(self.output_hidden_dim, 1),
                nn.Sigmoid()
            ])

            self.graph_output_head = nn.Sequential(
                *[nn.Linear(self.node_state_dim, self.graph_output_dim)])
コード例 #20
0
 def __init__(self, dim):
     super().__init__()
     self.gru = nn.GRUCell(dim, dim)
コード例 #21
0
ファイル: memory_updater.py プロジェクト: zstoebs/tgn
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(GRUMemoryUpdater, self).__init__(memory, message_dimension,
                                               memory_dimension, device)

        self.memory_updater = nn.GRUCell(input_size=message_dimension,
                                         hidden_size=memory_dimension)
コード例 #22
0
 def __init__(self, input_size, hidden_size, bias=True):
     cell = nn.GRUCell(input_size, hidden_size, bias=bias)
     super().__init__(cell)
コード例 #23
0
def GRUCell(input_size, hidden_size, **kwargs):
    m = nn.GRUCell(input_size, hidden_size)
    for name, param in m.named_parameters():
        if "weight" in name or "bias" in name:
            param.data.uniform_(-0.1, 0.1)
    return m
コード例 #24
0
ファイル: ggnn.py プロジェクト: knowledgeresearch/kaner
 def __init__(self, feat_dim: int, n_edge_types: int):
     super(Propogator, self).__init__()
     self.feat_dim = feat_dim
     self.n_edge_types = n_edge_types
     self.gru = nn.GRUCell(feat_dim * n_edge_types, feat_dim)
コード例 #25
0
 def __init__(self, **kwargs):
     super(MessageModule, self).__init__()
     h_dims = kwargs['h_dims']
     g_dims = kwargs['g_dims']
     self.to_state = nn.GRUCell(h_dims + g_dims, h_dims)
     INIT.orthogonal_(self.to_state.weight_hh)
コード例 #26
0
    def __init__(self):
        super().__init__()
        conv_channels = 32
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(3, conv_channels, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(conv_channels), nn.ReLU(),
            nn.Conv2d(conv_channels,
                      conv_channels,
                      4,
                      stride=2,
                      padding=1,
                      bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(),
            nn.Conv2d(conv_channels,
                      conv_channels,
                      1,
                      stride=1,
                      padding=0,
                      bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU(),
            nn.Conv2d(conv_channels,
                      conv_channels,
                      1,
                      stride=1,
                      padding=0,
                      bias=False), nn.BatchNorm2d(conv_channels), nn.ReLU())
        ae_dim = 512
        lstm_dim = 512
        self.lstm_dim = lstm_dim
        img_h = 5
        flat_inter_img_dim = img_h * img_h * conv_channels
        act_dim = 64

        self.conv_channels = conv_channels
        self.img_h = img_h

        self.lstm_act_proc_fc = nn.Linear(4, act_dim, bias=True)
        self.recon_act_proc_fc = nn.Linear(4, act_dim, bias=True)
        self.mask_act_proc_fc = nn.Linear(4, act_dim, bias=True)

        self.attention_seq = nn.Sequential(
            nn.Linear(lstm_dim + act_dim, lstm_dim, bias=False),
            nn.BatchNorm1d(lstm_dim),
            nn.ReLU(),
            nn.Linear(lstm_dim, lstm_dim),
            # nn.Sigmoid()
            # nn.Softmax()
        )

        self.fc_encoder = nn.Sequential(
            nn.Linear(flat_inter_img_dim + act_dim, ae_dim, bias=False),
            nn.BatchNorm1d(ae_dim),
            nn.ReLU(),
            # nn.Linear(ae_dim, ae_dim, bias=False),
            # nn.BatchNorm1d(ae_dim),
            # nn.ReLU(),
            # nn.Linear(ae_dim, ae_dim, bias=False),
            # nn.BatchNorm1d(ae_dim),
            # nn.ReLU(),
            # nn.Linear(ae_dim, ae_dim, bias=False),
            # nn.BatchNorm1d(ae_dim),
            # nn.ReLU()
        )
        self.lstm = nn.GRUCell(ae_dim, lstm_dim, bias=True)
        self.fc_decoder = nn.Sequential(
            nn.Linear(lstm_dim + act_dim, flat_inter_img_dim, bias=False),
            nn.BatchNorm1d(flat_inter_img_dim),
            nn.ReLU(),
            # nn.Linear(ae_dim, ae_dim, bias=False),
            # nn.BatchNorm1d(ae_dim),
            # nn.ReLU(),
            # # nn.Linear(ae_dim, ae_dim, bias=False),
            # # nn.BatchNorm1d(ae_dim),
            # # nn.ReLU(),
            # # nn.Linear(ae_dim, ae_dim, bias=False),
            # # nn.BatchNorm1d(ae_dim),
            # # nn.ReLU(),
            # nn.Linear(ae_dim, flat_inter_img_dim, bias=False),
            # nn.BatchNorm1d(flat_inter_img_dim),
            # nn.ReLU(),
        )
        self.conv_decoder = nn.Sequential(
            nn.ConvTranspose2d(conv_channels,
                               conv_channels,
                               4,
                               stride=2,
                               padding=1,
                               output_padding=0,
                               bias=False),
            nn.BatchNorm2d(conv_channels),
            nn.ReLU(),
            # nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False),
            # nn.BatchNorm2d(conv_channels),
            # nn.ReLU(),
            nn.ConvTranspose2d(conv_channels,
                               conv_channels,
                               4,
                               stride=2,
                               padding=1,
                               output_padding=0,
                               bias=False),
            nn.BatchNorm2d(conv_channels),
            nn.ReLU(),
            # nn.Conv2d(conv_channels, conv_channels, 1, stride=1, padding=0, bias=False),
            # nn.BatchNorm2d(conv_channels),
            # nn.ReLU(),
        )
        self.mean_decoder = nn.Sequential(
            nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True),
            nn.Sigmoid())
        self.log_cov_decoder = nn.Sequential(
            nn.Conv2d(conv_channels, 3, 1, stride=1, padding=0, bias=True), )
コード例 #27
0
    def __init__(self, opt, padding_idx=0, embedding=None):
        super().__init__(opt, padding_idx, embedding)

        # RNN document encoder
        self.doc_rnn = layers.StackedBRNN(
            input_size=self.doc_input_size,
            hidden_size=opt['hidden_size'],
            num_layers=2,
            dropout_rate=opt['dropout_rnn'],
            # dropout_output=opt['dropout_rnn_output'],
            variational_dropout=opt['variational_dropout'],
            concat_layers=True,
            rnn_type=opt['rnn_type'],
            padding=opt['rnn_padding'],
            residual=opt['residual'],
            squeeze_excitation=opt['squeeze_excitation'],
        )

        # RNN question encoder
        self.question_rnn = layers.StackedBRNN(
            input_size=self.question_input_size,
            hidden_size=opt['hidden_size'],
            num_layers=2,
            dropout_rate=opt['dropout_rnn'],
            # dropout_output=opt['dropout_rnn_output'],
            variational_dropout=opt['variational_dropout'],
            concat_layers=True,
            rnn_type=opt['rnn_type'],
            padding=opt['rnn_padding'],
            residual=opt['residual'],
            squeeze_excitation=opt['squeeze_excitation'],
        )

        # Output sizes of rnn encoders
        doc_hidden_size = 2 * 2 * opt['hidden_size']
        question_hidden_size = doc_hidden_size

        self.question_urnn = layers.StackedBRNN(
            input_size=question_hidden_size,
            hidden_size=opt['hidden_size'],
            num_layers=opt['fusion_understanding_layers'],
            dropout_rate=opt['dropout_rnn'],
            variational_dropout=opt['variational_dropout'],
            rnn_type=opt['rnn_type'],
            padding=opt['rnn_padding'],
            residual=opt['residual'],
            squeeze_excitation=opt['squeeze_excitation'],
            concat_layers=False,
        )

        self.multi_level_fusion = layers.FullAttention(
            full_size=self.paired_input_size + doc_hidden_size,
            hidden_size=2 * 3 * opt['hidden_size'],
            num_level=3,
            dropout=opt['dropout_rnn'],
            variational_dropout=opt['variational_dropout'],
        )

        self.doc_urnn = layers.StackedBRNN(
            input_size=2 * 5 * opt['hidden_size'],
            hidden_size=opt['hidden_size'],
            num_layers=opt['fusion_understanding_layers'],
            dropout_rate=opt['dropout_rnn'],
            variational_dropout=opt['variational_dropout'],
            rnn_type=opt['rnn_type'],
            padding=opt['rnn_padding'],
            residual=opt['residual'],
            squeeze_excitation=opt['squeeze_excitation'],
            concat_layers=False,
        )

        self.self_boost_fusions = nn.ModuleList()
        self.doc_final_rnns = nn.ModuleList()
        full_size = self.paired_input_size + 4 * 3 * opt['hidden_size']
        for i in range(self.opt['fusion_self_boost_times']):
            self.self_boost_fusions.append(
                layers.FullAttention(
                    full_size=full_size,
                    hidden_size=2 * opt['hidden_size'],
                    num_level=1,
                    dropout=opt['dropout_rnn'],
                    variational_dropout=opt['variational_dropout'],
                ))

            self.doc_final_rnns.append(
                layers.StackedBRNN(
                    input_size=4 * opt['hidden_size'],
                    hidden_size=opt['hidden_size'],
                    num_layers=opt['fusion_final_layers'],
                    dropout_rate=opt['dropout_rnn'],
                    variational_dropout=opt['variational_dropout'],
                    rnn_type=opt['rnn_type'],
                    padding=opt['rnn_padding'],
                    residual=opt['residual'],
                    squeeze_excitation=opt['squeeze_excitation'],
                    concat_layers=False,
                ))
            full_size += 2 * opt['hidden_size']

        # Question merging
        if opt['question_merge'] not in ['avg', 'self_attn']:
            raise NotImplementedError('question_merge = %s' %
                                      opt['question_merge'])
        if opt['question_merge'] == 'self_attn':
            self.quesiton_merge_attns = nn.ModuleList()

        # Question merging
        if opt['question_merge'] not in ['avg', 'self_attn']:
            raise NotImplementedError('question_merge = %s' %
                                      opt['question_merge'])
        if opt['question_merge'] == 'self_attn':
            self.self_attn = layers.LinearSeqAttn(2 * opt['hidden_size'])

        # Bilinear attention for span start/end
        self.start_attn = layers.BilinearSeqAttn(
            2 * opt['hidden_size'],
            2 * opt['hidden_size'],
        )

        if opt['end_gru']:
            self.end_gru = nn.GRUCell(2 * opt['hidden_size'],
                                      2 * opt['hidden_size'])

        self.end_attn = layers.BilinearSeqAttn(
            2 * opt['hidden_size'],
            2 * opt['hidden_size'],
        )
コード例 #28
0
    def __init__(self,
                 ctx_size_dict,
                 z_size,
                 z_len=10,
                 z_transform=None,
                 z_in_size=256,
                 z_merge='sum',
                 z_init='mean_ctx',
                 att_type='mlp',
                 att_activ='tanh',
                 att_bottleneck='ctx',
                 att_temp=1.0,
                 att_transform_ctx=False,
                 mlp_bias=False,
                 hiero_mid_dim=128):
        super().__init__()

        self.ctx_size_dict = ctx_size_dict
        self.z_size = z_size
        self.z_len = z_len
        self.z_transform = z_transform.lower() if z_transform else None
        self.z_in_size = z_in_size
        self.z_merge = z_merge

        # Other arguments
        self.att_type = att_type
        self.att_bottleneck = att_bottleneck
        self.att_activ = att_activ
        self.att_temp = att_temp
        self.att_transform_ctx = att_transform_ctx
        self.mlp_bias = mlp_bias
        self.z_init = z_init
        self.hiero_mid_dim = hiero_mid_dim

        # Safety check
        self._sanity_check()

        # Create FF layers to manage different context size...
        # Each layer maps ctx_size_dict[k] to z_in_size (==ctx_size)
        # z_transform tells the kind of (non-)linearity to use
        if self.z_transform:
            self.z_transforms = nn.ModuleDict()
            for k in self.ctx_size_dict:
                self.z_transforms[k] = FF(self.ctx_size_dict[k],
                                          self.z_in_size,
                                          activ=z_transform)
                self.ctx_size_dict[k] = self.z_in_size
            self.ctx_size = self.z_in_size
        else:
            s = set([size for size in self.ctx_size_dict.values()])
            assert len(set(s)) == 1, \
                "Incompatible encoding sizes, consider using z_transform:tanh in config."
            self.ctx_size = next(iter(s))

        # Create an attention layer for each modality
        # TODO: sharing weights between att. mechanisms is possible
        self.att = nn.ModuleDict()
        # Fetch correct attention class
        Attention = get_attention(self.att_type)
        for k in self.ctx_size_dict:
            att_in_size = self.ctx_size if self.z_transform else self.ctx_size_dict[
                k]
            self.att[k] = Attention(att_in_size,
                                    self.z_size,
                                    transform_ctx=self.att_transform_ctx,
                                    mlp_bias=self.mlp_bias,
                                    att_activ=self.att_activ,
                                    att_bottleneck=self.att_bottleneck,
                                    temp=self.att_temp,
                                    ctx2hid=False)

        # Fusion operation
        if self.z_merge == 'hierarchical':
            self.hiero_att = HierarchicalAttention(
                [self.ctx_size_dict[k] for k in self.ctx_size_dict.keys()],
                self.z_size, self.hiero_mid_dim)
            self.merge_op = self._merge_hierarchical
        else:
            self.merge_op = self._merge_sum

        # Create decoder layer necessary for attention
        self.dec = nn.GRUCell(self.ctx_size, self.z_size)

        # Several strategies to initialize the decoder can be considered
        # Set decoder initializer
        self._init_func = getattr(self, '_rnn_init_{}'.format(self.z_init))

        # if init is not zero, then create FF layer
        if self.z_init != 'zero':
            self.ff_z_init = FF(self.ctx_size, self.z_size, activ='tanh')
コード例 #29
0
ファイル: dynamics.py プロジェクト: km01/myrl
 def __init__(self, input_size, hidden_size):
     super().__init__()
     self.input_size, self.hidden_size = input_size, hidden_size
     self.dynamics = nn.GRUCell(input_size=input_size, hidden_size=hidden_size)
コード例 #30
0
def initGRUCell(input_size,
				hidden_size,
				**kwargs):
	
	model = nn.GRUCell(input_size, 
    					hidden_size,
    					**kwargs)

	for name, param in model.named_parameters():

        if 'weight' in name or 'bias' in name:
            param.data.uniform_(-0.1, 0.1)

    return model


# Attention: attention module
class Attention(nn.Module):
    def __init__(self, 
    			 hidden_size, 
    			 attn_size):

        super(Attention, self).__init__()

        self.hidden_size = hidden_size
        self.attn_size = attn_size

        self.linear_layer1 = nn.Linear(self.hidden_size, self.attn_size)

        self.linear_layer2 = nn.Linear(self.hidden_size + self.attn_size, self.attn_size)
        
    def forward(self, 
    			hidden, 
    			encoder_outs, 
    			source_lengths):

    	# hidden_size -> attn_size
        attn_hidden = self.linear_layer1(hidden)

        # get scores
        attn_score = torch.sum((encoder_outs.transpose(0,1) * attn_hidden.unsqueeze(0)),2)

        attn_mask = torch.transpose(seq_mask(source_lengths, 
        							max_len = max(source_lengths).item()),
        							0,1)

        masked_attn = attn_mask*attn_score
        masked_attn[masked_attn==0] = -1e10

        # softmax over attention to get weights
        attn_scores = F.softmax(masked_attn, dim=0)
        # compute weighted sum according to attention scores
        attn_hidden = torch.sum(attn_scores.unsqueeze(2)*encoder_outs.transpose(0,1), 0)

        attn_hidden = self.linear_layer2(torch.cat((attn_hidden, hidden), dim=1))
        attn_hidden = torch.tanh(attn_hidden)

        return attn_hidden, attn_scores

# AttnDecoderRNN
class AttnDecoderRNN(nn.Module):

    def __init__(self, 
    			 vocab_size, 
    			 embed_size, 
    			 hidden_size, 
    			 num_rnn_layers = 1, 
    			 attention = True,
    			 dropout_percent=0.1):

        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        encoder_output_size = self.hidden_size

        self.embedding = nn.Embedding(vocab_size, 
        							  embed_size, 
        							  PAD_IDX)

        self.dropout_f = nn.Dropout(p=dropout_percent)

        self.num_layers = num_rnn_layers

        if attention:
        	self.attention = Attention(self.hidden_size, 
        						   	   encoder_output_size)
        else:
        	self.attention = None

        self.layers = nn.ModuleList([initLSTMCell(input_size=self.hidden_size+self.embed_size if ((layer == 0) and attention) \
        									  else self.embed_size if layer == 0 else self.hidden_size,
                							  hidden_size=self.hidden_size,)for layer in range(self.num_layers)])

        self.linear_layer = nn.Linear(self.hidden_size, vocab_size)
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, 
                decoder_inputs,
                context, 
                prev_hiddens,
                prev_context,
                encoder_outputs,
                source_lengths):
        
        batch_size = decoder_inputs.size(0)

        # embed
        embed_target = self.embedding(decoder_inputs)
        out = self.dropout_f(embed_target)
        
        if self.attention is not None:
            input_ = torch.cat([out.squeeze(1), context], dim = 1)
        else:
            input_ = out.squeeze(1)

        context_ = []
        decoder_hiddens_ = []

        for layer, rnn in enumerate(self.layers):
            hidden, con = rnn(input_, (prev_hiddens[layer], 
            						   prev_context[layer]))
            input_ = self.dropout_f(hidden)
            decoder_hiddens_.append(hidden.unsqueeze(0))
            context_.append(con.unsqueeze(0))

        decoder_hiddens_ = torch.cat(decoder_hiddens_, dim = 0)
        context_ = torch.cat(context_, dim = 0)

        if self.attention is not None:
            out, attn_score = self.attention(hidden, 
            								 encoder_outputs, 
            								 source_lengths)
        else:
            out = hidden
            attn_score = None

        context_vec = out
        out = self.dropout_f(out)

        # linear: hidden_size -> vocab_size
        deco_out = self.linear_layer(out)
        deco_out = self.log_softmax(deco_out)

        return out_vocab, context_vec, decoder_hiddens_, context_, attn_score


# CNNencoder
class CNNencoder(nn.Module):

    def __init__(self, 
                 vocab_size, 
                 embed_size, 
                 hidden_size, 
                 kernel_size, 
                 num_layers,
                 percent_dropout=0.3):
        
        super(CNNencoder, self).__init__()

        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.kernel_size = kernel_size
        self.embed_size = embed_size
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(self.vocab_size, 
                                      self.embed_size, 
                                      padding_idx=0)
        
        self.dropout_f = nn.Dropout(percent_dropout)
        
        in_channels = self.embed_size
        
        self.conv = nn.Conv1d(in_channels, 
        					  self.hidden_size, 
        					  kernel_size, 
                              padding=kernel_size//2)
        
        # todo
        self.conv2 = nn.Conv1d(60, self.hidden_size, kernel_size,
        					   padding=kernel_size//2)
        
        self.ReLU = nn.ReLU()

    def forward(self, source_sentence):
        
        batch_size, seq_len = source_sentence.size()
        
        embeds_source = self.embedding(source_sentence)
        
        out = self.conv(embeds_source.transpose(1, 2)).transpose(1,2)
        out = self.ReLU(out)
        out = F.max_pool1d(out, kernel_size=5, stride=5)
        
        out = self.conv2(out.transpose(1, 2)).transpose(1,2)
        out = self.ReLU(out)
        out = torch.mean(out, dim=1).view(1, batch_size, self.hidden_size)
    
        return out


class LSTMencoder(nn.Module):

    def __init__(self, 
    			 input_size, 
    			 embed_size, 
    			 hidden_size,
    			 num_lstm_layers):

        super(LSTMencoder, self).__init__()
        self.hidden_size = hidden_size
        self.embed_size = embed_size

        self.embedding = Embedding(input_size, 
        						   self.embed_size, 
        						   padding_idx=0)

        self.dropout_ = nn.Dropout(p = 0.1)
        self.num_layers = num_lstm_layers

        self.lstm = LSTM(self.embed_size, self.hidden_size, 
        				 batch_first=True, bidirectional=True, 
        				 num_layers = self.num_layers, 
        				 dropout = 0.15)

    def initHidden(self, batch_size):
        return torch.zeros(self.num_layers*2,
                           batch_size,
                           self.hidden_size).to(device),\
               torch.zeros(self.num_layers*2,
                           batch_size,
                           self.hidden_size).to(device)

    def forward(self, 
    			encoder_inputs, 
    			source_lengths):

        sort_original_source = torch.sort(source_lengths, descending=True)[1]
        unsort_to_original_source = torch.sort(sort_original_source)[1]

        embeds_source = self.embedding(encoder_inputs)
        
        lstm_out = self.dropout_(embeds_source)

        batch_size, seq_len = embeds_source.size()

        hidden, context = self.initHidden(batch_size)
        sorted_output = lstm_out[sort_original_source]
        sorted_len = source_lengths[sort_original_source]

        packed_output = nn.utils.rnn.pack_padded_sequence(sorted_output, 
                                                          sorted_lengths, 
                                                          batch_first = True)

        packed_outs, (hiddden, context) = self.lstm(packed_output,(hidden, context))
        hidden = hidden[:,unsort_to_original_source,:]
        context = context[:,unsort_to_original_source,:]

        lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, 
        												padding_value=PAD_IDX, 
        												batch_first = True)
        # UNSORT OUTPUT
        lstm_out = lstm_out[unsort_to_original_source]
        hidden = hidden.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous().view(self.num_layers, batch_size, -1)
        context = context.view(self.num_layers, 2, batch_size, -1).transpose(1, 2).contiguous().view(self.num_layers, batch_size, -1)

        return output, hidden, context