def load_val_dataset():
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    source_simul_transform = simul_transforms.Compose([
        simul_transforms.FreeScale(cfg.VAL.IMG_SIZE)
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    print '='*50
    print 'Prepare Data...'
    val_set = CityScapes('val', list_filename = 'cityscapes_all.txt', simul_transform=source_simul_transform, \
                            transform=img_transform, target_transform=target_transform)
    target_loader = DataLoader(val_set, batch_size=cfg.VAL.IMG_BATCH_SIZE, num_workers=16, shuffle=True)

    return source_loader, target_loader, restore_transform
Exemple #2
0
    def __init__(self,
                 model,
                 loss,
                 resume,
                 config,
                 train_loader,
                 val_loader=None,
                 train_logger=None,
                 prefetch=True):
        super(Trainer, self).__init__(model, loss, resume, config,
                                      train_loader, val_loader, train_logger)

        self.wrt_mode, self.wrt_step = 'train_', 0
        self.log_step = config['trainer'].get(
            'log_per_iter', int(np.sqrt(self.train_loader.batch_size)))
        if config['trainer']['log_per_iter']:
            self.log_step = int(
                self.log_step / self.train_loader.batch_size) + 1

        self.num_classes = self.train_loader.dataset.num_classes

        self.batch_stride = config['train_loader']['args']['batch_stride']

        # TRANSORMS FOR VISUALIZATION
        self.restore_transform = transforms.Compose([
            local_transforms.DeNormalize(self.train_loader.MEAN,
                                         self.train_loader.STD),
            transforms.ToPILImage()
        ])
        self.viz_transform = transforms.Compose(
            [transforms.Resize((400, 400)),
             transforms.ToTensor()])

        torch.backends.cudnn.benchmark = True
Exemple #3
0
def __main__(args):
    #initializing pretrained network
    pspnet = PSPNet(n_classes=cityscapes.num_classes).cuda(gpu0)
    pspnet.load_pretrained_model(model_path=pspnet_path)
    #transformation and loading dataset
    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = standard_transforms.Compose(
        [extended_transforms.MaskToTensor()])

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])

    visualize = standard_transforms.ToTensor()
    val_set = cityscapes.CityScapes('val',
                                    transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False)
    validate(pspnet, val_loader, cityscapes.num_classes, args,
             restore_transform, visualize)
Exemple #4
0
def loading_data(root, mode, batch_size=1):
    mean_std = ([0.5, 0.5, 0.5], [0.25, 0.25, 0.25])
    log_para = 1
    if mode == 'train':
        main_transform = own_transforms.Compose(
            [own_transforms.RandomHorizontallyFlip()])
    else:
        main_transform = None

    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    gt_transform = standard_transforms.Compose(
        [own_transforms.LabelNormalize(log_para)])

    restore_transform = standard_transforms.Compose([
        own_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    dataset = HeadCountDataset(root, mode, main_transform, img_transform,
                               gt_transform)

    if mode == 'train':
        dataloader = DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=True,
                                drop_last=True)
    else:
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    return dataloader, restore_transform
Exemple #5
0
def main():
    net = AFENet(classes=19, pretrained_model_path=None).cuda()
    net.load_state_dict(
        torch.load(os.path.join(args['model_save_path'], args['snapshot'])))

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])

    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    dataset_path = args['dataset_path']

    test_set = cityscapes.CityScapes(dataset_path,
                                     'fine',
                                     'test',
                                     transform=input_transform,
                                     target_transform=target_transform,
                                     val_scale=args['scale'])
    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=1,
                             shuffle=False)
    test(test_loader, net, input_transform, restore_transform, args['scale'])
