コード例 #1
0
ファイル: kd_model.py プロジェクト: suyanzhou626/IFVD
    def __init__(self, args):
        self.args = args
        student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                             num_classes=args.num_classes)
        load_S_model(args, student)
        print_model_parm_nums(student, 'student_model')
        student.cuda()
        self.student = student

        teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                             num_classes=args.num_classes)
        load_T_model(args, teacher)
        print_model_parm_nums(teacher, 'teacher_model')
        teacher.cuda()
        self.teacher = teacher

        D_model = Discriminator(args.preprocess_GAN_mode, args.num_classes,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model)
        print_model_parm_nums(D_model, 'D_model')
        logging.info("------------")
        D_model.cuda()
        self.D_model = D_model

        self.G_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 student.parameters()),
                'initial_lr': args.lr_g
            }],
            args.lr_g,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        self.D_solver = optim.Adam(
            filter(lambda p: p.requires_grad, D_model.parameters()), args.lr_d,
            [0.9, 0.99])
        # self.D_solver = optim.SGD([{'params': filter(lambda p: p.requires_grad, D_model.parameters()), 'initial_lr': args.lr_d}], args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay)

        self.criterion_dsn = CriterionDSN().cuda()
        if args.kd:
            self.criterion_kd = CriterionKD().cuda()
        if args.adv:
            self.criterion_adv = CriterionAdv(args.adv_loss_type).cuda()
            if args.adv_loss_type == 'wgan-gp':
                self.criterion_AdditionalGP = CriterionAdditionalGP(
                    D_model, args.lambda_gp).cuda()
            self.criterion_adv_for_G = CriterionAdvForG(
                args.adv_loss_type).cuda()
        if args.ifv:
            self.criterion_ifv = CriterionIFV(classes=args.num_classes).cuda()

        self.G_loss, self.D_loss = 0.0, 0.0
        self.mc_G_loss, self.kd_G_loss, self.adv_G_loss, self.ifv_G_loss = 0.0, 0.0, 0.0, 0.0

        cudnn.deterministic = True
        cudnn.benchmark = False
コード例 #2
0
ファイル: train.py プロジェクト: zuqiutxy/OCNet.pytorch
def main():
    print("Input arguments:")
    for key, val in vars(args).items():
        print("{:16} {}".format(key, val))

    random.seed(args.seed)
    torch.manual_seed(args.seed)

    writer = SummaryWriter(args.snapshot_dir)
    os.environ["CUDA_VISIBLE_DEVICES"]=args.gpu
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    cudnn.enabled = True

    deeplab = get_segmentation_model("_".join([args.network, args.method]), num_classes=args.num_classes)

    saved_state_dict = torch.load(args.restore_from)
    new_params = deeplab.state_dict().copy()

    if 'wide' in args.network:
        saved_state_dict = saved_state_dict['state_dict']
        if 'vistas' in args.method:
            saved_state_dict = saved_state_dict['body']
            for i in saved_state_dict:
                new_params[i] = saved_state_dict[i]
        else:     
            for i in saved_state_dict:
                i_parts = i.split('.')
                if not 'classifier' in i_parts: 
                    new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
    elif 'mobilenet' in args.network:
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not (i_parts[0]=='features' and i_parts[1]=='18') and not i_parts[0]=='classifier':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 
    else:
        for i in saved_state_dict:
            i_parts = i.split('.')
            if not i_parts[0]=='fc' and not  i_parts[0]=='last_linear' and not  i_parts[0]=='classifier':
                new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 

    if args.start_iters > 0:
        deeplab.load_state_dict(saved_state_dict)
    else:
        deeplab.load_state_dict(new_params)

    model = DataParallelModel(deeplab)
    # model = nn.DataParallel(deeplab)
    model.train()     
    model.float()
    model.cuda()    

    criterion = CriterionCrossEntropy()
    if "dsn" in args.method:
        if args.ohem:
            if args.ohem_single:
                print('use ohem only for the second prediction map.')
                criterion = CriterionOhemDSN_single(thres=args.ohem_thres, min_kept=args.ohem_keep, dsn_weight=float(args.dsn_weight))
            else:
                criterion = CriterionOhemDSN(thres=args.ohem_thres, min_kept=args.ohem_keep, dsn_weight=float(args.dsn_weight), use_weight=True)
        else:
            criterion = CriterionDSN(dsn_weight=float(args.dsn_weight), use_weight=True)


    criterion = DataParallelCriterion(criterion)
    criterion.cuda()
    cudnn.benchmark = True


    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(get_segmentation_dataset(args.dataset, root=args.data_dir, list_path=args.data_list,
                    max_iters=args.num_steps*args.batch_size, crop_size=input_size, 
                    scale=args.random_scale, mirror=args.random_mirror, network=args.network), 
                    batch_size=args.batch_size, shuffle=True, num_workers=1, pin_memory=True)

    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, deeplab.parameters()), 'lr': args.learning_rate }], 
                lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)


    optimizer.zero_grad()

    for i_iter, batch in enumerate(trainloader):
        sys.stdout.flush()
        i_iter += args.start_iters
        images, labels, _, _ = batch
        images = Variable(images.cuda())
        labels = Variable(labels.long().cuda())
        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)
        if args.fix_lr:
            lr = args.learning_rate
        print('learning_rate: {}'.format(lr))

        if 'gt' in args.method:
            preds = model(images, labels)
        else:
            preds = model(images)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)
        print('iter = {} of {} completed, loss = {}'.format(i_iter, args.num_steps, loss.data.cpu().numpy()))

        if i_iter >= args.num_steps-1:
            print('save model ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(args.num_steps)+'.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir, 'CS_scenes_'+str(i_iter)+'.pth'))     

    end = timeit.default_timer()
    print(end-start,'seconds')
