Exemplo n.º 1
0
def main():
    net = U_Net(img_ch=1, num_classes=3).to(device)

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(384),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    center_crop = joint_transforms.CenterCrop(crop_size)
    train_input_transform = extended_transforms.ImgToTensor()

    target_transform = extended_transforms.MaskToTensor()
    make_dataset_fn = bladder.make_dataset_v2
    train_set = bladder.Bladder(data_path,
                                'train',
                                joint_transform=train_joint_transform,
                                center_crop=center_crop,
                                transform=train_input_transform,
                                target_transform=target_transform,
                                make_dataset_fn=make_dataset_fn)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    if loss_name == 'dice_':
        criterion = SoftDiceLossV2(activation='sigmoid',
                                   num_classes=3).to(device)
    elif loss_name == 'bcew_':
        criterion = nn.BCEWithLogitsLoss().to(device)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    train(train_loader, net, criterion, optimizer, n_epoch, 0)
Exemplo n.º 2
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    src_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
    ])
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    src_dataset = GTA5DataSetLMDB(
        args.data_dir, args.data_list,
        joint_transform=joint_transform,
        transform=src_input_transform, 
        target_transform=target_transform,
    )
    src_loader = data.DataLoader(
        src_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )
    tgt_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform, 
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(
        tgt_dataset, batch_size=args.batch_size, shuffle=True, 
        num_workers=args.num_workers, pin_memory=True, drop_last=True
    )

    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_val, args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.num_workers, pin_memory=True, drop_last=False
    )

    style_trans = StyleTrans(args)
    style_trans.train(src_loader, tgt_loader, val_loader, writer)

    writer.close()
Exemplo n.º 3
0
    def transform_image_and_mask_tt(self, image, mask, angle=None, crop_size=None):
        assert self.use_iaa == False
        transforms = [JT.RandomHorizontallyFlip()]
        if crop_size is not None:
            transforms.append(JT.RandomCrop(size=crop_size))
        if angle is not None:
            transforms.append(JT.RandomRotate(degree=angle))
        jt_random = JT.RandomOrderApply(transforms)

        jt_transform = JT.Compose([
            JT.ToPILImage(),
            jt_random,
            JT.ToNumpy(),
        ])

        return jt_transform(image, mask)
