Ejemplo n.º 1
0
def test(config):
    Best_Model = torch.load(config.test_model)
    Tokenizer = BertTokenizer.from_pretrained(config.model)

    f_in = open(config.inputFile, 'r')

    net = BertForMaskedLM.from_pretrained(config.model)

    # When loading from a model not trained from DataParallel
    #net.load_state_dict(Best_Model['state_dict'])
    #net.eval()

    if torch.cuda.is_available():
        net = net.cuda(0)
        if config.dataParallel:
            net = DataParallelModel(net)

    # When loading from a model trained from DataParallel
    net.load_state_dict(Best_Model['state_dict'])
    net.eval()

    mySearcher = Searcher(net, config)

    f_top1 = open('summary' + config.suffix + '.txt', 'w', encoding='utf-8')
    f_topK = open('summary' + config.suffix + '.txt.' +
                  str(config.answer_size),
                  'w',
                  encoding='utf-8')

    ed = '\n------------------------\n'

    for idx, line in enumerate(f_in):
        source_ = line.strip().split()
        source = Tokenizer.tokenize(line.strip())
        mapping = mapping_tokenize(source_, source)

        source = Tokenizer.convert_tokens_to_ids(source)

        print(idx)
        print(detokenize(translate(source, Tokenizer), mapping), end=ed)

        l_pred = mySearcher.length_Predict(source)
        Answers = mySearcher.search(source)
        baseline = sum(Answers[0][0])

        if config.reranking_method == 'none':
            Answers = sorted(Answers, key=lambda x: sum(x[0]))
        elif config.reranking_method == 'length_norm':
            Answers = sorted(Answers, key=lambda x: length_norm(x[0]))
        elif config.reranking_method == 'bounded_word_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_word_reward(x[0], config.reward, l_pred))
        elif config.reranking_method == 'bounded_adaptive_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_adaptive_reward(x[0], x[2], l_pred))

        texts = [
            detokenize(translate(Answers[k][1], Tokenizer), mapping)
            for k in range(len(Answers))
        ]

        if baseline != sum(Answers[0][0]):
            print('Reranked!')

        print(texts[0], end=ed)
        print(texts[0], file=f_top1)
        print(len(texts), file=f_topK)
        for i in range(len(texts)):
            print(Answers[i][0], file=f_topK)
            print(texts[i], file=f_topK)

    f_top1.close()
    f_topK.close()
