Exemplo n.º 1
0
def main():
    opts = get_argparser().parse_args()
    print(opts)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)
    """
    Training DeepLab by v2 protocol
    """
    # Configuration

    with open(opts.config_path) as f:
        CONFIG = Dict(yaml.load(f))

    device = get_device(opts.cuda)
    torch.backends.cudnn.benchmark = True

    # Dataset
    train_dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=True,
        base_size=CONFIG.IMAGE.SIZE.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        scales=CONFIG.DATASET.SCALES,
        flip=True,
        gt_path=opts.gt_path,
    )
    print(train_dataset)
    print()

    valid_dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.VAL,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G,
                  CONFIG.IMAGE.MEAN.R),
        augment=False,
        gt_path="SegmentationClassAug",
    )
    print(valid_dataset)

    # DataLoader
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
        pin_memory=True,
        drop_last=True,
    )
    valid_loader = torch.utils.data.DataLoader(
        dataset=valid_dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TEST,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=False,
        pin_memory=True,
    )

    # Model check
    print("Model:", CONFIG.MODEL.NAME)
    assert (CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC"
            ), 'Currently support only "DeepLabV2_ResNet101_MSC"'

    # Model setup
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL, map_location='cpu')

    for m in model.base.state_dict().keys():
        if m not in state_dict.keys():
            print("    Skip init:", m)

    model.base.load_state_dict(state_dict, strict=False)  # to skip ASPP
    model = nn.DataParallel(model)
    model.to(device)

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Path to save models
    checkpoint_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "models",
        opts.log_dir,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.TRAIN,
    )
    makedirs(checkpoint_dir)
    print("Checkpoint dst:", checkpoint_dir)

    def set_train(model):
        model.train()
        model.module.base.freeze_bn()

    metrics = StreamSegMetrics(CONFIG.DATASET.N_CLASSES)

    scaler = torch.cuda.amp.GradScaler(enabled=opts.amp)
    avg_loss = AverageMeter()
    avg_time = AverageMeter()

    set_train(model)
    best_score = 0
    end_time = time.time()

    for iteration in range(1, CONFIG.SOLVER.ITER_MAX + 1):
        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                _, images, labels, cls_labels = next(train_loader_iter)
            except:
                train_loader_iter = iter(train_loader)
                _, images, labels, cls_labels = next(train_loader_iter)
                avg_loss.reset()
                avg_time.reset()

            with torch.cuda.amp.autocast(enabled=opts.amp):
                # Propagate forward
                logits = model(images.to(device, non_blocking=True))

                # Loss
                iter_loss = 0
                for logit in logits:
                    # Resize labels for {100%, 75%, 50%, Max} logits
                    _, _, H, W = logit.shape
                    labels_ = resize_labels(labels, size=(H, W))

                    pseudo_labels = logit.detach(
                    ) * cls_labels[:, :, None, None].to(device)
                    pseudo_labels = pseudo_labels.argmax(dim=1)

                    _loss = criterion(logit, labels_.to(device, )) + criterion(
                        logit, pseudo_labels)

                    iter_loss += _loss

                # Propagate backward (just compute gradients wrt the loss)
                iter_loss /= CONFIG.SOLVER.ITER_SIZE

            scaler.scale(iter_loss).backward()
            loss += iter_loss.item()

        # Update weights with accumulated gradients
        scaler.step(optimizer)
        scaler.update()

        # Update learning rate
        scheduler.step(epoch=iteration)

        avg_loss.update(loss)
        avg_time.update(time.time() - end_time)
        end_time = time.time()

        # TensorBoard
        if iteration % 100 == 0:
            print(" Itrs %d/%d, Loss=%6f, Time=%.2f , LR=%.8f" %
                  (iteration, CONFIG.SOLVER.ITER_MAX, avg_loss.avg,
                   avg_time.avg * 1000, optimizer.param_groups[0]['lr']))

        # validation
        if iteration % opts.val_interval == 0:
            print("... validation")
            model.eval()
            metrics.reset()
            with torch.no_grad():
                for _, images, labels, _ in valid_loader:
                    images = images.to(device, non_blocking=True)

                    # Forward propagation
                    logits = model(images)

                    # Pixel-wise labeling
                    _, H, W = labels.shape
                    logits = F.interpolate(logits,
                                           size=(H, W),
                                           mode="bilinear",
                                           align_corners=False)
                    preds = torch.argmax(logits, dim=1).cpu().numpy()
                    targets = labels.cpu().numpy()
                    metrics.update(targets, preds)

            set_train(model)
            score = metrics.get_results()
            print(metrics.to_str(score))

            if score['Mean IoU'] > best_score:  # save best model
                best_score = score['Mean IoU']
                torch.save(model.module.state_dict(),
                           os.path.join(checkpoint_dir, "checkpoint_best.pth"))
            end_time = time.time()