Exemplo n.º 4
0
def train_with_correspondences(save_folder, startnet, args):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    check_mkdir(save_folder)
    writer = SummaryWriter(save_folder)

    # Network and weight loading
    model_config = model_configs.PspnetCityscapesConfig()
    net = model_config.init_network().to(device)

    if args['snapshot'] == 'latest':
        args['snapshot'] = get_latest_network_name(save_folder)

    if len(args['snapshot']) == 0:  # If start from beginning
        state_dict = torch.load(startnet)
        # needed since we slightly changed the structure of the network in
        # pspnet
        state_dict = rename_keys_to_match(state_dict)
        net.load_state_dict(state_dict)  # load original weights

        start_iter = 0
        args['best_record'] = {
            'iter': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:  # If continue training
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(save_folder,
                                    args['snapshot'])))  # load weights
        split_snapshot = args['snapshot'].split('_')

        start_iter = int(split_snapshot[1])
        with open(os.path.join(save_folder, 'bestval.txt')) as f:
            best_val_dict_str = f.read()
        args['best_record'] = eval(best_val_dict_str.rstrip())

    net.train()
    freeze_bn(net)

    # Data loading setup
    if args['corr_set'] == 'rc':
        corr_set_config = data_configs.RobotcarConfig()
    elif args['corr_set'] == 'cmu':
        corr_set_config = data_configs.CmuConfig()

    sliding_crop_im = joint_transforms.SlidingCropImageOnly(
        713, args['stride_rate'])

    input_transform = model_config.input_transform
    pre_validation_transform = model_config.pre_validation_transform

    target_transform = extended_transforms.MaskToTensor()

    train_joint_transform_seg = joint_transforms.Compose([
        joint_transforms.Resize(1024),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomCrop(713)
    ])

    train_joint_transform_corr = corr_transforms.Compose([
        corr_transforms.CorrResize(1024),
        corr_transforms.CorrRandomCrop(713)
    ])

    # keep list of segmentation loaders and validators
    seg_loaders = list()
    validators = list()

    # Correspondences
    corr_set = correspondences.Correspondences(
        corr_set_config.correspondence_path,
        corr_set_config.correspondence_im_path,
        input_size=(713, 713),
        mean_std=model_config.mean_std,
        input_transform=input_transform,
        joint_transform=train_joint_transform_corr)
    corr_loader = DataLoader(corr_set,
                             batch_size=args['train_batch_size'],
                             num_workers=args['n_workers'],
                             shuffle=True)

    # Cityscapes Training
    c_config = data_configs.CityscapesConfig()
    seg_set_cs = cityscapes.CityScapes(
        c_config.train_im_folder,
        c_config.train_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    seg_loader_cs = DataLoader(seg_set_cs,
                               batch_size=args['train_batch_size'],
                               num_workers=args['n_workers'],
                               shuffle=True)
    seg_loaders.append(seg_loader_cs)

    # Cityscapes Validation
    val_set_cs = cityscapes.CityScapes(
        c_config.val_im_folder,
        c_config.val_seg_folder,
        c_config.im_file_ending,
        c_config.seg_file_ending,
        id_to_trainid=c_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    val_loader_cs = DataLoader(val_set_cs,
                               batch_size=1,
                               num_workers=args['n_workers'],
                               shuffle=False)
    validator_cs = Validator(val_loader_cs,
                             n_classes=c_config.n_classes,
                             save_snapshot=False,
                             extra_name_str='Cityscapes')
    validators.append(validator_cs)

    # Vistas Training and Validation
    if args['include_vistas']:
        v_config = data_configs.VistasConfig(
            use_subsampled_validation_set=True, use_cityscapes_classes=True)

        seg_set_vis = cityscapes.CityScapes(
            v_config.train_im_folder,
            v_config.train_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            joint_transform=train_joint_transform_seg,
            sliding_crop=None,
            transform=input_transform,
            target_transform=target_transform)
        seg_loader_vis = DataLoader(seg_set_vis,
                                    batch_size=args['train_batch_size'],
                                    num_workers=args['n_workers'],
                                    shuffle=True)
        seg_loaders.append(seg_loader_vis)

        val_set_vis = cityscapes.CityScapes(
            v_config.val_im_folder,
            v_config.val_seg_folder,
            v_config.im_file_ending,
            v_config.seg_file_ending,
            id_to_trainid=v_config.id_to_trainid,
            sliding_crop=sliding_crop_im,
            transform=input_transform,
            target_transform=target_transform,
            transform_before_sliding=pre_validation_transform)
        val_loader_vis = DataLoader(val_set_vis,
                                    batch_size=1,
                                    num_workers=args['n_workers'],
                                    shuffle=False)
        validator_vis = Validator(val_loader_vis,
                                  n_classes=v_config.n_classes,
                                  save_snapshot=False,
                                  extra_name_str='Vistas')
        validators.append(validator_vis)
    else:
        seg_loader_vis = None
        map_validator = None

    # Extra Training
    extra_seg_set = cityscapes.CityScapes(
        corr_set_config.train_im_folder,
        corr_set_config.train_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        joint_transform=train_joint_transform_seg,
        sliding_crop=None,
        transform=input_transform,
        target_transform=target_transform)
    extra_seg_loader = DataLoader(extra_seg_set,
                                  batch_size=args['train_batch_size'],
                                  num_workers=args['n_workers'],
                                  shuffle=True)
    seg_loaders.append(extra_seg_loader)

    # Extra Validation
    extra_val_set = cityscapes.CityScapes(
        corr_set_config.val_im_folder,
        corr_set_config.val_seg_folder,
        corr_set_config.im_file_ending,
        corr_set_config.seg_file_ending,
        id_to_trainid=corr_set_config.id_to_trainid,
        sliding_crop=sliding_crop_im,
        transform=input_transform,
        target_transform=target_transform,
        transform_before_sliding=pre_validation_transform)
    extra_val_loader = DataLoader(extra_val_set,
                                  batch_size=1,
                                  num_workers=args['n_workers'],
                                  shuffle=False)
    extra_validator = Validator(extra_val_loader,
                                n_classes=corr_set_config.n_classes,
                                save_snapshot=True,
                                extra_name_str='Extra')
    validators.append(extra_validator)

    # Loss setup
    if args['corr_loss_type'] == 'class':
        corr_loss_fct = CorrClassLoss(input_size=[713, 713])
    else:
        corr_loss_fct = FeatureLoss(
            input_size=[713, 713],
            loss_type=args['corr_loss_type'],
            feat_dist_threshold_match=args['feat_dist_threshold_match'],
            feat_dist_threshold_nomatch=args['feat_dist_threshold_nomatch'],
            n_not_matching=0)

    seg_loss_fct = torch.nn.CrossEntropyLoss(
        reduction='elementwise_mean',
        ignore_index=cityscapes.ignore_label).to(device)

    # Optimizer setup
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and param.requires_grad
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and param.requires_grad
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(save_folder, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    open(os.path.join(save_folder,
                      str(datetime.datetime.now()) + '.txt'),
         'w').write(str(args) + '\n\n')

    if len(args['snapshot']) == 0:
        f_handle = open(os.path.join(save_folder, 'log.log'), 'w', buffering=1)
    else:
        clean_log_before_continuing(os.path.join(save_folder, 'log.log'),
                                    start_iter)
        f_handle = open(os.path.join(save_folder, 'log.log'), 'a', buffering=1)

    ##########################################################################
    #
    #       MAIN TRAINING CONSISTS OF ALL SEGMENTATION LOSSES AND A CORRESPONDENCE LOSS
    #
    ##########################################################################
    softm = torch.nn.Softmax2d()

    val_iter = 0
    train_corr_loss = AverageMeter()
    train_seg_cs_loss = AverageMeter()
    train_seg_extra_loss = AverageMeter()
    train_seg_vis_loss = AverageMeter()

    seg_loss_meters = list()
    seg_loss_meters.append(train_seg_cs_loss)
    if args['include_vistas']:
        seg_loss_meters.append(train_seg_vis_loss)
    seg_loss_meters.append(train_seg_extra_loss)

    curr_iter = start_iter

    for i in range(args['max_iter']):
        optimizer.param_groups[0]['lr'] = 2 * args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']
        optimizer.param_groups[1]['lr'] = args['lr'] * (
            1 - float(curr_iter) / args['max_iter'])**args['lr_decay']

        #######################################################################
        #       SEGMENTATION UPDATE STEP
        #######################################################################
        #
        for si, seg_loader in enumerate(seg_loaders):
            # get segmentation training sample
            inputs, gts = next(iter(seg_loader))

            slice_batch_pixel_size = inputs.size(0) * inputs.size(
                2) * inputs.size(3)

            inputs = inputs.to(device)
            gts = gts.to(device)

            optimizer.zero_grad()
            outputs, aux = net(inputs)

            main_loss = args['seg_loss_weight'] * seg_loss_fct(outputs, gts)
            aux_loss = args['seg_loss_weight'] * seg_loss_fct(aux, gts)
            loss = main_loss + 0.4 * aux_loss

            loss.backward()
            optimizer.step()

            seg_loss_meters[si].update(main_loss.item(),
                                       slice_batch_pixel_size)

        #######################################################################
        #       CORRESPONDENCE UPDATE STEP
        #######################################################################
        if args['corr_loss_weight'] > 0 and args[
                'n_iterations_before_corr_loss'] < curr_iter:
            img_ref, img_other, pts_ref, pts_other, weights = next(
                iter(corr_loader))

            # Transfer data to device
            # img_ref is from the "good" sequence with generally better
            # segmentation results
            img_ref = img_ref.to(device)
            img_other = img_other.to(device)
            pts_ref = [p.to(device) for p in pts_ref]
            pts_other = [p.to(device) for p in pts_other]
            weights = [w.to(device) for w in weights]

            # Forward pass
            if args['corr_loss_type'] == 'hingeF':  # Works on features
                net.output_all = True
                with torch.no_grad():
                    output_feat_ref, aux_feat_ref, output_ref, aux_ref = net(
                        img_ref)
                output_feat_other, aux_feat_other, output_other, aux_other = net(
                    img_other
                )  # output1 must be last to backpropagate derivative correctly
                net.output_all = False

            else:  # Works on class probs
                with torch.no_grad():
                    output_ref, aux_ref = net(img_ref)
                    if args['corr_loss_type'] != 'hingeF' and args[
                            'corr_loss_type'] != 'hingeC':
                        output_ref = softm(output_ref)
                        aux_ref = softm(aux_ref)

                # output1 must be last to backpropagate derivative correctly
                output_other, aux_other = net(img_other)
                if args['corr_loss_type'] != 'hingeF' and args[
                        'corr_loss_type'] != 'hingeC':
                    output_other = softm(output_other)
                    aux_other = softm(aux_other)

            # Correspondence filtering
            pts_ref_orig, pts_other_orig, weights_orig, batch_inds_to_keep_orig = correspondences.refine_correspondence_sample(
                output_ref,
                output_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_ref_orig)
                if b.item() > 0
            ]
            pts_other_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, pts_other_orig)
                if b.item() > 0
            ]
            weights_orig = [
                p for b, p in zip(batch_inds_to_keep_orig, weights_orig)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                output_vals_ref = output_feat_ref[batch_inds_to_keep_orig]
                output_vals_other = output_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                output_vals_ref = output_ref[batch_inds_to_keep_orig]
                output_vals_other = output_other[batch_inds_to_keep_orig]

            pts_ref_aux, pts_other_aux, weights_aux, batch_inds_to_keep_aux = correspondences.refine_correspondence_sample(
                aux_ref,
                aux_other,
                pts_ref,
                pts_other,
                weights,
                remove_same_class=args['remove_same_class'],
                remove_classes=args['classes_to_ignore'])
            pts_ref_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_ref_aux)
                if b.item() > 0
            ]
            pts_other_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, pts_other_aux)
                if b.item() > 0
            ]
            weights_aux = [
                p for b, p in zip(batch_inds_to_keep_aux, weights_aux)
                if b.item() > 0
            ]
            if args['corr_loss_type'] == 'hingeF':
                # remove entire samples if needed
                aux_vals_ref = aux_feat_ref[batch_inds_to_keep_orig]
                aux_vals_other = aux_feat_other[batch_inds_to_keep_orig]
            else:
                # remove entire samples if needed
                aux_vals_ref = aux_ref[batch_inds_to_keep_aux]
                aux_vals_other = aux_other[batch_inds_to_keep_aux]

            optimizer.zero_grad()

            # correspondence loss
            if output_vals_ref.size(0) > 0:
                loss_corr_hr = corr_loss_fct(output_vals_ref,
                                             output_vals_other, pts_ref_orig,
                                             pts_other_orig, weights_orig)
            else:
                loss_corr_hr = 0 * output_vals_other.sum()

            if aux_vals_ref.size(0) > 0:
                loss_corr_aux = corr_loss_fct(
                    aux_vals_ref, aux_vals_other, pts_ref_aux, pts_other_aux,
                    weights_aux)  # use output from img1 as "reference"
            else:
                loss_corr_aux = 0 * aux_vals_other.sum()

            loss_corr = args['corr_loss_weight'] * \
                (loss_corr_hr + 0.4 * loss_corr_aux)
            loss_corr.backward()

            optimizer.step()
            train_corr_loss.update(loss_corr.item())

        #######################################################################
        #       LOGGING ETC
        #######################################################################
        curr_iter += 1
        val_iter += 1

        writer.add_scalar('train_seg_loss_cs', train_seg_cs_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_extra', train_seg_extra_loss.avg,
                          curr_iter)
        writer.add_scalar('train_seg_loss_vis', train_seg_vis_loss.avg,
                          curr_iter)
        writer.add_scalar('train_corr_loss', train_corr_loss.avg, curr_iter)
        writer.add_scalar('lr', optimizer.param_groups[1]['lr'], curr_iter)

        if (i + 1) % args['print_freq'] == 0:
            str2write = '[iter %d / %d], [train corr loss %.5f] , [seg cs loss %.5f], [seg vis loss %.5f], [seg extra loss %.5f]. [lr %.10f]' % (
                curr_iter, len(corr_loader), train_corr_loss.avg,
                train_seg_cs_loss.avg, train_seg_vis_loss.avg,
                train_seg_extra_loss.avg, optimizer.param_groups[1]['lr'])
            print(str2write)
            f_handle.write(str2write + "\n")

        if val_iter >= args['val_interval']:
            val_iter = 0
            for validator in validators:
                validator.run(net,
                              optimizer,
                              args,
                              curr_iter,
                              save_folder,
                              f_handle,
                              writer=writer)

    # Post training
    f_handle.close()
    writer.close()
