# test_im = cubes[:test]
# test_sil = sils[:test]
# test_param = params[:test]
# number_testn_im = np.shape(test_im)[0]

test_im = cubes
test_sil = sils
test_param = params
number_testn_im = np.shape(test_im)[0]

#  ------------------------------------------------------------------

normalize = Normalize(mean=[0.5], std=[0.5])
gray_to_rgb = Lambda(lambda x: x.repeat(3, 1, 1))
transforms = Compose([ToTensor(), normalize])
test_dataset = CubeDataset(test_im, test_sil, test_param, transforms)

test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=2)

# for image, sil, param in test_dataloader:
#
#     nim = image.size()[0]
#     for i in range(0,nim):
#         print(image.size(), sil.size(), param.size()) #torch.Size([batch, 3, 512, 512]) torch.Size([batch, 6])
#         im = i
#         print(param[im])  # parameter in form tensor([2.5508, 0.0000, 0.0000, 0.0000, 0.0000, 5.0000])
#
#
val_im = cubes[:split]  # remaining ratio for validation
val_sil = sils[:split]
val_param = params[:split]

test_im = cubes[split:split + testlen]
test_sil = sils[split:split + testlen]
test_param = params[split:split + testlen]
number_testn_im = np.shape(test_im)[0]

#  ------------------------------------------------------------------

normalize = Normalize(mean=[0.5], std=[0.5])
gray_to_rgb = Lambda(lambda x: x.repeat(3, 1, 1))
transforms = Compose([ToTensor(), normalize])
train_dataset = CubeDataset(train_im, train_sil, train_param, transforms)
val_dataset = CubeDataset(val_im, val_sil, val_param, transforms)
test_dataset = CubeDataset(test_im, test_sil, test_param, transforms)

train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=2)
val_dataloader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=2)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=2)
BackgroundVal = np.load(Background_Valfile)
silsVal = np.load(BWShaft_Valfile)
paramsVal = np.load(parameters_Valfile)
# print(np.min(params[:,4]))

#  ------------------------------------------------------------------

val_im = BackgroundVal[start:start + vallen]  #100:200
val_sil = silsVal[start:start + vallen]
val_param = paramsVal[start:start + vallen]

#  ------------------------------------------------------------------

normalize = Normalize(mean=[0.5], std=[0.5])
transforms = Compose([ToTensor(), normalize])
val_dataset = CubeDataset(val_im, val_sil, val_param, transforms)

val_dataloader = DataLoader(val_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=2)

#  ------------------------------------------------------------------
# Setup the model

current_dir = os.path.dirname(os.path.realpath(__file__))
data_dir = os.path.join(current_dir, '3D_objects')

noise = 0.0
parser = argparse.ArgumentParser()
parser.add_argument('-io',