コード例 #1
0
    def define_loss(self, dataset):
        if self.output_type == "regression":
            self.criterion = nnloss.MSELoss()
        else:
            if self.weighted_loss:
                num_samples = len(dataset)

                # distribution of classes in the dataset
                label_to_count = {n: 0 for n in range(self.number_classes)}
                for idx in list(range(num_samples)):
                    label = dataset.load_datapoint(idx)[-1]
                    label_to_count[label] += 1

                label_percentage = {
                    l: label_to_count[l] / num_samples for l in label_to_count.keys()
                }
                median_perc = median(list(label_percentage.values()))
                class_weights = [
                    median_perc / label_percentage[c] if label_percentage[c] != 0 else 0
                    for c in range(self.number_classes)
                ]
                weights = torch.FloatTensor(class_weights).to(self.device)

            else:
                weights = None

            if self.classification_loss_type == 'cross-entropy':
                self.criterion = nnloss.CrossEntropyLoss(weight=weights)
            else:
                if weights is not None:
                    raise NotImplementedErrore
                self.criterion = log_f1_micro_loss
コード例 #2
0
ファイル: utils.py プロジェクト: smahliivaza/hackathon
def get_criterion(name, **kwargs):
    """
    Returns criterion by name.

    :param name: criterion name (str)
    :param kwargs: kwargs passed to criterion constructor.
    :return: corresponding criterion from torch.nn module.
    """
    return {
        'bce': loss.BCELoss(),
        'bcewithlogits': loss.BCEWithLogitsLoss(),
        'cosineembedding': loss.CosineEmbeddingLoss(),
        'crossentropy': loss.CrossEntropyLoss(),
        'hingeembedding': loss.HingeEmbeddingLoss(),
        'kldiv': loss.KLDivLoss(),
        'l1': loss.L1Loss(),
        'mse': loss.MSELoss(),
        'marginranking': loss.MarginRankingLoss(),
        'multilabelmargin': loss.MultiLabelMarginLoss(),
        'multilabelsoftmargin': loss.MultiLabelSoftMarginLoss(),
        'multimargin': loss.MultiMarginLoss(),
        'nll': loss.NLLLoss(),
        'nll2d': loss.NLLLoss2d(),
        'poissonnll': loss.PoissonNLLLoss(),
        'smoothl1': loss.SmoothL1Loss(),
        'softmargin': loss.SoftMarginLoss(),
        'tripletmargin': loss.TripletMarginLoss()
    }[name.strip().lower()]
