]

    dataloaders = {}
    dataloaders['train'] = DataLoader(dataset['train'],
                                      batch_size=batch_size['train'],
                                      sampler=ImbalancedDatasetSampler(
                                          dataset['train']),
                                      num_workers=0)
    dataloaders['val'] = DataLoader(dataset['val'],
                                    batch_size=batch_size['val'],
                                    shuffle=True,
                                    num_workers=0)

    img_model = torch.load('../model/tsx.pth')
    # spe_model = torch.load('../model/slc_spexy_cae_2.pth')
    net_joint = network.SLC_joint2(cate_num)
    net_joint = get_pretrained(img_model, net_joint)
    # net_joint.load_state_dict(torch.load('../model/slc_joint_deeper_' + str(datasetnum) + '_F_con.pth'))

    net_joint.to(device)

    epoch_num = 7000
    i = 0
    parameter_list = param_setting_jointmodel2(model=net_joint)

    optimizer = optim.SGD(parameter_list, lr=0.01, weight_decay=0.0005)
    lr_list = [param_group['lr'] for param_group in optimizer.param_groups]
    loss_weight = torch.Tensor(loss_weight).to(device)
    loss_func = nn.CrossEntropyLoss(weight=loss_weight)

    writer = SummaryWriter('../log/' + save_model_path.split('/')[-1] + 'log')
def get_train_features_transmat(data_dir):
    # data_dir = 'slc_train_3'
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # device = torch.device('cpu')
    txt_file = '../data/' + data_dir + '.txt'

    batch_size = 10
    cate_num = 8

    spe_transform = transforms.Compose([
        # transform_data.Normalize_spe_xy(),
        transform_data.Numpy2Tensor()
    ])

    img_transform = transforms.Compose([
        transform_data.Normalize_img(),
        transform_data.Numpy2Tensor_img(3)
    ])

    dataset = slc_dataset.SLC_img_spe4D(txt_file=txt_file,
                                        img_dir='../data/slc_data/',
                                        spe_dir='../data/spexy_data_3/',
                                        img_transform=img_transform,
                                        spe_transform=spe_transform)

    dataloaders = DataLoader(dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=0)

    img_model = torch.load('../model/tsx.pth')
    net_joint = network.SLC_joint2(cate_num)
    net_joint = get_pretrained(img_model, net_joint)

    net_joint.to(device)

    np_label = np.zeros(0)
    np_img_feature = np.zeros([0, 128, 16, 16])
    np_spe_feature = np.zeros([0, 128, 16, 16])

    for data in dataloaders:
        img_data = data['img'].to(device)
        spe_data = data['spe'].to(device)
        labels = data['label'].to(device)
        img_features = net_joint.pre_img_features(img_data)
        spe_features = net_joint.pre_spe_features(spe_data)

        np_label = np.concatenate((np_label, labels.cpu().data.numpy()))
        np_img_feature = np.concatenate((np_img_feature, img_features.cpu().data.numpy()), axis=0)
        np_spe_feature = np.concatenate((np_spe_feature, spe_features.cpu().data.numpy()), axis=0)

    N = len(np_label)
    np_img_feature = np_img_feature.reshape([N, 128, 256])
    np_spe_feature = np_spe_feature.reshape([N, 128, 256])

    transmat_img_feature = np.zeros([0, 7, 128])
    transmat_spe_feature = np.zeros([0, 7, 128])
    output_img_feature = np.zeros([0, 7, N])
    output_spe_feature = np.zeros([0, 7, N])

    # _, _, x, y = img_features.shape
    for i in range(256):
        per_img_feature = np_img_feature[:, :, i].T
        per_spe_feature = np_spe_feature[:, :, i].T
        per_output_img_feature, per_output_spe_feature, per_transmat_img_feature, per_transmat_spe_feature = \
            dcaFuse.dcaFuse(per_img_feature, per_spe_feature, np_label)

        transmat_img_feature = np.concatenate((transmat_img_feature, per_transmat_img_feature.reshape([1, 7, 128])),
                                              axis=0)
        transmat_spe_feature = np.concatenate((transmat_spe_feature, per_transmat_spe_feature.reshape([1, 7, 128])),
                                              axis=0)
        output_img_feature = np.concatenate((output_img_feature, per_output_img_feature.reshape([1, 7, N])), axis=0)
        output_spe_feature = np.concatenate((output_spe_feature, per_output_spe_feature.reshape([1, 7, N])), axis=0)

    np.save('../data/' + data_dir + '_img_features.npy',
            output_img_feature.reshape([16, 16, 7, N]).transpose(3, 2, 0, 1))
    np.save('../data/' + data_dir + '_spe_features.npy',
            output_spe_feature.reshape([16, 16, 7, N]).transpose(3, 2, 0, 1))
    np.save('../data/' + data_dir + '_transmat_img.npy', transmat_img_feature.reshape([16, 16, 7, 128]))
    np.save('../data/' + data_dir + '_transmat_spe.npy', transmat_spe_feature.reshape([16, 16, 7, 128]))
    np.save('../data/' + data_dir + '_label.npy', np_label)