def load_dataset():
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if cfg.TRAIN.DATA_AUG:
        source_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
            simul_transforms.PhotometricDistort()
        ])
        target_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
            simul_transforms.PhotometricDistort()
        ])
    else:
        source_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),

        ])
        target_simul_transform = simul_transforms.Compose([
            simul_transforms.FreeScale(cfg.TRAIN.IMG_SIZE),
            simul_transforms.RandomHorizontallyFlip(),
        ])\

    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(cfg.DATA.IGNORE_LABEL, cfg.DATA.NUM_CLASSES - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    print '='*50
    print 'Prepare Data...'
    source_set = []
    if cfg.TRAIN.SOURCE_DOMAIN=='GTA5':
        source_set = GTA5('train', list_filename = 'GTA5_'+ cfg.DATA.SSD_GT + '.txt', simul_transform=source_simul_transform, transform=img_transform,
                           target_transform=target_transform)
    elif cfg.TRAIN.SOURCE_DOMAIN=='SYN':
    	source_set = SYN('train', list_filename = 'SYN_'+ cfg.DATA.SSD_GT + '.txt', simul_transform=source_simul_transform, transform=img_transform,
                           target_transform=target_transform)

    source_loader = DataLoader(source_set, batch_size=cfg.TRAIN.IMG_BATCH_SIZE, num_workers=16, shuffle=True, drop_last=True)
    
    target_set = CityScapes('train', list_filename = 'cityscapes_'+ cfg.DATA.SSD_GT + '.txt',simul_transform=target_simul_transform, transform=img_transform,
                         target_transform=target_transform)
    target_loader = DataLoader(target_set, batch_size=cfg.TRAIN.IMG_BATCH_SIZE, num_workers=16, shuffle=True, drop_last=True)

    return source_loader, target_loader, restore_transform
Exemple #7
0
    def __init__(self,
                 model,
                 loss,
                 resume,
                 config,
                 train_loader,
                 val_loader=None,
                 train_logger=None,
                 prefetch=True):
        """ Trainer 类
        __init__:
            1、TRANSORMS FOR VISUALIZATION
            2、预读取

        _train_epoch:

        """
        super(Trainer, self).__init__(model, loss, resume, config,
                                      train_loader, val_loader, train_logger)

        self.wrt_mode, self.wrt_step = 'train_', 0
        self.log_step = config['trainer'].get(
            'log_per_iter', int(np.sqrt(self.train_loader.batch_size)))
        if config['trainer']['log_per_iter']:
            self.log_step = int(
                self.log_step / self.train_loader.batch_size) + 1

        self.num_classes = self.train_loader.dataset.num_classes

        # TRANSORMS FOR VISUALIZATION
        self.restore_transform = transforms.Compose([
            local_transforms.DeNormalize(self.train_loader.MEAN,
                                         self.train_loader.STD),
            transforms.ToPILImage()
        ])
        self.viz_transform = transforms.Compose(
            [transforms.Resize((400, 400)),
             transforms.ToTensor()])

        # 预读取
        if self.device == torch.device('cpu'):
            prefetch = False
        if prefetch:
            self.train_loader = DataPrefetcher(train_loader,
                                               device=self.device)
            self.val_loader = DataPrefetcher(val_loader, device=self.device)

        torch.backends.cudnn.benchmark = True
Exemple #8
0
def main():
    batch_size = 8

    net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
    snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        expanded_transform.FreeScale((512, 1024)),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    restore = transforms.Compose([
        expanded_transform.DeNormalize(*mean_std),
        transforms.ToPILImage()
    ])

    lsun_path = '/home/b3-542/LSUN'

    dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)

    if not os.path.exists(test_results_path):
        os.mkdir(test_results_path)

    for vi, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        outputs = net(inputs)

        prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()

        for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
            pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
            print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
Exemple #9
0
def main(train_args):
    # weight init
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            torch.nn.init.normal(m.weight.data, mean=0, std=0.01)
            torch.nn.init.constant(m.bias.data, 0)

    net = VGG(num_classes=VOC.num_classes)
    net.apply(weights_init)
    net_dict = net.state_dict()
    pretrain = torch.load('./vgg16_20M.pkl')

    pretrain_dict = pretrain.state_dict()
    pretrain_dict = {
        'features.' + k: v
        for k, v in pretrain_dict.items() if 'features.' + k in net_dict
    }

    net_dict.update(pretrain_dict)
    net.load_state_dict(net_dict)

    net = nn.DataParallel(net)
    net = net.cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.408, 0.457, 0.481], [1, 1, 1])

    joint_transform_train = joint_transforms.Compose(
        [joint_transforms.RandomCrop((321, 321))])

    joint_transform_test = joint_transforms.Compose(
        [joint_transforms.RandomCrop((512, 512))])

    input_transform = standard_transforms.Compose([
        #standard_transforms.Resize((321,321)),
        #standard_transforms.RandomCrop(224),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        #standard_transforms.Resize((224,224)),
        extended_transforms.MaskToTensor()
    ])
    #target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = VOC.VOC('train',
                        joint_transform=joint_transform_train,
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=20,
                              num_workers=4,
                              shuffle=True)
    val_set = VOC.VOC('val',
                      joint_transform=joint_transform_test,
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=VOC.ignore_label).cuda()

    #optimizer = optim.SGD(net.parameters(), lr = train_args['lr'], momentum=0.9,weight_decay=train_args['weight_decay'])
    optimizer = optim.SGD(
        [{
            'params': [
                param for name, param in net.named_parameters()
                if name[-4:] == 'bias'
            ],
            'lr':
            2 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            0
        }, {
            'params': [
                param for name, param in net.named_parameters()
                if name[-4:] != 'bias'
            ],
            'lr':
            train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            train_args['weight_decay']
        }], {
            'params': [
                param for name, param in net.named_parameters()
                if name[-8:] == 'voc.bias'
            ],
            'lr':
            20 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            0
        }, {
            'params': [
                param for name, param in net.named_parameters()
                if name[-10:] != 'voc.weight'
            ],
            'lr':
            10 * train_args['lr'],
            'momentum':
            train_args['momentum'],
            'weight_decay':
            train_args['weight_decay']
        })

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    #scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    scheduler = StepLR(optimizer, step_size=13, gamma=0.1)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        #scheduler.step(val_loss)
        scheduler.step()
Exemple #10
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    optimizer = optim.SGD(student_net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255,
                                              size_average=False)

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')

        miu = test_miou(student_net, tgt_val_loader, upsample,
                        './dataset/info.json')
        if miu > highest:
            torch.save(student_net.module.state_dict(),
                       osp.join(args.snapshot_dir, f'final_fcn.pth'))
            highest = miu
            print('>' * 50 + f'save highest with {miu:.2%}')
