Beispiel #1
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/RFMobileNetV2Plus-{}".format(save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows*args.crop_ratio), int(args.img_cols*args.crop_ratio)

    augment_train = Compose([RandomHorizontallyFlip(), RandomSized((0.5, 0.75)),
                             RandomRotate(5), RandomCrop((net_h, net_w))])
    augment_valid = Compose([RandomHorizontallyFlip(), Scale((args.img_rows, args.img_cols)),
                             CenterCrop((net_h, net_w))])

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='train',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_train)
    valid_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='val',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_valid)

    n_classes = train_loader.n_classes

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")

    model = RFMobileNetV2Plus(n_class=n_classes, in_size=(net_h, net_w), width_mult=1.0,
                              out_sec=256, aspp_sec=(12, 24, 36),
                              norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))
    """

    model = MobileNetV2Plus(n_class=n_classes, in_size=(net_h, net_w), width_mult=1.0,
                            out_sec=256, aspp_sec=(12, 24, 36),
                            norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))
    """
    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

    # 4.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.90,
                                    weight_decay=5e-4, nesterov=True)

        # for pg in optimizer.param_groups:
        #     print(pg['lr'])

        # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999),
        #                             eps=1e-08, weight_decay=0, amsgrad=True)
        # optimizer = YFOptimizer(model.parameters(), lr=2.5e-3, mu=0.9, clip_thresh=10000, weight_decay=5e-4)

    # 4.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = None
    if hasattr(model.module, 'loss'):
        print('> Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = lovasz_softmax

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    best_iou = -100.0
    args.start_epoch = 0
    if args.resume is not None:
        full_path = "{}{}".format(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(full_path)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])          # weights
            optimizer.load_state_dict(checkpoint['optimizer_state'])  # gradient state

            # for param_group in optimizer.param_groups:
            # s    param_group['lr'] = 1e-5

            del checkpoint
            print("> Loaded checkpoint '{}' (epoch {}, iou {})".format(args.resume,
                                                                       args.start_epoch,
                                                                       best_iou))

        else:
            print("> No checkpoint found at '{}'".format(args.resume))
    else:
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()

            pretrained_dict = {k: v for k, v in pre_weight.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup tensor_board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="RFMobileNetV2Plus")
        dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda(), requires_grad=True)
        writer.add_graph(model, dummy_input)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    train_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=6, shuffle=True)
    valid_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=6)

    num_batches = int(math.ceil(len(train_loader.dataset.files[train_loader.dataset.split]) /
                                float(train_loader.batch_size)))

    lr_period = 20 * num_batches
    swa_weights = model.state_dict()

    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    # scheduler = CyclicLR(optimizer, base_lr=1.0e-3, max_lr=6.0e-3, step_size=2*num_batches)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    for epoch in np.arange(args.start_epoch, args.n_epoch):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        last_loss = 0.0
        pbar = tqdm(np.arange(num_batches))
        for train_i, (images, labels) in enumerate(train_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            # poly_lr_scheduler(optimizer, init_lr=args.l_rate, iter=full_iter,
            #                   lr_decay_iter=1, max_iter=args.n_epoch*num_batches, power=0.9)

            batch_lr = args.l_rate * cosine_annealing_lr(lr_period, full_iter)
            optimizer = set_optimizer_lr(optimizer, batch_lr)

            images = Variable(images.cuda(), requires_grad=True)   # Image feed into the deep neural network
            labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            net_out = model(images)  # Here we have 3 output for 3 loss
            net_out = F.softmax(net_out, dim=1)
            loss = lovasz_softmax(net_out, labels, ignore=250)

            last_loss = loss.data[0]
            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))
            pbar.set_postfix(Loss=last_loss, LR=batch_lr)

            loss.backward()
            optimizer.step()

            if full_iter % lr_period == 0:
                swa_weights = update_aggregated_weight_average(model, swa_weights, full_iter, lr_period)
                state = {'model_state': swa_weights}
                torch.save(state, "{}{}_rfmobilenetv2_swa_model.pkl".format(weight_dir, args.dataset))

            if (train_i + 1) % 31 == 0:
                loss_log = "Epoch [%d/%d], Iter: %d Loss: \t %.4f" % (epoch + 1, args.n_epoch,
                                                                      train_i + 1, last_loss)

                # net_out = F.softmax(net_out, dim=1)
                pred = net_out.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()

                running_metrics.update(gt, pred)
                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.tensor_board:
                    writer.add_scalar('Training/Losses', last_loss, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()

        mval_loss = 0.0
        vali_count = 0
        for i_val, (images, labels) in enumerate(valid_loader):
            vali_count += 1

            images = Variable(images.cuda(), volatile=True)
            labels = Variable(labels.cuda(), volatile=True)

            net_out = model(images)  # Here we have 4 output for 4 loss
            net_out = F.softmax(net_out, dim=1)
            loss = lovasz_softmax(net_out, labels, ignore=250)
            mval_loss += loss.data[0]

            pred = net_out.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics.update(gt, pred)

        mval_loss /= vali_count

        loss_log = "Epoch [%d/%d] Loss: \t %.4f" % (epoch + 1, args.n_epoch, mval_loss)
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss, Vali_Loss=mval_loss, Vali_mIoU=score['Mean_IoU'])

        if args.tensor_board:
            writer.add_scalar('Validation/Losses', mval_loss, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)
            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {'epoch': epoch + 1,
                     "best_iou": best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()}
            torch.save(state, "{}{}_rfmobilenetv2_lovasz_best_model.pkl".format(weight_dir, args.dataset))

        # scheduler.step()
        # scheduler.batch_step()
        pbar.close()

    if args.tensor_board:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
