コード例 #1
0
 def __init__(self, configs):
     super(MMNIST_ConvLSTM, self).__init__()
     _KEYS = ['encoder_configs', 'reconstruct_configs', 'predict_configs']
     en_conf, rec_conf, pred_conf = unpack(configs, _KEYS)
     self.encoder = ConvLSTM(en_conf)
     self.reconstructor = Generator(rec_conf)
     self.predictor = Generator(pred_conf)
コード例 #2
0
 def __init__(self, configs):
     super(ConvLSTM, self).__init__()
     # 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG'
     _KEYS = ['num_layers', 'cell_configs']
     num_layers, cell_configs = unpack(configs, _KEYS)
     cells = [ConvLSTMCell(cell_configs[idx]) for idx in xrange(num_layers)]
     self.cell_list = nn.ModuleList(cells)
     self.num_layers = num_layers
     self.cell_config = cell_configs[0]
コード例 #3
0
    def forward(self, data, configs=None):
        x_train, x_predict, states = unpack(data,
                                            ['x_train', 'x_predict', 'states'])
        use_gt, max_steps = unpack(configs, ['use_gt', 'max_steps'])
        # x: batch_size * time_steps * channels * height * width
        # states: batch_size * channels' * height * width
        batch_size = x_train.size(0)
        time_steps = x_train.size(1)

        encoder = self.encoder
        reconstructor = self.reconstructor
        predictor = self.predictor

        # encoding stages
        en_data = pack([x_train, encoder.init_hidden(batch_size)],
                       ['x', 'states'])
        states = encoder(en_data)

        # reconstruct
        r_res = []
        r_x = None
        r_states = states
        for t in xrange(time_steps):
            r_data = pack([r_x, r_states], ['x', 'states'])
            r_out, r_states = reconstructor(r_data)
            r_res.append(r_out)
            r_x = x_train[:, t].unsqueeze(1) if use_gt else r_out.unsqueeze(1)

        # predict
        time_steps = x_predict.size(1) if use_gt else max_steps
        p_res = []
        p_x = None
        p_states = states
        # print ('start from p')
        for t in xrange(time_steps):
            p_data = pack([p_x, p_states], ['x', 'states'])
            p_out, p_states = predictor(p_data)
            p_res.append(p_out)
            p_x = x_predict[:,
                            t].unsqueeze(1) if use_gt else p_out.unsqueeze(1)

        return torch.cat(r_res, 1), torch.cat(p_res, 1)
コード例 #4
0
    def __init__(self, configs):
        super(ConvRNNCell, self).__init__()
        _KEYS = ConvRNNCell.get_init_keys()
        # ['num_layers', 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG']
        self.h_c, self.active_func, self.in_c, self.in_h, \
        self.in_w, self.kernel_size, DEBUG = unpack(configs, _KEYS)

        if DEBUG:
            print_dict(configs, _KEYS, 'ConvRNNCell.__init__')

        self.conv2d = nn.Conv2d(in_channels=self.in_c + self.h_c,
                                out_channels=self.h_c,
                                kernel_size=self.kernel_size,
                                padding=(self.kernel_size - 1) // 2)
コード例 #5
0
    def forward(self, data, configs=None):
        # x: batch_size * in_c * in_w * in_h
        # states: h, c
        # h: batch_size * h_c * in_w * in_h
        # c: the same shape as h
        _KEYS = ['x', 'states']
        x, states = unpack(data, _KEYS)
        h = states
        active_func = self.active_func

        concat_hx = torch.cat([x, h], 1)
        conv_hx = self.conv2d(concat_hx)
        next_h = active_func(conv_hx)

        return next_h
コード例 #6
0
    def __init__(self, configs):
        super(ConvLSTMCell, self).__init__()
        _KEYS = ConvLSTMCell.get_init_keys()
        # ['num_layers', 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG']
        self.h_c, self.active_func, self.in_c, self.in_h, \
            self.in_w, self.kernel_size, DEBUG = unpack(configs, _KEYS)

        if DEBUG:
            print_dict(configs, _KEYS, 'ConvLSTMCell.__init__')

        self.conv2d = nn.Conv2d(in_channels=self.in_c + self.h_c,
                                out_channels=4 * self.h_c,
                                kernel_size=self.kernel_size,
                                padding=(self.kernel_size - 1) // 2)

        # if use bn
        self.bn = nn.BatchNorm2d(4 * self.h_c)

        self.w_ci = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w))
        self.w_cf = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w))
        self.w_co = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w))
        self.init_weights()
コード例 #7
0
    def forward(self, data, configs=None):
        _KEYS = ['x', 'states']
        x, states = unpack(data, _KEYS)
        batch_size = states[0][0].size(0)
        if x is None:
            x_c, x_h, x_w = self.cell_config['in_c'], self.cell_config[
                'in_w'], self.cell_config['in_h']
            x = to_var(torch.zeros(batch_size, 1, x_c, x_h, x_w))
        # x: batch_size, time_steps, channels, height, width
        time_steps = x.size(1)
        next_states = []
        cell_list = self.cell_list
        current_input = [x[:, t] for t in xrange(time_steps)]
        for l in xrange(self.num_layers):
            h0, c0 = states[l]
            for t in xrange(time_steps):
                data = pack([current_input[t], (h0, c0)], ['x', 'states'])
                h, c = cell_list[l](data)
                next_states.append((h, c))
                current_input[t] = h
                states[l] = (h, c)

        return states
コード例 #8
0
    def forward(self, data, configs=None):
        # x: batch_size * in_c * in_w * in_h
        # states: h, c
        # h: batch_size * h_c * in_w * in_h
        # c: the same shape as h
        _KEYS = ['x', 'states']
        x, states = unpack(data, _KEYS)
        h, c = states
        active_func = self.active_func
        concat_hx = torch.cat([x, h], 1)
        # print(concat_hx.size())
        conv_hx = self.conv2d(concat_hx)
        conv_hx = self.bn(conv_hx)
        ai, af, ac, ao = torch.split(conv_hx, self.h_c, dim=1)
        # print (ai.size(), af.size(), ac.size(), ao.size())
        # print (c.size(), self.w_ci.size())
        ic, fc, oc = self.w_ci * c, self.w_cf * c, self.w_co * c
        next_i = nn.Sigmoid()(ai + ic)
        next_f = nn.Sigmoid()(af + fc)
        next_c = next_f * c + next_i * active_func(ac)
        next_o = nn.Sigmoid()(ao + oc)
        next_h = next_o * active_func(next_c)

        return next_h, next_c