Exemple #11
0
def main(train_args):
    net = FCN8s(num_classes=voc.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=1,
                              num_workers=4,
                              shuffle=True)
    val_set = voc.VOC('val',
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                           betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=train_args['lr_patience'],
                                  min_lr=1e-10,
                                  verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        scheduler.step(val_loss)
Exemple #12
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    if args.bn_sync:
        print('Using Sync BN')
        deeplabv3.BatchNorm2d = partial(InPlaceABNSync, activation='none')
    net = get_deeplabV3(args.num_classes, args.model_path_prefix)
    if not args.bn_sync:
        net.freeze_bn()
    net = torch.nn.DataParallel(net)

    net = net.cuda()

    mean_std = ([104.00698793, 116.66876762, 122.67891434], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    # freeze bn
    for module in net.module.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = False
    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             net.module.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, size_average=True
    # )
    criterion = CriterionDSN(ignore_index=255,
                             # size_average=False
                             )

    num_batches = len(src_loader)
    max_iter = args.iterations
    i_iter = 0
    highest_miu = 0

    while True:

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            i_iter += 1
            lr = adjust_learning_rate(args, optimizer, i_iter, max_iter)
            net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = criterion(src_output, src_label)

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            src_output = torch.nn.functional.upsample(input=src_output[0],
                                                      size=(h, w),
                                                      mode='bilinear',
                                                      align_corners=True)

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if i_iter % args.print_freq == 0:
                print(
                    # f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Iter: [{i_iter}/{max_iter}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')
            if i_iter % args.eval_freq == 0:
                miu = test_miou(net, tgt_val_loader, upsample,
                                './dataset/info.json')
                if miu > highest_miu:
                    torch.save(
                        net.module.state_dict(),
                        osp.join(args.snapshot_dir,
                                 f'{i_iter:d}_{miu*1000:.0f}.pth'))
                    highest_miu = miu
                print(f'>>>>>>>>>Learning Rate {lr}<<<<<<<<<')
            if i_iter == max_iter:
                return
Exemple #13
0
def main():
    # epoch = 100
    # info = "ATONet_final3_loss3_5_BN_batch=4_use_ohem=0_bins=8_4_2epoch=100"
    # snapshot = "epoch_98_loss_0.12540_acc_0.95847_acc-cls_0.78683_mean-iu_0.70210_fwavacc_0.92424_lr_0.0000453781.pth"
    # epoch=200
    info = "ATONet_final3_loss3_5_BN_batch=4_use_ohem=False_bins=8_4_2epoch=200"
    snapshot = "epoch_193_loss_0.11953_acc_0.96058_acc-cls_0.79683_mean-iu_0.71272_fwavacc_0.92781_lr_0.0000798490.pth"

    model_save_path = './save_models/cityscapes/{}'.format(info)
    print(model_save_path)

    net = ATONet(classes=19, bins=(8, 4, 2), use_ohem=False).cuda()

    net.load_state_dict(torch.load(os.path.join(model_save_path, snapshot)))

    net.eval()
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    #
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()

    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    root = '/titan_data1/caokuntao/data/cityscapes'

    test_set = cityscapes.CityScapes(root,
                                     'fine',
                                     'test',
                                     transform=input_transform,
                                     target_transform=target_transform)
    test_loader = DataLoader(test_set,
                             batch_size=args['test_batch_size'],
                             num_workers=4,
                             shuffle=False)

    if not os.path.exists(model_save_path):
        os.mkdir(model_save_path)

    trainid_to_id = {
        0: 7,
        1: 8,
        2: 11,
        3: 12,
        4: 13,
        5: 17,
        6: 19,
        7: 20,
        8: 21,
        9: 22,
        10: 23,
        11: 24,
        12: 25,
        13: 26,
        14: 27,
        15: 28,
        16: 31,
        17: 32,
        18: 33
    }

    net.eval()

    gts_all, predictions_all, img_name_all = [], [], []
    with torch.no_grad():
        for vi, data in enumerate(test_loader):
            inputs, img_name = data
            N = inputs.size(0)
            inputs = Variable(inputs).cuda()

            outputs = net(inputs)
            predictions = outputs.data.max(1)[1].squeeze_(1).cpu().numpy()

            predictions_all.append(predictions)
            img_name_all.append(img_name)

        print('done')
        predictions_all = np.concatenate(predictions_all)
        img_name_all = np.concatenate(img_name_all)

        to_save_dir = os.path.join(model_save_path, exp_file)
        if not os.path.exists(to_save_dir):
            os.mkdir(to_save_dir)

        for idx, data in enumerate(zip(img_name_all, predictions_all)):
            if data[0] is None:
                continue
            img_name = data[0]
            pred = data[1]
            pred_copy = pred.copy()
            for k, v in trainid_to_id.items():
                pred_copy[pred == k] = v
            pred = Image.fromarray(pred_copy.astype(np.uint8))
            pred.save(os.path.join(to_save_dir, img_name))
def main(train_args):
    import pdb
    pdb.set_trace()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    net = FCN8s(num_classes=plant.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    net.train()

    mean_std = ([0.385, 0.431, 0.452], [0.289, 0.294, 0.285])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(500),
        standard_transforms.CenterCrop(500),
        standard_transforms.ToTensor()
    ])

    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])
    train_set = plant.Plant('train',
                            augmentations=data_aug,
                            transform=input_transform,
                            target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=4,
                              num_workers=2,
                              shuffle=True)
    val_set = plant.Plant('val',
                          transform=input_transform,
                          target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=2,
                            shuffle=False)

    weights = torch.FloatTensor(cfg.train_weights)
    criterion = CrossEntropyLoss2d(weight=weights,
                                   size_average=False,
                                   ignore_index=plant.ignore_label).cuda()

    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                           betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(train_args) + '\n\n')

    #train_args['best_record']['mean_iu'] = 0.50
    #curr_epoch = 100
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=train_args['lr_patience'],
                                  min_lr=1e-10,
                                  verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize)
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        scheduler.step(val_loss)
Exemple #15
0
def main(train_args):
    net = PSPNet(num_classes=voc.num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        ToTensor(),
        Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    joint_transform = joint_transforms.Compose([
        joint_transforms.CenterCrop(224),
        # joint_transforms.Scale(2),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    target_transform = standard_transforms.Compose([
        extended_transforms.MaskToTensor(),
    ])
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])
    val_input_transform = standard_transforms.Compose([
        CenterCrop(224),
        ToTensor(),
        Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    val_target_transform = standard_transforms.Compose([
        CenterCrop(224),
        extended_transforms.MaskToTensor(),
    ])
    train_set = voc.VOC('train',
                        transform=input_transform,
                        target_transform=target_transform,
                        joint_transform=joint_transform)
    train_loader = DataLoader(train_set,
                              batch_size=4,
                              num_workers=4,
                              shuffle=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      target_transform=val_target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=4,
                            num_workers=4,
                            shuffle=False)

    # criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=voc.ignore_label).cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=train_args['lr'],
                          momentum=train_args['momentum'],
                          weight_decay=train_args['weight_decay'])

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    # open(os.path.join(ckpt_path, exp_name, 'loss_001_aux_SGD_momentum_95_random_lr_001.txt'), 'w').write(str(train_args) + '\n\n')

    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        # adjust_learning_rate(optimizer,epoch,net,train_args)
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        validate(val_loader, net, criterion, optimizer, epoch, train_args,
                 restore_transform, visualize)
        adjust_learning_rate(optimizer, epoch, net, train_args)
Exemple #16
0
def main():
    net = FCN8ResNet(num_classes=num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' in name
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' in name
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and 'fconv' not in name
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and 'fconv' not in name
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name,
                             'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
if not os.path.exists(exp_name+'/pred'):
    os.mkdir(exp_name+'/pred')

if not os.path.exists(exp_name+'/gt'):
    os.mkdir(exp_name+'/gt')

if not os.path.exists(exp_name+'/mask'):
    os.mkdir(exp_name+'/mask')

mean_std = ([0.452016860247, 0.447249650955, 0.431981861591],[0.23242045939, 0.224925786257, 0.221840232611])
img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
restore = standard_transforms.Compose([
        own_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
pil_to_tensor = standard_transforms.ToTensor()

dataRoot = os.path.join(OUTPUT_DIR,'full',sys.argv[2])

VESICLE_PATH = os.path.join(MODEL_DIR,sys.argv[1])
model_path = VESICLE_PATH

def main():
#     file_list = [filename for root,dirs,filename in os.walk(dataRoot+'/img/')]      
    file_list = [filename for root,dirs,filename in os.walk(dataRoot)]                                     

    test(file_list[0], model_path)
   
Exemple #18
0
def main():
    import pdb
    pdb.set_trace()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    net = PSPNet(num_classes=plant.num_classes).cuda()

    if len(args['snapshot']) == 0:
        #net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    #mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    mean_std = ([0.385, 0.431, 0.452], [0.289, 0.294, 0.285])
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                plant.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_joint_transform = joint_transforms.Compose(
        [joint_transforms.Scale(args['shorter_size'])])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(args['val_img_display_size']),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = plant.Plant('train',
                            joint_transform=train_joint_transform,
                            sliding_crop=sliding_crop,
                            transform=train_input_transform,
                            target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=2,
                              shuffle=True)
    val_set = plant.Plant('val',
                          transform=val_input_transform,
                          target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=2,
                            shuffle=False)

    weights = torch.FloatTensor(cfg.train_weights)
    criterion = CrossEntropyLoss2d(weight=weights,
                                   size_average=True,
                                   ignore_index=plant.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    args['max_iter'] = args['max_epoch'] * len(train_loader)
    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          restore_transform, val_loader, visualize)
def main():
    # args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # # if args.seed:
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(args.seed)
    seed = 63
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if args.gpu:
    torch.cuda.manual_seed_all(seed)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # train_transforms = transforms.Compose([
    # 	transforms.RandomCrop(args['crop_size']),
    # 	transforms.RandomRotation(90),
    # 	transforms.RandomHorizontalFlip(p=0.5),
    # 	transforms.RandomVerticalFlip(p=0.5),

    # 	])
    short_size = int(min(args['input_size']) / 0.875)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(short_size, interpolation=Image.NEAREST),
    # 	# joint_transforms.Scale(short_size),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['crop_size']),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomRotate(90)
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])
    visualize = transforms.ToTensor()

    train_set = cityscapes.CityScapes('train',
                                      joint_transform=train_joint_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    # train_set = cityscapes.CityScapes('train', transform=train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    # val_set = cityscapes.CityScapes('val', transform=val_transforms)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))

    # sdf

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    # net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Exemple #20
0
def main():
    net = FCN32VGG(num_classes=mapillary.num_classes).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = mapillary.Mapillary('semantic',
                                    'training',
                                    joint_transform=train_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True)
    val_set = mapillary.Mapillary('semantic',
                                  'validation',
                                  joint_transform=val_joint_transform,
                                  transform=input_transform,
                                  target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False,
                            pin_memory=True)

    criterion = CrossEntropyLoss2d(size_average=False).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()).replace(':', '-') + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=args['lr_patience'],
                                  min_lr=1e-10)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args,
                            restore_transform, visualize)
        scheduler.step(val_loss)

    torch.save(net.state_dict(), PATH)