コード例 #3
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        device = args.device
        student = Res_pspnet(BasicBlock, [2, 2, 2, 2],
                             num_classes=args.classes_num)
        load_S_model(args, student, False)
        print_model_parm_nums(student, 'student_model')
        self.parallel_student = self.DataParallelModelProcess(
            student, 2, 'train', device)
        self.student = student

        teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3],
                             num_classes=args.classes_num)
        load_T_model(teacher, args.T_ckpt_path)
        print_model_parm_nums(teacher, 'teacher_model')
        self.parallel_teacher = self.DataParallelModelProcess(
            teacher, 2, 'eval', device)
        self.teacher = teacher

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num,
                                args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        self.parallel_D = self.DataParallelModelProcess(
            D_model, 2, 'train', device)

        self.G_solver = optim.SGD([{
            'params':
            filter(lambda p: p.requires_grad, self.student.parameters()),
            'initial_lr':
            args.lr_g
        }],
                                  args.lr_g,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{
                'params': filter(lambda p: p.requires_grad,
                                 D_model.parameters()),
                'initial_lr': args.lr_d
            }],
            args.lr_d,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(
            CriterionDSN())  #CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(
            CriterionPixelWise())
        #self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale,
                                                   feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(
            CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(
            CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)
コード例 #4
0
def main():
    args.time = get_currect_time()

    visualizer = Visualizer(args)
    log = Log(args)
    log.record_sys_param()
    log.record_file()

    """Set GPU Environment"""
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    trainloader = data.DataLoader(NYUDataset_crop_fast(args.data_list, args.random_scale, args.random_mirror, args.random_crop,
                args.batch_size, args.colorjitter),batch_size=args.batch_size,
                shuffle=True, num_workers=4, pin_memory=True)
    valloader = data.DataLoader(NYUDataset_val_full(args.data_val_list, args.random_scale, args.random_mirror, args.random_crop,
                       1), batch_size=8, shuffle=False, pin_memory=True)

    """Create Network"""
    deeplab = Res_Deeplab(num_classes=args.num_classes)
    print(deeplab)

    """Load pretrained Network"""
    saved_state_dict = torch.load(args.restore_from)
    print(args.restore_from)
    new_params = deeplab.state_dict().copy()
    for i in saved_state_dict:
        # Scale.layer5.conv2d_list.3.weight
        i_parts = i.split('.')
        # print i_parts
        # if not i_parts[1]=='layer5':
        if not i_parts[0] == 'fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]

    deeplab.load_state_dict(new_params)

    model = deeplab
    model.cuda()
    model.train()
    model = model.float()
    model = DataParallelModel(model, device_ids=[0, 1])

    criterion = CriterionDSN()
    criterion = DataParallelCriterion(criterion)

    optimizer = optim.SGD([{'params': filter(lambda p: p.requires_grad, model.parameters()), 'lr': args.learning_rate }],
                lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay)

    optimizer.zero_grad()

    i_iter = 0
    args.num_steps = len(trainloader) * args.epoch
    best_iou = 0.0
    total = sum([param.nelement() for param in model.parameters()])
    print('  + Number of params: %.2fM' % (total / 1e6))

    for epoch in range(args.epoch):
        ## Train one epoch
        model.train()
        for batch in trainloader:
            start = timeit.default_timer()
            i_iter = i_iter + 1
            images = batch['image'].cuda()
            labels = batch['seg'].cuda()
            HHAs = batch['HHA'].cuda()
            depths = batch['depth'].cuda()
            labels = torch.squeeze(labels,1).long()
            if (images.size(0) != args.batch_size):
                break
            optimizer.zero_grad()
            preds = model(images, HHAs, depths)
            loss = criterion(preds, labels)
            loss.backward()
            optimizer.step()
            if i_iter % 100 == 0:
                visualizer.add_scalar('learning_rate', args.learning_rate, i_iter)
                visualizer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

            current_lr = optimizer.param_groups[0]['lr']
            end = timeit.default_timer()
            log.log_string(
                '====================> epoch=%03d/%d, iter=%05d/%05d, loss=%.3f, %.3fs/iter, %02d:%02d:%02d, lr=%.6f' % (
                    epoch, args.epoch, i_iter, len(trainloader)*args.epoch, loss.data.cpu().numpy(), (end - start),
                    (int((end - start) * (args.num_steps - i_iter)) // 3600),
                    (int((end - start) * (args.num_steps - i_iter)) % 3600 // 60),
                    (int((end - start) * (args.num_steps - i_iter)) % 3600 % 60), current_lr))
        if (epoch+1) % 40 == 0:
            adjust_learning_rate(optimizer, i_iter, args)

        if epoch % 5 == 0:
            model.eval()
            confusion_matrix = np.zeros((args.num_classes, args.num_classes))
            loss_val = 0
            log.log_string("====================> evaluating")
            for batch_val in valloader:
                images_val = batch_val['image'].cuda()
                labels_val = batch_val['seg'].cuda()
                labels_val = torch.squeeze(labels_val,1).long()
                HHAs_val = batch_val['HHA'].cuda()
                depths_val = batch_val['depth'].cuda()

                with torch.no_grad():
                    preds_val = model(images_val, HHAs_val, depths_val)
                    loss_val += criterion(preds_val, labels_val)
                    preds_val = torch.cat([preds_val[i][0] for i in range(len(preds_val))], 0)
                    preds_val = F.upsample(input=preds_val, size=(480, 640), mode='bilinear', align_corners=True)

                    preds_val = np.asarray(np.argmax(preds_val.cpu().numpy(), axis=1), dtype=np.uint8)

                    labels_val = np.asarray(labels_val.cpu().numpy(), dtype=np.int)
                    ignore_index = labels_val != 255

                    labels_val = labels_val[ignore_index]
                    preds_val = preds_val[ignore_index]

                    confusion_matrix += get_confusion_matrix(labels_val, preds_val, args.num_classes)
            loss_val = loss_val / len(valloader)
            pos = confusion_matrix.sum(1)
            res = confusion_matrix.sum(0)
            tp = np.diag(confusion_matrix)

            IU_array = (tp / np.maximum(1.0, pos + res - tp))
            mean_IU = IU_array.mean()

            # getConfusionMatrixPlot(confusion_matrix)
            log.log_string('val loss' + ' ' + str(loss_val.cpu().numpy()) + ' ' + 'meanIU' + str(mean_IU) + 'IU_array' + str(IU_array))

            visualizer.add_scalar('val loss', loss_val.cpu().numpy(), epoch)
            visualizer.add_scalar('meanIU', mean_IU, epoch)

            if mean_IU > best_iou:
                best_iou = mean_IU
                log.log_string('save best model ...')
                torch.save(deeplab.state_dict(),
                           osp.join(args.snapshot_dir, 'model', args.dataset + NAME + 'best_iu' + '.pth'))

        if epoch % 5 == 0:
            log.log_string('save model ...')
            torch.save(deeplab.state_dict(),osp.join(args.snapshot_dir,'model', args.dataset+ NAME + str(epoch)+'.pth'))
コード例 #5
0
    def __init__(self, args):
        cudnn.enabled = True
        self.args = args
        self.mode_name = 'kd_mmseg'
        device = args.device

        ######## skd
        # self.S_device = 'cuda:0'
        # self.T_device = 'cuda:1'
        # student = Res_pspnet(BasicBlock, [2, 2, 2, 2], num_classes = args.classes_num, deep_base=False)
        # load_S_model(args, student, False)
        # # print(student)
        #
        # # print(student.device)
        # print_model_parm_nums(student, 'student_model')
        # self.student = self.DataParallelModelProcess(student, 2, 'train', device=self.S_device)
        # self.student = student
        # # self.student.cuda()
        # # self.student.to('cuda:0')
        # # self.student.train()

        # teacher = Res_pspnet(Bottleneck, [3, 4, 23, 3], num_classes = args.classes_num)
        # load_T_model(teacher, args.T_ckpt_path)
        # print_model_parm_nums(teacher, 'teacher_model')
        # self.teacher = self.DataParallelModelProcess(teacher, 2, 'eval', device=self.T_device)
        # self.teacher = teacher
        # # self.teacher.to('cuda:1')
        # # self.teacher.eval()


        ##########################  mmseg

        self.S_device = 'cuda:0'
        S_config = 'configs/pspnet/pspnet_r18-d8_512x512_40k_cityscapes_1gpu.py'
        S_cfg = Config.fromfile(S_config)
        # print(S_cfg)
        # S_cfg.model.pretrained = args.student_pretrain_model_imgnet
        self.student = build_segmentor(S_cfg.model, train_cfg=S_cfg.train_cfg, test_cfg=S_cfg.test_cfg)
        # self.student = build_segmentor(S_cfg.model, train_cfg=None, test_cfg=None)

        # checkpoint = args.student_pretrain_model_imgnet
        # print(checkpoint)
        # checkpoint = load_checkpoint(self.student, checkpoint)

        # load_S_model(args, self.student, False)
        # self.student = self.DataParallelModelProcess(self.student, 2, 'train', device=self.S_device)
        self.student.train()
        self.student.to(self.S_device)

        # print(self.student)
        # for name, parameters in self.student.named_parameters():
        #     print(name, ':', parameters.size(), parameters.requires_grad)
        #
        # # print(self.student.parameters())
        #
        # # for parameters in self.student.parameters():
        # #     print(parameters)

        if self.args.pi or self.args.pa or self.args.ho:
            self.T_device = 'cuda:1'
            T_config = 'configs/pspnet/pspnet_r101-d8_512x512_80k_cityscapes_1gpu.py'
            T_cfg = Config.fromfile(T_config)
            # print(T_cfg)
            # self.teacher = build_segmentor(T_cfg.model, train_cfg=T_cfg.train_cfg, test_cfg=T_cfg.test_cfg)
            self.teacher = build_segmentor(T_cfg.model, train_cfg=None, test_cfg=None)
            checkpoint = 'work_dirs/models_zoo/pspnet_r101-d8_512x512_80k_cityscapes.2.2_iter_80000.pth'
            checkpoint = load_checkpoint(self.teacher, checkpoint)
            self.teacher = self.DataParallelModelProcess(self.teacher, 2, 'eval', device=self.T_device)

        ####################################################

        D_model = Discriminator(args.preprocess_GAN_mode, args.classes_num, args.batch_size, args.imsize_for_adv,
                                args.adv_conv_dim)
        load_D_model(args, D_model, False)
        print_model_parm_nums(D_model, 'D_model')
        # self.parallel_D = self.DataParallelModelProcess(D_model, 2, 'train', device)
        self.parallel_D = self.DataParallelModelProcess(D_model, 2, 'train', device='cuda:0')
        self.D_model = D_model

        self.G_solver = optim.SGD(
            [{'params': filter(lambda p: p.requires_grad, self.student.parameters()), 'initial_lr': args.lr_g}],
            lr=args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay)
        # self.G_solver  = optim.SGD(self.student.parameters(),
        #                            lr=args.lr_g, momentum=args.momentum, weight_decay=args.weight_decay)
        self.D_solver = optim.SGD(
            [{'params': filter(lambda p: p.requires_grad, D_model.parameters()), 'initial_lr': args.lr_d}],
            lr=args.lr_d, momentum=args.momentum, weight_decay=args.weight_decay)

        self.best_mean_IU = args.best_mean_IU

        self.criterion = self.DataParallelCriterionProcess(CriterionDSN())  # CriterionCrossEntropy()
        self.criterion_ce = self.DataParallelCriterionProcess(CriterionCE())  # CriterionCrossEntropy()
        self.criterion_pixel_wise = self.DataParallelCriterionProcess(CriterionPixelWise())
        # self.criterion_pair_wise_for_interfeat = [self.DataParallelCriterionProcess(CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale[ind], feat_ind=-(ind+1))) for ind in range(len(args.lambda_pa))]
        self.criterion_pair_wise_for_interfeat = self.DataParallelCriterionProcess(
            CriterionPairWiseforWholeFeatAfterPool(scale=args.pool_scale, feat_ind=-5))
        self.criterion_adv = self.DataParallelCriterionProcess(CriterionAdv(args.adv_loss_type))
        if args.adv_loss_type == 'wgan-gp':
            self.criterion_AdditionalGP = self.DataParallelCriterionProcess(
                CriterionAdditionalGP(self.parallel_D, args.lambda_gp))
        self.criterion_adv_for_G = self.DataParallelCriterionProcess(CriterionAdvForG(args.adv_loss_type))

        self.mc_G_loss = 0.0
        self.pi_G_loss = 0.0
        self.pa_G_loss = 0.0
        self.D_loss = 0.0

        self.criterion_AT = self.DataParallelCriterionProcess(AT(p=2))

        cudnn.benchmark = True
        if not os.path.exists(args.snapshot_dir):
            os.makedirs(args.snapshot_dir)

        print('init finish')
コード例 #6
0
def main():
    writer = SummaryWriter(args.snapshot_dir)

    if not args.gpu == 'None':
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    cudnn.enabled = True

    deeplab = Res_Deeplab(num_classes=args.num_classes)
    print(deeplab)

    saved_state_dict = torch.load(args.restore_from)
    new_params = deeplab.state_dict().copy()
    for i in saved_state_dict:
        i_parts = i.split('.')
        if not i_parts[0] == 'fc':
            new_params['.'.join(i_parts[0:])] = saved_state_dict[i]

    deeplab.load_state_dict(new_params)

    model = DataParallelModel(deeplab)
    model.train()
    model.float()
    # model.apply(set_bn_momentum)
    model.cuda()

    criterion = CriterionDSN()  # CriterionCrossEntropy()
    criterion = DataParallelCriterion(criterion)
    criterion.cuda()

    cudnn.benchmark = True

    if not os.path.exists(args.snapshot_dir):
        os.makedirs(args.snapshot_dir)

    trainloader = data.DataLoader(
        ModaDataset(
            args.data_dir,
            args.list_path,
            max_iters=args.num_steps * args.batch_size,
            # mirror=args.random_mirror,
            mirror=True,
            rotate=True,
            mean=IMG_MEAN),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True)

    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad, deeplab.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    optimizer.zero_grad()

    print('start training!')
    for i_iter, batch in enumerate(trainloader):
        i_iter += args.start_iters
        images, labels, _, _ = batch
        images = images.cuda()
        labels = labels.long().cuda()

        optimizer.zero_grad()
        lr = adjust_learning_rate(optimizer, i_iter)
        # preds = model(images, args.recurrence)
        preds = model(images)

        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        if i_iter % 100 == 0:
            writer.add_scalar('learning_rate', lr, i_iter)
            writer.add_scalar('loss', loss.data.cpu().numpy(), i_iter)

        # if i_iter % 5000 == 0:
        #     images_inv = inv_preprocess(images, args.save_num_images, IMG_MEAN)
        #     labels_colors = decode_labels(labels, args.save_num_images, args.num_classes)
        #     if isinstance(preds, list):
        #         preds = preds[0]
        #     preds_colors = decode_predictions(preds, args.save_num_images, args.num_classes)
        #     for index, (img, lab) in enumerate(zip(images_inv, labels_colors)):
        #         writer.add_image('Images/'+str(index), img, i_iter)
        #         writer.add_image('Labels/'+str(index), lab, i_iter)
        #         writer.add_image('preds/'+str(index), preds_colors[index], i_iter)

        print('iter = {} of {} completed, loss = {}'.format(
            i_iter, args.num_steps,
            loss.data.cpu().numpy()))

        if i_iter >= args.num_steps - 1:
            print('save model ...')
            torch.save(
                deeplab.state_dict(),
                osp.join(args.snapshot_dir,
                         'CS_scenes_' + str(args.num_steps) + '.pth'))
            break

        if i_iter % args.save_pred_every == 0:
            print('taking snapshot ...')
            torch.save(
                deeplab.state_dict(),
                osp.join(args.snapshot_dir,
                         'CS_scenes_' + str(i_iter) + '.pth'))

    end = timeit.default_timer()
    print(end - start, 'seconds')