Exemplo n.º 1
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/RFMobileNetV2Plus-{}".format(save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows*args.crop_ratio), int(args.img_cols*args.crop_ratio)

    augment_train = Compose([RandomHorizontallyFlip(), RandomSized((0.5, 0.75)),
                             RandomRotate(5), RandomCrop((net_h, net_w))])
    augment_valid = Compose([RandomHorizontallyFlip(), Scale((args.img_rows, args.img_cols)),
                             CenterCrop((net_h, net_w))])

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='train',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_train)
    valid_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='val',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_valid)

    n_classes = train_loader.n_classes

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")

    model = RFMobileNetV2Plus(n_class=n_classes, in_size=(net_h, net_w), width_mult=1.0,
                              out_sec=256, aspp_sec=(12, 24, 36),
                              norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))
    """

    model = MobileNetV2Plus(n_class=n_classes, in_size=(net_h, net_w), width_mult=1.0,
                            out_sec=256, aspp_sec=(12, 24, 36),
                            norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))
    """
    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

    # 4.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.90,
                                    weight_decay=5e-4, nesterov=True)

        # for pg in optimizer.param_groups:
        #     print(pg['lr'])

        # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999),
        #                             eps=1e-08, weight_decay=0, amsgrad=True)
        # optimizer = YFOptimizer(model.parameters(), lr=2.5e-3, mu=0.9, clip_thresh=10000, weight_decay=5e-4)

    # 4.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = None
    if hasattr(model.module, 'loss'):
        print('> Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = lovasz_softmax

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    best_iou = -100.0
    args.start_epoch = 0
    if args.resume is not None:
        full_path = "{}{}".format(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(full_path)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])          # weights
            optimizer.load_state_dict(checkpoint['optimizer_state'])  # gradient state

            # for param_group in optimizer.param_groups:
            # s    param_group['lr'] = 1e-5

            del checkpoint
            print("> Loaded checkpoint '{}' (epoch {}, iou {})".format(args.resume,
                                                                       args.start_epoch,
                                                                       best_iou))

        else:
            print("> No checkpoint found at '{}'".format(args.resume))
    else:
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()

            pretrained_dict = {k: v for k, v in pre_weight.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup tensor_board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="RFMobileNetV2Plus")
        dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda(), requires_grad=True)
        writer.add_graph(model, dummy_input)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    train_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=6, shuffle=True)
    valid_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=6)

    num_batches = int(math.ceil(len(train_loader.dataset.files[train_loader.dataset.split]) /
                                float(train_loader.batch_size)))

    lr_period = 20 * num_batches
    swa_weights = model.state_dict()

    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    # scheduler = CyclicLR(optimizer, base_lr=1.0e-3, max_lr=6.0e-3, step_size=2*num_batches)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    for epoch in np.arange(args.start_epoch, args.n_epoch):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        last_loss = 0.0
        pbar = tqdm(np.arange(num_batches))
        for train_i, (images, labels) in enumerate(train_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            # poly_lr_scheduler(optimizer, init_lr=args.l_rate, iter=full_iter,
            #                   lr_decay_iter=1, max_iter=args.n_epoch*num_batches, power=0.9)

            batch_lr = args.l_rate * cosine_annealing_lr(lr_period, full_iter)
            optimizer = set_optimizer_lr(optimizer, batch_lr)

            images = Variable(images.cuda(), requires_grad=True)   # Image feed into the deep neural network
            labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            net_out = model(images)  # Here we have 3 output for 3 loss
            net_out = F.softmax(net_out, dim=1)
            loss = lovasz_softmax(net_out, labels, ignore=250)

            last_loss = loss.data[0]
            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))
            pbar.set_postfix(Loss=last_loss, LR=batch_lr)

            loss.backward()
            optimizer.step()

            if full_iter % lr_period == 0:
                swa_weights = update_aggregated_weight_average(model, swa_weights, full_iter, lr_period)
                state = {'model_state': swa_weights}
                torch.save(state, "{}{}_rfmobilenetv2_swa_model.pkl".format(weight_dir, args.dataset))

            if (train_i + 1) % 31 == 0:
                loss_log = "Epoch [%d/%d], Iter: %d Loss: \t %.4f" % (epoch + 1, args.n_epoch,
                                                                      train_i + 1, last_loss)

                # net_out = F.softmax(net_out, dim=1)
                pred = net_out.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()

                running_metrics.update(gt, pred)
                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.tensor_board:
                    writer.add_scalar('Training/Losses', last_loss, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()

        mval_loss = 0.0
        vali_count = 0
        for i_val, (images, labels) in enumerate(valid_loader):
            vali_count += 1

            images = Variable(images.cuda(), volatile=True)
            labels = Variable(labels.cuda(), volatile=True)

            net_out = model(images)  # Here we have 4 output for 4 loss
            net_out = F.softmax(net_out, dim=1)
            loss = lovasz_softmax(net_out, labels, ignore=250)
            mval_loss += loss.data[0]

            pred = net_out.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics.update(gt, pred)

        mval_loss /= vali_count

        loss_log = "Epoch [%d/%d] Loss: \t %.4f" % (epoch + 1, args.n_epoch, mval_loss)
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss, Vali_Loss=mval_loss, Vali_mIoU=score['Mean_IoU'])

        if args.tensor_board:
            writer.add_scalar('Validation/Losses', mval_loss, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)
            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {'epoch': epoch + 1,
                     "best_iou": best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()}
            torch.save(state, "{}{}_rfmobilenetv2_lovasz_best_model.pkl".format(weight_dir, args.dataset))

        # scheduler.step()
        # scheduler.batch_step()
        pbar.close()

    if args.tensor_board:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