def main():

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    print('load model ' + args['snapshot'])

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)
    net.load_state_dict(
        torch.load(os.path.join(ckpt_path, args['exp_name'],
                                args['snapshot'])))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    short_size = int(min(args['input_size']) / 0.875)
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    test_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])

    # test_set = cityscapes.CityScapes('test', transform=test_transform)

    test_set = cityscapes.CityScapes('test',
                                     joint_transform=val_joint_transform,
                                     transform=test_transform,
                                     target_transform=target_transform)

    test_loader = DataLoader(test_set,
                             batch_size=1,
                             num_workers=8,
                             shuffle=False)

    transform = transforms.ToPILImage()

    check_mkdir(os.path.join(ckpt_path, args['exp_name'], 'test'))

    gts_all, predictions_all = [], []
    count = 0
    for vi, data in enumerate(test_loader):
        # img_name, img = data
        img_name, img, gts = data

        img_name = img_name[0]
        # print(img_name)
        img_name = img_name.split('/')[-1]
        # img.save(os.path.join(ckpt_path, args['exp_name'], 'test', img_name))

        img_transform = restore_transform(img[0])
        # img_transform = img_transform.convert('RGB')
        img_transform.save(
            os.path.join(ckpt_path, args['exp_name'], 'test', img_name))
        img_name = img_name.split('_leftImg8bit.png')[0]

        # img = Variable(img, volatile=True).cuda()
        img, gts = img.to(device), gts.to(device)
        output = net(img)

        prediction = output.data.max(1)[1].squeeze_(1).squeeze_(
            0).cpu().numpy()
        prediction_img = cityscapes.colorize_mask(prediction)
        # print(type(prediction_img))
        prediction_img.save(
            os.path.join(ckpt_path, args['exp_name'], 'test',
                         img_name + '.png'))
        # print(ckpt_path, args['exp_name'], 'test', img_name + '.png')

        print('%d / %d' % (vi + 1, len(test_loader)))
        gts_all.append(gts.data.cpu().numpy())
        predictions_all.append(prediction)
        # break

        # if count == 1:
        #     break
        # count += 1
    gts_all = np.concatenate(gts_all)
    predictions_all = np.concatenate(prediction)
    acc, acc_cls, mean_iou, _ = evaluate(predictions_all, gts_all,
                                         cityscapes.num_classes)

    print(
        '-----------------------------------------------------------------------------------------------------------'
    )
    print('[acc %.5f], [acc_cls %.5f], [mean_iu %.5f]' %
          (acc, acc_cls, mean_iu))
