if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") txt_file = '../data/slc_val_3.txt' batch_size = 30 cate_num = 8 spe_transform = transforms.Compose([ # transform_data.Normalize_spe_xy(), transform_data.Numpy2Tensor() ]) img_transform = transforms.Compose( [transform_data.Normalize_img(), transform_data.Numpy2Tensor_img(3)]) dataset = slc_dataset.SLC_img(txt_file=txt_file, root_dir='../data/slc_data/', transform=img_transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0) net_joint = network.SLC_joint2_img(cate_num) pretrained_model = '../model/slc_joint_deeper_img_3_F.pth' print(pretrained_model) net_joint.load_state_dict(torch.load(pretrained_model))
def get_train_features_transmat(data_dir): # data_dir = 'slc_train_3' device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device = torch.device('cpu') txt_file = '../data/' + data_dir + '.txt' batch_size = 10 cate_num = 8 spe_transform = transforms.Compose([ # transform_data.Normalize_spe_xy(), transform_data.Numpy2Tensor() ]) img_transform = transforms.Compose([ transform_data.Normalize_img(), transform_data.Numpy2Tensor_img(3) ]) dataset = slc_dataset.SLC_img_spe4D(txt_file=txt_file, img_dir='../data/slc_data/', spe_dir='../data/spexy_data_3/', img_transform=img_transform, spe_transform=spe_transform) dataloaders = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0) img_model = torch.load('../model/tsx.pth') net_joint = network.SLC_joint2(cate_num) net_joint = get_pretrained(img_model, net_joint) net_joint.to(device) np_label = np.zeros(0) np_img_feature = np.zeros([0, 128, 16, 16]) np_spe_feature = np.zeros([0, 128, 16, 16]) for data in dataloaders: img_data = data['img'].to(device) spe_data = data['spe'].to(device) labels = data['label'].to(device) img_features = net_joint.pre_img_features(img_data) spe_features = net_joint.pre_spe_features(spe_data) np_label = np.concatenate((np_label, labels.cpu().data.numpy())) np_img_feature = np.concatenate((np_img_feature, img_features.cpu().data.numpy()), axis=0) np_spe_feature = np.concatenate((np_spe_feature, spe_features.cpu().data.numpy()), axis=0) N = len(np_label) np_img_feature = np_img_feature.reshape([N, 128, 256]) np_spe_feature = np_spe_feature.reshape([N, 128, 256]) transmat_img_feature = np.zeros([0, 7, 128]) transmat_spe_feature = np.zeros([0, 7, 128]) output_img_feature = np.zeros([0, 7, N]) output_spe_feature = np.zeros([0, 7, N]) # _, _, x, y = img_features.shape for i in range(256): per_img_feature = np_img_feature[:, :, i].T per_spe_feature = np_spe_feature[:, :, i].T per_output_img_feature, per_output_spe_feature, per_transmat_img_feature, per_transmat_spe_feature = \ dcaFuse.dcaFuse(per_img_feature, per_spe_feature, np_label) transmat_img_feature = np.concatenate((transmat_img_feature, per_transmat_img_feature.reshape([1, 7, 128])), axis=0) transmat_spe_feature = np.concatenate((transmat_spe_feature, per_transmat_spe_feature.reshape([1, 7, 128])), axis=0) output_img_feature = np.concatenate((output_img_feature, per_output_img_feature.reshape([1, 7, N])), axis=0) output_spe_feature = np.concatenate((output_spe_feature, per_output_spe_feature.reshape([1, 7, N])), axis=0) np.save('../data/' + data_dir + '_img_features.npy', output_img_feature.reshape([16, 16, 7, N]).transpose(3, 2, 0, 1)) np.save('../data/' + data_dir + '_spe_features.npy', output_spe_feature.reshape([16, 16, 7, N]).transpose(3, 2, 0, 1)) np.save('../data/' + data_dir + '_transmat_img.npy', transmat_img_feature.reshape([16, 16, 7, 128])) np.save('../data/' + data_dir + '_transmat_spe.npy', transmat_spe_feature.reshape([16, 16, 7, 128])) np.save('../data/' + data_dir + '_label.npy', np_label)