def __init__(self, configs): super(MMNIST_ConvLSTM, self).__init__() _KEYS = ['encoder_configs', 'reconstruct_configs', 'predict_configs'] en_conf, rec_conf, pred_conf = unpack(configs, _KEYS) self.encoder = ConvLSTM(en_conf) self.reconstructor = Generator(rec_conf) self.predictor = Generator(pred_conf)
def __init__(self, configs): super(ConvLSTM, self).__init__() # 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG' _KEYS = ['num_layers', 'cell_configs'] num_layers, cell_configs = unpack(configs, _KEYS) cells = [ConvLSTMCell(cell_configs[idx]) for idx in xrange(num_layers)] self.cell_list = nn.ModuleList(cells) self.num_layers = num_layers self.cell_config = cell_configs[0]
def forward(self, data, configs=None): x_train, x_predict, states = unpack(data, ['x_train', 'x_predict', 'states']) use_gt, max_steps = unpack(configs, ['use_gt', 'max_steps']) # x: batch_size * time_steps * channels * height * width # states: batch_size * channels' * height * width batch_size = x_train.size(0) time_steps = x_train.size(1) encoder = self.encoder reconstructor = self.reconstructor predictor = self.predictor # encoding stages en_data = pack([x_train, encoder.init_hidden(batch_size)], ['x', 'states']) states = encoder(en_data) # reconstruct r_res = [] r_x = None r_states = states for t in xrange(time_steps): r_data = pack([r_x, r_states], ['x', 'states']) r_out, r_states = reconstructor(r_data) r_res.append(r_out) r_x = x_train[:, t].unsqueeze(1) if use_gt else r_out.unsqueeze(1) # predict time_steps = x_predict.size(1) if use_gt else max_steps p_res = [] p_x = None p_states = states # print ('start from p') for t in xrange(time_steps): p_data = pack([p_x, p_states], ['x', 'states']) p_out, p_states = predictor(p_data) p_res.append(p_out) p_x = x_predict[:, t].unsqueeze(1) if use_gt else p_out.unsqueeze(1) return torch.cat(r_res, 1), torch.cat(p_res, 1)
def __init__(self, configs): super(ConvRNNCell, self).__init__() _KEYS = ConvRNNCell.get_init_keys() # ['num_layers', 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG'] self.h_c, self.active_func, self.in_c, self.in_h, \ self.in_w, self.kernel_size, DEBUG = unpack(configs, _KEYS) if DEBUG: print_dict(configs, _KEYS, 'ConvRNNCell.__init__') self.conv2d = nn.Conv2d(in_channels=self.in_c + self.h_c, out_channels=self.h_c, kernel_size=self.kernel_size, padding=(self.kernel_size - 1) // 2)
def forward(self, data, configs=None): # x: batch_size * in_c * in_w * in_h # states: h, c # h: batch_size * h_c * in_w * in_h # c: the same shape as h _KEYS = ['x', 'states'] x, states = unpack(data, _KEYS) h = states active_func = self.active_func concat_hx = torch.cat([x, h], 1) conv_hx = self.conv2d(concat_hx) next_h = active_func(conv_hx) return next_h
def __init__(self, configs): super(ConvLSTMCell, self).__init__() _KEYS = ConvLSTMCell.get_init_keys() # ['num_layers', 'h_c', 'active_func', 'in_c', 'in_h', 'in_w', 'kernel_size', 'DEBUG'] self.h_c, self.active_func, self.in_c, self.in_h, \ self.in_w, self.kernel_size, DEBUG = unpack(configs, _KEYS) if DEBUG: print_dict(configs, _KEYS, 'ConvLSTMCell.__init__') self.conv2d = nn.Conv2d(in_channels=self.in_c + self.h_c, out_channels=4 * self.h_c, kernel_size=self.kernel_size, padding=(self.kernel_size - 1) // 2) # if use bn self.bn = nn.BatchNorm2d(4 * self.h_c) self.w_ci = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w)) self.w_cf = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w)) self.w_co = nn.Parameter(torch.zeros(self.h_c, self.in_h, self.in_w)) self.init_weights()
def forward(self, data, configs=None): _KEYS = ['x', 'states'] x, states = unpack(data, _KEYS) batch_size = states[0][0].size(0) if x is None: x_c, x_h, x_w = self.cell_config['in_c'], self.cell_config[ 'in_w'], self.cell_config['in_h'] x = to_var(torch.zeros(batch_size, 1, x_c, x_h, x_w)) # x: batch_size, time_steps, channels, height, width time_steps = x.size(1) next_states = [] cell_list = self.cell_list current_input = [x[:, t] for t in xrange(time_steps)] for l in xrange(self.num_layers): h0, c0 = states[l] for t in xrange(time_steps): data = pack([current_input[t], (h0, c0)], ['x', 'states']) h, c = cell_list[l](data) next_states.append((h, c)) current_input[t] = h states[l] = (h, c) return states
def forward(self, data, configs=None): # x: batch_size * in_c * in_w * in_h # states: h, c # h: batch_size * h_c * in_w * in_h # c: the same shape as h _KEYS = ['x', 'states'] x, states = unpack(data, _KEYS) h, c = states active_func = self.active_func concat_hx = torch.cat([x, h], 1) # print(concat_hx.size()) conv_hx = self.conv2d(concat_hx) conv_hx = self.bn(conv_hx) ai, af, ac, ao = torch.split(conv_hx, self.h_c, dim=1) # print (ai.size(), af.size(), ac.size(), ao.size()) # print (c.size(), self.w_ci.size()) ic, fc, oc = self.w_ci * c, self.w_cf * c, self.w_co * c next_i = nn.Sigmoid()(ai + ic) next_f = nn.Sigmoid()(af + fc) next_c = next_f * c + next_i * active_func(ac) next_o = nn.Sigmoid()(ao + oc) next_h = next_o * active_func(next_c) return next_h, next_c