Пример #1
0
def make_model(src_vocab,
               tgt_vocab,
               emb_size=256,
               hidden_size=512,
               num_layers=1,
               dropout=0.1):
    "Helper: Construct a model from hyperparameters."

    attention = model.BahdanauAttention(hidden_size)

    mdl = model.EncoderDecoder(
        model.Encoder(emb_size,
                      hidden_size,
                      num_layers=num_layers,
                      dropout=dropout),
        model.Decoder(emb_size,
                      hidden_size,
                      attention,
                      num_layers=num_layers,
                      dropout=dropout), nn.Embedding(src_vocab, emb_size),
        nn.Embedding(tgt_vocab, emb_size),
        model.Generator(hidden_size, tgt_vocab))

    return mdl.cuda() if USE_CUDA else mdl
Пример #2
0
def loadModel(path):
    encoderDecoder = model.EncoderDecoder()
    encoderDecoder.load_state_dict(torch.load(path))
    return encoderDecoder
Пример #3
0
def random_cv(cv_index, cv_year, roothpath, param_grid, num_random, model_name,
              device, one_day):
    """Hyperparameter tuning through random search

    Args:
    cv_index: the month of the valiation set
    cv_year: the year of the valiation set
    rootpath: the path where training-validtion sets are saved
    param_grid: a dictionary, consisting the grid of hyperparameters
    num_randon: the number of sets of hyperparameters to evaluate(tune)
    model_name: a string representing the name of a model
    device: indicates if the model should be run on cpu or gpu
    one_day: True or False, indicating if only the most recent available day is used for training a model (XGBoost or Lasso)
    """
    # load data
    if model_name in ['CNN_LSTM', 'CNN_FNN']:
        train_X = joblib.load(
            rootpath +
            'train_X_map_{}_forecast{}.pkl'.format(cv_year, cv_index))
        valid_X = joblib.load(
            rootpath + 'val_X_map_{}_forecast{}.pkl'.format(cv_year, cv_index))
        train_y = load_results(
            rootpath +
            'train_y_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        valid_y = load_results(
            rootpath + 'val_y_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        output_dim = train_y.shape[-1]
    else:
        train_X = load_results(
            rootpath +
            'train_X_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        valid_X = load_results(
            rootpath + 'val_X_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        train_y = load_results(
            rootpath +
            'train_y_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        valid_y = load_results(
            rootpath + 'val_y_pca_{}_forecast{}.pkl'.format(cv_year, cv_index))
        # set input and output dim
        input_dim = train_X.shape[-1]
        output_dim = train_y.shape[-1]

    if model_name == 'EncoderFNN_AllSeq_AR_CI' or model_name == 'EncoderFNN_AllSeq_AR':
        hidden_dim = param_grid['hidden_dim']
        num_layers = param_grid['num_layers']
        lr = param_grid['learning_rate']
        threshold = param_grid['threshold']
        num_epochs = param_grid['num_epochs']
        seq_len = param_grid['seq_len']
        linear_dim = param_grid['linear_dim']
        drop_out = param_grid['drop_out']
        if model_name == 'EncoderFNN_AllSeq_AR_CI':
            ci_dim = param_grid['ci_dim']

        train_y_ar = load_results(
            rootpath +
            'train_y_pca_ar_{}_forecast{}.pkl'.format(cv_year, cv_index))
        valid_y_ar = load_results(
            rootpath +
            'val_y_pca_ar_{}_forecast{}.pkl'.format(cv_year, cv_index))
        train_dataset = model.MapDataset_ar(train_X, train_y_ar, train_y)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=512,
                                  shuffle=False)

    elif model_name == 'EncoderDecoder' or model_name == 'EncoderFNN_AllSeq' or model_name == 'EncoderFNN':
        train_dataset = model.MapDataset(train_X, train_y)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=512,
                                  shuffle=False)
        hidden_dim = param_grid['hidden_dim']
        num_layers = param_grid['num_layers']
        lr = param_grid['learning_rate']
        threshold = param_grid['threshold']
        num_epochs = param_grid['num_epochs']
        if model_name == 'EncoderDecoder':
            decoder_len = param_grid['decoder_len']
        elif model_name == 'EncoderFNN':
            last_layer = param_grid['last_layer']
            seq_len = param_grid['seq_len']
        elif model_name == 'EncoderFNN_AllSeq':
            seq_len = param_grid['seq_len']
            linear_dim = param_grid['linear_dim']
            drop_out = param_grid['drop_out']
    elif model_name == 'XGBoost':
        if one_day is True:
            train_X = train_X[:, -1, :]  # one day
            valid_X = valid_X[:, -1, :]  # one day
        train_X = np.reshape(train_X, (train_X.shape[0], -1))
        valid_X = np.reshape(valid_X, (valid_X.shape[0], -1))
        max_depth = param_grid['max_depth']
        colsample_bytree = param_grid['colsample_bytree']
        gamma = param_grid['gamma']
        n_estimators = param_grid['n_estimators']
        lr = param_grid['learning_rate']
    elif model_name == 'Lasso':
        if one_day is True:
            train_X = train_X[:, -1, :]  # one day
            valid_X = valid_X[:, -1, :]  # one day
        train_X = np.reshape(train_X, (train_X.shape[0], -1))
        valid_X = np.reshape(valid_X, (valid_X.shape[0], -1))
        alphas = param_grid['alpha']
    elif model_name == 'FNN':
        if one_day is True:
            train_X = train_X[:, -1, :]  # one day
            valid_X = valid_X[:, -1, :]  # one day
        train_X = np.reshape(train_X, (train_X.shape[0], -1))
        valid_X = np.reshape(valid_X, (valid_X.shape[0], -1))
        train_dataset = model.MapDataset(train_X, train_y)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=512,
                                  shuffle=False)
        hidden_dim = param_grid['hidden_dim']
        num_layers = param_grid['num_layers']
    elif model_name == 'CNN_FNN':
        train_dataset = model.MapDataset_CNN(train_X, train_y)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=50,
                                  shuffle=False)
        stride = param_grid['stride']
        kernel_size = param_grid['kernel_size']
        hidden_dim = param_grid['hidden_dim']
        num_layers = param_grid['num_layers']
    elif model_name == 'CNN_LSTM':
        train_dataset = model.MapDataset_CNN(train_X, train_y)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=50,
                                  shuffle=False)
        stride = param_grid['module__stride']
        kernel_size = param_grid['module__kernel_size']
        hidden_dim = param_grid['module__hidden_dim']
        num_lstm_layers = param_grid['module__num_lstm_layers']
        lr = param_grid['lr']
        num_epochs = param_grid['module__num_epochs']
    else:
        print('the model name is not in the list')

    history_all = []
    score = []
    parameter_all = []
    for i in range(num_random):
        # set model
        if model_name == 'EncoderDecoder':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layer = num_layers[randint(0, len(num_layers) - 1)]
            curr_decoder_len = decoder_len[randint(0, len(decoder_len) - 1)]
            curr_threshold = threshold[randint(0, len(threshold) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layer,
                'decoder_len': curr_decoder_len,
                'threshold': curr_threshold,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs
            }
            parameter_all.append(parameters)
            mdl = model.EncoderDecoder(input_dim=input_dim,
                                       output_dim=output_dim,
                                       hidden_dim=curr_hidden_dim,
                                       num_layers=curr_num_layer,
                                       learning_rate=curr_lr,
                                       decoder_len=curr_decoder_len,
                                       threshold=curr_threshold,
                                       num_epochs=curr_num_epochs)

            # initialize the model
            model.init_weight(mdl)

            # send model to gpu
            mdl.to(device)
            # fit the model
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            # compute the prediction of validation set
            pred_y = mdl.predict(valid_X, device)
        elif model_name == 'EncoderFNN':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layer = num_layers[randint(0, len(num_layers) - 1)]
            curr_seq_len = seq_len[randint(0, len(seq_len) - 1)]
            curr_threshold = threshold[randint(0, len(threshold) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            curr_last_layer = last_layer[randint(0, len(last_layer) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layer,
                'last_layer': curr_last_layer,
                'threshold': curr_threshold,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs,
                'seq_len': curr_seq_len
            }
            parameter_all.append(parameters)
            mdl = model.EncoderFNN(input_dim=input_dim,
                                   output_dim=output_dim,
                                   hidden_dim=curr_hidden_dim,
                                   num_layers=curr_num_layer,
                                   last_layer=curr_last_layer,
                                   seq_len=curr_seq_len,
                                   learning_rate=curr_lr,
                                   threshold=curr_threshold,
                                   num_epochs=curr_num_epochs)
            # initialize the model
            model.init_weight(mdl)

            # send model to gpu
            mdl.to(device)
            # fit the model
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            # compute the prediction of validation set
            pred_y = mdl.predict(valid_X, device)
        elif model_name == 'EncoderFNN_AllSeq':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layer = num_layers[randint(0, len(num_layers) - 1)]
            curr_seq_len = seq_len[randint(0, len(seq_len) - 1)]
            curr_threshold = threshold[randint(0, len(threshold) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            curr_linear_dim = linear_dim[randint(0, len(linear_dim) - 1)]
            curr_drop_out = drop_out[randint(0, len(drop_out) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layer,
                'linear_dim': curr_linear_dim,
                'threshold': curr_threshold,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs,
                'seq_len': curr_seq_len,
                'drop_out': curr_drop_out
            }
            parameter_all.append(parameters)

            mdl = model.EncoderFNN_AllSeq(input_dim=input_dim,
                                          output_dim=output_dim,
                                          hidden_dim=curr_hidden_dim,
                                          num_layers=curr_num_layer,
                                          seq_len=curr_seq_len,
                                          linear_dim=curr_linear_dim,
                                          learning_rate=curr_lr,
                                          dropout=curr_drop_out,
                                          threshold=curr_threshold,
                                          num_epochs=curr_num_epochs)
            # initialize the model
            model.init_weight(mdl)

            # send model to gpu
            mdl.to(device)
            # fit the model
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            # compute the prediction of validation set
            pred_y = mdl.predict(valid_X, device)
        elif model_name == 'EncoderFNN_AllSeq_AR':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layer = num_layers[randint(0, len(num_layers) - 1)]
            curr_seq_len = seq_len[randint(0, len(seq_len) - 1)]
            curr_threshold = threshold[randint(0, len(threshold) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            curr_linear_dim = linear_dim[randint(0, len(linear_dim) - 1)]
            curr_drop_out = drop_out[randint(0, len(drop_out) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layer,
                'linear_dim': curr_linear_dim,
                'threshold': curr_threshold,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs,
                'seq_len': curr_seq_len,
                'drop_out': curr_drop_out
            }
            parameter_all.append(parameters)

            mdl = model.EncoderFNN_AllSeq_AR(input_dim=input_dim,
                                             output_dim=output_dim,
                                             hidden_dim=curr_hidden_dim,
                                             num_layers=curr_num_layer,
                                             seq_len=curr_seq_len,
                                             linear_dim=curr_linear_dim,
                                             learning_rate=curr_lr,
                                             dropout=curr_drop_out,
                                             threshold=curr_threshold,
                                             num_epochs=curr_num_epochs)
            # initialize the model
            model.init_weight(mdl)

            # send model to gpu
            mdl.to(device)
            # fit the model
            history = mdl.fit_cv(train_loader, valid_X, valid_y_ar, valid_y,
                                 device)
            # compute the prediction of validation set
            pred_y = mdl.predict(valid_X, valid_y_ar, device)
        elif model_name == 'EncoderFNN_AllSeq_AR_CI':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layer = num_layers[randint(0, len(num_layers) - 1)]
            curr_seq_len = seq_len[randint(0, len(seq_len) - 1)]
            curr_threshold = threshold[randint(0, len(threshold) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            curr_linear_dim = linear_dim[randint(0, len(linear_dim) - 1)]
            curr_drop_out = drop_out[randint(0, len(drop_out) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layer,
                'linear_dim': curr_linear_dim,
                'threshold': curr_threshold,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs,
                'seq_len': curr_seq_len,
                'drop_out': curr_drop_out,
                'ci_dim': ci_dim
            }
            parameter_all.append(parameters)

            mdl = model.EncoderFNN_AllSeq_AR_CI(input_dim=input_dim - ci_dim,
                                                output_dim=output_dim,
                                                hidden_dim=curr_hidden_dim,
                                                num_layers=curr_num_layer,
                                                seq_len=curr_seq_len,
                                                linear_dim=curr_linear_dim,
                                                ci_dim=ci_dim,
                                                learning_rate=curr_lr,
                                                dropout=curr_drop_out,
                                                threshold=curr_threshold,
                                                num_epochs=curr_num_epochs)
            # initialize the model
            model.init_weight(mdl)

            # send model to gpu
            mdl.to(device)
            # fit the model
            history = mdl.fit_cv(train_loader, valid_X, valid_y_ar, valid_y,
                                 device)
            pred_y = mdl.predict(valid_X, valid_y_ar, device)
        elif model_name == 'XGBoost':
            curr_max_depth = max_depth[randint(0, len(max_depth) - 1)]
            curr_colsample_bytree = colsample_bytree[randint(
                0,
                len(colsample_bytree) - 1)]
            curr_gamma = gamma[randint(0, len(gamma) - 1)]
            curr_n_estimators = n_estimators[randint(0, len(n_estimators) - 1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            parameters = {
                'max_depth': curr_max_depth,
                'colsample_bytree': curr_colsample_bytree,
                'gamma': curr_gamma,
                'n_estimators': curr_n_estimators,
                'learning_rate': curr_lr
            }
            parameter_all.append(parameters)
            mdl = model.XGBMultitask(num_models=output_dim,
                                     colsample_bytree=curr_colsample_bytree,
                                     gamma=curr_gamma,
                                     learning_rate=curr_lr,
                                     max_depth=curr_max_depth,
                                     n_estimators=curr_n_estimators,
                                     objective='reg:squarederror')
            # history = mdl.fit_cv(train_X, train_y, valid_X, valid_y)
            mdl.fit(train_X, train_y)
            pred_y = mdl.predict(valid_X)
            history = None
        elif model_name == 'Lasso':
            curr_alpha = alphas[randint(0, len(alphas) - 1)]
            parameter = {'alpha': curr_alpha}
            parameter_all.append(parameter)
            mdl = model.LassoMultitask(alpha=curr_alpha, fit_intercept=False)
            mdl.fit(train_X, train_y)
            pred_y = mdl.predict(valid_X)
            history = None
        elif model_name == 'FNN':
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layers = num_layers[randint(0, len(num_layers) - 1)]
            parameters = {
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layers
            }
            parameter_all.append(parameters)
            mdl = model.ReluNet(input_dim=input_dim,
                                output_dim=output_dim,
                                hidden_dim=curr_hidden_dim,
                                num_layers=curr_num_layers,
                                threshold=0.1,
                                num_epochs=1000)
            model.init_weight(mdl)
            mdl.to(device)
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            pred_y = mdl.predict(valid_X, device)
        elif model_name == 'CNN_FNN':
            curr_stride = stride[randint(0, len(stride) - 1)]
            curr_kernel_size = kernel_size[randint(0, len(kernel_size) - 1)]
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layers = num_layers[randint(0, len(num_layers) - 1)]
            parameters = {
                'stride': curr_stride,
                'kernel_size': curr_kernel_size,
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layers
            }
            parameter_all.append(parameters)
            num_var = len(train_X)
            input_dim = model.get_input_dim(train_X, num_var, curr_stride,
                                            curr_kernel_size)
            mdl = model.CnnFnn(num_var,
                               input_dim,
                               output_dim,
                               kernel_size=curr_kernel_size,
                               stride=curr_stride,
                               hidden_dim=curr_hidden_dim,
                               num_layers=curr_num_layers,
                               num_epochs=100)
            mdl.to(device)
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            pred_y = mdl.predict(valid_X, device)
        elif model_name == 'CNN_LSTM':
            curr_stride = stride[randint(0, len(stride) - 1)]
            curr_kernel_size = kernel_size[randint(0, len(kernel_size) - 1)]
            curr_hidden_dim = hidden_dim[randint(0, len(hidden_dim) - 1)]
            curr_num_layers = num_lstm_layers[randint(0,
                                                      len(num_lstm_layers) -
                                                      1)]
            curr_lr = lr[randint(0, len(lr) - 1)]
            curr_num_epochs = num_epochs[randint(0, len(num_epochs) - 1)]
            parameters = {
                'stride': curr_stride,
                'kernel_size': curr_kernel_size,
                'hidden_dim': curr_hidden_dim,
                'num_layers': curr_num_layers,
                'learning_rate': curr_lr,
                'num_epochs': curr_num_epochs
            }
            parameter_all.append(parameters)
            num_var = len(train_X)
            input_dim = model.get_input_dim(train_X, num_var, curr_stride,
                                            curr_kernel_size)
            mdl = model.CnnLSTM(num_var,
                                input_dim,
                                output_dim,
                                kernel_size=curr_kernel_size,
                                stride=curr_stride,
                                hidden_dim=curr_hidden_dim,
                                num_lstm_layers=curr_num_layers,
                                num_epochs=curr_num_epochs,
                                learning_rate=curr_lr)
            mdl.to(device)
            history = mdl.fit_cv(train_loader, valid_X, valid_y, device)
            pred_y = mdl.predict(valid_X, device)

        history_all.append(history)
        test_rmse = np.sqrt(((valid_y - pred_y)**2).mean())
        test_cos = np.asarray([
            compute_cosine(valid_y[i, :], pred_y[i, :])
            for i in range(len(valid_y))
        ]).mean()
        score.append([test_rmse, test_cos])

    cv_results = {
        'score': score,
        'parameter_all': parameter_all,
        'history_all': history_all
    }
    save_results(
        rootpath + 'cv_results_test/cv_results_' + model_name +
        '_{}_{}.pkl'.format(cv_year, cv_index), cv_results)
Пример #4
0
def test():

    imTransform = transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    testLoader = helper.PreprocessInFocusOnlyData(imTransform, rootDirTest,
                                                  'inputs', 'masks', batchSize,
                                                  False)
    sliceNames = ('outputSlice1', 'outputSlice10')
    numSlices = len(sliceNames)

    resultInputDir = rootDirResults + 'InputAIF/'
    resultDirInFocus = rootDirResults + 'OutputAIF/'
    resultDirSlices = rootDirResults + 'SliceOutputs/'
    resultDirSegMasks = rootDirResults + 'GroundTruthMasks/'

    extension = '.pth'

    loadRootDir = rootDirLoad
    epoch = loadEpochNum
    model1LoadPath = loadRootDir + 'model1/'
    model2LoadPath = loadRootDir + 'model2/'
    model3LoadPath = loadRootDir + 'model3/'

    print(model1LoadPath, model2LoadPath, model3LoadPath)

    model1 = model.EncoderDecoder(4, 6 * numSlices).to(device)
    model1.load_state_dict(torch.load(model1LoadPath + str(epoch) + extension))
    model1.eval()

    model2 = model.DepthToInFocus(3).to(device)
    model2.load_state_dict(torch.load(model2LoadPath + str(epoch) + extension))
    model2.eval()

    model3 = model.EncoderDecoderv2(3, 1).to(device)
    model3.load_state_dict(torch.load(model3LoadPath + str(epoch) + extension))
    model3.eval()

    count = 0

    for step, sample in enumerate(testLoader):
        print(step)
        inData = sample['Input'].to(device)
        print(inData.shape)
        segMask = sample['Mask'].to(device)
        blurOut = model1(inData, segMask)
        count += 1

        for dispIndex in range(len(blurOut)):

            dispOut = blurOut[dispIndex]

            inFocus = F.interpolate(inData,
                                    size=None,
                                    scale_factor=1 / pow(2.0, dispIndex),
                                    mode='bilinear')

            blurMapSlices = helper.BlurMapsToSlices(dispOut, numSlices, 6)
            inFocusRegionSlices = []
            nearFocusOutSlices = []
            binaryBlurMapSlices = []

            for slices in range(numSlices):

                nearFocusOutSlice = model2(inFocus, blurMapSlices[slices])
                binaryBlurMapSlice = model3(nearFocusOutSlice)
                binaryBlurMapSlice = torch.cat(
                    (binaryBlurMapSlice, binaryBlurMapSlice,
                     binaryBlurMapSlice),
                    dim=1)
                inFocusRegionSlice = torch.mul(nearFocusOutSlice,
                                               binaryBlurMapSlice)
                inFocusRegionSlices.append(inFocusRegionSlice)
                nearFocusOutSlices.append(nearFocusOutSlice)
                binaryBlurMapSlices.append(binaryBlurMapSlice)

            inFocusOutput = helper.CombineFocusRegionsToAIF(
                inFocusRegionSlices, numSlices, device)

            if dispIndex == 0:

                inFocusOutDispOut = helper.TensorToDispImage(inFocusOutput, 3)
                inDataDispOut = helper.TensorToDispImage(inFocus, 3)
                inSegMask = F.interpolate(segMask,
                                          size=None,
                                          scale_factor=1 / pow(2.0, dispIndex),
                                          mode='bilinear')

                save_image(inFocusOutDispOut,
                           resultDirInFocus + str(step) + ".png",
                           nrow=8,
                           padding=2,
                           normalize=False,
                           range=None,
                           scale_each=False,
                           pad_value=0)
                save_image(inDataDispOut,
                           resultInputDir + str(step) + ".png",
                           nrow=8,
                           padding=2,
                           normalize=False,
                           range=None,
                           scale_each=False,
                           pad_value=0)
                save_image(inSegMask,
                           resultDirSegMasks + str(step) + ".png",
                           nrow=8,
                           padding=2,
                           normalize=False,
                           range=None,
                           scale_each=False,
                           pad_value=0)

                for indices in range(numSlices):

                    dirToWrite = resultDirSlices + sliceNames[indices] + '/'
                    sliceDir = dirToWrite + 'FocalSlices/'
                    blurMapDir = dirToWrite + 'BlurMaps/'
                    binaryMapDir = dirToWrite + 'BinaryMaps/'
                    inFocusRegionDir = dirToWrite + 'InFocusRegion/'

                    focalSliceDispOut = helper.TensorToDispImage(
                        nearFocusOutSlices[indices], 3)
                    binaryMapDispOut = binaryBlurMapSlices[indices]
                    inFocusRegionDispOut = helper.TensorToDispImage(
                        inFocusRegionSlices[indices], 3)

                    displayOut = blurMapSlices[indices].mean(1)
                    displayOutDispOut = displayOut.new(*displayOut.size())
                    displayOutDispOut[
                        0, :, :] = displayOut[0, :, :] * 0.5 + 0.5

                    save_image(focalSliceDispOut,
                               sliceDir + str(step) + ".png",
                               nrow=8,
                               padding=2,
                               normalize=False,
                               range=None,
                               scale_each=False,
                               pad_value=0)
                    save_image(displayOutDispOut,
                               blurMapDir + str(step) + ".png",
                               nrow=8,
                               padding=2,
                               normalize=False,
                               range=None,
                               scale_each=False,
                               pad_value=0)
                    save_image(binaryMapDispOut,
                               binaryMapDir + str(step) + ".png",
                               nrow=8,
                               padding=2,
                               normalize=False,
                               range=None,
                               scale_each=False,
                               pad_value=0)
                    save_image(inFocusRegionDispOut,
                               inFocusRegionDir + str(step) + ".png",
                               nrow=8,
                               padding=2,
                               normalize=False,
                               range=None,
                               scale_each=False,
                               pad_value=0)
Пример #5
0
def forecast_rep(month_id, year, rootpath, param_path, device, model_name,
                 num_rep):
    """Run encoder-decoder style models with repetition - results are saved in a folder named forecast_results
    Args:
    month_id: an int indicating the month which is being forecasted
    year: an int indicating the year which is being forecasted
    rootpath: the path where the training and test sets are saved
    param_path: the path where the best hyperparameters are saved
    device: an indication if the model is runing on GPU or CPU
    model_name: a string indicating the name of a model
    num_rep: the number of repetition
    """
    results = {}
    results['prediction_train'] = []
    results['prediction_test'] = []
    train_X = load_results(
        rootpath + 'train_X_pca_{}_forecast{}.pkl'.format(year, month_id))
    test_X = load_results(
        rootpath + 'test_X_pca_{}_forecast{}.pkl'.format(year, month_id))
    train_y = load_results(
        rootpath + 'train_y_pca_{}_forecast{}.pkl'.format(year, month_id))
    test_y = load_results(
        rootpath + 'test_y_pca_{}_forecast{}.pkl'.format(year, month_id))

    if model_name == 'EncoderFNN_AllSeq_AR_CI' or model_name == 'EncoderFNN_AllSeq_AR':
        train_y_ar = load_results(
            rootpath +
            'train_y_pca_ar_{}_forecast{}.pkl'.format(year, month_id))
        test_y_ar = load_results(
            rootpath +
            'test_y_pca_ar_{}_forecast{}.pkl'.format(year, month_id))

    input_dim = train_X.shape[-1]
    output_dim = train_y.shape[-1]
    # ar_dim = train_X.shape[1]
    best_parameter = load_results(
        param_path + '{}_forecast{}.pkl'.format(model_name, month_id))
    for rep in range(num_rep):
        if model_name == 'EncoderDecoder':
            curr_hidden_dim = best_parameter['hidden_dim']
            curr_num_layer = best_parameter['num_layers']
            curr_decoder_len = best_parameter['decoder_len']
            curr_threshold = best_parameter['threshold']
            curr_lr = best_parameter['learning_rate']
            curr_num_epochs = best_parameter['num_epochs']
            mdl = model.EncoderDecoder(input_dim=input_dim,
                                       output_dim=output_dim,
                                       hidden_dim=curr_hidden_dim,
                                       num_layers=curr_num_layer,
                                       learning_rate=curr_lr,
                                       decoder_len=curr_decoder_len,
                                       threshold=curr_threshold,
                                       num_epochs=curr_num_epochs)

            # set data for training
            train_dataset = model.MapDataset(train_X, train_y)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=512,
                                      shuffle=False)
            model.init_weight(mdl)
            # send model to gpu
            mdl.to(device)
            mdl.fit(train_loader, device)
            state = {'state_dict': mdl.state_dict()}
            torch.save(
                state, rootpath +
                'models/{}_{}_{}.t7'.format(model_name, year, month_id))
            pred_train = mdl.predict(train_X, device)
            pred_y = mdl.predict(test_X, device)
        elif model_name == 'EncoderFNN':
            curr_hidden_dim = best_parameter['hidden_dim']
            curr_num_layer = best_parameter['num_layers']
            curr_seq_len = best_parameter['seq_len']
            curr_threshold = best_parameter['threshold']
            curr_lr = best_parameter['learning_rate']
            curr_num_epochs = best_parameter['num_epochs']
            curr_last_layer = best_parameter['last_layer']
            mdl = model.EncoderFNN(input_dim=input_dim,
                                   output_dim=output_dim,
                                   hidden_dim=curr_hidden_dim,
                                   num_layers=curr_num_layer,
                                   last_layer=curr_last_layer,
                                   seq_len=curr_seq_len,
                                   learning_rate=curr_lr,
                                   threshold=curr_threshold,
                                   num_epochs=curr_num_epochs)
            # set data for training
            train_dataset = model.MapDataset(train_X, train_y)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=512,
                                      shuffle=False)
            model.init_weight(mdl)
            # send model to gpu
            mdl.to(device)
            mdl.fit(train_loader, device)
            state = {'state_dict': mdl.state_dict()}
            torch.save(
                state, rootpath +
                'models/{}_{}_{}.t7'.format(model_name, year, month_id))
            pred_train = mdl.predict(train_X, device)
            pred_y = mdl.predict(test_X, device)
        elif model_name == 'EncoderFNN_AllSeq':
            curr_hidden_dim = best_parameter['hidden_dim']
            curr_num_layer = best_parameter['num_layers']
            curr_seq_len = best_parameter['seq_len']
            curr_threshold = best_parameter['threshold']
            curr_lr = best_parameter['learning_rate']
            curr_num_epochs = best_parameter['num_epochs']
            curr_linear_dim = best_parameter['linear_dim']
            curr_drop_out = best_parameter['drop_out']

            mdl = model.EncoderFNN_AllSeq(input_dim=input_dim,
                                          output_dim=output_dim,
                                          hidden_dim=curr_hidden_dim,
                                          num_layers=curr_num_layer,
                                          seq_len=curr_seq_len,
                                          linear_dim=curr_linear_dim,
                                          learning_rate=curr_lr,
                                          dropout=curr_drop_out,
                                          threshold=curr_threshold,
                                          num_epochs=curr_num_epochs)
            # set data for training
            train_dataset = model.MapDataset(train_X, train_y)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=512,
                                      shuffle=False)
            model.init_weight(mdl)
            # send model to gpu
            mdl.to(device)
            mdl.fit(train_loader, device)
            state = {'state_dict': mdl.state_dict()}
            torch.save(
                state, rootpath +
                'models/{}_{}_{}.t7'.format(model_name, year, month_id))
            pred_train = mdl.predict(train_X, device)
            pred_y = mdl.predict(test_X, device)
        elif model_name == 'EncoderFNN_AllSeq_AR':
            curr_hidden_dim = best_parameter['hidden_dim']
            curr_num_layer = best_parameter['num_layers']
            curr_seq_len = best_parameter['seq_len']
            curr_threshold = best_parameter['threshold']
            curr_lr = best_parameter['learning_rate']
            curr_num_epochs = best_parameter['num_epochs']
            curr_linear_dim = best_parameter['linear_dim']
            curr_drop_out = best_parameter['drop_out']
            mdl = model.EncoderFNN_AllSeq_AR(input_dim=input_dim,
                                             output_dim=output_dim,
                                             hidden_dim=curr_hidden_dim,
                                             num_layers=curr_num_layer,
                                             seq_len=curr_seq_len,
                                             linear_dim=curr_linear_dim,
                                             learning_rate=curr_lr,
                                             dropout=curr_drop_out,
                                             threshold=curr_threshold,
                                             num_epochs=curr_num_epochs)
            train_dataset = model.MapDataset_ar(train_X, train_y_ar, train_y)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=512,
                                      shuffle=False)
            model.init_weight(mdl)
            # send model to gpu
            mdl.to(device)
            mdl.fit(train_loader, device)
            state = {'state_dict': mdl.state_dict()}
            torch.save(
                state, rootpath +
                'models/{}_{}_{}.t7'.format(model_name, year, month_id))
            pred_train = mdl.predict(train_X, train_y_ar, device)
            pred_y = mdl.predict(test_X, test_y_ar, device)
        elif model_name == 'EncoderFNN_AllSeq_AR_CI':
            curr_hidden_dim = best_parameter['hidden_dim']
            curr_num_layer = best_parameter['num_layers']
            curr_seq_len = best_parameter['seq_len']
            curr_threshold = best_parameter['threshold']
            curr_lr = best_parameter['learning_rate']
            curr_num_epochs = best_parameter['num_epochs']
            curr_linear_dim = best_parameter['linear_dim']
            curr_drop_out = best_parameter['drop_out']
            ci_dim = best_parameter['ci_dim']
            mdl = model.EncoderFNN_AllSeq_AR_CI(input_dim=input_dim - ci_dim,
                                                output_dim=output_dim,
                                                hidden_dim=curr_hidden_dim,
                                                num_layers=curr_num_layer,
                                                seq_len=curr_seq_len,
                                                linear_dim=curr_linear_dim,
                                                ci_dim=ci_dim,
                                                learning_rate=curr_lr,
                                                dropout=curr_drop_out,
                                                threshold=curr_threshold,
                                                num_epochs=curr_num_epochs)
            train_dataset = model.MapDataset_ar(train_X, train_y_ar, train_y)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=512,
                                      shuffle=False)
            model.init_weight(mdl)
            # send model to gpu
            mdl.to(device)
            mdl.fit(train_loader, device)
            state = {'state_dict': mdl.state_dict()}
            torch.save(
                state, rootpath +
                'models/{}_{}_{}.t7'.format(model_name, year, month_id))
            pred_train = mdl.predict(train_X, train_y_ar, device)
            pred_y = mdl.predict(test_X, test_y_ar, device)
        results['target_train'] = train_y
        results['prediction_train'].append(pred_train)
        results['target_test'] = test_y
        results['prediction_test'].append(pred_y)

    save_results(
        rootpath + 'forecast_results/results_{}_{}_{}.pkl'.format(
            model_name, year, month_id), results)
Пример #6
0
def ids2words(lang, ids):
    return [lang.index2word[idx] for idx in ids]

def greedy_decode(model, dataloader, input_lang, output_lang):
    with torch.no_grad():
        batch = next(iter(dataloader))
        input_tensor  = batch[0]
        input_mask    = batch[1]
        target_tensor = batch[2]

        decoder_outputs, decoder_hidden = model(input_tensor, input_mask)
        topv, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        for idx in range(input_tensor.size(0)):
            input_sent = ids2words(input_lang, input_tensor[idx].cpu().numpy())
            output_sent = ids2words(output_lang, decoded_ids[idx].cpu().numpy())
            target_sent = ids2words(output_lang, target_tensor[idx].cpu().numpy())
            print('Input:  {}'.format(input_sent))
            print('Target: {}'.format(target_sent))
            print('Output: {}'.format(output_sent))


if __name__ == '__main__':
    input_lang, output_lang, train_dataloader = load_data.get_dataloader(batch_size)
    model = model.EncoderDecoder(hidden_size, input_lang.n_words, output_lang.n_words).to(device)
    train(train_dataloader, model, n_epochs=20)
    greedy_decode(model, train_dataloader, input_lang, output_lang)


Пример #7
0
def train(epochs):
    imTransform = transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainLoader = helper.PreprocessInFocusOnlyData(imTransform, rootDirTrain,
                                                   batchSize, True)
    validLoader = helper.PreprocessInFocusOnlyData(imTransform, rootDirValid,
                                                   batchSize, True)
    slices = ('outputSlice1', 'outputSlice4', 'outputSlice7', 'outputSlice10')
    numSlices = len(slices)

    rootDirToSave = rootDirSave
    model1SaveDir = rootDirToSave + 'model1/'
    model2SaveDir = rootDirToSave + 'model2/'
    model3SaveDir = rootDirToSave + 'model3/'
    extension = '.pth'

    loadRootDir = rootDirSave
    model1LoadPath = loadRootDir + 'model1/'
    model2LoadPath = loadRootDir + 'model2/'
    model3LoadPath = loadRootDir + 'model3/'

    model1 = model.EncoderDecoder(3, 6 * numSlices).to(device)
    model2 = model.DepthToInFocus(3).to(device)
    model3 = model.EncoderDecoderv2(3, 1).to(device)

    model1.load_state_dict(
        torch.load(model1LoadPath + str(loadEpochNum) + extension))
    model2.load_state_dict(
        torch.load(model2LoadPath + str(loadEpochNum) + extension))
    model3.load_state_dict(
        torch.load(model3LoadPath + str(loadEpochNum) + extension))

    contentLoss = losses.PerceptualLoss()
    contentLoss.initialize(nn.MSELoss())

    trainLossArray = np.zeros(epochs)
    validLossArray = np.zeros(epochs)

    lr = learningRate
    minValidEpoch = 0
    minValidLoss = 100.0

    for epoch in range(epochs):
        start = time.clock()

        epochLoss = np.zeros(2)
        epochCount = epoch + 1

        print("Epoch num : " + str(epochCount))
        if epoch >= 20 and epoch % 20 == 0:
            lr = lr / 2
            print('Learning rate changed to ' + str(lr))

        for step, sample in enumerate(trainLoader):

            loss = torch.zeros(1).to(device)
            loss.requires_grad = False

            data = sample.to(device)

            blurOut = model1(data)

            for dispIndex in range(len(blurOut)):

                scaleLoss = torch.zeros(1).to(device)
                scaleLoss.requires_grad = False

                dispOut = blurOut[dispIndex]

                inFocus = F.interpolate(data,
                                        size=None,
                                        scale_factor=1 / pow(2.0, dispIndex),
                                        mode='bilinear')

                blurMapSlices = helper.BlurMapsToSlices(dispOut, numSlices, 6)
                inFocusRegionSlices = []

                for slices in range(numSlices):

                    nearFocusOutSlice = model2(inFocus, blurMapSlices[slices])
                    binaryBlurMapSlice = model3(nearFocusOutSlice)

                    inFocusRegionSlice = torch.mul(nearFocusOutSlice,
                                                   binaryBlurMapSlice)
                    inFocusRegionSlices.append(inFocusRegionSlice)

                inFocusOutput = helper.CombineFocusRegionsToAIF(
                    inFocusRegionSlices, numSlices, device)

                perceptLoss = contentLoss.get_loss(inFocus, inFocusOutput)
                simiLoss = losses.similarityLoss(inFocus, inFocusOutput, alpha,
                                                 winSize)
                scaleLoss += perceptLossW * perceptLoss + simiLossW * simiLoss

                loss += scaleLoss

            epochLoss[0] += loss

            optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr,\
            betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step % 500 == 0:
                print("Step num :" + str(step) + " Train Loss :",
                      str(loss.item()))

        with torch.no_grad():
            for validStep, validSample in enumerate(validLoader):

                validLoss = torch.zeros(1).to(device)
                validLoss.requires_grad = False

                validData = sample.to(device)

                validBlurOut = model1(validData)

                for validDispIndex in range(len(validBlurOut)):

                    validScaleLoss = torch.zeros(1).to(device)
                    validScaleLoss.requires_grad = False

                    validDispOut = validBlurOut[validDispIndex]

                    validInFocus = F.interpolate(validData,
                                                 size=None,
                                                 scale_factor=1 /
                                                 pow(2.0, validDispIndex),
                                                 mode='bilinear')

                    validBlurMapSlices = helper.BlurMapsToSlices(
                        validDispOut, numSlices, 6)
                    validInFocusRegionSlices = []

                    for validSlices in range(numSlices):

                        validNearFocusOutSlice = model2(
                            validInFocus, validBlurMapSlices[validSlices])
                        validBinaryBlurMapSlice = model3(
                            validNearFocusOutSlice)
                        validInFocusRegionSlice = torch.mul(
                            validNearFocusOutSlice, validBinaryBlurMapSlice)
                        validInFocusRegionSlices.append(
                            validInFocusRegionSlice)

                    validInFocusOutput = helper.CombineFocusRegionsToAIF(
                        validInFocusRegionSlices, numSlices, device)

                    validPerceptLoss = contentLoss.get_loss(
                        validInFocus, validInFocusOutput)
                    validSimiLoss = losses.similarityLoss(
                        validInFocus, validInFocusOutput, alpha, winSize)
                    validScaleLoss += perceptLossW * validPerceptLoss + simiLossW * validSimiLoss

                    validLoss += validScaleLoss

                epochLoss[1] += validLoss

        epochLoss[0] /= step
        epochLoss[1] /= validStep
        epochFreq = epochSaveFreq
        print("Time taken for epoch " + str(epoch) + " is " +
              str(time.clock() - start))

        if validLoss < minValidLoss:
            minValidLoss = validLoss
            minValidEpoch = epoch
            torch.save(model1.state_dict(),
                       model1SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model2.state_dict(),
                       model2SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model3.state_dict(),
                       model3SaveDir + str(epoch) + 'generic_' + extension)
            print("Saving latest model at epoch: " + str(epoch))

        if epochCount % epochFreq == 0:
            torch.save(model1.state_dict(),
                       model1SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model2.state_dict(),
                       model2SaveDir + str(epoch) + 'generic_' + extension)
            torch.save(model3.state_dict(),
                       model3SaveDir + str(epoch) + 'generic_' + extension)

        print("Training loss for epoch " + str(epochCount) + " : " +
              str(epochLoss[0]))
        print("Validation loss for epoch " + str(epochCount) + " : " +
              str(epochLoss[1]))

        trainLossArray[epoch] = epochLoss[0]
        validLossArray[epoch] = epochLoss[1]

    print(minValidEpoch)
    return (trainLossArray, validLossArray)
Пример #8
0
def train(epochs) :
	imTransform = transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
	sliceNames =  ('outputSlice1','outputSlice10')
	numSlices = len(sliceNames)

	trainLoaderLF = helper.PreprocessData(imTransform, sliceNames, rootDirTrainLF, 'input', 'segments', batchSize, True)
	validLoaderLF = helper.PreprocessData(imTransform, sliceNames, rootDirValidLF, 'input', 'segments', batchSize, True)

	trainLoaderReal = helper.PreprocessInFocusOnlyData(imTransform, rootDirTrainReal, 'inputs', 'masks', batchSize, False)
	validLoaderReal = helper.PreprocessInFocusOnlyData(imTransform, rootDirValidReal, 'inputs', 'masks', batchSize, False)



	rootDirToSave	= rootDirSave
	model1SaveDir 	= rootDirToSave + 'model1/'
	model2SaveDir 	= rootDirToSave + 'model2/'
	model3SaveDir 	= rootDirToSave + 'model3/'
	extension		= '.pth'

	model1 = model.EncoderDecoder(4, 6 * numSlices).to(device)
	model2 = model.DepthToInFocus(3).to(device)
	model3 = model.EncoderDecoderv2(3, 1).to(device)
	
	contentLoss = losses.PerceptualLoss()
	contentLoss.initialize(nn.MSELoss())


	trainLossArray = np.zeros(epochs)
	validLossArray = np.zeros(epochs)

	lr = learningRate
	minValidEpoch = 0
	minValidLoss = 100.0

	trainLoaderRealList = list(trainLoaderReal)
	validLoaderRealList = list(validLoaderReal)
	numRealSteps = len(trainLoaderRealList)

	for epoch in range(epochs) :
		start = time.clock()

		realStep = 0

		epochLoss = np.zeros(2)
		epochCount = epoch + 1

		print ("Epoch num : " + str(epochCount))
		if epoch >= 20 and epoch % 20 == 0 :
			lr = lr / 2
			print ('Learning rate changed to ' + str(lr))
		if epoch < 40 :
			binarizationW = 0.0

		for step, sample in enumerate(trainLoaderLF) :

			if step % 3 == 1 and realStep < numRealSteps:
				realStepLoss = torch.zeros(1).to(device)
				realInData = trainLoaderRealList[realStep]['Input'].to(device)
				realSegMask = trainLoaderRealList[realStep]['Mask'].to(device)
				realBlurOut = model1(realInData, realSegMask)
				for realDispIndex in range(len(realBlurOut)) :
					realScaleLoss = torch.zeros(1).to(device)
					realBinarizationLoss = torch.zeros(1).to(device)
					realDispOut = realBlurOut[realDispIndex]
					realInFocus = F.interpolate(realInData, size = None, scale_factor = 1/pow(2.0, realDispIndex), mode = 'bilinear')
					realBlurMapSlices = helper.BlurMapsToSlices(realDispOut, numSlices, 6)
					realInFocusRegionSlices = []
					for realSlices in range(numSlices) :
						realNearFocusOutSlice = model2(realInFocus, realBlurMapSlices[realSlices])
						realBinaryBlurMapSlice = model3(realNearFocusOutSlice)
						realBinaryBlurMapSlice = torch.cat((realBinaryBlurMapSlice, realBinaryBlurMapSlice, realBinaryBlurMapSlice), dim = 1)
						binaryLoss = losses.binarizationLoss(realBinaryBlurMapSlice) * binarizationLossW
						realBinarizationLoss += binaryLoss
						realInFocusRegionSlice = torch.mul(realNearFocusOutSlice, realBinaryBlurMapSlice)
						realInFocusRegionSlices.append(realInFocusRegionSlice)
					realInFocusOutput = helper.CombineFocusRegionsToAIF(realInFocusRegionSlices, numSlices, device)
					realPerceptLoss = contentLoss.get_loss(realInFocus, realInFocusOutput)
					realSimiLoss = losses.similarityLoss(realInFocus, realInFocusOutput, alpha, winSize)
					realScaleLoss += perceptLossW * realPerceptLoss + simiLossW * realSimiLoss + realBinarizationLoss
					realStepLoss += realScaleLoss
					loss += realStepLoss
				optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr/5,\
				betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizer.zero_grad()
				realStepLoss.backward(retain_graph = True)
				optimizer.step()
				realStep += 1

			stepLossStage1 = torch.zeros(1).to(device)
			stepLossStage1.requires_grad = False

			stepLossStage2 = torch.zeros(1).to(device)
			stepLossStage2.requires_grad = False

			loss = torch.zeros(1).to(device)
			loss.requires_grad = False
			
			inData 	= sample['Input'].to(device)
			outData = sample['Output'].to(device)
			inSegMask = sample['Mask'].to(device)

			blurOut = model1(inData, inSegMask)

			for dispIndex in range(len(blurOut)) :

				scaleLossStage1 = torch.zeros(1).to(device)
				scaleLossStage1.requires_grad = False

				scaleLossStage2 = torch.zeros(1).to(device)
				scaleLossStage2.requires_grad = False

				dispOut 	= blurOut[dispIndex]

				inFocus	= F.interpolate(inData, size = None, scale_factor = 1 / pow(2.0, dispIndex), mode = 'bilinear')
				nearFocusStackGT 	= F.interpolate(outData, size = None, scale_factor = 1 / pow(2.0, dispIndex) , mode = 'bilinear')

				nearFocusGTSlices = helper.FocalStackToSlices(nearFocusStackGT, numSlices)
				blurMapSlices = helper.BlurMapsToSlices(dispOut, numSlices, 6)
				inFocusRegionSlices = []

				
				for slices in range(numSlices) :

					sliceLoss = torch.zeros(1).to(device)
					sliceLoss.requires_grad = False
					#print (time.clock() - start)
				
					nearFocusOutSlice = model2(inFocus, blurMapSlices[slices])
					#print (time.clock() - start)

					perceptLoss = contentLoss.get_loss(nearFocusGTSlices[slices], nearFocusOutSlice)
					simiLoss = losses.similarityLoss(nearFocusGTSlices[slices], nearFocusOutSlice, alpha, winSize)

					sliceLoss = perceptLossW * perceptLoss + simiLoss * simiLoss

					binaryBlurMapSlice = model3(nearFocusOutSlice)
					binaryBlurMapSlice = torch.cat((binaryBlurMapSlice, binaryBlurMapSlice, binaryBlurMapSlice), dim = 1)
					binaryLoss = losses.binarizationLoss(binaryBlurMapSlice) * binarizationLossW
					sliceLoss += binaryLoss
					inFocusRegionSlice = torch.mul(nearFocusOutSlice, binaryBlurMapSlice)
					inFocusRegionSlices.append(inFocusRegionSlice)

					scaleLossStage1 += sliceLoss

				inFocusOutput = helper.CombineFocusRegionsToAIF(inFocusRegionSlices, numSlices, device)

				perceptLossStage2 = contentLoss.get_loss(inFocus, inFocusOutput)
				simiLossStage2 = losses.similarityLoss(inFocus, inFocusOutput, alpha, winSize)
				scaleLossStage2 += perceptLossW * perceptLossStage2 + simiLossW * simiLossStage2

				stepLossStage1 += scaleLossStage1
				stepLossStage2 += scaleLossStage2

				loss += stepLossStage1 + stepLossStage2
			
			epochLoss[0] += loss

			if jointTraining == True :
				optimizer = optim.Adam(list(model1.parameters()) + list(model2.parameters()) + list(model3.parameters()), lr=lr,\
				betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizer.zero_grad()
				loss.backward()
				optimizer.step()
			else :
				stage1Parameters = list(model1.parameters()) + list(model2.parameters())
				stage2Parameters = list(model3.parameters())

				optimizerStage1 = optim.Adam(stage1Parameters, lr=lr, betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)
				optimizerStage2 = optim.Adam(stage2Parameters, lr=lr, betas=(beta1, beta2), eps=1e-08, weight_decay=0, amsgrad=False)

				optimizerStage1.zero_grad()
				stepLossStage1.backward(retain_graph = True)
				optimizerStage1.step()

				optimizerStage2.zero_grad()
				stepLossStage2.backward()
				optimizerStage2.step()

			if step % 500 == 0 :
					print ("Step num :" + str(step) + " Train Loss :", str(loss.item()))

		with torch.no_grad() :
			validRealStep = 0
			for validStep, validSample in enumerate(validLoaderLF) :

				if validRealStep % 3 == 1 and validRealStep < numRealSteps:
					validRealStepLoss = torch.zeros(1).to(device)
					validRealInData = validLoaderRealList[realStep]['Input'].to(device)
					validRealSegMask = validLoaderRealList[realStep]['Mask'].to(device)
					validRealBlurOut = model1(validRealInData, validRealSegMask)
					for validRealDispIndex in range(len(validRealBlurOut)) :
						validRealScaleLoss = torch.zeros(1).to(device)
						validRealDispOut = validRealBlurOut[validRealDispIndex]
						validRealInFocus = F.interpolate(validRealInData, size = None, scale_factor = 1/pow(2.0, validRealDispIndex), mode = 'bilinear')
						validRealBlurMapSlices = helper.BlurMapsToSlices(validRealDispOut, numSlices, 6)
						validRealInFocusRegionSlices = []
						for validRealSlices in range(numSlices) :
							validRealNearFocusOutSlice = model2(validRealInFocus, validRealBlurMapSlices[validRealSlices])
							validRealBinaryBlurMapSlice = model3(validRealNearFocusOutSlice)
							validRealBinaryBlurMapSlice = torch.cat((validRealBinaryBlurMapSlice, validRealBinaryBlurMapSlice, validRealBinaryBlurMapSlice), dim = 1)
							validRealInFocusRegionSlice = torch.mul(validRealNearFocusOutSlice, validRealBinaryBlurMapSlice)
							validRealInFocusRegionSlices.append(validRealInFocusRegionSlice)
						validRealInFocusOutput = helper.CombineFocusRegionsToAIF(validRealInFocusRegionSlices, numSlices, device)
						validRealPerceptLoss = contentLoss.get_loss(validRealInFocus, validRealInFocusOutput)
						validRealSimiLoss = losses.similarityLoss(validRealInFocus, validRealInFocusOutput, alpha, winSize)
						validRealScaleLoss += perceptLossW * validRealPerceptLoss + simiLossW * validRealSimiLoss
						validRealStepLoss += validRealScaleLoss

					validRealStep += 1
					validLoss += validRealStepLoss

				validStepLossStage1 = torch.zeros(1).to(device)
				validStepLossStage1.requires_grad = False

				validStepLossStage2 = torch.zeros(1).to(device)
				validStepLossStage2.requires_grad = False

				validLoss = torch.zeros(1).to(device)
				validLoss.requires_grad = False
			
				validInData 	= validSample['Input'].to(device)
				validOutData = validSample['Output'].to(device)
				validInSegMask = validSample['Mask'].to(device)

				validBlurOut = model1(validInData, validInSegMask)


				for validDispIndex in range(len(validBlurOut)) :

					validScaleLossStage1 = torch.zeros(1).to(device)
					validScaleLossStage1.requires_grad = False

					validScaleLossStage2 = torch.zeros(1).to(device)
					validScaleLossStage2.requires_grad = False

					validDispOut 	= validBlurOut[validDispIndex]

					validInFocus	= F.interpolate(validInData, size = None, scale_factor = 1 / pow(2.0, validDispIndex), mode = 'bilinear')
					validNearFocusStackGT 	= F.interpolate(validOutData, size = None, scale_factor = 1 / pow(2.0, validDispIndex) , mode = 'bilinear')

					validNearFocusGTSlices = helper.FocalStackToSlices(validNearFocusStackGT, numSlices)
					validBlurMapSlices = helper.BlurMapsToSlices(validDispOut, numSlices, 6)
					validInFocusRegionSlices = []

				
					for validSlices in range(numSlices) :

						validSliceLoss = torch.zeros(1).to(device)
						validSliceLoss.requires_grad = False

						validNearFocusOutSlice = model2(validInFocus, validBlurMapSlices[validSlices])

						validPerceptLoss = contentLoss.get_loss(validNearFocusGTSlices[slices], validNearFocusOutSlice)
						validSimiLoss = losses.similarityLoss(validNearFocusGTSlices[slices], validNearFocusOutSlice, alpha, winSize)

						validSliceLoss = perceptLossW * validPerceptLoss + validSimiLoss * validSimiLoss

						validBinaryBlurMapSlice = model3(validNearFocusOutSlice)
						validBinaryBlurMapSlice = torch.cat((validBinaryBlurMapSlice, validBinaryBlurMapSlice, validBinaryBlurMapSlice), dim = 1)
						validBinaryLoss = losses.binarizationLoss(validBinaryBlurMapSlice) * binarizationLossW
						validSliceLoss += validBinaryLoss
						validInFocusRegionSlice = torch.mul(validNearFocusOutSlice, validBinaryBlurMapSlice)
						validInFocusRegionSlices.append(validInFocusRegionSlice)

						validScaleLossStage1 += validSliceLoss

					validInFocusOutput = helper.CombineFocusRegionsToAIF(validInFocusRegionSlices, numSlices, device)

					validPerceptLossStage2 = contentLoss.get_loss(validInFocus, validInFocusOutput)
					validSimiLossStage2 = losses.similarityLoss(validInFocus, validInFocusOutput, alpha, winSize)
					validScaleLossStage2 += perceptLossW * validPerceptLossStage2 + simiLossW * validSimiLossStage2

					validStepLossStage1 += validScaleLossStage1
					validStepLossStage2 += validScaleLossStage2

					validLoss += validStepLossStage1 + validStepLossStage2
			
				epochLoss[1] += validLoss

		epochLoss[0] /= step
		epochLoss[1] /= validStep
		epochFreq = epochSaveFreq
		print ("Time taken for epoch " + str(epoch) + " is " + str(time.clock() - start))
		
		if validLoss < minValidLoss :
			minValidLoss = validLoss
			minValidEpoch = epoch
			torch.save(model1.state_dict(), model1SaveDir + str(epoch) + extension)
			torch.save(model2.state_dict(), model2SaveDir + str(epoch) + extension)
			torch.save(model3.state_dict(), model3SaveDir + str(epoch) + extension)
			print ("Saving latest model at epoch: " + str(epoch))

		if epochCount % epochFreq == 0 :
			torch.save(model1.state_dict(), model1SaveDir + str(epoch) + extension)
			torch.save(model2.state_dict(), model2SaveDir + str(epoch) + extension)
			torch.save(model3.state_dict(), model3SaveDir + str(epoch) + extension)
					
		print ("Training loss for epoch " + str(epochCount) + " : " + str(epochLoss[0]))
		print ("Validation loss for epoch " + str(epochCount) + " : " + str(epochLoss[1]))

		trainLossArray[epoch] = epochLoss[0]
		validLossArray[epoch] = epochLoss[1]
	
	print (minValidEpoch)
	return (trainLossArray, validLossArray)