Exemplo n.º 2
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/MobileNetV2Context-{}".format(save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows*args.crop_ratio), int(args.img_cols*args.crop_ratio)

    augment_train = Compose([RandomHorizontallyFlip(), RandomSized((0.5, 0.75)),
                             RandomRotate(5), RandomCrop((net_h, net_w))])
    augment_valid = Compose([RandomHorizontallyFlip(), Scale((args.img_rows, args.img_cols)),
                             CenterCrop((net_h, net_w))])

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='train',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_train)
    valid_loader = CityscapesLoader(data_root, gt="gtFine", is_transform=True, split='val',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_valid)

    n_classes = train_loader.n_classes

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")

    model = MobileNetV2Context(n_class=19, in_size=(net_h, net_w), width_mult=1., out_sec=256, context=(32, 4),
                               norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))

    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # 4.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.90,
                                    weight_decay=5e-4, nesterov=True)

    # 4.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    class_weight = None
    if hasattr(model.module, 'loss'):
        print('> Using custom loss')
        loss_fn = model.module.loss
    else:
        # loss_fn = cross_entropy2d

        class_weight = np.array([0.05570516, 0.32337477, 0.08998544, 1.03602707, 1.03413147, 1.68195437,
                                 5.58540548, 3.56563995, 0.12704978, 1.,         0.46783719, 1.34551528,
                                 5.29974114, 0.28342531, 0.9396095,  0.81551811, 0.42679146, 3.6399074,
                                 2.78376194], dtype=float)

        """
        class_weight = np.array([3.045384,  12.862123,   4.509889,  38.15694,  35.25279,  31.482613,
                                 45.792305,  39.694073,  6.0639296,  32.16484,  17.109228,   31.563286,
                                 47.333973,  11.610675,  44.60042,   45.23716,  45.283024,  48.14782,
                                 41.924667], dtype=float)/10.0
        """
        class_weight = torch.from_numpy(class_weight).float().cuda()
        se_loss = SemanticEncodingLoss(num_classes=19, ignore_label=250, alpha=0.20)
        ce_loss = bootstrapped_cross_entropy2d

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    best_iou = -100.0
    args.start_epoch = 0
    if args.resume is not None:
        full_path = "{}{}".format(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(args.resume))

            checkpoint = torch.load(full_path)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])          # weights
            optimizer.load_state_dict(checkpoint['optimizer_state'])  # gradient state

            # for param_group in optimizer.param_groups:
            # s    param_group['lr'] = 1e-5

            del checkpoint
            print("> Loaded checkpoint '{}' (epoch {}, iou {})".format(args.resume,
                                                                       args.start_epoch,
                                                                       best_iou))

        else:
            print("> No checkpoint found at '{}'".format(args.resume))
    else:
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()

            pretrained_dict = {k: v for k, v in pre_weight.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup visdom for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.visdom:
        writer = SummaryWriter(log_dir=log_dir, comment="MobileNetV2Context")

    if args.visdom:
        dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda(), requires_grad=True)
        writer.add_graph(model, dummy_input)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    train_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=6, shuffle=True)
    valid_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=6)

    num_batches = int(math.ceil(len(train_loader.dataset.files[train_loader.dataset.split]) /
                                float(train_loader.batch_size)))

    lr_period = 20 * num_batches
    swa_weights = model.state_dict()

    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    # scheduler = CyclicLR(optimizer, base_lr=1.0e-3, max_lr=6.0e-3, step_size=2*num_batches)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=32, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    topk_init = 512
    # topk_multipliers = [64, 128, 256, 512]
    for epoch in np.arange(args.start_epoch, args.n_epoch):
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        last_loss = 0.0
        topk_base = topk_init
        pbar = tqdm(np.arange(num_batches))
        for train_i, (images, labels) in enumerate(train_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            # poly_lr_scheduler(optimizer, init_lr=args.l_rate, iter=full_iter,
            #                   lr_decay_iter=1, max_iter=args.n_epoch*num_batches, power=0.9)

            batch_lr = args.l_rate * cosine_annealing_lr(lr_period, full_iter)
            optimizer = set_optimizer_lr(optimizer, batch_lr)

            topk_base = poly_topk_scheduler(init_topk=topk_init, iter=full_iter, topk_decay_iter=1,
                                            max_iter=args.n_epoch*num_batches, power=0.95)

            images = Variable(images.cuda(), requires_grad=True)   # Image feed into the deep neural network
            se_labels = se_loss.unique_encode(labels)
            se_labels = Variable(se_labels.cuda(), requires_grad=False)
            ce_labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            enc1, enc2, net_out = model(images)  # Here we have 3 output for 3 loss

            topk = topk_base * 512
            if random.random() < 0.20:
                train_ce_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                                        weight=class_weight, size_average=True)
                train_se_loss1 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

                train_se_loss2 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)
            else:
                train_ce_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                                        weight=None, size_average=True)
                train_se_loss1 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

                train_se_loss2 = se_loss(predicts=enc1, enc_cls_target=se_labels, size_average=True)

            train_loss = train_ce_loss + train_se_loss1 + train_se_loss2

            last_loss = train_loss.data[0]
            last_ce_loss = train_ce_loss.data[0]
            last_se_loss1 = train_se_loss1.data[0]
            last_se_loss2 = train_se_loss2.data[0]
            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))
            pbar.set_postfix(Loss=last_loss, CELoss=last_ce_loss,
                             SELoss1=last_se_loss1, SELoss2=last_se_loss2,
                             TopK=topk_base, LR=batch_lr)

            train_loss.backward()
            optimizer.step()

            if full_iter % lr_period == 0:
                swa_weights = update_aggregated_weight_average(model, swa_weights, full_iter, lr_period)
                state = {'model_state': swa_weights}
                torch.save(state, "{}{}_mobilenetv2context_swa_model.pkl".format(weight_dir, args.dataset))

            if (train_i + 1) % 31 == 0:
                loss_log = "Epoch [%d/%d], Iter: %d, Loss: \t %.4f, CELoss: \t %.4f, " \
                           "SELoss1: \t %.4f, SELoss2: \t%.4f, " % (epoch + 1, args.n_epoch,
                                                                    train_i + 1, last_loss,
                                                                    last_ce_loss,
                                                                    last_se_loss1,
                                                                    last_se_loss2)

                net_out = F.softmax(net_out, dim=1)
                pred = net_out.data.max(1)[1].cpu().numpy()
                gt = ce_labels.data.cpu().numpy()

                running_metrics.update(gt, pred)
                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.visdom:
                    writer.add_scalar('Training/Losses', last_loss, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)
                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()

        mval_loss = 0.0
        vali_count = 0
        for i_val, (images, labels) in enumerate(valid_loader):
            vali_count += 1

            images = Variable(images.cuda(), volatile=True)
            ce_labels = Variable(labels.cuda(), requires_grad=False)

            enc1, enc2, net_out = model(images)  # Here we have 4 output for 4 loss

            topk = topk_base * 512
            val_loss = ce_loss(input=net_out, target=ce_labels, K=topk,
                               weight=None, size_average=False)

            mval_loss += val_loss.data[0]

            net_out = F.softmax(net_out, dim=1)
            pred = net_out.data.max(1)[1].cpu().numpy()
            gt = ce_labels.data.cpu().numpy()
            running_metrics.update(gt, pred)

        mval_loss /= vali_count

        loss_log = "Epoch [%d/%d] Loss: \t %.4f" % (epoch + 1, args.n_epoch, mval_loss)
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss, Vali_Loss=mval_loss, Vali_mIoU=score['Mean_IoU'])

        if args.visdom:
            writer.add_scalar('Validation/Losses', mval_loss, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)
            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

            # export scalar data to JSON for external processing
            # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {'epoch': epoch + 1,
                     "best_iou": best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()}
            torch.save(state, "{}{}_mobilenetv2context_best_model.pkl".format(weight_dir, args.dataset))

        # scheduler.step()
        # scheduler.batch_step()
        pbar.close()

    if args.visdom:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