Exemplo n.º 2
0
def train(config, cuda):
    # Auto-tune cuDNN
    torch.backends.cudnn.benchmark = True

    # Configuration
    device = get_device(cuda)
    CONFIG = Dict(yaml.load(open(config)))

    # Dataset 10k or 164k
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        base_size=CONFIG.IMAGE.SIZE.TRAIN.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN.CROP,
        mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        warp=CONFIG.DATASET.WARP_IMAGE,
        scale=CONFIG.DATASET.SCALES,
        flip=True,
    )

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model
    model = setup_model(CONFIG.MODEL.INIT_MODEL,
                        CONFIG.DATASET.N_CLASSES,
                        train=True)
    model.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # TensorBoard logger
    writer = SummaryWriter(CONFIG.SOLVER.LOG_DIR)
    average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS)

    # Freeze the batch norm pre-trained on COCO
    model.train()
    model.module.base.freeze_bn()

    for iteration in tqdm(
            range(1, CONFIG.SOLVER.ITER_MAX + 1),
            total=CONFIG.SOLVER.ITER_MAX,
            leave=False,
            dynamic_ncols=True,
    ):

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                images, labels = next(loader_iter)

            images = images.to(device)
            labels = labels.to(device)

            # Propagate forward
            logits = model(images)

            # Loss
            iter_loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                _, _, H, W = logit.shape
                labels_ = resize_labels(labels, shape=(H, W))
                iter_loss += criterion(logit, labels_)

            # Backpropagate (just compute gradients wrt the loss)
            iter_loss /= CONFIG.SOLVER.ITER_SIZE
            iter_loss.backward()

            loss += float(iter_loss)

        average_loss.add(loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # Update learning rate
        scheduler.step(epoch=iteration)

        # TensorBoard
        if iteration % CONFIG.SOLVER.ITER_TB == 0:
            writer.add_scalar("loss/train", average_loss.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("lr/group{}".format(i), o["lr"], iteration)
            if False:  # This produces a large log file
                for name, param in model.named_parameters():
                    name = name.replace(".", "/")
                    # Weight/gradient distribution
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(name + "/grad",
                                             param.grad,
                                             iteration,
                                             bins="auto")

        # Save a model
        if iteration % CONFIG.SOLVER.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.MODEL.SAVE_DIR,
                         "checkpoint_{}.pth".format(iteration)),
            )

        # To verify progress separately
        torch.save(
            model.module.state_dict(),
            osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_current.pth"),
        )

    torch.save(
        model.module.state_dict(),
        osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_final.pth"),
    )