Exemplo n.º 5
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
Exemplo n.º 6
0
def main(train_args):
    net = FCN8s(num_classes=11)

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.Scale(2000),
        # joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = PRIMA('train', joint_transform=train_joint_transform, transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=1, num_workers=0, shuffle=True)

    it = iter(train_loader)
    first = next(it)
    import pdb; pdb.set_trace()

    criterion = CrossEntropyLoss2d(size_average=False).cuda()
    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], betas=(train_args['momentum'], 0.999))

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=train_args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, train_args)
Exemplo n.º 7
0
def main():
    net = FCN8s(num_classes=cityscapes.num_classes, caffe=True).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                               'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                               'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = cityscapes.CityScapes('fine', 'train', joint_transform=train_joint_transform,
                                      transform=input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = cityscapes.CityScapes('fine', 'val', joint_transform=val_joint_transform, transform=input_transform,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=args['val_batch_size'], num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=False, ignore_index=cityscapes.ignore_label).cuda()

    optimizer = optim.Adam([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], betas=(args['momentum'], 0.999))

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10, verbose=True)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Exemplo n.º 8
0
def train_without_ignite(model,
                         loss,
                         batch_size,
                         img_size,
                         epochs,
                         lr,
                         num_workers,
                         optimizer,
                         logger,
                         gray_image=False,
                         scheduler=None,
                         viz=True):
    import visdom
    from utils.metrics import Accuracy, IoU

    DEFAULT_PORT = 8097
    DEFAULT_HOSTNAME = "http://localhost"

    if viz:
        vis = visdom.Visdom(port=DEFAULT_PORT, server=DEFAULT_HOSTNAME)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    data_loader = {}

    joint_transforms = jnt_trnsf.Compose([
        jnt_trnsf.RandomCrop(img_size),
        jnt_trnsf.RandomRotate(5),
        jnt_trnsf.RandomHorizontallyFlip()
    ])

    train_image_transforms = std_trnsf.Compose([
        std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()])

    test_image_transforms = std_trnsf.Compose([
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()])

    data_loader['train'] = get_loader(dataset='figaro',
                                      train=True,
                                      joint_transforms=joint_transforms,
                                      image_transforms=train_image_transforms,
                                      mask_transforms=mask_transforms,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers,
                                      gray_image=gray_image)

    data_loader['test'] = get_loader(dataset='figaro',
                                     train=False,
                                     joint_transforms=test_joint_transforms,
                                     image_transforms=test_image_transforms,
                                     mask_transforms=mask_transforms,
                                     batch_size=1,
                                     shuffle=True,
                                     num_workers=num_workers,
                                     gray_image=gray_image)

    for epoch in range(epochs):
        for phase in ['train', 'test']:
            if phase == 'train':
                model.train(True)
            else:
                prev_grad_state = torch.is_grad_enabled()
                torch.set_grad_enabled(False)
                model.train(False)

            running_loss = 0.0

            for i, data in enumerate(tqdm(data_loader[phase],
                                          file=sys.stdout)):
                if i == len(data_loader[phase]) - 1: break
                data_ = [
                    t.to(device) if isinstance(t, torch.Tensor) else t
                    for t in data
                ]

                if gray_image:
                    img, mask, gray = data_
                else:
                    img, mask = data_

                model.zero_grad()

                pred_mask = model(img)

                if gray_image:
                    l = loss(pred_mask, mask, gray)
                else:
                    l = loss(pred_mask, mask)

                if phase == 'train':
                    l.backward()
                    optimizer.step()

                running_loss += l.item()

            epoch_loss = running_loss / len(data_loader[phase])

            if phase == 'train':
                logger.info(
                    f"Training Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}"
                )
                if viz:
                    vis.images(
                        [
                            np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1),
                            mask.detach().cpu().numpy()[0]
                        ],
                        opts=dict(title=f'pred img for {epoch}-th iter'))

            if phase == 'test':
                if viz:
                    vis.images(
                        [
                            np.clip(pred_mask.detach().cpu().numpy()[0], 0, 1),
                            mask.detach().cpu().numpy()[0]
                        ],
                        opts=dict(title=f'pred img for {epoch}-th iter'))
                logger.info(
                    f"Test Results - Epoch: {epoch} Avg-loss: {epoch_loss:.3f}"
                )

                if scheduler: scheduler.step(epoch_loss)

                torch.set_grad_enabled(prev_grad_state)