Beispiel #2
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/MobileNetV2Context-{}".format(save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows*args.crop_ratio), int(args.img_cols*args.crop_ratio)

    augment_train = Compose([RandomHorizontallyFlip(), RandomSized((0.5, 0.75)),
                             RandomRotate(5), RandomCrop((net_h, net_w))])
    augment_valid = Compose([RandomHorizontallyFlip(), Scale((args.img_rows, args.img_cols)),
                             CenterCrop((net_h, net_w))])

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='train',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_train)
    valid_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='val',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_valid)

    n_classes = train_loader.n_classes

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")

    model = MobileNetV2Context(n_class=19, in_size=(net_h, net_w), width_mult=1., out_sec=256, context=(32, 4),
                               norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))

    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # 4.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.90,
                                    weight_decay=5e-4, nesterov=True)

    # 4.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = None
    if hasattr(model.module, 'loss'):
        print('> Using custom loss')
        loss_fn = model.module.loss
    else:
        # loss_fn = cross_entropy2d

        class_weight = np.array([0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437,
                                 5.58540548, 3.56563995, 0.12704978, 1.,         0.46783719, 1.34551528,
                                 5.29974114, 0.28342531, 0.9396095,  0.81551811, 0.42679146, 3.6399074,
                                 2.78376194], dtype=float)

        """
        class_weight = np.array([3.045384,  12.862123,   4.509889,  38.15694,  35.25279,  31.482613,
                                 45.792305,  39.694073,  6.0639296,  32.16484,  17.109228,   31.563286,
                                 47.333973,  11.610675,  44.60042,   45.23716,  45.283024,  48.14782,
                                 41.924667], dtype=float)/10.0
        """
        class_weight = torch.from_numpy(class_weight).float().cuda()
        se_loss = SemanticEncodingLoss(num_classes=19, ignore_label=250, alpha=0.20)
        ce_loss = bootstrapped_cross_entropy2d

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    best_iou = -100.0
    args.start_epoch = 0
    if args.resume is not None:
        full_path = "{}{}".format(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(full_path)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])          # weights
            optimizer.load_state_dict(checkpoint['optimizer_state'])  # gradient state

            # for param_group in optimizer.param_groups:
            # s    param_group['lr'] = 1e-5

            del checkpoint
            print("> Loaded checkpoint '{}' (epoch {}, iou {})".format(args.resume,
                                                                       args.start_epoch,
                                                                       best_iou))

        else:
            print("> No checkpoint found at '{}'".format(args.resume))
    else:
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()

            pretrained_dict = {k: v for k, v in pre_weight.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup visdom for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.visdom:
        writer = SummaryWriter(log_dir=log_dir, comment="MobileNetV2Context")

    if args.visdom:
        dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda(), requires_grad=True)
        writer.add_graph(model, dummy_input)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    train_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=6, shuffle=True)
    valid_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=6)

    num_batches = int(math.ceil(len(train_loader.dataset.files[train_loader.dataset.split]) /
                                float(train_loader.batch_size)))

    lr_period = 20 * num_batches
    swa_weights = model.state_dict()

    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    # scheduler = CyclicLR(optimizer, base_lr=1.0e-3, max_lr=6.0e-3, step_size=2*num_batches)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    topk_init = 512
    # topk_multipliers = [64, 128, 256, 512]
    for epoch in np.arange(args.start_epoch, args.n_epoch):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        last_loss = 0.0
        topk_base = topk_init
        pbar = tqdm(np.arange(num_batches))
        for train_i, (images, labels) in enumerate(train_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            # poly_lr_scheduler(optimizer, init_lr=args.l_rate, iter=full_iter,
            #                   lr_decay_iter=1, max_iter=args.n_epoch*num_batches, power=0.9)

            batch_lr = args.l_rate * cosine_annealing_lr(lr_period, full_iter)
            optimizer = set_optimizer_lr(optimizer, batch_lr)

            topk_base = poly_topk_scheduler(init_topk=topk_init, iter=full_iter, topk_decay_iter=1,
                                            max_iter=args.n_epoch*num_batches, power=0.95)

            images = Variable(images.cuda(), requires_grad=True)   # Image feed into the deep neural network
            se_labels = se_loss.unique_encode(labels)
            se_labels = Variable(se_labels.cuda(), requires_grad=False)
            ce_labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            enc1, enc2, net_out = model(images)  # Here we have 3 output for 3 loss

            topk = topk_base * 512
            if random.random() < 0.20:
                train_ce_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                                        weight=class_weight, size_average=True)
                train_se_loss1 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

                train_se_loss2 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)
            else:
                train_ce_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                                        weight=None, size_average=True)
                train_se_loss1 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

                train_se_loss2 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

            train_loss = train_ce_loss + train_se_loss1 + train_se_loss2

            last_loss = train_loss.data[0]
            last_ce_loss = train_ce_loss.data[0]
            last_se_loss1 = train_se_loss1.data[0]
            last_se_loss2 = train_se_loss2.data[0]
            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))
            pbar.set_postfix(Loss=last_loss, CELoss=last_ce_loss,
                             SELoss1=last_se_loss1, SELoss2=last_se_loss2,
                             TopK=topk_base, LR=batch_lr)

            train_loss.backward()
            optimizer.step()

            if full_iter % lr_period == 0:
                swa_weights = update_aggregated_weight_average(model, swa_weights, full_iter, lr_period)
                state = {'model_state': swa_weights}
                torch.save(state, "{}{}_mobilenetv2context_swa_model.pkl".format(weight_dir, args.dataset))

            if (train_i + 1) % 31 == 0:
                loss_log = "Epoch [%d/%d], Iter: %d, Loss: \t %.4f, CELoss: \t %.4f, " \
                           "SELoss1: \t %.4f, SELoss2: \t%.4f, " % (epoch + 1, args.n_epoch,
                                                                    train_i + 1, last_loss,
                                                                    last_ce_loss,
                                                                    last_se_loss1,
                                                                    last_se_loss2)

                net_out = F.softmax(net_out, dim=1)
                pred = net_out.data.max(1)[1].cpu().numpy()
                gt = ce_labels.data.cpu().numpy()

                running_metrics.update(gt, pred)
                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.visdom:
                    writer.add_scalar('Training/Losses', last_loss, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()

        mval_loss = 0.0
        vali_count = 0
        for i_val, (images, labels) in enumerate(valid_loader):
            vali_count += 1

            images = Variable(images.cuda(), volatile=True)
            ce_labels = Variable(labels.cuda(), requires_grad=False)

            enc1, enc2, net_out = model(images)  # Here we have 4 output for 4 loss

            topk = topk_base * 512
            val_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                               weight=None, size_average=False)

            mval_loss += val_loss.data[0]

            net_out = F.softmax(net_out, dim=1)
            pred = net_out.data.max(1)[1].cpu().numpy()
            gt = ce_labels.data.cpu().numpy()
            running_metrics.update(gt, pred)

        mval_loss /= vali_count

        loss_log = "Epoch [%d/%d] Loss: \t %.4f" % (epoch + 1, args.n_epoch, mval_loss)
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss, Vali_Loss=mval_loss, Vali_mIoU=score['Mean_IoU'])

        if args.visdom:
            writer.add_scalar('Validation/Losses', mval_loss, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)
            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

            # export scalar data to JSON for external processing
            # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {'epoch': epoch + 1,
                     "best_iou": best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()}
            torch.save(state, "{}{}_mobilenetv2context_best_model.pkl".format(weight_dir, args.dataset))

        # scheduler.step()
        # scheduler.batch_step()
        pbar.close()

    if args.visdom:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")