Exemplo n.º 1
0
class HKOModel(nn.Module):
    def __init__(self, inplanes, input_num_seqs, output_num_seqs):
        super(HKOModel, self).__init__()
        self.input_num_seqs = input_num_seqs
        self.output_num_seqs = output_num_seqs
        self.encoder = Encoder(inplanes=inplanes, num_seqs=input_num_seqs)
        self.forecaster = Forecaster(num_seqs=output_num_seqs)
        if cuda_flag == True:
            self.encoder = self.encoder.cuda()
            self.forecaster = self.forecaster.cuda()

    def forward(self, data):
        self.encoder.init_h0()
        for time in range(self.input_num_seqs):
            self.encoder(data[time])
        all_pre_data = []
        self.forecaster.set_h0(self.encoder)
        for time in range(self.output_num_seqs):

            pre_data = self.forecaster(None)
            # print h_next.size()

            all_pre_data.append(pre_data)

        return all_pre_data
Exemplo n.º 2
0
            true_img = target_image[pre_id, 0, 0, ...]
            encode_img = input_image[pre_id, 0, 0, ...]
            cv2.imwrite(os.path.join(save_path, 'a_%s.png' % pre_id),
                        encode_img)
            cv2.imwrite(os.path.join(save_path, 'c_%s.png' % pre_id), tmp_img)
            cv2.imwrite(os.path.join(save_path, 'b_%s.png' % pre_id), true_img)

    # for pre_data in pre_list:
    #     temp = pre_data.cpu().data.numpy()
    #     print temp.mean()


train_arr, test_arr, train_imgs_maps, test_imgs_maps = load_data(
    ['AZ9010', 'AZ9200'])

if __name__ == '__main__':
    # m = HKOModel(inplanes=1, input_num_seqs=input_num_seqs, output_num_seqs=output_num_seqs)
    m_e = Encoder(inplanes=input_channels_img, num_seqs=input_num_seqs)
    m_e = m_e.cuda()

    m_f = Forecaster(num_seqs=output_num_seqs)
    m_f = m_f.cuda()

    test(input_channels_img,
         output_channels_img,
         size_image,
         max_epoch,
         model_e=m_e,
         model_f=m_f,
         cuda_test=cuda_flag)