def main(train_args):
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {'epoch': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
    else:
        print ('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {'epoch': int(split_snapshot[1]), 'val_loss': float(split_snapshot[3]),
                                     'acc': float(split_snapshot[5]), 'acc_cls': float(split_snapshot[7]),
                                     'mean_iu': float(split_snapshot[9]), 'fwavacc': float(split_snapshot[11])}

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_simul_transform = simul_transforms.Compose([
        simul_transforms.RandomSized(train_args['input_size']),
        simul_transforms.RandomRotate(10),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Scale(train_args['input_size'])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.RandomGaussianBlur(),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train', joint_transform=train_simul_transform, transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=train_args['train_batch_size'], num_workers=8, shuffle=True)
    val_set = voc.VOC('val', joint_transform=val_simul_transform, transform=val_input_transform,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * train_args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': train_args['lr'], 'weight_decay': train_args['weight_decay']}
    ], momentum=train_args['momentum'])

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['lr']
        optimizer.param_groups[1]['lr'] = train_args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, str(datetime.datetime.now()) + '.txt'), 'w').write(str(train_args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, train_args, val_loader, restore_transform, visualize)
Exemplo n.º 10
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if args.seg_net == 'fcn':
        mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            extended_transforms.FlipChannels(),
            standard_transforms.ToTensor(),
            standard_transforms.Lambda(lambda x: x.mul_(255)),
            standard_transforms.Normalize(*mean_std),
        ])
    else:
        normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*normalize),
        ])

    tgt_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform,
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(tgt_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)
    val_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    if args.seg_net == 'fcn':
        net = FCN8s(args.n_classes, pretrained=False)
        net_static = FCN8s(args.n_classes, pretrained=False)
        file_name = os.path.join(args.resume, args.fcn_name)
        # for name, param in net.named_parameters():
        #     if 'feat' not in name:
        #         param.requires_grad = False
    elif args.seg_net == 'deeplab_ibn':
        deeplab = resnet101_ibn_a_deeplab()
        file_name = os.path.join(args.resume, 'deeplab_ibn.pth')
    net.load_state_dict(torch.load(file_name))
    net_static.load_state_dict(torch.load(file_name))
    for param in net_static.parameters():
        param.requires_grad = False

    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net.cuda())
    net_static = torch.nn.DataParallel(net_static.cuda())
    # criterion = torch.nn.MSELoss()
    # criterion = torch.nn.SmoothL1Loss()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    gen_model = define_G()
    gen_model.load_state_dict(
        torch.load(os.path.join(args.resume, args.gen_name)))
    gen_model.eval()
    for param in gen_model.parameters():
        param.requires_grad = False
    gen_model = torch.nn.DataParallel(gen_model.cuda())

    # for seg net
    def normalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        if args.seg_net == 'fcn':
            mean = [103.939, 116.779, 123.68]
            flip_x = torch.cat(
                [x[:, 2 - i, :, :].unsqueeze(1) for i in range(3)],
                dim=1,
            )
            new_x = []
            for tem_x in flip_x:
                tem_new_x = []
                for c, m in zip(tem_x, mean):
                    tem_new_x.append(c.mul(255.0).sub(m).unsqueeze(0))
                new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
            new_x = torch.cat(new_x, dim=0)
            return new_x
        else:
            for tem_x in x:
                for c, m, s in zip(tem_x, mean, std):
                    c = c.sub(m).div(s)
            return x

    def de_normalize(x, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        new_x = []
        for tem_x in x:
            tem_new_x = []
            for c, m, s in zip(tem_x, mean, std):
                tem_new_x.append(c.mul(s).add(s).unsqueeze(0))
            new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
        new_x = torch.cat(new_x, dim=0)
        return new_x

    # ###################################################
    # direct test with gen
    # ###################################################
    print('Direct Test')
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')
    direct_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_dataset_direct = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=direct_input_transform,
        target_transform=target_transform,
    )
    val_loader_direct = data.DataLoader(val_dataset_direct,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        drop_last=False)

    class NewModel(object):
        def __init__(self, gen_net, val_net):
            self.gen_net = gen_net
            self.val_net = val_net

        def __call__(self, x):
            x = de_normalize(self.gen_net(x))
            new_x = normalize(x)
            out = self.val_net(new_x)
            return out

        def eval(self):
            self.gen_net.eval()
            self.val_net.eval()

    new_model = NewModel(gen_model, net)
    print('Test with Gen')
    mean_iu = test_miou(new_model, val_loader_direct, upsample,
                        './dataset/info.json')
    # return

    num_batches = len(tgt_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(tgt_loader):
            iteration = batch_index + 1 + epoch * num_batches

            net.train()
            net_static.eval()  # fine-tune use eval

            img, _, name = batch_data
            img = img.cuda()
            data_time_rec.update(time.time() - tem_time)

            with torch.no_grad():
                gen_output = gen_model(img)
                gen_seg_output_logits = net_static(
                    normalize(de_normalize(gen_output)))
            ori_seg_output_logits = net(normalize(de_normalize(img)))

            prob = torch.nn.Softmax(dim=1)
            max_value, label = torch.max(prob(gen_seg_output_logits), dim=1)
            label_mask = torch.zeros(label.shape, dtype=torch.uint8).cuda()
            for tem_label in range(19):
                tem_mask = label == tem_label
                if torch.sum(tem_mask) < 5:
                    continue
                value_vec = max_value[tem_mask]
                large_value = torch.topk(
                    value_vec, int(args.percent * value_vec.shape[0]))[0][0]
                large_mask = max_value > large_value
                label_mask = label_mask | (tem_mask & large_mask)
            label[label_mask] = 255

            # loss = criterion(ori_seg_output_logits, gen_seg_output_logits)
            loss = criterion(ori_seg_output_logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if iteration % args.checkpoint_freq == 0:
                mean_iu = test_miou(net,
                                    val_loader,
                                    upsample,
                                    './dataset/info.json',
                                    print_results=False)
                if mean_iu > highest:
                    torch.save(
                        net.module.state_dict(),
                        os.path.join(args.save_path_prefix,
                                     'cityscapes_best_fcn.pth'))
                    highest = mean_iu
                    print(f'save fcn model with {mean_iu:.2%}')

    print(('-' * 100 + '\n') * 3)
    print('>' * 50 + 'Final Model')
    net.module.load_state_dict(
        torch.load(
            os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth')))
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')

    writer.close()
Exemplo n.º 11
0
    def load_dataset(self):
        # set mean and std value from ImageNet dataset
        if self.args.input_normalize:
            rgb_mean, rgb_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        else:
            rgb_mean, rgb_std = [122.675, 116.669,
                                 104.008], [58.395, 57.12, 57.375]

        # train joint transforms
        train_jts = joint_transforms.Compose(
            [  # joint_transforms.ElasticTransform(),
                joint_transforms.Resize(self.args.max_size),
                joint_transforms.Apply(tv_transforms.Lambda(
                    lambda img: img.astype(np.float) / 255),
                                       th=[False, True]),
                joint_transforms.RandomCrop(self.args.crop_size,
                                            ignore_label=255),
                joint_transforms.RandomHorizontallyFlip(),
            ])
        # train source transforms
        train_sts = tv_transforms.Compose(
            [  # extend_transforms.RandomGaussianBlur(blur_prob=0.1),
                # extend_transforms.RandomBright(),
                extend_transforms.ImageToTensor(self.args.input_normalize),
                tv_transforms.Normalize(mean=rgb_mean, std=rgb_std)
            ])
        # train target transforms
        # train_tts = extend_transforms.MapToTensor()
        train_tts = extend_transforms.GrayImageToTensor(False)

        train_dataset = saliency.SaliencyMergedData(
            root=self.args.dataset_dir,
            phase='train',
            dataset_list=self.args.datasets,
            joint_transform=train_jts,
            source_transform=train_sts,
            target_transform=train_tts,
        )

        self.train_sampler = None
        if self.args.weighted_sampler:
            self.train_sampler = sampler.DatasetWeightedSampler(
                train_dataset.weight_list, self.args.train_num_sampler)

        self.train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self.args.batch_size,
            drop_last=True,
            num_workers=self.args.workers,
            shuffle=(self.train_sampler is None),
            pin_memory=True,
            sampler=self.train_sampler,
        )

        # valid joint transforms
        valid_jts = joint_transforms.Compose([
            joint_transforms.Resize(self.args.crop_size),
            # joint_transforms.RandomCrop(args.crop_size),
        ])
        # valid source transforms
        valid_sts = tv_transforms.Compose([
            extend_transforms.ImageToTensor(self.args.input_normalize),
            tv_transforms.Normalize(mean=rgb_mean, std=rgb_std)
        ])
        # valid target transforms
        valid_tts = extend_transforms.GrayImageToTensor()

        valid_dataset = saliency.SaliencyMergedData(
            root=self.args.dataset_dir,
            phase='valid',
            dataset_list=self.args.datasets,
            joint_transform=valid_jts,
            source_transform=valid_sts,
            target_transform=valid_tts,
        )

        self.valid_sampler = None
        if self.args.weighted_sampler:
            self.valid_sampler = sampler.DatasetWeightedSampler(
                valid_dataset.weight_list, self.args.valid_num_sampler)

        self.valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            drop_last=True,
            batch_size=self.args.test_batch_size,
            num_workers=self.args.workers,
            shuffle=False,
            pin_memory=True,
            sampler=self.valid_sampler,
        )
def main():
    # args = parse_args()

    torch.backends.cudnn.benchmark = True
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
    device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

    # # if args.seed:
    # random.seed(args.seed)
    # np.random.seed(args.seed)
    # torch.manual_seed(args.seed)
    # # if args.gpu:
    # torch.cuda.manual_seed_all(args.seed)
    seed = 63
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if args.gpu:
    torch.cuda.manual_seed_all(seed)

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # train_transforms = transforms.Compose([
    # 	transforms.RandomCrop(args['crop_size']),
    # 	transforms.RandomRotation(90),
    # 	transforms.RandomHorizontalFlip(p=0.5),
    # 	transforms.RandomVerticalFlip(p=0.5),

    # 	])
    short_size = int(min(args['input_size']) / 0.875)
    # val_transforms = transforms.Compose([
    # 	transforms.Scale(short_size, interpolation=Image.NEAREST),
    # 	# joint_transforms.Scale(short_size),
    # 	transforms.CenterCrop(args['input_size'])
    # 	])
    train_joint_transform = joint_transforms.Compose([
        # joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['crop_size']),
        joint_transforms.RandomHorizontallyFlip(),
        joint_transforms.RandomRotate(90)
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(*mean_std)])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = transforms.Compose(
        [extended_transforms.DeNormalize(*mean_std),
         transforms.ToPILImage()])
    visualize = transforms.ToTensor()

    train_set = cityscapes.CityScapes('train',
                                      joint_transform=train_joint_transform,
                                      transform=input_transform,
                                      target_transform=target_transform)
    # train_set = cityscapes.CityScapes('train', transform=train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = cityscapes.CityScapes('val',
                                    joint_transform=val_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    # val_set = cityscapes.CityScapes('val', transform=val_transforms)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=True)

    print(len(train_loader), len(val_loader))

    # sdf

    vgg_model = VGGNet(requires_grad=True, remove_fc=True)
    net = FCN8s(pretrained_net=vgg_model,
                n_class=cityscapes.num_classes,
                dropout_rate=0.4)
    # net.apply(init_weights)
    criterion = nn.CrossEntropyLoss(ignore_index=cityscapes.ignore_label)

    optimizer = optim.Adam(net.parameters(), lr=1e-4)

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min', patience=args['lr_patience'], min_lr=1e-10)

    vgg_model = vgg_model.to(device)
    net = net.to(device)

    if torch.cuda.device_count() > 1:
        net = nn.DataParallel(net)

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9][:-4])
        }

    criterion.to(device)

    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, device, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, device, criterion, optimizer,
                            epoch, args, restore_transform, visualize)
        scheduler.step(val_loss)