def eval_net(i_epoch, i_iter, i_tb, writer, ext_model):
    ext_model.eval()
    # processed_val_img_path = os.path.join(processed_val_path, 'img')
    processed_val_img_path = os.path.join(processed_val_path, 'img')
    processed_val_mask_path = os.path.join(processed_val_path, 'mask')
    valSet = []
    for img_name in [
            img_name.split('leftImg8bit.png')[0]
            for img_name in os.listdir(processed_val_img_path)
    ]:
        item = (processed_val_img_path + '/' + img_name + 'leftImg8bit.png',
                processed_val_mask_path + '/' + img_name +
                'gtFine_labelIds.png')
        valSet.append(item)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    print '=' * 50
    print 'Start validating...'
    all_pred = np.zeros((0, cfg.VAL.IMG_SIZE[0], cfg.VAL.IMG_SIZE[1]))
    all_labels = np.zeros((0, cfg.VAL.IMG_SIZE[0], cfg.VAL.IMG_SIZE[1]))

    i_img = 0
    all_pred_list = []
    all_labels_list = []
    _t = {'iter': Timer()}
    for val_data in valSet:

        img_path, mask_path = val_data
        img = Image.open(img_path)
        img = img.resize(cfg.VAL.IMG_SIZE, Image.NEAREST)
        img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ])
        img = img_transform(img)

        labels = Image.open(mask_path)
        labels = labels.resize(cfg.VAL.IMG_SIZE)
        label_transform = standard_transforms.Compose([
            expanded_transforms.MaskToTensor(),
            expanded_transforms.ChangeLabel(ignored_label,
                                            cfg.DATA.NUM_CLASSES - 1)
        ])
        labels = label_transform(labels)
        labels = labels[None, :, :]
        img = Variable(img[None, :, :, :], volatile=True).cuda()

        _t['iter'].tic()
        # forward ext model
        # pred_val_outputs = forward_ext_model(ext_tgt_inputs = img)
        pred_val_outputs = ext_model(img, train_flag=False)
        _t['iter'].toc(average=True)

        if i_img % 50 == 0:
            print 'i_img: {:d}, net_forward: {:.3f}s'.format(
                i_img, _t['iter'].average_time)

        pred_map = pred_val_outputs.data.cpu().max(1)[1].squeeze_(1).numpy()

        all_pred_list.append(pred_map.tolist())
        all_labels_list.append(labels.numpy().tolist())

        i_img = i_img + 1

    all_pred_np = np.array(all_pred_list)
    all_labels_np = np.array(all_labels_list)

    all_pred = all_pred_np.reshape(
        (-1, all_pred_np.shape[2], all_pred_np.shape[3]))
    all_labels = all_labels_np.reshape(
        (-1, all_labels_np.shape[2], all_labels_np.shape[3]))

    tgt_m_iu, tgt_class_iu = calculate_mean_iu_test(all_pred, all_labels)
    # pdb.set_trace()
    batch_size = cfg.TRAIN.IMG_BATCH_SIZE
    writer.add_scalar('meanIU_tgt_val', tgt_m_iu,
                      i_tb * batch_size * cfg.TRAIN.PRINT_FREQ)
    print tgt_m_iu
    print tgt_class_iu
    ext_model.train()
    print '=' * 50
    print 'COntinue training...'