Exemplo n.º 3
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/MobileNetV2Plus-{}".format(
        save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows * args.crop_ratio), int(args.img_cols *
                                                             args.crop_ratio)

    augment_train = Compose([
        RandomHorizontallyFlip(),
        RandomSized((0.625, 0.75)),
        RandomRotate(6),
        RandomCrop((net_h, net_w))
    ])
    augment_valid = Compose([
        RandomHorizontallyFlip(),
        RandomSized((0.625, 0.75)),
        CenterCrop((net_h, net_w))
    ])

    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = MapillaryVistasLoader(data_root,
                                         split="training",
                                         is_transform=True,
                                         img_size=(args.img_rows,
                                                   args.img_cols),
                                         augmentations=augment_train)
    valid_loader = MapillaryVistasLoader(data_root,
                                         split="validation",
                                         is_transform=True,
                                         img_size=(args.img_rows,
                                                   args.img_cols),
                                         augmentations=augment_valid)

    n_classes = train_loader.n_classes

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")

    model = MobileNetV2Plus(n_class=n_classes,
                            in_size=(net_h, net_w),
                            width_mult=1.0,
                            out_sec=(252, 86),
                            aspp_sec=(12, 24, 36))

    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0, 1]).cuda()

    # 4.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.90,
                                    weight_decay=5e-4,
                                    nesterov=True)

        # for pg in optimizer.param_groups:
        #     print(pg['lr'])

        # optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999),
        #                             eps=1e-08, weight_decay=0, amsgrad=True)
        # optimizer = YFOptimizer(model.parameters(), lr=2.5e-3, mu=0.9, clip_thresh=10000, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.90)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[30, 80], gamma=0.1)

    # 4.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if hasattr(model.module, 'loss'):
        print('> Using custom loss')
        loss_fn = model.module.loss
    else:
        # loss_fn = cross_entropy2d
        loss_fn = bootstrapped_cross_entropy2d

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    args.start_epoch = 0
    best_iou = -100.0
    if args.resume is not None:
        full_path = "{}{}".format(weight_dir, args.resume)
        if os.path.isfile(full_path):
            print("> Loading model and optimizer from checkpoint '{}'".format(
                args.resume))

            checkpoint = torch.load(full_path)
            args.start_epoch = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['model_state'])  # weights

            optimizer.load_state_dict(
                checkpoint['optimizer_state'])  # gradient state

            # optimizer = YFOptimizer(model.parameters(), lr=2.5e-3, mu=0.9, clip_thresh=10000, weight_decay=5e-4)
            del checkpoint

            print("> Loaded checkpoint '{}' (epoch {})".format(
                args.resume, args.start_epoch))

        else:
            print("> No checkpoint found at '{}'".format(args.resume))
    else:
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(
                args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()

            pretrained_dict = {
                k: v
                for k, v in pre_weight.items() if k in model_dict
            }
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup tensor_board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="MobileNetV2")
        dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda(),
                               requires_grad=True)
        writer.add_graph(model, dummy_input)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    train_loader = data.DataLoader(train_loader,
                                   batch_size=args.batch_size,
                                   num_workers=6,
                                   shuffle=True)
    valid_loader = data.DataLoader(valid_loader,
                                   batch_size=args.batch_size,
                                   num_workers=6)

    num_batches = int(
        math.ceil(
            len(train_loader.dataset.files[train_loader.dataset.split]) /
            float(train_loader.batch_size)))

    loss_wgt1 = 1.0
    loss_wgt2 = 1.0
    loss_wgt3 = 1.0

    topk_base = 512
    # topk_multipliers = [64, 128, 256, 512]
    for epoch in np.arange(args.start_epoch, args.n_epoch):
        pbar = tqdm(np.arange(num_batches))
        last_loss = [0.0, 0.0, 0.0]

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        for train_i, (images, labels) in enumerate(
                train_loader):  # One mini-Batch data, One iteration
            full_iter = (epoch * num_batches) + train_i + 1

            # poly_lr_scheduler(optimizer, init_lr=args.l_rate, iter=full_iter,
            #                   lr_decay_iter=1, max_iter=args.n_epoch*num_batches, power=0.9)

            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))

            images = Variable(
                images.cuda(),
                requires_grad=True)  # Image feed into the deep neural network
            labels = Variable(labels.cuda(), requires_grad=False)

            optimizer.zero_grad()
            out_stg1, out_stg2, out_stg3 = model(
                images)  # Here we have 3 output for 3 loss

            topk = topk_base * 256
            stg1_loss = loss_wgt1 * loss_fn(
                input=out_stg1, target=labels, K=topk)
            stg2_loss = loss_wgt2 * loss_fn(
                input=out_stg2, target=labels, K=topk)
            stg3_loss = loss_wgt3 * loss_fn(
                input=out_stg3, target=labels, K=topk)

            last_loss = [
                stg1_loss.data[0], stg2_loss.data[0], stg3_loss.data[0]
            ]
            loss = [stg1_loss, stg2_loss, stg3_loss]
            torch.autograd.backward(loss)
            optimizer.step()

            pbar.set_postfix(Loss1=stg1_loss.data[0],
                             Loss2=stg2_loss.data[0],
                             Loss3=stg3_loss.data[0])

            if (train_i + 1) % 62 == 0:

                loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f, Loss2: \t %.4f, " \
                           "Loss3: \t %.4f" % (epoch + 1, args.n_epoch, train_i + 1,
                                               last_loss[0], last_loss[1], last_loss[2])

                out_stg3 = F.softmax(out_stg3, dim=1)
                pred = out_stg3.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()

                running_metrics.update(gt, pred)

                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.tensor_board:
                    writer.add_scalars(
                        'Training/Losses', {
                            'Loss_Stage1': last_loss[0],
                            'Loss_Stage2': last_loss[1],
                            'Loss_Stage3': last_loss[2]
                        }, full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)

                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name,
                                             param.clone().cpu().data.numpy(),
                                             full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()
        val_loss = [0.0, 0.0, 0.0]
        vali_count = 0
        for i_val, (images_val, labels_val) in enumerate(valid_loader):
            vali_count += 1

            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            out_stg1, out_stg2, out_stg3 = model(
                images_val)  # Here we have 4 output for 4 loss

            topk = topk_base * 256
            stg1_val_loss = loss_wgt1 * loss_fn(
                input=out_stg1, target=labels_val, K=topk)
            stg2_val_loss = loss_wgt2 * loss_fn(
                input=out_stg2, target=labels_val, K=topk)
            stg3_val_loss = loss_wgt3 * loss_fn(
                input=out_stg3, target=labels_val, K=topk)

            val_loss = [
                val_loss[0] + stg1_val_loss.data[0],
                val_loss[1] + stg2_val_loss.data[0],
                val_loss[2] + stg3_val_loss.data[0]
            ]

            out_stg3 = F.softmax(out_stg3, dim=1)
            pred = out_stg3.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        val_loss = [
            val_loss[0] / vali_count, val_loss[1] / vali_count,
            val_loss[2] / vali_count
        ]

        loss_log = "Epoch [%d/%d] Loss1: \t %.4f, Loss2: \t %.4f, " \
                   "Loss3: \t %.4f" % (epoch + 1, args.n_epoch,
                                       val_loss[0], val_loss[1],  val_loss[2])
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss[1],
                         Vali_Loss=val_loss[1] / loss_wgt2,
                         Vali_mIoU=score['Mean_IoU'])

        if args.tensor_board:
            writer.add_scalars(
                'Validation/Losses', {
                    'Loss_Stage1': val_loss[0],
                    'Loss_Stage2': val_loss[1],
                    'Loss_Stage3': val_loss[2]
                }, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)

            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(), epoch)

            # export scalar data to JSON for external processing
            # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {
                'epoch': epoch + 1,
                "best_iou": best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict()
            }
            torch.save(
                state, "{}{}_mobilenetv2_gtfine_best_model.pkl".format(
                    weight_dir, args.dataset))

        # Note that step should be called after validate()
        scheduler.step()
        pbar.close()

    if args.tensor_board:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
