class HKOModel(nn.Module): def __init__(self, inplanes, input_num_seqs, output_num_seqs): super(HKOModel, self).__init__() self.input_num_seqs = input_num_seqs self.output_num_seqs = output_num_seqs self.encoder = Encoder(inplanes=inplanes, num_seqs=input_num_seqs) self.forecaster = Forecaster(num_seqs=output_num_seqs) if cuda_flag == True: self.encoder = self.encoder.cuda() self.forecaster = self.forecaster.cuda() def forward(self, data): self.encoder.init_h0() for time in range(self.input_num_seqs): self.encoder(data[time]) all_pre_data = [] self.forecaster.set_h0(self.encoder) for time in range(self.output_num_seqs): pre_data = self.forecaster(None) # print h_next.size() all_pre_data.append(pre_data) return all_pre_data
true_img = target_image[pre_id, 0, 0, ...] encode_img = input_image[pre_id, 0, 0, ...] cv2.imwrite(os.path.join(save_path, 'a_%s.png' % pre_id), encode_img) cv2.imwrite(os.path.join(save_path, 'c_%s.png' % pre_id), tmp_img) cv2.imwrite(os.path.join(save_path, 'b_%s.png' % pre_id), true_img) # for pre_data in pre_list: # temp = pre_data.cpu().data.numpy() # print temp.mean() train_arr, test_arr, train_imgs_maps, test_imgs_maps = load_data( ['AZ9010', 'AZ9200']) if __name__ == '__main__': # m = HKOModel(inplanes=1, input_num_seqs=input_num_seqs, output_num_seqs=output_num_seqs) m_e = Encoder(inplanes=input_channels_img, num_seqs=input_num_seqs) m_e = m_e.cuda() m_f = Forecaster(num_seqs=output_num_seqs) m_f = m_f.cuda() test(input_channels_img, output_channels_img, size_image, max_epoch, model_e=m_e, model_f=m_f, cuda_test=cuda_flag)