Exemplo n.º 13
0
def main(train_args):
    net = PSPNet(num_classes=voc.num_classes).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 1
        train_args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + train_args['snapshot'])
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        train_args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    input_transform = standard_transforms.Compose([
        ToTensor(),
        Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    joint_transform = joint_transforms.Compose([
        joint_transforms.CenterCrop(224),
        # joint_transforms.Scale(2),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    target_transform = standard_transforms.Compose([
        extended_transforms.MaskToTensor(),
    ])
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage(),
    ])
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(400),
        standard_transforms.CenterCrop(400),
        standard_transforms.ToTensor()
    ])
    val_input_transform = standard_transforms.Compose([
        CenterCrop(224),
        ToTensor(),
        Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    val_target_transform = standard_transforms.Compose([
        CenterCrop(224),
        extended_transforms.MaskToTensor(),
    ])
    train_set = voc.VOC('train',
                        transform=input_transform,
                        target_transform=target_transform,
                        joint_transform=joint_transform)
    train_loader = DataLoader(train_set,
                              batch_size=4,
                              num_workers=4,
                              shuffle=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      target_transform=val_target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=4,
                            num_workers=4,
                            shuffle=False)

    # criterion = CrossEntropyLoss2d(size_average=True, ignore_index=voc.ignore_label).cuda()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=voc.ignore_label).cuda()
    optimizer = optim.SGD(net.parameters(),
                          lr=train_args['lr'],
                          momentum=train_args['momentum'],
                          weight_decay=train_args['weight_decay'])

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    # open(os.path.join(ckpt_path, exp_name, 'loss_001_aux_SGD_momentum_95_random_lr_001.txt'), 'w').write(str(train_args) + '\n\n')

    for epoch in range(curr_epoch, train_args['epoch_num'] + 1):
        # adjust_learning_rate(optimizer,epoch,net,train_args)
        train(train_loader, net, criterion, optimizer, epoch, train_args)
        validate(val_loader, net, criterion, optimizer, epoch, train_args,
                 restore_transform, visualize)
        adjust_learning_rate(optimizer, epoch, net, train_args)
Exemplo n.º 14
0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_filters = [16, 32]
batch_size = 32
lr = 1e-4
weight_decay = 1e-5
epochs = 600

# unet setting
num_conv_blocks = 1
# prob unet only
num_convs_per_block = num_conv_blocks
num_convs_fcomb = 1
partial_data = False
latent_dim = 6
beta = 10
# isotropic = False

# kaiming_normal and orthogonal
initializers = {'w': 'kaiming_normal', 'b': 'normal'}
# initializers = {'w':None, 'b':None}

# Transforms
joint_transfm = joint_transforms.Compose([joint_transforms.RandomHorizontallyFlip(),
                                          joint_transforms.RandomSizedCrop(128),
                                          joint_transforms.RandomRotate(60)])
# joint_transfm = None
input_transfm = transforms.Compose([transforms.ToPILImage()])
target_transfm = transforms.Compose([transforms.ToTensor()])
# joint_transfm=None
# input_transfm=None
Exemplo n.º 15
0
def get_transforms(scale_size, input_size, region_size, supervised, test,
                   al_algorithm, full_res, dataset):
    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    if scale_size == 0:
        print('(Data loading) Not scaling the data')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        if (not test and al_algorithm == 'ralis') and not full_res:
            val_joint_transform = joint_transforms.Scale(1024)
        else:
            val_joint_transform = None
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
    else:
        print('(Data loading) Scaling training data: ' + str(scale_size) +
              ' width dimension')
        print('(Data loading) Random crops of ' + str(input_size) +
              ' in training')
        print('(Data loading) No crops nor scale_size in validation')
        if supervised:
            train_joint_transform = joint_transforms.Compose([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCrop(input_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        else:
            train_joint_transform = joint_transforms.ComposeRegion([
                joint_transforms.Scale(scale_size),
                joint_transforms.RandomCropRegion(input_size,
                                                  region_size=region_size),
                joint_transforms.RandomHorizontallyFlip()
            ])
        al_train_joint_transform = joint_transforms.ComposeRegion([
            joint_transforms.Scale(scale_size),
            joint_transforms.CropRegion(region_size, region_size=region_size),
            joint_transforms.RandomHorizontallyFlip()
        ])
        if dataset == 'gta_for_camvid':
            val_joint_transform = joint_transforms.ComposeRegion(
                [joint_transforms.Scale(scale_size)])
        else:
            val_joint_transform = None
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()

    return input_transform, target_transform, train_joint_transform, val_joint_transform, al_train_joint_transform
Exemplo n.º 16
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    loader = data.DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    net = resnet101_ibn_a_deeplab(args.model_path_prefix,
                                  n_classes=args.n_classes)
    # optimizer = get_seg_optimizer(net, args)
    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net)
    criterion = torch.nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=args.ignore_index)

    num_batches = len(loader)
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            show_fig = (batch_index + 1) % args.show_img_freq == 0
            iteration = batch_index + 1 + epoch * num_batches

            # poly_lr_scheduler(
            #     optimizer=optimizer,
            #     init_lr=args.learning_rate,
            #     iter=iteration - 1,
            #     lr_decay_iter=args.lr_decay,
            #     max_iter=args.num_epoch*num_batches,
            #     power=args.poly_power,
            # )

            net.train()
            # net.module.freeze_bn()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time() - tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if show_fig:
                base_lr = optimizer.param_groups[0]["lr"]
                output = torch.argmax(output, dim=1).detach()[0, ...].cpu()
                fig, axes = plt.subplots(2, 1, figsize=(12, 14))
                axes = axes.flat
                axes[0].imshow(colorize_mask(output.numpy()))
                axes[0].set_title(name[0])
                axes[1].imshow(colorize_mask(label[0, ...].numpy()))
                axes[1].set_title(f'seg_true_{base_lr:.6f}')
                writer.add_figure('A_seg', fig, iteration)

        mean_iu = test_miou(net, val_loader, upsample,
                            './ae_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix,
                         f'{epoch:d}_{mean_iu*100:.0f}.pth'))

    writer.close()
Exemplo n.º 17
0
def main():
    net = FCN32VGG(num_classes=mapillary.num_classes).cuda()

    if len(args['snapshot']) == 0:
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    short_size = int(min(args['input_size']) / 0.875)
    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.RandomCrop(args['input_size']),
        joint_transforms.RandomHorizontallyFlip()
    ])
    val_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(short_size),
        joint_transforms.CenterCrop(args['input_size'])
    ])
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])
    visualize = standard_transforms.ToTensor()

    train_set = mapillary.Mapillary('semantic',
                                    'training',
                                    joint_transform=train_joint_transform,
                                    transform=input_transform,
                                    target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True,
                              pin_memory=True)
    val_set = mapillary.Mapillary('semantic',
                                  'validation',
                                  joint_transform=val_joint_transform,
                                  transform=input_transform,
                                  target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=args['val_batch_size'],
                            num_workers=8,
                            shuffle=False,
                            pin_memory=True)

    criterion = CrossEntropyLoss2d(size_average=False).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'])

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()).replace(':', '-') + '.txt'),
        'w').write(str(args) + '\n\n')

    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  patience=args['lr_patience'],
                                  min_lr=1e-10)
    for epoch in range(curr_epoch, args['epoch_num'] + 1):
        train(train_loader, net, criterion, optimizer, epoch, args)
        val_loss = validate(val_loader, net, criterion, optimizer, epoch, args,
                            restore_transform, visualize)
        scheduler.step(val_loss)

    torch.save(net.state_dict(), PATH)
