print("===> Loading datasets")
#root_dir = '/tmp4/hang_data/DIV2K'
root_dir = '/tmp4/hang_data/VOCdevkit/VOC2012/'
SR_dir = join(root_dir, 'VOC_SSSR4')
if os.path.isdir(SR_dir):
    pass
else:
    os.mkdir(SR_dir)

model = torch.load(opt.model,
                   map_location=lambda storage, loc: storage)["model"]
deeplab_res = Res_Deeplab(num_classes=21)
saved_state_dict = torch.load('model/VOC12_scenes_20000.pth')
deeplab_res.load_state_dict(saved_state_dict)
deeplab_res = deeplab_res.eval()
mid = mid_layer()
criterion = torch.nn.MSELoss(size_average=False)

mid = mid.cuda(gpuid)
deeplab_res = deeplab_res.cuda(gpuid)
model = model.cuda(gpuid)
criterion = criterion.cuda(gpuid)

testloader = data.DataLoader(VOCDataValSet(root_dir,
                                           DATA_LIST_PATH,
                                           crop_size=(321, 321),
                                           mean=IMG_MEAN,
                                           scale=False,
                                           mirror=False),
                             batch_size=1,
                             shuffle=False,
def main():
    print("SSSRNet5_deeplan training finetuning on VOC 160*160 patches.")
    global opt, model, netContent
    opt = parser.parse_args()
    print(opt)
    gpuid = 0
    cuda = True
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    if opt.vgg_loss:
        print('===> Loading VGG model')
        netVGG = models.vgg19()
        netVGG.load_state_dict(
            model_zoo.load_url(
                'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))

        class _content_model(nn.Module):
            def __init__(self):
                super(_content_model, self).__init__()
                self.feature = nn.Sequential(
                    *list(netVGG.features.children())[:-1])

            def forward(self, x):
                out = self.feature(x)
                return out

        netContent = _content_model()

    print("===> Building model")
    #deeplab
    deeplab_res = Res_Deeplab(num_classes=21)
    saved_state_dict = torch.load('model/VOC12_scenes_20000.pth')
    deeplab_res.load_state_dict(saved_state_dict)
    deeplab_res = deeplab_res.eval()
    mid = mid_layer()
    #SRResNet
    print("===> Building model")
    model = Net()
    print('Parameters: {}'.format(get_n_params(model)))
    model_pretrained = torch.load(
        'model/model_DIV2K_noBN_96_epoch_36.pth',
        map_location=lambda storage, loc: storage)["model"]
    finetune = True

    if finetune == True:

        index = 0
        for (src, dst) in zip(model_pretrained.parameters(),
                              model.parameters()):
            if index > 1:
                list(model.parameters())[index].data = src.data
            index = index + 1

    criterion = nn.MSELoss(size_average=False)

    print("===> Setting GPU")
    if cuda:
        model = model.cuda(gpuid)
        mid = mid.cuda(gpuid)
        deeplab_res = deeplab_res.cuda(gpuid)
        criterion = criterion.cuda(gpuid)
        if opt.vgg_loss:
            netContent = netContent.cuda(gpuid)

            # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("=> loading model '{}'".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(opt.pretrained))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    print("===> Training")
    #root_dir = '/tmp4/hang_data/DIV2K/DIV2K_train_320_HDF5'
    root_dir = '/tmp4/hang_data/VOCdevkit/VOC2012/VOC_train_label160_HDF5'
    files_num = len(os.listdir(root_dir))
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        #save_checkpoint(model, epoch)
        print("===> Loading datasets")
        x = random.sample(os.listdir(root_dir), files_num)
        for index in range(0, files_num):
            train_path = os.path.join(root_dir, x[index])
            print("===> Training datasets: '{}'".format(train_path))
            train_set = DatasetFromHdf5(train_path)
            training_data_loader = DataLoader(dataset=train_set,
                                              num_workers=opt.threads,
                                              batch_size=opt.batchSize,
                                              shuffle=True)
            avgloss = train(training_data_loader, optimizer, deeplab_res,
                            model, mid, criterion, epoch, gpuid)
        if epoch % 2 == 0:
            save_checkpoint(model, epoch)