Exemple #23
0
def main(network,
         train_batch_size=4,
         val_batch_size=4,
         epoch_num=50,
         lr=2e-2,
         weight_decay=1e-4,
         momentum=0.9,
         factor=10,
         val_scale=None,
         model_save_path='./save_models/cityscapes',
         data_type='Cityscapes',
         snapshot='',
         accumulation_steps=1):

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    # Loading dataset
    if data_type == 'Cityscapes':
        # dataset_path = '/home/caokuntao/data/cityscapes'
        # dataset_path = '/titan_data2/ckt/datasets/cityscapes'  # 23341
        dataset_path = '/titan_data1/caokuntao/data/cityscapes'  # 新的23341
        train_set = cityscapes.CityScapes(dataset_path,
                                          'fine',
                                          'train',
                                          transform=input_transform,
                                          target_transform=target_transform)
        val_set = cityscapes.CityScapes(dataset_path,
                                        'fine',
                                        'val',
                                        val_scale=val_scale,
                                        transform=input_transform,
                                        target_transform=target_transform)
    else:
        dataset_path = '/home/caokuntao/data/camvid'
        # dataset_path = '/titan_data1/caokuntao/data/camvid'  # 新的23341
        train_set = camvid.Camvid(dataset_path,
                                  'train',
                                  transform=input_transform,
                                  target_transform=target_transform)
        val_set = camvid.Camvid(dataset_path,
                                'test',
                                val_scale=val_scale,
                                transform=input_transform,
                                target_transform=target_transform)

    train_loader = DataLoader(train_set,
                              batch_size=train_batch_size,
                              num_workers=train_batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_set,
                            batch_size=val_batch_size,
                            num_workers=val_batch_size,
                            shuffle=False)

    if len(snapshot) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        logger.info('training resumes from ' + snapshot)
        network.load_state_dict(
            torch.load(os.path.join(model_save_path, snapshot)))
        split_snapshot = snapshot.split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    criterion = torch.nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    paras = dict(network.named_parameters())
    paras_new = []
    for k, v in paras.items():
        if 'layer' in k and ('conv' in k or 'downsample.0' in k):
            paras_new += [{
                'params': [v],
                'lr': lr / factor,
                'weight_decay': weight_decay / factor
            }]
        else:
            paras_new += [{
                'params': [v],
                'lr': lr,
                'weight_decay': weight_decay
            }]

    optimizer = torch.optim.SGD(paras_new, momentum=momentum)
    lr_sheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                             epoch_num,
                                                             eta_min=1e-6)

    # if len(snapshot) > 0:
    #     optimizer.load_state_dict(torch.load(os.path.join(model_save_path, 'opt_' + snapshot)))

    check_makedirs(model_save_path)

    all_iter = epoch_num * len(train_loader)

    #
    # validate(val_loader, network, criterion, optimizer, curr_epoch, restore_transform, model_save_path)
    # return

    for epoch in range(curr_epoch, epoch_num + 1):
        train(train_loader, network, optimizer, epoch, all_iter,
              accumulation_steps)
        validate(val_loader, network, criterion, optimizer, epoch,
                 restore_transform, model_save_path)
        lr_sheduler.step()

    # 1024 x 2048
    # dataset_path = '/titan_data1/caokuntao/data/cityscapes'  # 新的23341
    # val_set = cityscapes.CityScapes(dataset_path, 'fine', 'val', val_scale=True, transform=input_transform,
    #                                 target_transform=target_transform)
    # val_loader = DataLoader(val_set, batch_size=val_batch_size, num_workers=val_batch_size, shuffle=False)
    # validate(val_loader, network, criterion, optimizer, epoch, restore_transform, model_save_path)

    return
    # # cityscapes
    # val_set = val_cityscapes(dataset_path, 'fine', 'val')
    # val_loader = DataLoader(val_set, batch_size=1, num_workers=4, shuffle=False)
    n = len(val_loader)
    device = torch.device('cuda')
    net.eval()
    with torch.no_grad():
        # torch.cuda.synchronize()
        time_all = 0
        for vi, inputs in enumerate(val_loader):
            inputs = inputs[0].to(device)
            t0 = 1000 * time.time()
            outputs = net(inputs)
            torch.cuda.synchronize()
            t1 = 1000 * time.time()
            time_all = time_all + t1 - t0
            # predictions = outputs.data.max(1)[1].squeeze_(1).cpu()
            # torch.cuda.synchronize()

        fps = (1000 * n) / time_all
        # 每秒多少张
        print(fps)