コード例 #3
0
    def __init__(self, args):
        input_size = (args.inputSize, args.inputSize)

        self.run_name = args.runName
        self.input_size = input_size
        self.lr = args.learningRate

        self.criterion = nnloss.MSELoss()

        self.transforms = {}

        self.model = SiameseNetwork()

        if torch.cuda.device_count() > 1:
            logger.info('Using {} GPUs'.format(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model)

        for s in ('train', 'validation', 'test'):
            self.transforms[s] = get_pretrained_iv3_transforms(s)

        logger.debug('Num params: {}'.format(
            len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = ReduceLROnPlateau(self.optimizer,
                                              factor=0.1,
                                              patience=10,
                                              min_lr=1e-5,
                                              verbose=True)
コード例 #4
0
ファイル: loss.py プロジェクト: SpicyYeol/Pytorch_rppgs
def loss_fn(loss_fn: str = "mse"):
    """
    :param loss_fn: implement loss function for training
    :return: loss function module(class)
    """
    if loss_fn == "mse":
        return loss.MSELoss()
    elif loss_fn == "L1":
        return loss.L1Loss()
    elif loss_fn == "neg_pearson":
        return NegPearsonLoss()
    elif loss_fn == "multi_margin":
        return loss.MultiMarginLoss()
    elif loss_fn == "bce":
        return loss.BCELoss()
    elif loss_fn == "huber":
        return loss.HuberLoss()
    elif loss_fn == "cosine_embedding":
        return loss.CosineEmbeddingLoss()
    elif loss_fn == "cross_entropy":
        return loss.CrossEntropyLoss()
    elif loss_fn == "ctc":
        return loss.CTCLoss()
    elif loss_fn == "bce_with_logits":
        return loss.BCEWithLogitsLoss()
    elif loss_fn == "gaussian_nll":
        return loss.GaussianNLLLoss()
    elif loss_fn == "hinge_embedding":
        return loss.HingeEmbeddingLoss()
    elif loss_fn == "KLDiv":
        return loss.KLDivLoss()
    elif loss_fn == "margin_ranking":
        return loss.MarginRankingLoss()
    elif loss_fn == "multi_label_margin":
        return loss.MultiLabelMarginLoss()
    elif loss_fn == "multi_label_soft_margin":
        return loss.MultiLabelSoftMarginLoss()
    elif loss_fn == "nll":
        return loss.NLLLoss()
    elif loss_fn == "nll2d":
        return loss.NLLLoss2d()
    elif loss_fn == "pairwise":
        return loss.PairwiseDistance()
    elif loss_fn == "poisson_nll":
        return loss.PoissonNLLLoss()
    elif loss_fn == "smooth_l1":
        return loss.SmoothL1Loss()
    elif loss_fn == "soft_margin":
        return loss.SoftMarginLoss()
    elif loss_fn == "triplet_margin":
        return loss.TripletMarginLoss()
    elif loss_fn == "triplet_margin_distance":
        return loss.TripletMarginWithDistanceLoss()
    else:
        log_warning("use implemented loss functions")
        raise NotImplementedError(
            "implement a custom function(%s) in loss.py" % loss_fn)
コード例 #5
0
    def __init__(self, scale: float = 1.0, size_average=None, reduce=None, reduction='mean'):
        """
        reconstruction loss.
        :return L2(x, x')

        """
        super(ReconstructionLoss, self).__init__(size_average, reduce, reduction)
        # sample-wise & element-wise mean
        self._mse_loss = L.MSELoss(reduction=reduction)
        self._scale = scale
コード例 #6
0
ファイル: trainer.py プロジェクト: p-phung/caladrius
    def __init__(self, args):
        input_size = (args.input_size, args.input_size)

        self.run_name = args.run_name
        self.input_size = input_size
        self.lr = args.learning_rate
        self.output_type = args.output_type

        network_architecture_class = InceptionSiameseNetwork
        network_architecture_transforms = get_pretrained_iv3_transforms
        if args.model_type == "light":
            network_architecture_class = LightSiameseNetwork
            network_architecture_transforms = get_light_siamese_transforms

        # define the loss measure
        if self.output_type == "regression":
            self.criterion = nnloss.MSELoss()
            self.model = network_architecture_class()
        elif self.output_type == "classification":
            self.criterion = nnloss.CrossEntropyLoss()
            self.n_classes = 4  # replace by args
            self.model = network_architecture_class(
                output_type=self.output_type, n_classes=self.n_classes
            )

        self.transforms = {}

        if torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs".format(torch.cuda.device_count()))
            self.model = torch.nn.DataParallel(self.model)

        for s in ("train", "validation", "test", "inference"):
            self.transforms[s] = network_architecture_transforms(s)

        logger.debug("Num params: {}".format(len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        # reduces the learning rate when loss plateaus, i.e. doesn't improve
        self.lr_scheduler = ReduceLROnPlateau(
            self.optimizer, factor=0.1, patience=10, min_lr=1e-5, verbose=True
        )
        # creates tracking file for tensorboard
        self.writer = SummaryWriter(args.checkpoint_path)

        self.device = args.device
        self.model_path = args.model_path
        self.prediction_path = args.prediction_path
        self.model_type = args.model_type
        self.is_statistical_model = args.statistical_model
        self.is_neural_model = args.neural_model
        self.log_step = args.log_step
コード例 #7
0
ファイル: losses.py プロジェクト: ZWLori/pixelda_pytorch
def transfer_similarity_loss(reconstructions, source_images, weight):
    """
    Computes a loss encouraging similarity between source and transferred.
    """

    if weight == 0:
        return 0

    # todo : check pairewise mse
    reconstruction_similarity_criterion = loss.MSELoss()
    reconstruction_similarity_loss = reconstruction_similarity_criterion(
        reconstructions, source_images)

    return reconstruction_similarity_loss
コード例 #8
0
ファイル: main.py プロジェクト: piravp/XAI-CBR
def test_network():
    wine_attributes = ["alch","malic","ash","alcash","mag","phen","flav","nfphens","proant","color","hue","dil","prol"]
    columns = ["class","alch","malic","ash","alcash","mag","phen","flav","nfphens","proant","color","hue","dil","prol"]
    """
        0) class
    	1) Alcohol
        2) Malic acid
        3) Ash
        4) Alcalinity of ash  
        5) Magnesium
        6) Total phenols
        7) Flavanoids
        8) Nonflavanoid phenols
        9) Proanthocyanins
        10)Color intensity
        11)Hue
        12)OD280/OD315 of diluted wines
        13)Proline    
    """

    from Induction.IntGrad.integratedGradients import random_baseline_integrated_gradients, integrated_gradients
    exit()

    df = read_data_pd("../../Data/wine.csv",columns = columns)

    df.columns = columns # Add columns to dataframe.
    #Cov.columns = ["Sequence", "Start", "End", "Coverage"]
    dataman = Datamanager.Datamanager(dataframe_train=df,classes=3,dataset="wine")   

    model = network.NN_3_25("wine",in_dim=13,out_dim=3)
    print(model.input_type)
    optimizer = optim.Adam(model.parameters(), lr=0.01,betas=(0.9,0.999),eps=1e-6)
    #loss = network.RootMeanSquareLoss()
    #loss = t_loss.L1Loss()
    loss = t_loss.MSELoss()
    network.train(model, dataman,validation=True, optimizer = optimizer,loss_function = loss, batch=20, iterations=50)
コード例 #9
0
def get_criterion(name, **kwargs):
    """
    Returns criterion instance given the name.

    Args:
        name (str): criterion name
        kwargs (dict): keyword arguments passed to criterion constructor

    Returns:
        Corresponding criterion from torch.nn module

    Available criteria:
        BCE, BCEWithLogits, CosineEmbedding, CrossEntropy, HingeEmbedding, KLDiv,
        L1, MSE, MarginRanking, MultilabelMargin, MultilabelSoftmargin, MultiMargin,
        NLL, PoissoNLL, SmoothL1, SoftMargin, TripletMargin

    """
    return {
        'bce': loss.BCELoss(**kwargs),
        'bcewithlogits': loss.BCEWithLogitsLoss(**kwargs),
        'cosineembedding': loss.CosineEmbeddingLoss(**kwargs),
        'crossentropy': loss.CrossEntropyLoss(**kwargs),
        'hingeembedding': loss.HingeEmbeddingLoss(**kwargs),
        'kldiv': loss.KLDivLoss(**kwargs),
        'l1': loss.L1Loss(**kwargs),
        'mse': loss.MSELoss(**kwargs),
        'marginranking': loss.MarginRankingLoss(**kwargs),
        'multilabelmargin': loss.MultiLabelMarginLoss(**kwargs),
        'multilabelsoftmargin': loss.MultiLabelSoftMarginLoss(**kwargs),
        'multimargin': loss.MultiMarginLoss(**kwargs),
        'nll': loss.NLLLoss(**kwargs),
        'poissonnll': loss.PoissonNLLLoss(**kwargs),
        'smoothl1': loss.SmoothL1Loss(**kwargs),
        'softmargin': loss.SoftMarginLoss(**kwargs),
        'tripletmargin': loss.TripletMarginLoss(**kwargs)
    }[name.strip().lower()]
コード例 #10
0
    def __init__(self, args):
        input_size = (args.input_size, args.input_size)

        self.run_name = args.run_name
        self.input_size = input_size
        self.lr = args.learning_rate
        self.output_type = args.output_type
        self.test_epoch = args.test_epoch
        self.freeze = args.freeze
        self.no_augment = args.no_augment
        self.augment_type = args.augment_type
        self.weighted_loss = args.weighted_loss
        self.save_all = args.save_all
        self.probability = args.probability
        self.classification_loss_type = args.classification_loss_type
        self.disable_cuda = args.disable_cuda
        network_architecture_class = InceptionSiameseNetwork
        network_architecture_transforms = get_pretrained_iv3_transforms
        if args.model_type == "shared":
            network_architecture_class = InceptionSiameseShared
            network_architecture_transforms = get_pretrained_iv3_transforms
        elif args.model_type == "light":
            network_architecture_class = LightSiameseNetwork
            network_architecture_transforms = get_light_siamese_transforms
        elif args.model_type == "after":
            network_architecture_class = InceptionCNNNetwork
            network_architecture_transforms = get_pretrained_iv3_transforms
        elif args.model_type == "vgg":
            network_architecture_class = VggSiameseNetwork
            network_architecture_transforms = get_pretrained_vgg_transforms
        elif args.model_type == "attentive":
            network_architecture_class = AttentiveNetwork
            network_architecture_transforms = get_pretrained_attentive_transforms

        # define the loss measure
        if self.output_type == "regression":
            self.criterion = nnloss.MSELoss()
            self.model = network_architecture_class()
        elif self.output_type == "classification":
            self.number_classes = args.number_classes
            self.model = network_architecture_class(
                output_type=self.output_type,
                n_classes=self.number_classes,
                freeze=self.freeze,
            )

            if self.classification_loss_type == 'cross-entropy':
                self.criterion = nnloss.CrossEntropyLoss()
            else:
                self.criterion = log_f1_micro_loss

        self.transforms = {}

        if torch.cuda.device_count() > 1:
            logger.info("Using {} GPUs".format(torch.cuda.device_count()))
            self.model = torch.nn.DataParallel(self.model)

        for s in ("train", "validation", "test", "inference"):
            self.transforms[s] = network_architecture_transforms(
                s, self.no_augment, self.augment_type
            )

        logger.debug("Num params: {}".format(len([_ for _ in self.model.parameters()])))

        self.optimizer = Adam(self.model.parameters(), lr=self.lr)
        # reduces the learning rate when loss plateaus, i.e. doesn't improve
        self.lr_scheduler = ReduceLROnPlateau(
            self.optimizer, factor=0.1, patience=10, min_lr=1e-5, verbose=True
        )

        if not self.disable_cuda:
            self.scaler = torch.cuda.amp.GradScaler()
        else:
            self.scaler = None

        # creates tracking file for tensorboard
        self.writer = SummaryWriter(args.checkpoint_path)

        self.device = args.device
        self.model_path = args.model_path
        self.trained_model_path = args.trained_model_path
        self.prediction_path = args.prediction_path
        self.model_type = args.model_type
        self.is_statistical_model = args.statistical_model
        self.is_neural_model = args.neural_model
        self.log_step = args.log_step
コード例 #11
0
def train():
    # INITIALIZATIONS
    tr_loss = list()
    tr_ssim_loss = list()
    ssimCriterion = SSIM()
    mseCriterion = Loss.MSELoss()

    lossCrit = mseCriterion
    vld_mse_loss = list()
    vld_ssim_loss = list()

    vld_mse_loss_in = list()
    vld_ssim_loss_in = list()
    vi = 0
    i = 0

    # LOAD LATEST (or SPECIFIC) MODEL
    s_epoch = load_model(-1)

    for epoch in range(s_epoch, params.epochs):
        print('epoch {}/{}...'.format(epoch + 1, params.epochs))

        adjust_learning_rate(epoch)

        ###########################################
        # Training
        l = 0
        itt = 0
        TAG = 'Training'
        if not params.Validation_Only:
            for local_batch, local_labels, sliceID, orig_size, usr in training_DG:

                X = Variable(torch.FloatTensor(local_batch.float())).to(params.device)
                y = Variable(torch.FloatTensor(local_labels.float())).to(params.device)

                input_mag = normalize_batch_torch(get_magnitude(X))
                LOST_mag = normalize_batch_torch(y[:, :, :, :, 0])

                if params.complex_net:
                    X = normalize_complex_batch_by_magnitude_only(X, False)
                    y = normalize_complex_batch_by_magnitude_only(y, True)
                else:
                    X = get_magnitude(X)
                    y = y[:, :, :, :, 0]
                    X = normalize_batch_torch(X)
                    y = normalize_batch_torch(y)

                y_pred = net(X)

                if params.complex_net:
                    loss = lossCrit(magnitude(y_pred).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                    simloss = ssimCriterion(magnitude(y_pred), y[:, :, :, :, 0])
                else:
                    loss = lossCrit(y_pred, y)
                    simloss = ssimCriterion(y_pred, y)

                tr_loss.append(loss.cpu().data.numpy())
                tr_ssim_loss.append(simloss.cpu().data.numpy())

                l += loss.data[0]

                optimizer.zero_grad()
                loss.backward()
                i += 1
                optimizer.step()

                inloss = mseCriterion(input_mag, LOST_mag)

                print('Epoch: {0} - {1:.3f}%'.format(epoch + 1, 100 * (itt * params.batch_size) / len(
                    training_DG.dataset.input_IDs))
                      + ' \tIter: ' + str(i)
                      + '\tLoss: {0:.6f}'.format(loss.data[0])
                      + '\tInputLoss: {0:.6f}'.format(inloss.data[0]))
                itt += 1

                if itt % 100 == 0:
                    is_best = 0
                    save_checkpoint({
                        'epoch': epoch + 1,
                        'loss': tr_loss,
                        'arch': 'recoNet_Model1',
                        'state_dict': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, is_best, filename=params.model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch + 1))

                    print('Model Saved!')
            avg_loss = params.batch_size * l / len(training_DG.dataset.input_IDs)
            print('Total Loss : {0:.6f} \t Avg. Loss {1:.6f}'.format(l, avg_loss))

        else:
            load_model(epoch + 1)

        #####################################
        # Validation

        vitt = 0
        vld_mse = 0
        vld_ssim = 0
        vld_psnr = 0
        vld_mse_in = 0
        vld_ssim_in = 0
        vld_psnr_in = 0

        TAG = 'Validation'
        with torch.no_grad():
            for local_batch, local_labels, sliceID, orig_size, usr in validation_DG:
                X = Variable(torch.FloatTensor(local_batch.float())).to(params.device)
                y = Variable(torch.FloatTensor(local_labels.float())).to(params.device)

                input_mag = normalize_batch_torch(get_magnitude(X))
                LOST_mag = normalize_batch_torch(y[:, :, :, :, 0])

                if params.complex_net:
                    X = normalize_complex_batch_by_magnitude_only(X, False)
                    y = normalize_complex_batch_by_magnitude_only(y, True)
                else:
                    X = get_magnitude(X)
                    y = y[:, :, :, :, 0]
                    X = normalize_batch_torch(X)
                    y = normalize_batch_torch(y)

                y_pred = net(X)

                if params.complex_net:
                    mseloss = mseCriterion(magnitude(y_pred).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                    ssimloss = ssimCriterion(magnitude(y_pred), y[:, :, :, :, 0])

                else:
                    mseloss = mseCriterion(y_pred, y)
                    ssimloss = ssimCriterion(y_pred, y)

                mseloss_in = mseCriterion(input_mag, LOST_mag)
                ssimloss_in = ssimCriterion(input_mag, LOST_mag)

                vld_mse_loss.append(mseloss.cpu().data.numpy())
                vld_ssim_loss.append(ssimloss.cpu().data.numpy())

                vld_mse_loss_in.append(mseloss_in.cpu().data.numpy())
                vld_ssim_loss_in.append(ssimloss_in.cpu().data.numpy())

                vld_mse += mseloss.data[0]
                vld_ssim += ssimloss.data[0]

                vld_mse_in += mseloss_in.data[0]
                vld_ssim_in += ssimloss_in.data[0]

                vi += 1
                vitt += 1

                if params.complex_net:
                    inloss = mseCriterion(magnitude(X).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                else:
                    inloss = mseCriterion(X, y)

                print('Epoch: {0} - {1:.3f}%'.format(epoch + 1, 100 * (vitt * params.batch_size) / len(
                    validation_DG.dataset.input_IDs))
                      + ' \tIter: ' + str(vi)
                      + '\tSME: {0:.6f}'.format(mseloss.data[0])
                      + '\tSSIM: {0:.6f}'.format(ssimloss.data[0])
                      + '\tInputLoss: {0:.6f}'.format(inloss.data[0]))

            avg_factor = params.batch_size / len(validation_DG.dataset.input_IDs)
            print('Avg. MSE : {0:.6f}'.format(vld_mse * avg_factor)
                  + '\tAvg. SSIM : {0:.6f}'.format(vld_ssim * avg_factor)
                  + '\tAvg. PSNR : {0:.6f}'.format(vld_psnr * avg_factor)
                  + 'Avg. Input_MSE : {0:.6f}'.format(vld_mse_in * avg_factor)
                  + '\tAvg. Input_SSIM : {0:.6f}'.format(vld_ssim_in * avg_factor)
                  + '\tAvg. Input_PSNR : {0:.6f}'.format(vld_psnr_in * avg_factor)
                  )

    writer.close()
コード例 #12
0
 pytest.param(
     "nn.modules.loss",
     "MarginRankingLoss",
     {},
     [],
     {},
     loss.MarginRankingLoss(),
     id="MarginRankingLossConf",
 ),
 pytest.param(
     "nn.modules.loss",
     "MSELoss",
     {},
     [],
     {},
     loss.MSELoss(),
     id="MSELossConf",
 ),
 pytest.param(
     "nn.modules.loss",
     "MultiLabelMarginLoss",
     {},
     [],
     {},
     loss.MultiLabelMarginLoss(),
     id="MultiLabelMarginLossConf",
 ),
 pytest.param(
     "nn.modules.loss",
     "MultiLabelSoftMarginLoss",
     {},
コード例 #13
0
def train(net):
    ###########################################
    #
    # INITIALIZATIONS
    #
    ############################################
    optimizer = optim.Adam(net.parameters(), lr=params.args.lr)
    #     optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)

    LOSS = list()
    SSIMLOSS = list()
    ssimCriterion = SSIM()
    mseCriterion = Loss.MSELoss()
    kscCriterion = KspaceConsistency()
    tvCriterion = TotalVariations()

    lossCrit = mseCriterion
    vld_MSE_LOSS = list()
    vld_SSIM_LOSS = list()

    vld_MSE_LOSS_in = list()
    vld_SSIM_LOSS_in = list()
    vi = 0
    i = 0
    bt = 0

    ###########################################
    #
    # LOAD LATEST (or SPECIFIC) MODEL
    #
    ############################################
    models = os.listdir(params.model_save_dir);
    s_epoch = 0

    def load_model(epoch):
        print('loading model at epoch ' + str(epoch))
        net.load_state_dict(torch.load(params.model_save_dir + models[0][0:11] + str(epoch) + '.pth')['state_dict'])
        optimizer.load_state_dict(
            torch.load(params.model_save_dir + models[0][0:11] + str(epoch) + '.pth')['optimizer'])
        LOSS = torch.load(params.model_save_dir + models[0][0:11] + str(epoch) + '.pth')['loss']

    print(len(models))
    if len(models) > 0:
        if s_epoch == -1:
            s_epoch = max([int(epo[11:-4]) for epo in models[:]])
        print("Loading model ...")
        load_model(s_epoch)
        s_epoch = s_epoch - 1
        print("Model loaded !")

    for epoch in range(s_epoch, params.epochs):
        print('epoch {}/{}...'.format(epoch + 1, params.epochs))

        adjust_learning_rate(optimizer, epoch)

        ###########################################
        #
        # Training
        #
        ############################################
        l = 0
        itt = 0
        TAG = 'Training'
        MAX = list()
        if not params.Validation_Only:
            for local_batch, local_labels, sliceID, orig_size, usr in training_DG:

                X = Variable(torch.FloatTensor(local_batch.float())).to(params.device)
                y = Variable(torch.FloatTensor(local_labels.float())).to(params.device)

                input_mag = normalizeBatch_torch(get_magnitude(X))
                LOST_mag = normalizeBatch_torch(y[:, :, :, :, 0])

                if params.complex_net:
                    X = normalizeComplexBatch_byMagnitudeOnly(X, False)
                    y = normalizeComplexBatch_byMagnitudeOnly(y, True)
                else:
                    X = get_magnitude(X)
                    y = y[:, :, :, :, 0]
                    X = normalizeBatch_torch(X)
                    y = normalizeBatch_torch(y)

                y_pred = net(X)

                if params.complex_net:
                    loss = lossCrit(magnitude(y_pred).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                else:
                    loss = lossCrit(y_pred, y)
                    simloss = ssimCriterion(y_pred, y)

                if params.complex_net:
                    loss = lossCrit(magnitude(y_pred).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                    simloss = ssimCriterion(magnitude(y_pred), y[:, :, :, :, 0])
                else:
                    loss = lossCrit(y_pred, y)
                    simloss = ssimCriterion(y_pred, y)

                LOSS.append(loss.cpu().data.numpy())
                SSIMLOSS.append(simloss.cpu().data.numpy())

                l += loss.data[0]

                optimizer.zero_grad()
                loss.backward()
                i += 1
                optimizer.step()

                inloss = mseCriterion(input_mag, LOST_mag)

                print('Epoch: {0} - {1:.3f}%'.format(epoch + 1, 100 * (itt * params.batch_size) / len(
                    training_DG.dataset.input_IDs))
                      + ' \tIter: ' + str(i)
                      + '\tLoss: {0:.6f}'.format(loss.data[0])
                      + '\tInputLoss: {0:.6f}'.format(inloss.data[0]))
                itt += 1

                if itt % 100 == 0:
                    is_best = 0
                    save_checkpoint({
                        'epoch': epoch + 1,
                        'loss': LOSS,
                        'arch': 'recoNet_Model1',
                        'state_dict': net.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, is_best, filename=params.model_save_dir + 'MODEL_EPOCH{}.pth'.format(epoch + 1))

                    print('Model Saved!')
            avg_loss = params.batch_size * l / len(training_DG.dataset.input_IDs)
            print('Total Loss : {0:.6f} \t Avg. Loss {1:.6f}'.format(l, avg_loss))

        else:
            load_model(epoch + 1)
        #####################################
        #
        # Validation
        #
        #####################################

        vl = 0
        vitt = 0
        vld_mse = 0
        vld_ssim = 0
        vld_psnr = 0
        vld_mse_in = 0
        vld_ssim_in = 0
        vld_psnr_in = 0

        TAG = 'Validation'
        with torch.no_grad():
            for local_batch, local_labels, sliceID, orig_size, usr in validation_DG:
                X = Variable(torch.FloatTensor(local_batch.float())).to(params.device)
                y = Variable(torch.FloatTensor(local_labels.float())).to(params.device)

                input_mag = normalizeBatch_torch(get_magnitude(X))
                LOST_mag = normalizeBatch_torch(y[:, :, :, :, 0])

                if params.complex_net:
                    X = normalizeComplexBatch_byMagnitudeOnly(X, False)
                    y = normalizeComplexBatch_byMagnitudeOnly(y, True)
                else:
                    X = get_magnitude(X)
                    y = y[:, :, :, :, 0]
                    X = normalizeBatch_torch(X)
                    y = normalizeBatch_torch(y)

                y_pred = net(X)

                if params.complex_net:
                    mseloss = mseCriterion(magnitude(y_pred).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                    ssimloss = ssimCriterion(magnitude(y_pred), y[:, :, :, :, 0])

                else:
                    mseloss = mseCriterion(y_pred, y)
                    ssimloss = ssimCriterion(y_pred, y)


                mseloss_in = mseCriterion(input_mag, LOST_mag)
                ssimloss_in = ssimCriterion(input_mag, LOST_mag)


                vld_MSE_LOSS.append(mseloss.cpu().data.numpy())
                vld_SSIM_LOSS.append(ssimloss.cpu().data.numpy())

                vld_MSE_LOSS_in.append(mseloss_in.cpu().data.numpy())
                vld_SSIM_LOSS_in.append(ssimloss_in.cpu().data.numpy())

                vld_mse += mseloss.data[0]
                vld_ssim += ssimloss.data[0]

                vld_mse_in += mseloss_in.data[0]
                vld_ssim_in += ssimloss_in.data[0]

                vi += 1
                vitt += 1

                if params.complex_net:
                    inloss = mseCriterion(magnitude(X).squeeze(1), y[:, :, :, :, 0].squeeze(1))
                else:
                    inloss = mseCriterion(X, y)

                print('Epoch: {0} - {1:.3f}%'.format(epoch + 1, 100 * (vitt * params.batch_size) / len(
                    validation_DG.dataset.input_IDs))
                      + ' \tIter: ' + str(vi)
                      + '\tSME: {0:.6f}'.format(mseloss.data[0])
                      + '\tSSIM: {0:.6f}'.format(ssimloss.data[0])
                      + '\tInputLoss: {0:.6f}'.format(inloss.data[0]))

            avg_factor = params.batch_size / len(validation_DG.dataset.input_IDs)
            print('Avg. MSE : {0:.6f}'.format(vld_mse * avg_factor)
                  + '\tAvg. SSIM : {0:.6f}'.format(vld_ssim * avg_factor)
                  + '\tAvg. PSNR : {0:.6f}'.format(vld_psnr * avg_factor)
                  + 'Avg. Input_MSE : {0:.6f}'.format(vld_mse_in * avg_factor)
                  + '\tAvg. Input_SSIM : {0:.6f}'.format(vld_ssim_in * avg_factor)
                  + '\tAvg. Input_PSNR : {0:.6f}'.format(vld_psnr_in * avg_factor)
                  )

    writer.close()