Exemplo n.º 4
0
def train(args, data_root, save_root):
    weight_dir = "{}weights/".format(save_root)
    log_dir = "{}logs/SE-DPShuffleNet-{}".format(save_root, time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 1. Setup Augmentations
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    net_h, net_w = int(args.img_rows*args.crop_ratio), int(args.img_cols*args.crop_ratio)
    augment_train = Compose([RandomHorizontallyFlip(), RandomRotate(6), RandomCrop((net_h, net_w))])
    augment_valid = Compose([CenterCrop((net_h, net_w))])

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 2. Setup Dataloader
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 0. Setting up DataLoader...")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    train_loader = CityscapesLoader(data_root, is_transform=True, gt="gtCoarse", split='train_extra',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_train)
    valid_loader = CityscapesLoader(data_path, is_transform=True, gt="gtCoarse", split='val',
                                    img_size=(args.img_rows, args.img_cols),
                                    augmentations=augment_valid)

    n_classes = train_loader.n_classes
    train_loader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=8, shuffle=True)
    valid_loader = data.DataLoader(valid_loader, batch_size=args.batch_size, num_workers=8)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 3. Setup Metrics
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    running_metrics = RunningScore(n_classes)

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 4. Setup tensor_board for visualization
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    writer = None
    if args.tensor_board:
        writer = SummaryWriter(log_dir=log_dir, comment="SE-DPShuffleNet")

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 5. Setup Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> 1. Setting up Model...")
    model = SEDPNShuffleNet(small=False, classes=n_classes, in_size=(net_h, net_w), num_init_features=64,
                            k_r=96, groups=4, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
                            out_sec=(512, 256, 128), dil_sec=(1, 1, 1, 2, 4), aspp_sec=(6, 12, 18),
                            norm_act=partial(InPlaceABNWrapper, activation="leaky_relu", slope=0.1))

    # np.arange(torch.cuda.device_count())
    model = torch.nn.DataParallel(model, device_ids=[0]).cuda()

    # if args.tensor_board:
    #     dummy_input = Variable(torch.rand(1, 3, net_h, net_w).cuda())
    #     writer.add_graph(model, dummy_input)

    # 5.1 Setup Optimizer
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # optimizer = torch.optim.Adam(model.parameters(), lr=args.l_rate, weight_decay=5e-4)
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.90, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

    # 5.2 Setup Loss
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = bootstrapped_cross_entropy2d

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 6. Resume Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            full_path = "{}{}".format(weight_dir, args.resume)

            checkpoint = torch.load(full_path)
            model.load_state_dict(checkpoint['model_state'])          # weights
            optimizer.load_state_dict(checkpoint['optimizer_state'])  # gradient state

            print("Loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    else:
        init_weights(model)
        if args.pre_trained is not None:
            print("> Loading weights from pre-trained model '{}'".format(args.pre_trained))
            full_path = "{}{}".format(weight_dir, args.pre_trained)

            pre_weight = torch.load(full_path)
            pre_weight = pre_weight["model_state"]
            # pre_weight = pre_weight["state_dict"]

            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in pre_weight.items() if k in model_dict}
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            del pre_weight
            del model_dict
            del pretrained_dict

    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    # 7. Train Model
    # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
    print("> 2. Model Training start...")
    num_batches = int(math.ceil(len(train_loader.dataset.files[train_loader.dataset.split]) /
                                float(train_loader.batch_size)))

    loss_wgt1 = 1.0
    loss_wgt2 = 1.0
    loss_wgt3 = 1.0
    loss_wgt4 = 1.0

    best_iou = -100.0
    for epoch in np.arange(args.n_epoch):
        pbar = tqdm(np.arange(num_batches))
        last_loss = [0.0, 0.0, 0.0, 0.0]
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.1 Mini-Batch Learning
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Training Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.train()

        for train_i, (images, labels) in enumerate(train_loader):  # One mini-Batch data, One iteration
            pbar.update(1)
            pbar.set_description("> Epoch [%d/%d]" % (epoch + 1, args.n_epoch))

            images = Variable(images.cuda(), requires_grad=True)   # Image feed into the deep neural network
            labels = Variable(labels.cuda(async=True), requires_grad=False)

            optimizer.zero_grad()
            out_stg1, out_stg2, out_stg3, out_stg4 = model(images)  # Here we have 4 output for 4 loss

            stg1_loss = loss_wgt1 * loss_fn(input=out_stg1, target=labels, K=512*256)
            stg2_loss = loss_wgt2 * loss_fn(input=out_stg2, target=labels, K=512*256)
            stg3_loss = loss_wgt3 * loss_fn(input=out_stg3, target=labels, K=512*256)
            stg4_loss = loss_wgt4 * loss_fn(input=out_stg4, target=labels, K=512*256)

            last_loss = [stg1_loss.data[0], stg2_loss.data[0],
                         stg3_loss.data[0], stg4_loss.data[0]]

            loss = [stg1_loss, stg2_loss, stg3_loss, stg4_loss]
            torch.autograd.backward(loss)
            optimizer.step()

            pbar.set_postfix(Loss1=last_loss[0], Loss2=last_loss[1], Loss3=last_loss[2], Loss4=last_loss[3])

            if (train_i + 1) % 31 == 0:
                full_iter = (epoch*num_batches) + train_i + 1
                loss_log = "Epoch [%d/%d], Iter: %d Loss1: \t %.4f, Loss2: \t %.4f, " \
                           "Loss3: \t %.4f, Loss: \t %.4f," % (epoch + 1, args.n_epoch, train_i + 1,
                                                               last_loss[0], last_loss[1],
                                                               last_loss[2], last_loss[3])

                pred = out_stg4.data.max(1)[1].cpu().numpy()
                gt = labels.data.cpu().numpy()
                running_metrics.update(gt, pred)
                score, class_iou = running_metrics.get_scores()

                metric_log = ""
                for k, v in score.items():
                    metric_log += " {}: \t %.4f, ".format(k) % v
                running_metrics.reset()

                logs = loss_log + metric_log
                # print(logs)

                if args.tensor_board:
                    writer.add_scalars('Training/Losses',
                                       {'Loss_Stage1': last_loss[0], 'Loss_Stage2': last_loss[1],
                                        'Loss_Stage3': last_loss[2], 'Loss_Stage4': last_loss[3]},
                                       full_iter)
                    writer.add_scalars('Training/Metrics', score, full_iter)

                    writer.add_text('Training/Text', logs, full_iter)

                    for name, param in model.named_parameters():
                        writer.add_histogram(name, param.clone().cpu().data.numpy(), full_iter)

        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # 7.2 Mini-Batch Validation
        # +++++++++++++++++++++++++++++++++++++++++++++++++++ #
        # print("> Validation for Epoch [%d/%d]:" % (epoch + 1, args.n_epoch))
        model.eval()
        val_loss = [0.0, 0.0, 0.0, 0.0]

        vali_count = 0
        for i_val, (images_val, labels_val) in enumerate(valid_loader):
            vali_count += 1

            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            out_stg1, out_stg2, out_stg3, out_stg4 = model(images_val)  # Here we have 4 output for 4 loss
            stg1_val_loss = loss_wgt1 * loss_fn(input=out_stg1, target=labels_val, K=512*256)
            stg2_val_loss = loss_wgt2 * loss_fn(input=out_stg2, target=labels_val, K=512*256)
            stg3_val_loss = loss_wgt3 * loss_fn(input=out_stg3, target=labels_val, K=512*256)
            stg4_val_loss = loss_wgt4 * loss_fn(input=out_stg4, target=labels_val, K=512*256)

            val_loss = [val_loss[0] + stg1_val_loss.data[0], val_loss[1] + stg2_val_loss.data[0],
                        val_loss[2] + stg3_val_loss.data[0], val_loss[3] + stg4_val_loss.data[0]]

            pred = out_stg4.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        val_loss = [val_loss[0]/vali_count, val_loss[1]/vali_count,
                    val_loss[2]/vali_count, val_loss[3]/vali_count]

        loss_log = "Epoch [%d/%d] Loss1: \t %.4f, Loss2: \t %.4f, " \
                   "Loss3: \t %.4f, Loss: \t %.4f," % (epoch + 1, args.n_epoch,
                                                       val_loss[0], val_loss[1],
                                                       val_loss[2], val_loss[3])
        metric_log = ""
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            metric_log += " {} \t %.4f, ".format(k) % v
        running_metrics.reset()

        logs = loss_log + metric_log
        # print(logs)
        pbar.set_postfix(Train_Loss=last_loss[3], Vali_Loss=val_loss[3]/loss_wgt4, Vali_mIoU=score['Mean_IoU'])

        if args.tensor_board:
            writer.add_scalars('Validation/Losses',
                               {'Loss_Stage1': val_loss[0], 'Loss_Stage2': val_loss[1],
                                'Loss_Stage3': val_loss[2], 'Loss_Stage4': val_loss[3]}, epoch)
            writer.add_scalars('Validation/Metrics', score, epoch)

            writer.add_text('Validation/Text', logs, epoch)

            for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)

            # export scalar data to JSON for external processing
            # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))

        if score['Mean_IoU'] >= best_iou:
            best_iou = score['Mean_IoU']
            state = {'epoch': epoch + 1,
                     "best_iou": best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state': optimizer.state_dict()}
            torch.save(state, "{}{}_sedpshufflenet_best_model.pkl".format(weight_dir, args.dataset))

        # Note that step should be called after validate()
        scheduler.step()
        pbar.close()

    if args.tensor_board:
        # export scalar data to JSON for external processing
        # writer.export_scalars_to_json("{}/all_scalars.json".format(log_dir))
        writer.close()
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")
    print("> Training Done!!!")
    print("> # +++++++++++++++++++++++++++++++++++++++++++++++++++++++ #")