def main(train_args):
    backbone = ResNet()
    backbone.load_state_dict(torch.load(
        './weight/resnet34-333f7ec4.pth'), strict=False)
    net = Decoder34(num_classes=13, backbone=backbone).cuda()
    D = discriminator(input_channels=16).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(
            ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()
    D.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = wp.Wp('train', transform=input_transform,
                      target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=4,
                              num_workers=4, shuffle=True)
    # val_set = wp.Wp('val', transform=input_transform,
    #                 target_transform=target_transform)
    # XR:所以这里本来就不能用到val?这里为什么不用一个val的数据集呢?
    val_loader = DataLoader(train_set, batch_size=1,
                            num_workers=4, shuffle=False)
    criterion = DiceLoss().cuda()
    criterion_D = nn.BCELoss().cuda()
    optimizer_AE = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))
    optimizer_D = optim.Adam([
        {'params': [param for name, param in D.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in D.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer_AE.load_state_dict(torch.load(os.path.join(
            ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer_AE.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer_AE.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) +
                      '.txt'), 'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(
        optimizer_AE, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, D, criterion, criterion_D, optimizer_AE,
              optimizer_D, epoch, train_args)
        val_loss = validate(val_loader, net, criterion, optimizer_AE,
                            epoch, train_args, restore_transform, visualize)
        scheduler.step(val_loss)
def main(train_args):
    net = PSPNet(num_classes=cityscapes.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse-extra)-psp_net', 'xx.pth')))
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('coarse', 'train', simul_transform=train_simul_transform,
                                      transform=train_input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('coarse', 'val', simul_transform=val_simul_transform, transform=val_input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=train_args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
Exemple #26
0
def main():
    net = AFENet(classes=19,
                 pretrained_model_path=args['pretrained_model_path']).cuda()
    net_ori = [net.layer0, net.layer1, net.layer2, net.layer3, net.layer4]
    net_new = [
        net.ppm, net.cls, net.aux, net.ppm_reduce, net.aff1, net.aff2, net.aff3
    ]

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    dataset_path = args['dataset_path']

    # Loading dataset
    train_set = cityscapes.CityScapes(dataset_path,
                                      'fine',
                                      'train',
                                      transform=input_transform,
                                      target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=2,
                              shuffle=True)
    val_set = cityscapes.CityScapes(dataset_path,
                                    'fine',
                                    'val',
                                    transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=2,
                            shuffle=False)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(args['model_save_path'],
                                    args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    params_list = []
    for module in net_ori:
        params_list.append(dict(params=module.parameters(), lr=args['lr']))
    for module in net_new:
        params_list.append(dict(params=module.parameters(),
                                lr=args['lr'] * 10))
    args['index_split'] = 5

    criterion = torch.nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = torch.optim.SGD(params_list,
                                lr=args['lr'],
                                momentum=args['momentum'],
                                weight_decay=args['weight_decay'])
    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(args['model_save_path'],
                             'opt_' + args['snapshot'])))

    check_makedirs(args['model_save_path'])

    all_iter = args['epoch_num'] * len(train_loader)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, optimizer, epoch, all_iter)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
def main(test_args):
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    with torch.no_grad():
        net = FCN8s(num_classes=plant.num_classes).cuda()

        print('loading model from ' + test_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, test_args['exp_name'], test_args['snapshot'])))
        net.eval()

        mean_std = ([0.385, 0.431, 0.452], [0.289, 0.294, 0.285])

        input_transform = standard_transforms.Compose([
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*mean_std)
        ])

        restore_transform = standard_transforms.Compose([
            extended_transforms.DeNormalize(*mean_std),
            standard_transforms.ToPILImage(),
        ])

        test_set    = plant.Plant('test', transform=input_transform)
        test_loader = DataLoader(test_set, batch_size=1, num_workers=1, shuffle=False)
        # with plant.gpu_mem_restore_ctx():
        for vi, data in enumerate(test_loader):
            img, img_info = data
            save_subpath  = img_info['sub_path']
            img_name      = img_info['img_name'][0]

            save_dir = os.path.join(cfg.save_dir, save_subpath[0])
            check_mkdir(save_dir)

            img      = Variable(img).cuda()
            output   = net(img)
            prediction   = output.data.max(1)[1].squeeze_(1).squeeze_(0).cpu().numpy()

            # do analysis and write to file
            rgbI      = plant.readRGBImage(save_subpath[0], img_name)
            tmp = prediction + 0

            if 'stem' in cfg.exp_name: # stem binary classification process
                prediction = tmp.astype(np.uint8)
                labelI = smeasure.label(cv2.morphologyEx(prediction, cv2.MORPH_OPEN, np.ones([7,7])))
                props = smeasure.regionprops(labelI)
                if len(props) > 1:
                    for prop in props:
                        cir_to_sqr = 4 * prop.area/((prop.bbox[2]-prop.bbox[0])*(prop.bbox[3]-prop.bbox[1]))
                        if cir_to_sqr < np.pi*0.98 or cir_to_sqr > np.pi*1.02:
                            labelI[labelI==prop.label] = 0
                    if np.sum(labelI) > 0:
                        prediction[labelI==0] = 0

                    if False and len(props) > 1: # get rid of not good stem prediction.
                        cir_to_sqr = 4 * props[0].area/((props[0].bbox[2]-props[0].bbox[0])*(props[0].bbox[3]-props[0].bbox[1]))
                        if cir_to_sqr < np.pi*0.98 or cir_to_sqr > np.pi*1.02:
                            continue
                # save prediction result
                save_img = Image.new('RGB', (2*rgbI.shape[1], rgbI.shape[0]))
                save_img.paste(Image.fromarray(rgbI), (0,0))
                pred_pil = plant.colorize_mask(prediction)
                save_img.paste(pred_pil, (rgbI.shape[1], 0))
                save_img.save(os.path.join(save_dir, img_name+'_cmp.jpg'))

                save_labelI = Image.new('P',(rgbI.shape[1], rgbI.shape[0]))
                save_labelI.putpalette(plant.palette)
                save_labelI.paste(Image.fromarray(prediction), (0,0))
                save_labelI.save(os.path.join(save_dir, img_name+'.png'))
            else:
                # get the crop predictions into the original size:
                prediction = np.zeros((img_info['ori_size'][1], img_info['ori_size'][0]), dtype=np.uint8)
                x0,y0,x1,y1 = img_info['crop_box']
                prediction[y0:y1, x0:x1] = tmp.astype(np.uint8)

                # save prediction result as image
                pred_pil   = plant.colorize_mask(prediction)
                save_img   = Image.new('RGB', (2*rgbI.shape[1], rgbI.shape[0]))
                save_img.paste(Image.fromarray(rgbI), (0,0))
                save_img.paste(pred_pil, (rgbI.shape[1], 0))
                save_img.save(os.path.join(save_dir, img_name + '_cmp.png'))

                save_labelI = Image.new('P',(rgbI.shape[1], rgbI.shape[0]))
                save_labelI.putpalette(plant.palette)
                save_labelI.paste(Image.fromarray(prediction), (0,0))
                save_labelI.save(os.path.join(save_dir, img_name+'.png'))

            if vi%1 == 0:
                print('%d / %d, %s' % (vi + 1, len(test_loader), img_name))
Exemple #28
0
short_size = int(min(args['input_size']) / 0.875)

joint_transform = joint_transforms.Compose([
    joint_transforms.Scale(short_size),
    joint_transforms.RandomCrop(args['input_size']),
    joint_transforms.RandomHorizontallyFlip()
])

input_transform = standard_transforms.Compose(
    [standard_transforms.ToTensor(),
     standard_transforms.Normalize(*mean_std)])

target_transform = extended_transforms.MaskToTensor()

restore_transform = standard_transforms.Compose([
    extended_transforms.DeNormalize(*mean_std),
    standard_transforms.ToPILImage()
])

visualize = standard_transforms.ToTensor()

## Loading the test dataset
test_set = cityscapes.CityScapes('fine',
                                 'test',
                                 joint_transform=joint_transform,
                                 transform=input_transform,
                                 target_transform=target_transform)
test_loader = DataLoader(train_set,
                         batch_size=args['test_batch_size'],
                         shuffle=False)
def main():
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.999))

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
def main(train_args):
    if cuda.is_available():
        net = fcn8s.FCN8s(num_classes=voc.num_classes, pretrained=False).cuda()
        #net = MBO.MBO().cuda()
        #net = deeplab_resnet.Res_Deeplab().cuda()
    else:
        print('cuda is not available')
        net = fcn8s.FCN8s(num_classes=voc.num_classes, pretrained=True)

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        set='benchmark',
                        transform=input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=bsz,
                              num_workers=8,
                              shuffle=True)

    val_set = voc.VOC('val',
                      set='voc',
                      transform=input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=4,
                            shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False,
                                   ignore_index=voc.ignore_label).cuda()
    optimizer = optim.Adam([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        train_args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        train_args['lr']
    }],
                           betas=(train_args['momentum'], 0.999))
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=2,
                                  min_lr=1e-10,
                                  verbose=True)

    lr0 = 1e-7
    max_epoch = 50
    max_iter = max_epoch * len(train_loader)
    #optimizer = optim.SGD(net.parameters(),lr = lr0, momentum = 0.9, weight_decay = 0.0005)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

    log_dir = os.path.join(root, 'logs', 'voc-fcn')
    time = datetime.datetime.now().strftime('%d-%m-%H-%M')
    train_file = 'train_log' + time + '.txt'
    val_file = 'val_log' + time + '.txt'
    #os.makedirs(log_dir,exist_ok=True)

    training_log = open(os.path.join(log_dir, train_file), 'w')
    val_log = open(os.path.join(log_dir, val_file), 'w')

    curr_epoch = 1
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args,
              training_log, max_iter, lr0)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch,
                            train_args, restore_transform, visualize, val_log)

        scheduler.step(val_loss)

        lr_tmp = 0.0
        k = 0
        for param_group in optimizer.param_groups:
            lr_tmp += param_group['lr']
            k += 1
        val_log.write('learning rate = {}'.format(str(lr_tmp / k)) + '\n')