import easydict
import time
import torch
from misc.dataloader import DataLoader
import torch.optim as optim
import misc.datasets as datasets
import ctrlfnet_model as ctrlf
from train_opts import parse_args
from evaluate import mAP
import misc.h5_dataset as h5_dataset

opt = parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)

if opt.h5:
    trainset = h5_dataset.H5Dataset(opt, split=0)
    valset = h5_dataset.H5Dataset(opt, split=1)
    testset = h5_dataset.H5Dataset(opt, split=2)
    opt.num_workers = 0
else:
    if opt.dataset.find('iiit_hws') > -1:
        trainset = datasets.SegmentedDataset(opt, 'train')
    else:
        trainset = datasets.Dataset(opt, 'train')

    valset = datasets.Dataset(opt, 'val')
    testset = datasets.Dataset(opt, 'test')
sampler = datasets.RandomSampler(trainset, opt.max_iters)
trainloader = DataLoader(trainset,
                         batch_size=1,
                         sampler=sampler,
            spot.setAttribute('w', str(w))
            spot.setAttribute('h', str(h))
            top_element.appendChild(spot)

    with open("botany_konz_eval/data/%s_results_%s.xml" % (dataset, mode),
              'wb') as f:
        newdoc.writexml(f, addindent='  ', newl='\n', encoding='utf-8')


#%%
opt = parse_args()
opt.augment = 0
os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpu)

if opt.h5:
    testset = h5_dataset.H5Dataset(opt, split=2)
    valset = h5_dataset.H5Dataset(opt, split=1)
    opt.num_workers = 0
else:
    testset = datasets.Dataset(opt, 'test')
    valset = datasets.Dataset(opt, 'val')

loader = dataloader.DataLoader(testset,
                               batch_size=1,
                               shuffle=False,
                               num_workers=0)
valloader = dataloader.DataLoader(valset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=0)
torch.set_default_tensor_type('torch.FloatTensor')