Exemplo n.º 3
0
def train(config_path, cuda):
    """
    Training DeepLab by v2 protocol
    """

    # Configuration
    CONFIG = Dict(yaml.load(config_path))
    device = get_device(cuda)
    torch.backends.cudnn.benchmark = True

    # Dataset
    dataset = get_dataset(CONFIG.DATASET.NAME)(
        root=CONFIG.DATASET.ROOT,
        split=CONFIG.DATASET.SPLIT.TRAIN,
        ignore_label=CONFIG.DATASET.IGNORE_LABEL,
        mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R),
        augment=True,
        base_size=CONFIG.IMAGE.SIZE.BASE,
        crop_size=CONFIG.IMAGE.SIZE.TRAIN,
        scales=CONFIG.DATASET.SCALES,
        flip=True,
    )
    print(dataset)

    # DataLoader
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        shuffle=True,
    )
    loader_iter = iter(loader)

    # Model check
    print("Model:", CONFIG.MODEL.NAME)
    assert (
        CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC"
    ), 'Currently support only "DeepLabV2_ResNet101_MSC"'

    # Model setup
    model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)
    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    for m in model.base.state_dict().keys():
        if m not in state_dict.keys():
            print("    Skip init:", m)
    model.base.load_state_dict(state_dict, strict=False)  # to skip ASPP
    model = nn.DataParallel(model)
    model.to(device)

    # Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)

    # Optimizer
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )

    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    # Setup loss logger
    writer = SummaryWriter(os.path.join(CONFIG.EXP.OUTPUT_DIR, "logs", CONFIG.EXP.ID))
    average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS)

    # Path to save models
    checkpoint_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,
        "models",
        CONFIG.EXP.ID,
        CONFIG.MODEL.NAME.lower(),
        CONFIG.DATASET.SPLIT.TRAIN,
    )
    makedirs(checkpoint_dir)
    print("Checkpoint dst:", checkpoint_dir)

    # Freeze the batch norm pre-trained on COCO
    model.train()
    model.module.base.freeze_bn()

    for iteration in tqdm(
        range(1, CONFIG.SOLVER.ITER_MAX + 1),
        total=CONFIG.SOLVER.ITER_MAX,
        dynamic_ncols=True,
    ):

        # Clear gradients (ready to accumulate)
        optimizer.zero_grad()

        loss = 0
        for _ in range(CONFIG.SOLVER.ITER_SIZE):
            try:
                _, images, labels = next(loader_iter)
            except:
                loader_iter = iter(loader)
                _, images, labels = next(loader_iter)

            # Propagate forward
            logits = model(images.to(device))

            # Loss
            iter_loss = 0
            for logit in logits:
                # Resize labels for {100%, 75%, 50%, Max} logits
                _, _, H, W = logit.shape
                labels_ = resize_labels(labels, size=(H, W))
                iter_loss += criterion(logit, labels_.to(device))

            # Propagate backward (just compute gradients wrt the loss)
            iter_loss /= CONFIG.SOLVER.ITER_SIZE
            iter_loss.backward()

            loss += float(iter_loss)

        #print(loss)
        average_loss.add(loss)

        # Update weights with accumulated gradients
        optimizer.step()

        # Update learning rate
        scheduler.step(epoch=iteration)

        # TensorBoard
        if iteration % CONFIG.SOLVER.ITER_TB == 0:
            writer.add_scalar("loss/train", average_loss.value()[0], iteration)
            for i, o in enumerate(optimizer.param_groups):
                writer.add_scalar("lr/group_{}".format(i), o["lr"], iteration)
            for i in range(torch.cuda.device_count()):
                writer.add_scalar(
                    "gpu/device_{}/memory_cached".format(i),
                    torch.cuda.memory_cached(i) / 1024 ** 3,
                    iteration,
                )

            if False:
                for name, param in model.module.base.named_parameters():
                    name = name.replace(".", "/")
                    # Weight/gradient distribution
                    writer.add_histogram(name, param, iteration, bins="auto")
                    if param.requires_grad:
                        writer.add_histogram(
                            name + "/grad", param.grad, iteration, bins="auto"
                        )

        # Save a model
        if iteration % CONFIG.SOLVER.ITER_SAVE == 0:
            torch.save(
                model.module.state_dict(),
                os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(iteration)),
            )

    torch.save(
        model.module.state_dict(), os.path.join(checkpoint_dir, "checkpoint_final.pth")
    )