Ejemplo n.º 2
0
class FoodIngredients(Network):
    def __init__(self,
                 model_name='DenseNet',
                 model_type='food',
                 lr=0.02,
                 optimizer_name='Adam',
                 criterion1=nn.CrossEntropyLoss(),
                 criterion2=nn.BCEWithLogitsLoss(),
                 dropout_p=0.45,
                 pretrained=True,
                 device=None,
                 best_accuracy=0.,
                 best_validation_loss=None,
                 best_model_file='best_model.pth',
                 head1={
                     'num_outputs': 10,
                     'layers': [],
                     'model_type': 'classifier'
                 },
                 head2={
                     'num_outputs': 10,
                     'layers': [],
                     'model_type': 'multi_label_classifier'
                 },
                 class_names=[],
                 num_classes=None,
                 ingredient_names=[],
                 num_ingredients=None,
                 add_extra=True,
                 set_params=True,
                 set_head=True):

        super().__init__(device=device)

        self.set_transfer_model(model_name,
                                pretrained=pretrained,
                                add_extra=add_extra,
                                dropout_p=dropout_p)

        if set_head:
            self.set_model_head(model_name=model_name,
                                head1=head1,
                                head2=head2,
                                dropout_p=dropout_p,
                                criterion1=criterion1,
                                criterion2=criterion2,
                                device=device)
        if set_params:
            self.set_model_params(
                optimizer_name=optimizer_name,
                lr=lr,
                dropout_p=dropout_p,
                model_name=model_name,
                model_type=model_type,
                best_accuracy=best_accuracy,
                best_validation_loss=best_validation_loss,
                best_model_file=best_model_file,
                class_names=class_names,
                num_classes=num_classes,
                ingredient_names=ingredient_names,
                num_ingredients=num_ingredients,
            )

        self.model = self.model.to(device)

    def set_model_params(self,
                         criterion1=nn.CrossEntropyLoss(),
                         criterion2=nn.BCEWithLogitsLoss(),
                         optimizer_name='Adam',
                         lr=0.1,
                         dropout_p=0.45,
                         model_name='DenseNet',
                         model_type='cv_transfer',
                         best_accuracy=0.,
                         best_validation_loss=None,
                         best_model_file='best_model_file.pth',
                         head1={
                             'num_outputs': 10,
                             'layers': [],
                             'model_type': 'classifier'
                         },
                         head2={
                             'num_outputs': 10,
                             'layers': [],
                             'model_type': 'muilti_label_classifier'
                         },
                         class_names=[],
                         num_classes=None,
                         ingredient_names=[],
                         num_ingredients=None):

        print(
            'Food Names: current best accuracy = {:.3f}'.format(best_accuracy))
        if best_validation_loss is not None:
            print('Food Ingredients: current best loss = {:.3f}'.format(
                best_validation_loss))

        super(FoodIngredients,
              self).set_model_params(optimizer_name=optimizer_name,
                                     lr=lr,
                                     dropout_p=dropout_p,
                                     model_name=model_name,
                                     model_type=model_type,
                                     best_accuracy=best_accuracy,
                                     best_validation_loss=best_validation_loss,
                                     best_model_file=best_model_file)
        self.class_names = class_names
        self.num_classes = num_classes
        self.ingredeint_names = ingredient_names
        self.num_ingredients = num_ingredients
        self.criterion1 = criterion1
        self.criterion2 = criterion2

    def forward(self, x):
        l = list(self.model.children())
        for m in l[:-2]:
            x = m(x)
        food = l[-2](x)
        ingredients = l[-1](x)
        return (food, ingredients)

    def compute_loss(self, outputs, labels, w1=1., w2=1.):
        out1, out2 = outputs
        label1, label2 = labels
        loss1 = self.criterion1(out1, label1)
        loss2 = self.criterion2(out2, label2)
        return [(loss1 * w1) + (loss2 * w2)]

    def freeze(self, train_classifier=True):
        super(FoodIngredients, self).freeze()
        if train_classifier:
            for param in self.model.fc1.parameters():
                param.requires_grad = True
            for param in self.model.fc2.parameters():
                param.requires_grad = True

    def parallelize(self):
        self.parallel = True
        self.model = DataParallelModel(self.model)
        self.criterion = DataParallelCriterion(self.criterion)

    def set_transfer_model(self,
                           mname,
                           pretrained=True,
                           add_extra=True,
                           dropout_p=0.45):
        self.model = None
        models_dict = {
            'densenet': {
                'model': models.densenet121(pretrained=pretrained),
                'conv_channels': 1024
            },
            'resnet34': {
                'model': models.resnet34(pretrained=pretrained),
                'conv_channels': 512
            },
            'resnet50': {
                'model': models.resnet50(pretrained=pretrained),
                'conv_channels': 2048
            }
        }
        meta = models_dict[mname.lower()]
        try:
            model = meta['model']
            for param in model.parameters():
                param.requires_grad = False
            self.model = model
            print(
                'Setting transfer learning model: self.model set to {}'.format(
                    mname))
        except:
            print(
                'Setting transfer learning model: model name {} not supported'.
                format(mname))

        # creating and adding extra layers to the model
        dream_model = None
        if add_extra:
            channels = meta['conv_channels']
            dream_model = nn.Sequential(
                nn.Conv2d(channels, channels, 3, 1, 1),
                # Printer(),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p),
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p),
                nn.Conv2d(channels, channels, 3, 1, 1),
                nn.BatchNorm2d(channels),
                nn.ReLU(True),
                nn.Dropout2d(dropout_p))
        self.dream_model = dream_model

    def set_model_head(
            self,
            model_name='DenseNet',
            head1={
                'num_outputs': 10,
                'layers': [],
                'class_names': None,
                'model_type': 'classifier'
            },
            head2={
                'num_outputs': 10,
                'layers': [],
                'class_names': None,
                'model_type': 'muilti_label_classifier'
            },
            criterion1=nn.CrossEntropyLoss(),
            criterion2=nn.BCEWithLogitsLoss(),
            adaptive=True,
            dropout_p=0.45,
            device=None):

        models_meta = {
            'resnet34': {
                'conv_channels': 512,
                'head_id': -2,
                'adaptive_head': [DAI_AvgPool],
                'normal_head': [nn.AvgPool2d(7, 1)]
            },
            'resnet50': {
                'conv_channels': 2048,
                'head_id': -2,
                'adaptive_head': [DAI_AvgPool],
                'normal_head': [nn.AvgPool2d(7, 1)]
            },
            'densenet': {
                'conv_channels': 1024,
                'head_id': -1,
                'adaptive_head': [nn.ReLU(inplace=True), DAI_AvgPool],
                'normal_head': [nn.ReLU(inplace=True),
                                nn.AvgPool2d(7, 1)]
            }
        }

        name = model_name.lower()
        meta = models_meta[name]
        modules = list(self.model.children())
        l = modules[:meta['head_id']]
        if self.dream_model:
            l += self.dream_model
        heads = [head1, head2]
        crits = [criterion1, criterion2]
        fcs = []
        for head, criterion in zip(heads, crits):
            head['criterion'] = criterion
            if head['model_type'].lower() == 'classifier':
                head['output_non_linearity'] = None
            fc = modules[-1]
            try:
                in_features = fc.in_features
            except:
                in_features = fc.model.out.in_features
            fc = FC(num_inputs=in_features,
                    num_outputs=head['num_outputs'],
                    layers=head['layers'],
                    model_type=head['model_type'],
                    output_non_linearity=head['output_non_linearity'],
                    dropout_p=dropout_p,
                    criterion=head['criterion'],
                    optimizer_name=None,
                    device=device)
            fcs.append(fc)
        if adaptive:
            l += meta['adaptive_head']
        else:
            l += meta['normal_head']
        model = nn.Sequential(*l)
        model.add_module('fc1', fcs[0])
        model.add_module('fc2', fcs[1])
        self.model = model
        self.head1 = head1
        self.head2 = head2

        print('Multi-head set up complete.')

    def train_(self, e, trainloader, optimizer, print_every):

        epoch, epochs = e
        self.train()
        t0 = time.time()
        t1 = time.time()
        batches = 0
        running_loss = 0.
        for data_batch in trainloader:
            inputs, label1, label2 = data_batch[0], data_batch[1], data_batch[
                2]
            batches += 1
            inputs = inputs.to(self.device)
            label1 = label1.to(self.device)
            label2 = label2.to(self.device)
            labels = (label1, label2)
            optimizer.zero_grad()
            outputs = self.forward(inputs)
            loss = self.compute_loss(outputs, labels)[0]
            if self.parallel:
                loss.sum().backward()
                loss = loss.sum()
            else:
                loss.backward()
                loss = loss.item()
            optimizer.step()
            running_loss += loss
            if batches % print_every == 0:
                elapsed = time.time() - t1
                if elapsed > 60:
                    elapsed /= 60.
                    measure = 'min'
                else:
                    measure = 'sec'
                batch_time = time.time() - t0
                if batch_time > 60:
                    batch_time /= 60.
                    measure2 = 'min'
                else:
                    measure2 = 'sec'
                print(
                    '+----------------------------------------------------------------------+\n'
                    f"{time.asctime().split()[-2]}\n"
                    f"Time elapsed: {elapsed:.3f} {measure}\n"
                    f"Epoch:{epoch+1}/{epochs}\n"
                    f"Batch: {batches+1}/{len(trainloader)}\n"
                    f"Batch training time: {batch_time:.3f} {measure2}\n"
                    f"Batch training loss: {loss:.3f}\n"
                    f"Average training loss: {running_loss/(batches):.3f}\n"
                    '+----------------------------------------------------------------------+\n'
                )
                t0 = time.time()
        return running_loss / len(trainloader)

    def evaluate(self, dataloader, metric='accuracy'):

        running_loss = 0.
        classifier = None

        if self.model_type == 'classifier':  # or self.num_classes is not None:
            classifier = Classifier(self.class_names)

        y_pred = []
        y_true = []

        self.eval()
        rmse_ = 0.
        with torch.no_grad():
            for data_batch in dataloader:
                inputs, label1, label2 = data_batch[0], data_batch[
                    1], data_batch[2]
                inputs = inputs.to(self.device)
                label1 = label1.to(self.device)
                label2 = label2.to(self.device)
                labels = (label1, label2)
                outputs = self.forward(inputs)
                loss = self.compute_loss(outputs, labels)[0]
                if self.parallel:
                    running_loss += loss.sum()
                    outputs = parallel.gather(outputs, self.device)
                else:
                    running_loss += loss.item()
                if classifier is not None and metric == 'accuracy':
                    classifier.update_accuracies(outputs, labels)
                    y_true.extend(list(labels.squeeze(0).cpu().numpy()))
                    _, preds = torch.max(torch.exp(outputs), 1)
                    y_pred.extend(list(preds.cpu().numpy()))
                elif metric == 'rmse':
                    rmse_ += rmse(outputs, labels).cpu().numpy()

        self.train()

        ret = {}
        # print('Running_loss: {:.3f}'.format(running_loss))
        if metric == 'rmse':
            print('Total rmse: {:.3f}'.format(rmse_))
            ret['final_rmse'] = rmse_ / len(dataloader)

        ret['final_loss'] = running_loss / len(dataloader)

        if classifier is not None:
            ret['accuracy'], ret[
                'class_accuracies'] = classifier.get_final_accuracies()
            ret['report'] = classification_report(
                y_true, y_pred, target_names=self.class_names)
            ret['confusion_matrix'] = confusion_matrix(y_true, y_pred)
            try:
                ret['roc_auc_score'] = roc_auc_score(y_true, y_pred)
            except:
                pass
        return ret

    def evaluate_food(self, dataloader, metric='accuracy'):

        running_loss = 0.
        classifier = None

        classifier = Classifier(self.class_names)

        y_pred = []
        y_true = []

        self.eval()
        rmse_ = 0.
        with torch.no_grad():
            for data_batch in dataloader:
                inputs, labels = data_batch[0], data_batch[1]
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = self.forward(inputs)[0]
                if classifier is not None and metric == 'accuracy':
                    try:
                        classifier.update_accuracies(outputs, labels)
                        y_true.extend(list(labels.squeeze(0).cpu().numpy()))
                        _, preds = torch.max(torch.exp(outputs), 1)
                        y_pred.extend(list(preds.cpu().numpy()))
                    except:
                        pass
                elif metric == 'rmse':
                    rmse_ += rmse(outputs, labels).cpu().numpy()

        self.train()

        ret = {}
        # print('Running_loss: {:.3f}'.format(running_loss))
        if metric == 'rmse':
            print('Total rmse: {:.3f}'.format(rmse_))
            ret['final_rmse'] = rmse_ / len(dataloader)

        ret['final_loss'] = running_loss / len(dataloader)

        if classifier is not None:
            ret['accuracy'], ret[
                'class_accuracies'] = classifier.get_final_accuracies()
            ret['report'] = classification_report(
                y_true, y_pred, target_names=self.class_names)
            ret['confusion_matrix'] = confusion_matrix(y_true, y_pred)
            try:
                ret['roc_auc_score'] = roc_auc_score(y_true, y_pred)
            except:
                pass
        return ret

    def find_lr(self,
                trn_loader,
                init_value=1e-8,
                final_value=10.,
                beta=0.98,
                plot=False):

        print('\nFinding the ideal learning rate.')

        model_state = copy.deepcopy(self.model.state_dict())
        optim_state = copy.deepcopy(self.optimizer.state_dict())
        optimizer = self.optimizer
        num = len(trn_loader) - 1
        mult = (final_value / init_value)**(1 / num)
        lr = init_value
        optimizer.param_groups[0]['lr'] = lr
        avg_loss = 0.
        best_loss = 0.
        batch_num = 0
        losses = []
        log_lrs = []
        for data_batch in trn_loader:
            batch_num += 1
            inputs, label1, label2 = data_batch[0], data_batch[1], data_batch[
                2]
            inputs = inputs.to(self.device)
            label1 = label1.to(self.device)
            label2 = label2.to(self.device)
            labels = (label1, label2)
            optimizer.zero_grad()
            outputs = self.forward(inputs)
            loss = self.compute_loss(outputs, labels)[0]
            #Compute the smoothed loss
            if self.parallel:
                avg_loss = beta * avg_loss + (1 - beta) * loss.sum()
            else:
                avg_loss = beta * avg_loss + (1 - beta) * loss.item()
            smoothed_loss = avg_loss / (1 - beta**batch_num)
            #Stop if the loss is exploding
            if batch_num > 1 and smoothed_loss > 4 * best_loss:
                self.log_lrs, self.find_lr_losses = log_lrs, losses
                self.model.load_state_dict(model_state)
                self.optimizer.load_state_dict(optim_state)
                if plot:
                    self.plot_find_lr()
                temp_lr = self.log_lrs[np.argmin(self.find_lr_losses) -
                                       (len(self.log_lrs) // 8)]
                self.lr = (10**temp_lr)
                print('Found it: {}\n'.format(self.lr))
                return self.lr
            #Record the best loss
            if smoothed_loss < best_loss or batch_num == 1:
                best_loss = smoothed_loss
            #Store the values
            losses.append(smoothed_loss)
            log_lrs.append(math.log10(lr))
            #Do the SGD step
            if self.parallel:
                loss.sum().backward()
            else:
                loss.backward()
            optimizer.step()
            #Update the lr for the next step
            lr *= mult
            optimizer.param_groups[0]['lr'] = lr

        self.log_lrs, self.find_lr_losses = log_lrs, losses
        self.model.load_state_dict(model_state)
        self.optimizer.load_state_dict(optim_state)
        if plot:
            self.plot_find_lr()
        temp_lr = self.log_lrs[np.argmin(self.find_lr_losses) -
                               (len(self.log_lrs) // 10)]
        self.lr = (10**temp_lr)
        print('Found it: {}\n'.format(self.lr))
        return self.lr

    def plot_find_lr(self):
        plt.ylabel("Loss")
        plt.xlabel("Learning Rate (log scale)")
        plt.plot(self.log_lrs, self.find_lr_losses)
        plt.show()

    def classify(self,
                 inputs,
                 thresh=0.4):  #,show = False,mean = None,std = None):
        outputs = self.predict(inputs)
        food, ing = outputs
        try:
            _, preds = torch.max(torch.exp(food), 1)
        except:
            _, preds = torch.max(torch.exp(food.unsqueeze(0)), 1)
        ing_outs = ing.sigmoid()
        ings = (ing_outs >= thresh)
        class_preds = [str(self.class_names[p]) for p in preds]
        ing_preds = [
            self.ingredeint_names[p.nonzero().squeeze(1).cpu()] for p in ings
        ]
        return class_preds, ing_preds

    def _get_dropout(self):
        return self.dropout_p

    def get_model_params(self):
        params = super(FoodIngredients, self).get_model_params()
        params['class_names'] = self.class_names
        params['num_classes'] = self.num_classes
        params['ingredient_names'] = self.ingredient_names
        params['num_ingredients'] = self.num_ingredients
        params['head1'] = self.head1
        params['head2'] = self.head2
        return params
Ejemplo n.º 3
0
def main_tr(args, crossVal):
    dataLoad = ld.LoadData(args.data_dir, args.classes)
    data = dataLoad.processData(crossVal, args.data_name)

    # load the model
    model = net.MiniSeg(args.classes, aux=True)
    if not osp.isdir(osp.join(args.savedir + '_mod' + str(args.max_epochs))):
        os.mkdir(args.savedir + '_mod' + str(args.max_epochs))
    if not osp.isdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name)):
        os.mkdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name))
    saveDir = args.savedir + '_mod' + str(
        args.max_epochs) + '/' + args.data_name + '/' + args.model_name
    # create the directory if not exist
    if not osp.exists(saveDir):
        os.mkdir(saveDir)

    if args.gpu and torch.cuda.device_count() > 1:
        #model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)
    if args.gpu:
        model = model.cuda()

    total_paramters = sum([np.prod(p.size()) for p in model.parameters()])
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.gpu:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight, args.ignore_label)  #weight
    if args.gpu and torch.cuda.device_count() > 1:
        criteria = DataParallelCriterion(criteria)
    if args.gpu:
        criteria = criteria.cuda()

    # compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.5), int(args.height * 1.5)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.25), int(args.height * 1.25)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 0.75), int(args.height * 0.75)),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.ToTensor()
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting
    trainLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale1),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale2),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)
    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale3),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)
    max_batches = len(trainLoader) + len(trainLoader_scale1) + len(
        trainLoader_scale2) + len(trainLoader_scale3)

    if args.gpu:
        cudnn.benchmark = True

    start_epoch = 0

    if args.pretrained is not None:
        state_dict = torch.load(args.pretrained)
        new_keys = []
        new_values = []
        for idx, key in enumerate(state_dict.keys()):
            if 'pred' not in key:
                new_keys.append(key)
                new_values.append(list(state_dict.values())[idx])
        new_dict = OrderedDict(list(zip(new_keys, new_values)))
        model.load_state_dict(new_dict, strict=False)
        print('pretrained model loaded')

    if args.resume is not None:
        if osp.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    log_file = osp.join(saveDir, 'trainValLog_' + args.model_name + '.txt')
    if osp.isfile(log_file):
        logger = open(log_file, 'a')
    else:
        logger = open(log_file, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t\t%s\t%s\t%s\t%s\tlr" %
                     ('CrossVal', 'Epoch', 'Loss(Tr)', 'Loss(val)',
                      'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-4)
    maxmIOU = 0
    maxEpoch = 0
    print(args.model_name + '-CrossVal: ' + str(crossVal + 1))
    for epoch in range(start_epoch, args.max_epochs):
        # train for one epoch
        cur_iter = 0

        train(args, trainLoader_scale1, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale1)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale2)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale3)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr, lr = \
                train(args, trainLoader, model, criteria, optimizer, epoch, max_batches, cur_iter)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = \
                val(args, valLoader, model, criteria)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
                'lr': lr
            },
            osp.join(
                saveDir, 'checkpoint_' + args.model_name + '_crossVal' +
                str(crossVal + 1) + '.pth.tar'))

        # save the model also
        model_file_name = osp.join(
            saveDir, 'model_' + args.model_name + '_crossVal' +
            str(crossVal + 1) + '_' + str(epoch + 1) + '.pth')
        torch.save(model.state_dict(), model_file_name)

        logger.write(
            "\n%d\t\t%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
            (crossVal + 1, epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\n" \
                % (epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val))

        if mIOU_val >= maxmIOU:
            maxmIOU = mIOU_val
            maxEpoch = epoch + 1
        torch.cuda.empty_cache()
    logger.flush()
    logger.close()
    return maxEpoch, maxmIOU