Exemplo n.º 18
0
def train_with_ignite(networks, dataset, data_dir, batch_size, img_size,
                      epochs, lr, momentum, num_workers, optimizer, logger):

    from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
    from ignite.metrics import Loss
    from utils.metrics import MultiThresholdMeasures, Accuracy, IoU, F1score

    # device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # build model
    model = get_network(networks)

    # log model summary
    input_size = (3, img_size, img_size)
    summarize_model(model.to(device), input_size, logger, batch_size, device)

    # build loss
    loss = torch.nn.BCEWithLogitsLoss()

    # build optimizer and scheduler
    model_optimizer = get_optimizer(optimizer, model, lr, momentum)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(model_optimizer)

    # transforms on both image and mask
    train_joint_transforms = jnt_trnsf.Compose([
        jnt_trnsf.RandomCrop(img_size),
        jnt_trnsf.RandomRotate(5),
        jnt_trnsf.RandomHorizontallyFlip()
    ])

    # transforms only on images
    train_image_transforms = std_trnsf.Compose([
        std_trnsf.ColorJitter(0.05, 0.05, 0.05, 0.05),
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    test_joint_transforms = jnt_trnsf.Compose([jnt_trnsf.Safe32Padding()])

    test_image_transforms = std_trnsf.Compose([
        std_trnsf.ToTensor(),
        std_trnsf.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    # transforms only on mask
    mask_transforms = std_trnsf.Compose([std_trnsf.ToTensor()])

    # build train / test loader
    train_loader = get_loader(dataset=dataset,
                              data_dir=data_dir,
                              train=True,
                              joint_transforms=train_joint_transforms,
                              image_transforms=train_image_transforms,
                              mask_transforms=mask_transforms,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    test_loader = get_loader(dataset=dataset,
                             data_dir=data_dir,
                             train=False,
                             joint_transforms=test_joint_transforms,
                             image_transforms=test_image_transforms,
                             mask_transforms=mask_transforms,
                             batch_size=1,
                             shuffle=False,
                             num_workers=num_workers)

    # build trainer / evaluator with ignite
    trainer = create_supervised_trainer(model,
                                        model_optimizer,
                                        loss,
                                        device=device)
    measure = MultiThresholdMeasures()
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                '': measure,
                                                'pix-acc': Accuracy(measure),
                                                'iou': IoU(measure),
                                                'loss': Loss(loss),
                                                'f1': F1score(measure),
                                            },
                                            device=device)

    # initialize state variable for checkpoint
    state = update_state(model.state_dict(), 0, 0, 0, 0, 0)

    # make ckpt path
    ckpt_root = './ckpt/'
    filename = '{network}_{optimizer}_lr_{lr}_epoch_{epoch}.pth'
    ckpt_path = os.path.join(ckpt_root, filename)

    # execution after every training iteration
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(trainer):
        num_iter = (trainer.state.iteration - 1) % len(train_loader) + 1
        if num_iter % 20 == 0:
            logger.info("Epoch[{}] Iter[{:03d}] Loss: {:.2f}".format(
                trainer.state.epoch, num_iter, trainer.state.output))

    # execution after every training epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(trainer):
        # evaluate on training set
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Training Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update state
        update_state(weight=model.state_dict(),
                     train_loss=metrics['loss'],
                     val_loss=state['val_loss'],
                     val_pix_acc=state['val_pix_acc'],
                     val_iou=state['val_iou'],
                     val_f1=state['val_f1'])

    # execution after every epoch
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(trainer):
        # evaluate test(validation) set
        evaluator.run(test_loader)
        metrics = evaluator.state.metrics
        logger.info(
            "Validation Results - Epoch: {} Avg-loss: {:.3f}\n Pix-acc: {}\n IoU: {}\n F1: {}\n"
            .format(trainer.state.epoch, metrics['loss'],
                    str(metrics['pix-acc']), str(metrics['iou']),
                    str(metrics['f1'])))

        # update scheduler
        lr_scheduler.step(metrics['loss'])

        # update and save state
        update_state(weight=model.state_dict(),
                     train_loss=state['train_loss'],
                     val_loss=metrics['loss'],
                     val_pix_acc=metrics['pix-acc'],
                     val_iou=metrics['iou'],
                     val_f1=metrics['f1'])

        path = ckpt_path.format(network=networks,
                                optimizer=optimizer,
                                lr=lr,
                                epoch=trainer.state.epoch)
        save_ckpt_file(path, state)

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 19
0
}

# Paths to trained models & epoch counts

DUCHDC_trainedModelPath = './ducModelFinal.pth'
FCN8_trainedModelPath = './fcnModelFinal.pth'
Unet_trainedModelPath = './unetModelFinal.pth'

# Transforms
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
short_size = int(min(args['input_size']) / 0.875)

joint_transform = joint_transforms.Compose([
    joint_transforms.Scale(short_size),
    joint_transforms.RandomCrop(args['input_size']),
    joint_transforms.RandomHorizontallyFlip()
])

input_transform = standard_transforms.Compose(
    [standard_transforms.ToTensor(),
     standard_transforms.Normalize(*mean_std)])

target_transform = extended_transforms.MaskToTensor()

restore_transform = standard_transforms.Compose([
    extended_transforms.DeNormalize(*mean_std),
    standard_transforms.ToPILImage()
])

visualize = standard_transforms.ToTensor()
Exemplo n.º 20
0
def main():
    net = PSPNet(num_classes=voc.num_classes).cuda()

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {
            'epoch': 0,
            'val_loss': 1e10,
            'acc': 0,
            'acc_cls': 0,
            'mean_iu': 0,
            'fwavacc': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'val_loss': float(split_snapshot[3]),
            'acc': float(split_snapshot[5]),
            'acc_cls': float(split_snapshot[7]),
            'mean_iu': float(split_snapshot[9]),
            'fwavacc': float(split_snapshot[11])
        }
    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'],
                                                args['stride_rate'],
                                                voc.ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Resize(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = voc.VOC('train',
                        joint_transform=train_joint_transform,
                        sliding_crop=sliding_crop,
                        transform=train_input_transform,
                        target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=1,
                              shuffle=True,
                              drop_last=True)
    val_set = voc.VOC('val',
                      transform=val_input_transform,
                      sliding_crop=sliding_crop,
                      target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=1,
                            shuffle=False,
                            drop_last=True)

    criterion = CrossEntropyLoss2d(size_average=True,
                                   ignore_index=voc.ignore_label).cuda()

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args,
          val_loader, visualize)
Exemplo n.º 21
0
def main():
    net = PSPNet(19)
    net.load_pretrained_model(
        model_path='./Caffe-PSPNet/pspnet101_cityscapes.caffemodel')
    for param in net.parameters():
        param.requires_grad = False
    net.cbr_final = conv2DBatchNormRelu(4096, 128, 3, 1, 1, False)
    net.dropout = nn.Dropout2d(p=0.1, inplace=True)
    net.classification = nn.Conv2d(128, kitti_binary.num_classes, 1, 1, 0)

    # Find total parameters and trainable parameters
    total_params = sum(p.numel() for p in net.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(p.numel() for p in net.parameters()
                                 if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        args['best_record'] = {
            'epoch': 0,
            'iter': 0,
            'val_loss': 1e10,
            'accu': 0
        }
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(
            torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {
            'epoch': int(split_snapshot[1]),
            'iter': int(split_snapshot[3]),
            'val_loss': float(split_snapshot[5]),
            'accu': float(split_snapshot[7])
        }
    net.cuda(args['gpu']).train()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    train_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    train_set = kitti_binary.KITTI(mode='train',
                                   joint_transform=train_joint_transform,
                                   transform=train_input_transform,
                                   target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=args['train_batch_size'],
                              num_workers=8,
                              shuffle=True)
    val_set = kitti_binary.KITTI(mode='val',
                                 transform=val_input_transform,
                                 target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=1,
                            num_workers=8,
                            shuffle=False)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.full([1], 1.05)).cuda(
        args['gpu'])

    optimizer = optim.SGD([{
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
        ],
        'lr':
        2 * args['lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
        ],
        'lr':
        args['lr'],
        'weight_decay':
        args['weight_decay']
    }],
                          momentum=args['momentum'],
                          nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(
        os.path.join(ckpt_path, exp_name,
                     str(datetime.datetime.now()) + '.txt'),
        'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, args, val_loader)