Exemplo n.º 4
0
def train():
    """Create the model and start the training."""
    # === 1.Configuration
    print(CONFIG_PATH)
    # === select which GPU you want to use
    # === here assume to use 8 GPUs, idx are 0,1,2,3,...,7
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, CONFIG.EXP.GPU_IDX))

    device = get_device(torch.cuda.is_available())
    cudnn.benchmark = True
    comment_init = ""
    writer = SummaryWriter(comment=comment_init)  # Setup loss logger
    # === MovingAverageValueMeter(self,windowsize)
    # === - add(value): 记录value
    # === - reset()
    # === - value() : 返回MA和标准差
    average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS)
    if not os.path.exists(CONFIG.MODEL.SAVE_PATH):
        os.makedirs(CONFIG.MODEL.SAVE_PATH)
    # Path to save models
    checkpoint_dir = os.path.join(
        CONFIG.EXP.OUTPUT_DIR,  # ./data
        "models",
        CONFIG.MODEL.NAME.lower(),  # DeepLabV2_ResNet101_MSC
        CONFIG.DATASET.SPLIT.TRAIN,  # train_aug
    )
    # === checkpoint_dir: ./data/DeepLabV2_ResNet101_MSC/train_aug
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    print("Checkpoint dst:", checkpoint_dir)

    # === 2.Dataloader ===
    trainloader = data.DataLoader(
        VOCDataSet(
            CONFIG.DATASET.DIRECTORY,
            CONFIG.DATASET.LIST_PATH,
            max_iters=CONFIG.SOLVER.ITER_MAX * CONFIG.SOLVER.BATCH_SIZE.TRAIN,
            crop_size=(CONFIG.IMAGE.SIZE.TRAIN, CONFIG.IMAGE.SIZE.TRAIN),
            scale=CONFIG.DATASET.RANDOM.SCALE,
            mirror=CONFIG.DATASET.RANDOM.MIRROR,
            mean=IMG_MEAN,
            label_path=CONFIG.DATASET.SEG_LABEL),  # for training 
        batch_size=CONFIG.SOLVER.BATCH_SIZE.TRAIN,
        shuffle=True,
        num_workers=CONFIG.DATALOADER.NUM_WORKERS,
        pin_memory=True)

    # 使用iter(dataloader)返回的是一个迭代器,可以使用next访问
    # loader_iter = iter(trainloader)

    # === 3.Create network & weights ===
    print("Model:", CONFIG.MODEL.NAME)

    # model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
    model = DeepLabV2_DRN105_MSC(n_classes=CONFIG.DATASET.N_CLASSES)
    state_dict = torch.load(CONFIG.MODEL.INIT_MODEL)
    # model.base.load_state_dict(state_dict, strict=False)  # to skip ASPP
    print("    Init:", CONFIG.MODEL.INIT_MODEL)
    # === show the skip weight
    for m in model.base.state_dict().keys():
        if m not in state_dict.keys():
            print("    Skip init:", m)

    # === DeepLabv2 = Res101+ASPP
    # === model.base = DeepLabv2
    # === model = MSC(DeepLabv2)
    # model.base.load_state_dict(state_dict,
    #                            strict=False)  # strict=False to skip ASPP
    model = nn.DataParallel(model)  # multi-GPU
    model.to(device)  # put in GPU is available
    # === 4.Loss definition
    criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL)
    criterion.to(device)  # put in GPU is available

    # === 5.optimizer ===
    optimizer = torch.optim.SGD(
        # cf lr_mult and decay_mult in train.prototxt
        params=[
            {
                "params": get_params(model.module, key="1x"),
                "lr": CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="10x"),
                "lr": 10 * CONFIG.SOLVER.LR,
                "weight_decay": CONFIG.SOLVER.WEIGHT_DECAY,
            },
            {
                "params": get_params(model.module, key="20x"),
                "lr": 20 * CONFIG.SOLVER.LR,
                "weight_decay": 0.0,
            },
        ],
        momentum=CONFIG.SOLVER.MOMENTUM,
    )
    # Learning rate scheduler
    scheduler = PolynomialLR(
        optimizer=optimizer,
        step_size=CONFIG.SOLVER.LR_DECAY,
        iter_max=CONFIG.SOLVER.ITER_MAX,
        power=CONFIG.SOLVER.POLY_POWER,
    )

    time_start = time.time()  # set start time
    # === training iteration ===
    for i_iter, batch in enumerate(trainloader, start=1):
        torch.set_grad_enabled(True)
        model.train()
        model.module.base.freeze_bn()
        optimizer.zero_grad()
        images, labels, _, _ = batch

        logits = model(images.to(device))
        # <<<<<<<<<<<<<<<<<<<<
        # === Loss
        # === logits = [logits] + logits_pyramid + [logits_max]
        iter_loss = 0
        loss = 0
        for logit in logits:
            # Resize labels for {100%, 75%, 50%, Max} logits
            _, _, H, W = logit.shape
            labels_ = resize_labels(labels, size=(H, W))
            iter_loss += criterion(logit, labels_.to(device))
        # iter_loss /= CONFIG.SOLVER.ITER_SIZE
        iter_loss /= 4
        iter_loss.backward()
        loss += float(iter_loss)

        average_loss.add(loss)
        # Update weights with accumulated gradients
        optimizer.step()

        # Update learning rate
        scheduler.step(epoch=i_iter)

        # TensorBoard
        writer.add_scalar("loss", average_loss.value()[0], global_step=i_iter)
        print(
            'iter/max_iter = [{}/{}]  completed, loss = {:4.3} time:{}'.format(
                i_iter, CONFIG.SOLVER.ITER_MAX,
                average_loss.value()[0], show_timing(time_start, time.time())))
        # print('iter = ', i_iter, 'of', args.num_steps, '',
        #       loss.data.cpu().numpy())

        # === save final model
        if i_iter >= CONFIG.SOLVER.ITER_MAX:
            print('save final model as...{}'.format(
                osp.join(CONFIG.MODEL.SAVE_PATH,
                         'VOC12_' + str(CONFIG.SOLVER.ITER_MAX) + '.pth')))
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.MODEL.SAVE_PATH,
                         'VOC12_' + str(CONFIG.SOLVER.ITER_MAX) + '.pth'))
            break
        if i_iter % CONFIG.EXP.EVALUATE_ITER == 0:
            print("Evaluation....")
            evaluate_gpu(model, writer, i_iter)

        # === Save model every 250 iteration==========================
        # because DataParalel will add 'module' in each name of layer.
        # so here use model.module.state_dict()
        # ============================================================
        if i_iter % CONFIG.MODEL.SAVE_EVERY_ITER == 0:
            print('saving model ...')
            torch.save(
                model.module.state_dict(),
                osp.join(CONFIG.MODEL.SAVE_PATH,
                         'VOC12_{}.pth'.format(i_iter)))