示例#1
0
    def __init__(self, model, optimizer, scheduler, device, cfg):
        self.scheduler = scheduler
        self.model = model
        self.cfg = cfg
        self.optimizer = optimizer
        self.device = device
        self.loss_function = MultiLosses(device=device)

        # Setup dataloader
        self.train_dst, self.val_dst = get_dataset(self.cfg)
        self.train_loader = data.DataLoader(self.train_dst,
                                            batch_size=self.cfg.batch_size,
                                            shuffle=True,
                                            num_workers=8,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(self.val_dst,
                                          batch_size=self.cfg.val_batch_size,
                                          shuffle=True,
                                          num_workers=8,
                                          pin_memory=True)
        print("Dataset: %s, Train set: %d, Val set: %d" %
              (self.cfg.dataset, len(self.train_dst), len(self.val_dst)))

        # visom setup
        vis = Visualizer(port=self.cfg.vis_port,
                         env=self.cfg.vis_env) if self.cfg.enable_vis else None
        if vis is not None:  # display options
            vis.vis_table("Options", vars(self.cfg))
        self.vis = vis
        self.vis_sample_id = np.random.randint(
            0, len(self.val_loader), self.cfg.vis_num_samples, np.int32
        ) if self.cfg.enable_vis else None  # sample idxs for visualization

        # metric
        self.metrics = StreamSegMetrics(self.cfg.num_classes)
示例#2
0
    def __init__(self, data_loader, opts):
        #super(Trainer, self).__init__(data_loader, opts)
        #self.opts = opts
        self.train_loader = data_loader[0]
        self.val_loader = data_loader[1]

        # Set up model
        model_map = {
            'deeplabv3_resnet50': network.deeplabv3_resnet50,
            'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
            'deeplabv3_resnet101': network.deeplabv3_resnet101,
            'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
            'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
            'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
        }

        self.model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(self.model.classifier)

        def set_bn_momentum(model, momentum=0.1):
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.momentum = momentum

        set_bn_momentum(self.model.backbone, momentum=0.01)             ##### What is Momentum? 0.01 or 0.99? #####


        # Set up metrics
        self.metrics = StreamSegMetrics(opts.num_classes)

        # Set up optimizer
        self.optimizer = torch.optim.SGD(params=[
            {'params': self.model.backbone.parameters(), 'lr': 0.1*opts.lr},
            {'params': self.model.classifier.parameters(), 'lr': opts.lr},
        ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

        if opts.lr_policy=='poly':
            self.scheduler = utils.PolyLR(self.optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy=='step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opts.step_size, gamma=0.1)

        # Set up criterion
        if opts.loss_type == 'focal_loss':
            self.criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        elif opts.loss_type == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')


        self.best_mean_iu = 0
        self.iteration = 0
示例#3
0
def main():
    opts = get_argparser().parse_args()
    opts = modify_command_options(opts)

    # Set up visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None

    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up dataloader
    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=opts.num_workers)
    val_loader = data.DataLoader(
        val_dst,
        batch_size=opts.batch_size if opts.crop_val else 1,
        shuffle=False,
        num_workers=opts.num_workers)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    print("Backbone: %s" % opts.backbone)
    model = DeepLabv3(num_classes=opts.num_classes,
                      backbone=opts.backbone,
                      pretrained=True,
                      momentum=opts.bn_mom,
                      output_stride=opts.output_stride,
                      use_separable_conv=opts.use_separable_conv)
    if opts.use_gn == True:
        print("[!] Replace BatchNorm with GroupNorm!")
        model = utils.convert_bn2gn(model)

    if opts.fix_bn == True:
        model.fix_bn()

    if torch.cuda.device_count() > 1:  # Parallel
        print("%d GPU parallel" % (torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
        model_ref = model.module  # for ckpt
    else:
        model_ref = model
    model = model.to(device)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    decay_1x, no_decay_1x = model_ref.group_params_1x()
    decay_10x, no_decay_10x = model_ref.group_params_10x()
    optimizer = torch.optim.SGD(params=[
        {
            "params": decay_1x,
            'lr': opts.lr,
            'weight_decay': opts.weight_decay
        },
        {
            "params": no_decay_1x,
            'lr': opts.lr
        },
        {
            "params": decay_10x,
            'lr': opts.lr * 10,
            'weight_decay': opts.weight_decay
        },
        {
            "params": no_decay_10x,
            'lr': opts.lr * 10
        },
    ],
                                lr=opts.lr,
                                momentum=opts.momentum,
                                nesterov=not opts.no_nesterov)
    del decay_1x, no_decay_1x, decay_10x, no_decay_10x

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer,
                                 max_iters=opts.epochs * len(train_loader),
                                 power=opts.lr_power)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer,
            step_size=opts.lr_decay_step,
            gamma=opts.lr_decay_factor)
    print("Optimizer:\n%s" % (optimizer))

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt)
        model_ref.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        cur_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint['best_score']
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")

    def save_ckpt(path):
        """ save current model
        """
        state = {
            "epoch": cur_epoch,
            "model_state": model_ref.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }
        torch.save(state, path)
        print("Model saved as %s" % path)

    # Set up criterion
    criterion = utils.get_loss(opts.loss_type)
    #==========   Train Loop   ==========#

    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_sample_num,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    label2color = utils.Label2Color(cmap=utils.color_map(
        opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images
    while cur_epoch < opts.epochs:
        # =====  Train  =====
        model.train()
        if opts.fix_bn == True:
            model_ref.fix_bn()

        epoch_loss = train(cur_epoch=cur_epoch,
                           criterion=criterion,
                           model=model,
                           optim=optimizer,
                           train_loader=train_loader,
                           device=device,
                           scheduler=scheduler,
                           vis=vis)
        print("End of Epoch %d/%d, Average Loss=%f" %
              (cur_epoch, opts.epochs, epoch_loss))
        if opts.enable_vis:
            vis.vis_scalar("Epoch Loss", cur_epoch, epoch_loss)

        # =====  Save Latest Model  =====
        if (cur_epoch + 1) % opts.ckpt_interval == 0:
            save_ckpt('checkpoints/latest_%s_%s.pkl' %
                      (opts.backbone, opts.dataset))

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            print("validate on val set...")
            model.eval()
            val_score, ret_samples = validate(model=model,
                                              loader=val_loader,
                                              device=device,
                                              metrics=metrics,
                                              ret_samples_ids=vis_sample_id)
            print(metrics.to_str(val_score))

            # =====  Save Best Model  =====
            if val_score['Mean IoU'] > best_score:  # save best model
                best_score = val_score['Mean IoU']
                save_ckpt('checkpoints/best_%s_%s.pkl' %
                          (opts.backbone, opts.dataset))

            if vis is not None:  # visualize validation score and samples
                vis.vis_scalar("[Val] Overall Acc", cur_epoch,
                               val_score['Overall Acc'])
                vis.vis_scalar("[Val] Mean IoU", cur_epoch,
                               val_score['Mean IoU'])
                vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                for k, (img, target, lbl) in enumerate(ret_samples):
                    img = (denorm(img) * 255).astype(np.uint8)
                    target = label2color(target).transpose(2, 0,
                                                           1).astype(np.uint8)
                    lbl = label2color(lbl).transpose(2, 0, 1).astype(np.uint8)

                    concat_img = np.concatenate((img, target, lbl),
                                                axis=2)  # concat along width
                    vis.vis_image('Sample %d' % k, concat_img)

            if opts.val_on_trainset == True:  # validate on train set
                print("validate on train set...")
                model.eval()
                train_score, _ = validate(model=model,
                                          loader=train_loader,
                                          device=device,
                                          metrics=metrics)
                print(metrics.to_str(train_score))
                if vis is not None:
                    vis.vis_scalar("[Train] Overall Acc", cur_epoch,
                                   train_score['Overall Acc'])
                    vis.vis_scalar("[Train] Mean IoU", cur_epoch,
                                   train_score['Mean IoU'])

        cur_epoch += 1
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    # Set up metrics
    # metrics = StreamSegMetrics(opts.num_classes)
    metrics = StreamSegMetrics(21)
    # Set up optimizer
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
    elif opts.loss_type == 'logit':
        criterion = nn.BCELoss(reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None:
        print("Error --ckpt, can't read model")
        return

    _, val_dst, test_dst = get_dataset(opts)
    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    test_loader = data.DataLoader(
        test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    vis_sample_id = np.random.randint(0, len(test_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization

    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images
    # ==========   Test Loop   ==========#

    if opts.test_only:
        print("Dataset: %s,  Val set: %d, Test set: %d" %
              (opts.dataset, len(val_dst), len(test_dst)))

        metrics = StreamSegMetrics(21)
        print("val")

        test_score, ret_samples = test_single(opts=opts,
                                              loader=test_loader, device=device, metrics=metrics,
                                              ret_samples_ids=vis_sample_id)
        print("test")
        test_score, ret_samples = test_multiple(
            opts=opts, loader=test_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
        print(metrics.to_str(test_score))
        return

    # ==========   Train Loop   ==========#
    utils.mkdir('checkpoints/multiple_model2')
    for class_num in range(opts.start_class, opts.num_classes):
        # ==========   Dataset   ==========#
        train_dst, val_dst, test_dst = get_dataset_multiple(opts, class_num)
        train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
        val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        test_loader = data.DataLoader(test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        print("Dataset: %s Class %d, Train set: %d, Val set: %d, Test set: %d" % (
            opts.dataset, class_num, len(train_dst), len(val_dst), len(test_dst)))

        # ==========   Model   ==========#
        model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(model.backbone, momentum=0.01)

        # ==========   Params and learning rate   ==========#
        params_list = [
            {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
            {'params': model.classifier.parameters(), 'lr': 0.1 * opts.lr}  # opts.lr
        ]
        if 'SA' in opts.model:
            params_list.append({'params': model.attention.parameters(), 'lr': 0.1 * opts.lr})
        optimizer = torch.optim.Adam(params=params_list, lr=opts.lr, weight_decay=opts.weight_decay)

        if opts.lr_policy == 'poly':
            scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

        model = nn.DataParallel(model)
        model.to(device)

        best_score = 0.0
        cur_itrs = 0
        cur_epochs = 0

        interval_loss = 0
        while True:  # cur_itrs < opts.total_itrs:
            # =====  Train  =====
            model.train()

            cur_epochs += 1
            for (images, labels) in train_loader:
                cur_itrs += 1

                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.long)
                # labels=(labels==class_num).float()
                optimizer.zero_grad()
                outputs = model(images)

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss
                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                if (cur_itrs) % 10 == 0:
                    interval_loss = interval_loss / 10
                    print("Epoch %d, Itrs %d/%d, Loss=%f" %
                          (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                    interval_loss = 0.0

                if (cur_itrs) % opts.val_interval == 0:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("validation...")
                    model.eval()
                    val_score, ret_samples = validate(
                        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
                        ret_samples_ids=vis_sample_id, class_num=class_num)
                    print(metrics.to_str(val_score))

                    if val_score['Mean IoU'] > best_score:  # save best model
                        best_score = val_score['Mean IoU']
                        save_ckpt('checkpoints/multiple_model2/best_%s_%s_class%d_os%d.pth' %
                                  (opts.model, opts.dataset, class_num, opts.output_stride))

                    if vis is not None:  # visualize validation score and samples
                        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                        for k, (img, target, lbl) in enumerate(ret_samples):
                            img = (denorm(img) * 255).astype(np.uint8)
                            target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                            concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                            vis.vis_image('Sample %d' % k, concat_img)
                    model.train()

                scheduler.step()

                if cur_itrs >= opts.total_itrs:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("Saving..")
                    break
            if cur_itrs >= opts.total_itrs:
                cur_itrs = 0
                break

        print("Model of class %d is trained and saved " % (class_num))
示例#5
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset=='voc' and not opts.crop_val:
        opts.val_batch_size = 1
    

    pipe = create_dali_pipeline(batch_size=opts.batch_size, num_threads=8,
                                device_id=0, data_dir="/home/ubuntu/cityscapes")
    pipe.build()
    train_loader = DALIGenericIterator(pipe, output_map=['image', 'label'], last_batch_policy=LastBatchPolicy.PARTIAL)

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy=='poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy=='step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)
    
    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #==========   Train Loop   ==========#
    interval_loss = 0

    class_conv = [255, 255, 255, 255, 255,
                  255, 255, 255, 0, 1,
                  255, 255, 2, 3, 4,
                  255, 255, 255, 5, 255,
                  6, 7, 8, 9, 10,
                  11, 12, 13, 14, 15,
                  255, 255, 16, 17, 18]

    while True: #cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        #model = model.half()
        cur_epochs += 1
        while True:
            train_iter = iter(train_loader)
            try:
                nvtx.range_push("Batch " + str(cur_itrs))

                nvtx.range_push("Data loading")
                data = next(train_iter)
                cur_itrs += 1

                images = data[0]['image'].to(dtype=torch.float32)
                labels = data[0]['label'][:, :, :, 0].to(dtype=torch.long)
                labels = torch.zeros(data[0]['label'][:, :, :, 0].shape).to(device, dtype=torch.long)
                nvtx.range_pop()

                nvtx.range_push("Forward pass")
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                nvtx.range_pop()

                nvtx.range_push("Backward pass")
                loss.backward()
                optimizer.step()
                nvtx.range_pop()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss

                nvtx.range_pop()

                if cur_itrs == 10:
                    break

                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

                scheduler.step()  

                if cur_itrs >=  opts.total_itrs:
                    return
            except StopIteration:
                break

        break
示例#6
0
文件: eval.py 项目: qjadud1994/DRS
def main():
    opts = get_argparser().parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    os.makedirs(opts.logit_dir, exist_ok=True)

    # Setup dataloader
    if not opts.crop_val:
        opts.val_batch_size = 1

    val_dst = get_dataset(opts)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=4)

    print("Dataset: voc, Val set: %d" % (len(val_dst)))

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes,
                                  output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        assert "no checkpoint"

    #==========   Eval   ==========#
    model.eval()
    val_score = validate(opts=opts,
                         model=model,
                         loader=val_loader,
                         device=device,
                         metrics=metrics)
    print(metrics.to_str(val_score))

    print("\n\n----------- crf -------------")
    crf_score = crf_inference(opts, val_dst, metrics)
    print(metrics.to_str(crf_score))

    os.system(f"rm -rf {opts.logit_dir}")
示例#7
0
def main(opts):
    distributed.init_process_group(backend='nccl', init_method='env://')
    device_id, device = opts.local_rank, torch.device(opts.local_rank)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    torch.cuda.set_device(device_id)

    # Initialize logging
    task_name = f"{opts.task}-{opts.dataset}"
    logdir_full = f"{opts.logdir}/{task_name}/{opts.name}/"
    if rank == 0:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=opts.visualize, step=opts.step)
    else:
        logger = Logger(logdir_full, rank=rank, debug=opts.debug, summary=False)

    logger.print(f"Device: {device}")

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # xxx Set up dataloader
    train_dst, val_dst, test_dst, n_classes = get_dataset(opts)
    # reset the seed, this revert changes in random seed
    random.seed(opts.random_seed)

    train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size,
                                   sampler=DistributedSampler(train_dst, num_replicas=world_size, rank=rank),
                                   num_workers=opts.num_workers, drop_last=True)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                 sampler=DistributedSampler(val_dst, num_replicas=world_size, rank=rank),
                                 num_workers=opts.num_workers)
    logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, Val set: {len(val_dst)},"
                f" Test set: {len(test_dst)}, n_classes {n_classes}")
    logger.info(f"Total batch size is {opts.batch_size * world_size}")

    # xxx Set up model
    logger.info(f"Backbone: {opts.backbone}")

    step_checkpoint = None
    model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
    logger.info(f"[!] Model made with{'out' if opts.no_pretrained else ''} pre-trained")

    if opts.step == 0:  # if step 0, we don't need to instance the model_old
        model_old = None
    else:  # instance model_old
        model_old = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step - 1))

    if opts.fix_bn:
        model.fix_bn()

    logger.debug(model)

    # xxx Set up optimizer
    params = []
    if not opts.freeze:
        params.append({"params": filter(lambda p: p.requires_grad, model.body.parameters()),
                       'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.head.parameters()),
                   'weight_decay': opts.weight_decay})

    params.append({"params": filter(lambda p: p.requires_grad, model.cls.parameters()),
                   'weight_decay': opts.weight_decay})

    optimizer = torch.optim.SGD(params, lr=opts.lr, momentum=0.9, nesterov=True)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, max_iters=opts.epochs * len(train_loader), power=opts.lr_power)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    else:
        raise NotImplementedError
    logger.debug("Optimizer:\n%s" % optimizer)

    if model_old is not None:
        [model, model_old], optimizer = amp.initialize([model.to(device), model_old.to(device)], optimizer,
                                                       opt_level=opts.opt_level)
        model_old = DistributedDataParallel(model_old)
    else:
        model, optimizer = amp.initialize(model.to(device), optimizer, opt_level=opts.opt_level)

    # Put the model on GPU
    model = DistributedDataParallel(model, delay_allreduce=True)

    # xxx Load old model from old weights if step > 0!
    if opts.step > 0:
        # get model path
        if opts.step_ckpt is not None:
            path = opts.step_ckpt
        else:
            path = f"checkpoints/step/{task_name}_{opts.name}_{opts.step - 1}.pth"

        # generate model from path
        if os.path.exists(path):
            step_checkpoint = torch.load(path, map_location="cpu")
            model.load_state_dict(step_checkpoint['model_state'], strict=False)  # False because of incr. classifiers
            if opts.init_balanced:
                # implement the balanced initialization (new cls has weight of background and bias = bias_bkg - log(N+1)
                model.module.init_new_classifier(device)
            # Load state dict from the model state dict, that contains the old model parameters
            model_old.load_state_dict(step_checkpoint['model_state'], strict=True)  # Load also here old parameters
            logger.info(f"[!] Previous model loaded from {path}")
            # clean memory
            del step_checkpoint['model_state']
        elif opts.debug:
            logger.info(f"[!] WARNING: Unable to find of step {opts.step - 1}! Do you really want to do from scratch?")
        else:
            raise FileNotFoundError(path)
        # put the old model into distributed memory and freeze it
        for par in model_old.parameters():
            par.requires_grad = False
        model_old.eval()

    # xxx Set up Trainer
    trainer_state = None
    # if not first step, then instance trainer from step_checkpoint
    if opts.step > 0 and step_checkpoint is not None:
        if 'trainer_state' in step_checkpoint:
            trainer_state = step_checkpoint['trainer_state']

    # instance trainer (model must have already the previous step weights)
    trainer = Trainer(model, model_old, device=device, opts=opts, trainer_state=trainer_state,
                      classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))

    # xxx Handle checkpoint for current model (model old will always be as previous step or None)
    best_score = 0.0
    cur_epoch = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"], strict=True)
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        cur_epoch = checkpoint["epoch"] + 1
        best_score = checkpoint['best_score']
        logger.info("[!] Model restored from %s" % opts.ckpt)
        # if we want to resume training, resume trainer from checkpoint
        if 'trainer_state' in checkpoint:
            trainer.load_state_dict(checkpoint['trainer_state'])
        del checkpoint
    else:
        if opts.step == 0:
            logger.info("[!] Train from scratch")

    # xxx Train procedure
    # print opts before starting training to log all parameters
    logger.add_table("Opts", vars(opts))

    if rank == 0 and opts.sample_num > 0:
        sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False)  # sample idxs for visualization
        logger.info(f"The samples id are {sample_ids}")
    else:
        sample_ids = None

    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225])  # de-normalization for original images

    TRAIN = not opts.test
    val_metrics = StreamSegMetrics(n_classes)
    results = {}

    # check if random is equal here.
    logger.print(torch.randint(0,100, (1,1)))
    # train/val here
    while cur_epoch < opts.epochs and TRAIN:
        # =====  Train  =====
        model.train()

        epoch_loss = trainer.train(cur_epoch=cur_epoch, optim=optimizer,
                                   train_loader=train_loader, scheduler=scheduler, logger=logger)

        logger.info(f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0]+epoch_loss[1]},"
                    f" Class Loss={epoch_loss[0]}, Reg Loss={epoch_loss[1]}")

        # =====  Log metrics on Tensorboard =====
        logger.add_scalar("E-Loss", epoch_loss[0]+epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch)
        logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch)

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            logger.info("validate on val set...")
            model.eval()
            val_loss, val_score, ret_samples = trainer.validate(loader=val_loader, metrics=val_metrics,
                                                                ret_samples_ids=sample_ids, logger=logger)

            logger.print("Done validation")
            logger.info(f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss[0]+val_loss[1]},"
                        f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")

            logger.info(val_metrics.to_str(val_score))

            # =====  Save Best Model  =====
            if rank == 0:  # save best model at the last iteration
                score = val_score['Mean IoU']
                # best model to build incremental steps
                save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                          model, trainer, optimizer, scheduler, cur_epoch, score)
                logger.info("[!] Checkpoint saved.")

            # =====  Log metrics on Tensorboard =====
            # visualize validation score and samples
            logger.add_scalar("V-Loss", val_loss[0]+val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-reg", val_loss[1], cur_epoch)
            logger.add_scalar("V-Loss-cls", val_loss[0], cur_epoch)
            logger.add_scalar("Val_Overall_Acc", val_score['Overall Acc'], cur_epoch)
            logger.add_scalar("Val_MeanIoU", val_score['Mean IoU'], cur_epoch)
            logger.add_table("Val_Class_IoU", val_score['Class IoU'], cur_epoch)
            logger.add_table("Val_Acc_IoU", val_score['Class Acc'], cur_epoch)
            # logger.add_figure("Val_Confusion_Matrix", val_score['Confusion Matrix'], cur_epoch)

            # keep the metric to print them at the end of training
            results["V-IoU"] = val_score['Class IoU']
            results["V-Acc"] = val_score['Class Acc']

            for k, (img, target, lbl) in enumerate(ret_samples):
                img = (denorm(img) * 255).astype(np.uint8)
                target = label2color(target).transpose(2, 0, 1).astype(np.uint8)
                lbl = label2color(lbl).transpose(2, 0, 1).astype(np.uint8)

                concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                logger.add_image(f'Sample_{k}', concat_img, cur_epoch)

        cur_epoch += 1

    # =====  Save Best Model at the end of training =====
    if rank == 0 and TRAIN:  # save best model at the last iteration
        # best model to build incremental steps
        save_ckpt(f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth",
                  model, trainer, optimizer, scheduler, cur_epoch, best_score)
        logger.info("[!] Checkpoint saved.")

    torch.distributed.barrier()

    # xxx From here starts the test code
    logger.info("*** Test the model on all seen classes...")
    # make data loader
    test_loader = data.DataLoader(test_dst, batch_size=opts.batch_size if opts.crop_val else 1,
                                  sampler=DistributedSampler(test_dst, num_replicas=world_size, rank=rank),
                                  num_workers=opts.num_workers)

    # load best model
    if TRAIN:
        model = make_model(opts, classes=tasks.get_per_task_classes(opts.dataset, opts.task, opts.step))
        # Put the model on GPU
        model = DistributedDataParallel(model.cuda(device))
        ckpt = f"checkpoints/step/{task_name}_{opts.name}_{opts.step}.pth"
        checkpoint = torch.load(ckpt, map_location="cpu")
        model.load_state_dict(checkpoint["model_state"])
        logger.info(f"*** Model restored from {ckpt}")
        del checkpoint
        trainer = Trainer(model, None, device=device, opts=opts)

    model.eval()

    val_loss, val_score, _ = trainer.validate(loader=test_loader, metrics=val_metrics, logger=logger)
    logger.print("Done test")
    logger.info(f"*** End of Test, Total Loss={val_loss[0]+val_loss[1]},"
                f" Class Loss={val_loss[0]}, Reg Loss={val_loss[1]}")
    logger.info(val_metrics.to_str(val_score))
    logger.add_table("Test_Class_IoU", val_score['Class IoU'])
    logger.add_table("Test_Class_Acc", val_score['Class Acc'])
    logger.add_figure("Test_Confusion_Matrix", val_score['Confusion Matrix'])
    results["T-IoU"] = val_score['Class IoU']
    results["T-Acc"] = val_score['Class Acc']
    logger.add_results(results)

    logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'], opts.step)
    logger.add_scalar("T_MeanIoU", val_score['Mean IoU'], opts.step)
    logger.add_scalar("T_MeanAcc", val_score['Mean Acc'], opts.step)

    logger.close()
示例#8
0
class Trainer():
    def __init__(self, data_loader, opts):
        #super(Trainer, self).__init__(data_loader, opts)
        #self.opts = opts
        self.train_loader = data_loader[0]
        self.val_loader = data_loader[1]

        # Set up model
        model_map = {
            'deeplabv3_resnet50': network.deeplabv3_resnet50,
            'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
            'deeplabv3_resnet101': network.deeplabv3_resnet101,
            'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
            'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
            'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
        }

        self.model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(self.model.classifier)

        def set_bn_momentum(model, momentum=0.1):
            for m in model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.momentum = momentum

        set_bn_momentum(self.model.backbone, momentum=0.01)             ##### What is Momentum? 0.01 or 0.99? #####


        # Set up metrics
        self.metrics = StreamSegMetrics(opts.num_classes)

        # Set up optimizer
        self.optimizer = torch.optim.SGD(params=[
            {'params': self.model.backbone.parameters(), 'lr': 0.1*opts.lr},
            {'params': self.model.classifier.parameters(), 'lr': opts.lr},
        ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)

        if opts.lr_policy=='poly':
            self.scheduler = utils.PolyLR(self.optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy=='step':
            self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=opts.step_size, gamma=0.1)

        # Set up criterion
        if opts.loss_type == 'focal_loss':
            self.criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        elif opts.loss_type == 'cross_entropy':
            self.criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')


        self.best_mean_iu = 0
        self.iteration = 0


    def _label_accuracy_score(self, label_trues, label_preds, n_class):
        """Returns accuracy score evaluation result.
          - overall accuracy
          - mean accuracy
          - mean IU
          - fwavacc
        """
        def _fast_hist(label_true, label_pred, n_class):
            mask = (label_true >= 0) & (label_true < n_class)
            hist = np.bincount(n_class * label_true[mask].astype(int) +
                           label_pred[mask],
                           minlength=n_class**2).reshape(n_class, n_class)
            return hist



        hist = np.zeros((n_class, n_class))
        for lt, lp in zip(label_trues, label_preds):
            hist += _fast_hist(lt.flatten(), lp.flatten(), n_class)
        acc = np.diag(hist).sum() / hist.sum()
        acc_cls = np.diag(hist) / hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
        mean_iu = np.nanmean(iu)
        freq = hist.sum(axis=1) / hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        return acc, acc_cls, mean_iu, fwavacc


    def validate(self, opts):
        # import matplotlib.pyplot as plt
        training = self.model.training
        self.model.eval()

        n_class = opts.num_classes

        val_loss = 0
        visualizations = []
        label_trues, label_preds = [], []
        with torch.no_grad():
            for i, (data, target) in tqdm(enumerate(self.val_loader)):
            #for batch_idx, (data, target) in tqdm.tqdm(
            #        enumerate(self.val_loader),
            #        total=len(self.val_loader),
            #        desc='Valid iteration=%d' % self.iteration,
            #        ncols=80,
            #        leave=False):
                #print(target)
                data, target = data.to(opts.device, dtype=torch.float), target.to(opts.device)
                score = self.model(data)

                loss = self.criterion(score, target)
                if np.isnan(float(loss.item())):
                    raise ValueError('loss is nan while validating')
                val_loss += float(loss.item()) / len(data)

                imgs = data.data.cpu()
                lbl_pred = score.data.max(1)[1].cpu().numpy()[:, :, :]
                lbl_true = target.data.cpu()
                for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
                    img, lt = self.val_loader.dataset.untransform(img, lt)
                    label_trues.append(lt)
                    label_preds.append(lp)
                    if len(visualizations) < 9:
                        pass
                        viz = utils.fcn_utils.visualize_segmentation(lbl_pred=lp,
                                                           lbl_true=lt,
                                                           img=img,
                                                           n_class=opts.num_classes)
                        visualizations.append(viz)
                        pass
        acc, acc_cls, mean_iu, fwavacc = self._label_accuracy_score(label_trues, label_preds, n_class)

        out = os.path.join(opts.output, 'visualization_viz')
        if not os.path.exists(out):
            os.makedirs(out)
        out_file = os.path.join(out, 'iter%012d.jpg' % self.iteration)
        #raise Exception(len(visualizations))
        img_ = utils.fcn_utils.get_tile_image(visualizations)
        imageio.imwrite(out_file, img_)
        # plt.imshow(imageio.imread(out_file))
        # plt.show()

        val_loss /= len(self.val_loader)



        is_best = mean_iu > self.best_mean_iu
        if is_best:  # save best model
            self.best_mean_iu = mean_iu

        def save_ckpt(path):
            """ save current model
            """
            torch.save({
                "cur_itrs": self.iteration,
                "model_state": self.model.module.state_dict(),
                "optimizer_state": self.optimizer.state_dict(),
                "scheduler_state": self.scheduler.state_dict(),
                "best_score": self.best_mean_iu,
            }, path)
            print("Model saved as %s" % path)


        if is_best:
            save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))
        else:
            save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))

        if training:
            self.model.train()

        """Do validation and return specified samples"""
        self.metrics.reset()




    def train(self, opts):

        print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(self.train_loader), len(self.val_loader)))


        if not os.path.exists('./checkpoints'):
            os.mkdir('./checkpoints')

        # Restore
        best_score = 0.0
        cur_itrs = 0
        cur_epochs = 0
        if opts.ckpt is not None and os.path.isfile(opts.ckpt):
            print('[!] Retrain from a checkpoint')
            # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
            checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
            self.model.load_state_dict(checkpoint["model_state"])
            self.model = nn.DataParallel(model)
            self.model.to(opts.device)
            if opts.continue_training:
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.scheduler.load_state_dict(checkpoint["scheduler_state"])
                cur_itrs = checkpoint["cur_itrs"]
                best_score = checkpoint['best_score']
                print("Training state restored from %s" % opts.ckpt)
            print("Model restored from %s" % opts.ckpt)
            del checkpoint  # free memory
        else:
            print('[!] Retrain from base network')
            self.model = nn.DataParallel(self.model)
            self.model.to(opts.device)

        #==========   Train Loop   ==========#
        interval_loss = 0
        while True: #cur_itrs < opts.total_itrs:
            # =====  Train  =====
            self.model.train()
            cur_epochs += 1
            for (images, labels) in self.train_loader:
                cur_itrs += 1

                images = images.to(opts.device, dtype=torch.float32)
                labels = labels.to(opts.device, dtype=torch.long)

                self.optimizer.zero_grad()

                outputs = self.model(images)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss


                lbl_pred = outputs.data.max(1)[1].cpu().numpy()[:, :, :]
                lbl_true = labels.data.cpu().numpy()
                acc, acc_cls, mean_iu, fwavacc = self._label_accuracy_score(lbl_true, lbl_pred, n_class=opts.num_classes)




                self.iteration = cur_itrs
                '''
                if (cur_itrs) % 10 == 0:
                    interval_loss = interval_loss/10
                    print("Epoch %d, Itrs %d/%d, Loss=%f" %
                          (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                    interval_loss = 0.0
                '''

                if cur_itrs % opts.interval_val == 0:
                    print('Epoch %d, Itrs %d/%d' % (cur_epochs, cur_itrs, opts.total_itrs))
                    self.validate(opts)

                if cur_itrs >=  opts.total_itrs:
                    return
示例#9
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=True,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet,
        'doubleattention_resnet50': network.doubleattention_resnet50,
        'doubleattention_resnet101': network.doubleattention_resnet101,
        'head_resnet50': network.head_resnet50,
        'head_resnet101': network.head_resnet101
    }

    model = model_map[opts.model](num_classes=opts.num_classes,
                                  output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {
            'params': model.backbone.parameters(),
            'lr': 0.1 * opts.lr
        },
        {
            'params': model.classifier.parameters(),
            'lr': opts.lr
        },
    ],
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)

    # Set up criterion
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,
                             labels) + coss_manifode(outputs, labels) * 0.01
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        target = train_dst.decode_target(target).transpose(
                            2, 0, 1).astype(np.uint8)
                        lbl = train_dst.decode_target(lbl).transpose(
                            2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                return
示例#10
0
def main():
    opts = get_argparser().parse_args()
    opts = modify_command_options(opts)

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' )
    print("Device: %s"%device)

    # Set up random seed
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up dataloader
    _, val_dst = get_dataset(opts)
    val_loader = data.DataLoader(val_dst, batch_size=opts.batch_size if opts.crop_val else 1 , shuffle=False, num_workers=opts.num_workers)
    print("Dataset: %s, Val set: %d"%(opts.dataset, len(val_dst)))
    
    # Set up model
    print("Backbone: %s"%opts.backbone)
    model = DeepLabv3(num_classes=opts.num_classes, backbone=opts.backbone, pretrained=True, momentum=opts.bn_mom, output_stride=opts.output_stride, use_separable_conv=opts.use_separable_conv)
    if opts.use_gn==True:
        print("[!] Replace BatchNorm with GroupNorm!")
        model = utils.convert_bn2gn(model)

    if torch.cuda.device_count()>1: # Parallel
        print("%d GPU parallel"%(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
        model_ref = model.module # for ckpt
    else:
        model_ref = model
    model = model.to(device)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    if opts.save_path is not None:
        utils.mkdir(opts.save_path)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt)
        model_ref.load_state_dict(checkpoint["model_state"])
        print("Model restored from %s"%opts.ckpt)
    else:
        print("[!] Retrain")
    
    label2color = utils.Label2Color(cmap=utils.color_map(opts.dataset)) # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],  
                               std=[0.229, 0.224, 0.225])  # denormalization for ori images
    model.eval()
    metrics.reset()
    idx = 0

    if opts.save_path is not None:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt
        

    with torch.no_grad():
        for i, (images, labels) in tqdm( enumerate( val_loader ) ):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            outputs = model(images)
            preds = outputs.detach().max(dim=1)[1].cpu().numpy()
            targets = labels.cpu().numpy()
            
            metrics.update(targets, preds)
            if opts.save_path is not None:
                for i in range(len(images)):
                    image = images[i].detach().cpu().numpy()
                    target = targets[i]
                    pred = preds[i]

                    image = (denorm(image) * 255).transpose(1,2,0).astype(np.uint8)
                    target = label2color(target).astype(np.uint8)
                    pred = label2color(pred).astype(np.uint8)

                    Image.fromarray(image).save(os.path.join(opts.save_path, '%d_image.png'%idx) )
                    Image.fromarray(target).save(os.path.join(opts.save_path, '%d_target.png'%idx) )
                    Image.fromarray(pred).save(os.path.join(opts.save_path, '%d_pred.png'%idx) )
                    
                    fig = plt.figure()
                    plt.imshow(image)
                    plt.axis('off')
                    plt.imshow(pred, alpha=0.7)
                    ax = plt.gca()
                    ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
                    plt.savefig(os.path.join(opts.save_path, '%d_overlay.png'%idx), bbox_inches='tight', pad_inches=0)
                    plt.close()
                    idx+=1
                
    score = metrics.get_results()
    print(metrics.to_str(score))
    if opts.save_path is not None:
        with open(os.path.join(opts.save_path, 'score.txt'), mode='w') as f:
            f.write(metrics.to_str(score))
def main():

    opts = get_argparser().parse_args()

    save_dir = os.path.join(opts.save_dir + opts.model + '/')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    print('Save position is %s\n' % (save_dir))

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    # select the GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s,  CUDA_VISIBLE_DEVICES: %s\n" % (device, opts.gpu_id))

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=8,
                                   drop_last=True,
                                   pin_memory=False)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.batch_size,
                                 shuffle=True,
                                 num_workers=8,
                                 drop_last=True,
                                 pin_memory=False)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'self_contrast': network.self_contrast,
        'DCNet_L1': network.DCNet_L1,
        'DCNet_L12': network.DCNet_L12,
        'DCNet_L123': network.DCNet_L123,
        'FCN': network.FCN,
        'UNet': network.UNet,
        'SegNet': network.SegNet,
        'cloudSegNet': network.cloudSegNet,
        'cloudUNet': network.cloudUNet
    }

    print('Model = %s, num_classes=%d' % (opts.model, opts.num_classes))
    model = model_map[opts.model](n_classes=opts.num_classes,
                                  is_batchnorm=True,
                                  in_channels=opts.in_channels,
                                  feature_scale=opts.feature_scale,
                                  is_deconv=False)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.5)

    # Set up criterion
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss()
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss()

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s\n\n" % path)

    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in checkpoint["model_state"].items() if (k in model_dict)
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            scheduler.max_iters = opts.total_itrs
            scheduler.min_lr = opts.lr
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Continue training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        print("Best_score is %s" % (str(best_score)))
        # del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization

    interval_loss = 0
    train_loss = list()
    train_accuracy = list()
    best_val_itrs = list()
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1

        for (images, labels) in train_loader:
            if (cur_itrs) == 0 or (cur_itrs) % opts.print_interval == 0:
                t1 = time.time()

            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if (cur_itrs) % opts.print_interval == 0:
                interval_loss = interval_loss / opts.print_interval
                train_loss.append(interval_loss)
                t2 = time.time()
                print("Epoch %d, Itrs %d/%d, Loss=%f, Time = %f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss,
                       t2 - t1))
                interval_loss = 0.0

            # save the ckpt file per 5000 itrs
            if (cur_itrs) % opts.val_interval == 0:
                print("validation...")
                model.eval()

                save_ckpt(save_dir + 'latest_%s_%s_itrs%s.pth' %
                          (opts.model, opts.dataset, str(cur_itrs)))
                time_before_val = time.time()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)

                time_after_val = time.time()
                print('Time_val = %f' % (time_after_val - time_before_val))
                print(metrics.to_str(val_score))

                train_accuracy.append(val_score['overall_acc'])
                if val_score['overall_acc'] > best_score:  # save best model
                    best_score = val_score['overall_acc']
                    save_ckpt(save_dir + 'best_%s_%s_.pth' %
                              (opts.model, opts.dataset))
                    best_val_itrs.append(cur_itrs)
                model.train()
            scheduler.step()  # update

            if cur_itrs >= opts.total_itrs:
                print(cur_itrs)
                print(opts.total_itrs)
                return
示例#12
0
def main(criterion):
    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    pretrained_backbone = False if "ACE2P" in opts.model else True
    model = network.model_map[opts.model](
        num_classes=opts.num_classes,
        output_stride=opts.output_stride,
        pretrained_backbone=pretrained_backbone,
        use_abn=opts.use_abn)
    if opts.use_schp:
        schp_model = network.model_map[opts.model](
            num_classes=opts.num_classes,
            output_stride=opts.output_stride,
            pretrained_backbone=pretrained_backbone,
            use_abn=opts.use_abn)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    model_params = [
        {
            'params': model.backbone.parameters(),
            'lr': 0.01 * opts.lr
        },
        {
            'params': model.classifier.parameters(),
            'lr': opts.lr
        },
    ]
    optimizer = create_optimizer(opts, model_params=model_params)
    # optimizer = torch.optim.SGD(params=[
    #     {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
    #     {'params': model.classifier.parameters(), 'lr': opts.lr},
    # ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_epochs": cur_epochs,
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    cycle_n = 0

    if opts.use_schp and opts.schp_ckpt is not None and os.path.isfile(
            opts.schp_ckpt):
        # TODO: there is a problem with this part.
        checkpoint = torch.load(opts.schp_ckpt,
                                map_location=torch.device('cpu'))
        schp_model.load_state_dict(checkpoint["model_state"])
        print("SCHP Model restored from %s" % opts.schp_ckpt)

    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.use_schp:
            schp_model = nn.DataParallel(schp_model)
            schp_model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_epochs = checkpoint[
                "cur_epochs"] - 1  # to start from the last epoch for schp
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)
        if opts.use_schp:
            schp_model = nn.DataParallel(schp_model)
            schp_model.to(device)

    # ==========   Train Loop   ==========#
    if opts.test_only:
        model.eval()
        val_score = validate(opts=opts,
                             model=model,
                             loader=val_loader,
                             device=device,
                             metrics=metrics)
        print(metrics.to_str(val_score))
        return
    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        criterion.start_log()
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            # images = images.to(device, dtype=torch.float32)
            # labels = labels.to(device, dtype=torch.long)
            images, labels = get_input(images, labels, opts, device, cur_itrs)
            if opts.use_mixup:
                images, main_images = images
            else:
                main_images = None
            images = images[:, [2, 1, 0]]  # for backbone
            optimizer.zero_grad()
            outputs = model(images)

            if opts.use_schp:
                # Online Self Correction Cycle with Label Refinement
                soft_labels = []
                if cycle_n >= 1:
                    with torch.no_grad():
                        if opts.use_mixup:
                            soft_preds = [
                                schp_model(main_images[0]),
                                schp_model(main_images[1])
                            ]
                            soft_edges = [None, None]
                        else:
                            soft_preds = schp_model(images)
                            soft_edges = None
                        if 'ACE2P' in opts.model:
                            soft_edges = soft_preds[1][-1]
                            soft_preds = soft_preds[0][-1]
                            # soft_parsing = []
                            # soft_edge = []
                            # for soft_pred in soft_preds:
                            #     soft_parsing.append(soft_pred[0][-1])
                            #     soft_edge.append(soft_pred[1][-1])
                            # soft_preds = torch.cat(soft_parsing, dim=0)
                            # soft_edges = torch.cat(soft_edge, dim=0)
                else:
                    if opts.use_mixup:
                        soft_preds = [None, None]
                        soft_edges = [None, None]
                    else:
                        soft_preds = None
                        soft_edges = None
                soft_labels.append(soft_preds)
                soft_labels.append(soft_edges)
                labels = [labels, soft_labels]

            # loss = criterion(outputs, labels)
            loss = calc_loss(criterion, outputs, labels, opts, cycle_n)
            loss.backward()
            optimizer.step()

            criterion.batch_step(len(images))
            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            sub_loss_text = ''
            for sub_loss, sub_prop in zip(criterion.losses, criterion.loss):
                if sub_prop['weight'] > 0:
                    sub_loss_text += f", {sub_prop['type']}: {sub_loss.item():.4f}"
            print(
                f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={np_loss:.4f}{sub_loss_text}",
                end='')

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print(
                    f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={interval_loss:.4f} {criterion.display_loss().replace('][',', ')}"
                )
                interval_loss = 0.0
                torch.cuda.empty_cache()

            if (cur_itrs) % opts.save_interval == 0 and (
                    cur_itrs) % opts.val_interval != 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score = validate(opts=opts,
                                     model=model,
                                     loader=val_loader,
                                     device=device,
                                     metrics=metrics)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))
                    # save_ckpt('/content/drive/MyDrive/best_%s_%s_os%d.pth' %
                    #           (opts.model, opts.dataset, opts.output_stride))
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                criterion.end_log(len(train_loader))
                return

        # Self Correction Cycle with Model Aggregation
        if opts.use_schp:
            if (cur_epochs + 1) >= opts.schp_start and (
                    cur_epochs + 1 - opts.schp_start) % opts.cycle_epochs == 0:
                print(f'\nSelf-correction cycle number {cycle_n}')

                schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
                cycle_n += 1
                schp.bn_re_estimate(train_loader, schp_model)
                schp.save_schp_checkpoint(
                    {
                        'state_dict': schp_model.state_dict(),
                        'cycle_n': cycle_n,
                    },
                    False,
                    "checkpoints",
                    filename=
                    f'schp_{opts.model}_{opts.dataset}_cycle{cycle_n}_checkpoint.pth'
                )
                # schp.save_schp_checkpoint({
                #     'state_dict': schp_model.state_dict(),
                #     'cycle_n': cycle_n,
                # }, False, '/content/drive/MyDrive/', filename=f'schp_{opts.model}_{opts.dataset}_checkpoint.pth')
        torch.cuda.empty_cache()
        criterion.end_log(len(train_loader))
示例#13
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
        ignore_index = 255
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
        ignore_index = 255
    elif opts.dataset.lower() == 'ade20k':
        opts.num_classes = 150
        ignore_index = -1
    elif opts.dataset.lower() == 'lvis':
        opts.num_classes = 1284
        ignore_index = -1
    elif opts.dataset.lower() == 'coco':
        opts.num_classes = 182
        ignore_index = 255
    if (opts.reduce_dim == False):
        opts.num_channels = opts.num_classes
    if (opts.test_only == False):
        writer = SummaryWriter('summary/' + opts.vis_env)
    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))
    epoch_interval = int(len(train_dst) / opts.batch_size)
    if (epoch_interval > 5000):
        opts.val_interval = 5000
    else:
        opts.val_interval = epoch_interval
    print("Evaluation after %d iterations" % (opts.val_interval))

    # Set up model
    model_map = {
        #'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        #'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        #'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }
    if (opts.reduce_dim):
        num_classes_input = [opts.num_channels, opts.num_classes]
    else:
        num_classes_input = [opts.num_classes]
    model = model_map[opts.model](num_classes=num_classes_input,
                                  output_stride=opts.output_stride,
                                  reduce_dim=opts.reduce_dim)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)
    if opts.reduce_dim:
        emb_layer = ['embedding.weight']
        params_classifier = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] not in emb_layer,
                           model.classifier.named_parameters()))))
        params_embedding = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] in emb_layer,
                           model.classifier.named_parameters()))))
        if opts.freeze_backbone:
            for param in model.backbone.parameters():
                param.requires_grad = False
            optimizer = torch.optim.SGD(
                params=[
                    #@{'params': model.backbone.parameters(),'lr':0.1*opts.lr},
                    {
                        'params': params_classifier,
                        'lr': opts.lr
                    },
                    {
                        'params': params_embedding,
                        'lr': opts.lr,
                        'momentum': 0.95
                    },
                ],
                lr=opts.lr,
                momentum=0.9,
                weight_decay=opts.weight_decay)
        else:
            optimizer = torch.optim.SGD(params=[
                {
                    'params': model.backbone.parameters(),
                    'lr': 0.1 * opts.lr
                },
                {
                    'params': params_classifier,
                    'lr': opts.lr
                },
                {
                    'params': params_embedding,
                    'lr': opts.lr
                },
            ],
                                        lr=opts.lr,
                                        momentum=0.9,
                                        weight_decay=opts.weight_decay)
    # Set up optimizer
    else:
        optimizer = torch.optim.SGD(params=[
            {
                'params': model.backbone.parameters(),
                'lr': 0.1 * opts.lr
            },
            {
                'params': model.classifier.parameters(),
                'lr': opts.lr
            },
        ],
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)
    elif opts.lr_policy == 'multi_poly':
        scheduler = utils.MultiPolyLR(optimizer,
                                      opts.total_itrs,
                                      power=[0.9, 0.9, 0.95])

    # Set up criterion
    if (opts.reduce_dim):
        opts.loss_type = 'nn_cross_entropy'
    else:
        opts.loss_type = 'cross_entropy'

    if opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=ignore_index,
                                        reduction='mean')
    elif opts.loss_type == 'nn_cross_entropy':
        criterion = utils.NNCrossEntropy(ignore_index=ignore_index,
                                         reduction='mean',
                                         num_neighbours=opts.num_neighbours,
                                         temp=opts.temp,
                                         dataset=opts.dataset)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir(opts.checkpoint_dir)
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        increase_iters = True
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("scheduler state dict :", scheduler.state_dict())
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0

    writer.add_text('lr', str(opts.lr))
    writer.add_text('batch_size', str(opts.batch_size))
    writer.add_text('reduce_dim', str(opts.reduce_dim))
    writer.add_text('checkpoint_dir', opts.checkpoint_dir)
    writer.add_text('dataset', opts.dataset)
    writer.add_text('num_channels', str(opts.num_channels))
    writer.add_text('num_neighbours', str(opts.num_neighbours))
    writer.add_text('loss_type', opts.loss_type)
    writer.add_text('lr_policy', opts.lr_policy)
    writer.add_text('temp', str(opts.temp))
    writer.add_text('crop_size', str(opts.crop_size))
    writer.add_text('model', opts.model)
    accumulation_steps = 1
    writer.add_text('accumulation_steps', str(accumulation_steps))
    j = 0
    updateflag = False
    while True:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)
            if (opts.dataset == 'ade20k' or opts.dataset == 'lvis'):
                labels = labels - 1

            optimizer.zero_grad()
            if (opts.reduce_dim):
                outputs, class_emb = model(images)
                loss = criterion(outputs, labels, class_emb)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            model.zero_grad()
            j = j + 1
            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)
                vis.vis_scalar('LR', cur_itrs,
                               scheduler.state_dict()['_last_lr'][0])
            torch.cuda.empty_cache()
            del images, labels, outputs, loss
            if (opts.reduce_dim):
                del class_emb
            gc.collect()
            if (cur_itrs) % 50 == 0:
                interval_loss = interval_loss / 50
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                writer.add_scalar('Loss', interval_loss, cur_itrs)
                writer.add_scalar('lr',
                                  scheduler.state_dict()['_last_lr'][0],
                                  cur_itrs)
            if cur_itrs % opts.val_interval == 0:
                save_ckpt(opts.checkpoint_dir + '/latest_%d.pth' % (cur_itrs))
            if cur_itrs % opts.val_interval == 0:
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt(opts.checkpoint_dir + '/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                writer.add_scalar('[Val] Overall Acc',
                                  val_score['Overall Acc'], cur_itrs)
                writer.add_scalar('[Val] Mean IoU', val_score['Mean IoU'],
                                  cur_itrs)
                writer.add_scalar('[Val] Mean Acc', val_score['Mean Acc'],
                                  cur_itrs)
                writer.add_scalar('[Val] Freq Acc', val_score['FreqW Acc'],
                                  cur_itrs)

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        if (opts.dataset.lower() == 'coco'):
                            target = numpy.asarray(
                                train_dst._colorize_mask(target).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                            lbl = numpy.asarray(
                                train_dst._colorize_mask(lbl).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                        else:
                            target = train_dst.decode_target(target).transpose(
                                2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(
                                2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()
            if cur_itrs >= opts.total_itrs:
                return
    writer.close()
#Only inference
model.eval()
# cretate path list
train_list, val_list = make_datapath_list_a2d2(rootpath)
#Create Dataset
train_dataset = A2D2Dataset(file_list=train_list,
                            transform=transform_dict['train'],
                            phase='train',
                            seg_label=SEG_COLOR_DICT_A2D2)
#Create Dataloader
train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

#criterion
metrics = StreamSegMetrics(cls_num)

#run
for image, seg in tqdm(train_dataloader):
    #input image into model
    image = image.cuda(cuda)
    output = model(image)

    #set label
    target = seg.to(dtype=torch.long)
    #for pspnet
    if type(output) == tuple:
        output = output[0]
    #back to cpu
    output = output.to('cpu')
    target = target.to('cpu')
def main():

    opts = get_argparser().parse_args()

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    # select the GPU
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s,  CUDA_VISIBLE_DEVICES: %s\n" % (device, opts.gpu_id))

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    train_dst, val_dst, test_dst = get_dataset(opts)

    test_loader = data.DataLoader(test_dst,
                                  batch_size=opts.batch_size,
                                  shuffle=True,
                                  num_workers=16,
                                  drop_last=True,
                                  pin_memory=False)

    print("Dataset: %s, Train set: %d, Test set: %d" %
          (opts.dataset, len(train_dst), len(test_dst)))

    # Set up model
    model_map = {
        'self_contrast': network.self_contrast,
        'DCNet_L1': network.DCNet_L1,
        'DCNet_L12': network.DCNet_L12,
        'DCNet_L123': network.DCNet_L123,
        'FCN': network.FCN,
        'UNet': network.UNet,
        'SegNet': network.SegNet,
        'cloudSegNet': network.cloudSegNet,
        'cloudUNet': network.cloudUNet
    }

    print('Model = %s, num_classes=%d' % (opts.model, opts.num_classes))
    model = model_map[opts.model](n_classes=opts.num_classes,
                                  is_batchnorm=True,
                                  in_channels=opts.in_channels,
                                  feature_scale=opts.feature_scale,
                                  is_deconv=False)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in checkpoint["model_state"].items() if (k in model_dict)
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)
        model = nn.DataParallel(model)
        model.to(device)
        print("Model restored from %s" % opts.ckpt)

    else:
        print("Model checkpoints Error!!!!!!! %s" % opts.ckpt)

    # ==========   Train Loop   ==========#l
    vis_sample_id = np.random.randint(
        0, len(test_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization

    if opts.test_only:
        model.eval()
        time_before_val = time.time()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=test_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        time_after_val = time.time()
        print('Time_val = %f' % (time_after_val - time_before_val))
        print(metrics.to_str(val_score))

        return
示例#16
0
def main(opts):
    # ===== Setup distributed =====
    distributed.init_process_group(backend='nccl', init_method='env://')
    if opts.device is not None:
        device_id = opts.device
    else:
        device_id = opts.local_rank
    device = torch.device(device_id)
    rank, world_size = distributed.get_rank(), distributed.get_world_size()
    if opts.device is not None:
        torch.cuda.set_device(opts.device)
    else:
        torch.cuda.set_device(device_id)

    # ===== Initialize logging =====
    logdir_full = f"{opts.logdir}/{opts.dataset}/{opts.name}/"
    if rank == 0:
        logger = Logger(logdir_full,
                        rank=rank,
                        debug=opts.debug,
                        summary=opts.visualize)
    else:
        logger = Logger(logdir_full,
                        rank=rank,
                        debug=opts.debug,
                        summary=False)

    logger.print(f"Device: {device}")

    checkpoint_path = f"checkpoints/{opts.dataset}/{opts.name}.pth"
    os.makedirs(f"checkpoints/{opts.dataset}", exist_ok=True)

    # ===== Setup random seed to reproducibility =====
    torch.manual_seed(opts.random_seed)
    torch.cuda.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # ===== Set up dataset =====
    train_dst, val_dst = get_dataset(opts, train=True)

    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   sampler=DistributedSampler(
                                       train_dst,
                                       num_replicas=world_size,
                                       rank=rank),
                                   num_workers=opts.num_workers,
                                   drop_last=True,
                                   pin_memory=True)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.batch_size,
                                 sampler=DistributedSampler(
                                     val_dst,
                                     num_replicas=world_size,
                                     rank=rank),
                                 num_workers=opts.num_workers)
    logger.info(f"Dataset: {opts.dataset}, Train set: {len(train_dst)}, "
                f"Val set: {len(val_dst)}, n_classes {opts.num_classes}")
    logger.info(f"Total batch size is {opts.batch_size * world_size}")
    # This is necessary for computing the scheduler decay
    opts.max_iter = opts.max_iter = opts.epochs * len(train_loader)

    # ===== Set up model and ckpt =====
    model = Trainer(device, logger, opts)
    model.distribute()

    cur_epoch = 0
    if opts.continue_ckpt:
        opts.ckpt = checkpoint_path
    if opts.ckpt is not None:
        assert os.path.isfile(
            opts.ckpt), "Error, ckpt not found. Check the correct directory"
        checkpoint = torch.load(opts.ckpt, map_location="cpu")
        cur_epoch = checkpoint["epoch"] + 1
        model.load_state_dict(checkpoint["model_state"])
        logger.info("[!] Model restored from %s" % opts.ckpt)
        del checkpoint
    else:
        logger.info("[!] Train from scratch")

    # ===== Train procedure =====
    # print opts before starting training to log all parameters
    logger.add_table("Opts", vars(opts))

    # uncomment if you want qualitative on val
    # if rank == 0 and opts.sample_num > 0:
    #     sample_ids = np.random.choice(len(val_loader), opts.sample_num, replace=False)  # sample idxs for visualization
    #     logger.info(f"The samples id are {sample_ids}")
    # else:
    #     sample_ids = None

    label2color = utils.Label2Color(cmap=utils.color_map(
        opts.dataset))  # convert labels to images
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225
                                    ])  # de-normalization for original images

    train_metrics = StreamSegMetrics(opts.num_classes)
    val_metrics = StreamSegMetrics(opts.num_classes)
    results = {}

    # check if random is equal here.
    logger.print(torch.randint(0, 100, (1, 1)))

    while cur_epoch < opts.epochs and not opts.test:
        # =====  Train  =====
        start = time.time()
        epoch_loss = model.train(cur_epoch=cur_epoch,
                                 train_loader=train_loader,
                                 metrics=train_metrics,
                                 print_int=opts.print_interval)
        train_score = train_metrics.get_results()
        end = time.time()

        len_ep = int(end - start)
        logger.info(
            f"End of Epoch {cur_epoch}/{opts.epochs}, Average Loss={epoch_loss[0] + epoch_loss[1]:.4f}, "
            f"Class Loss={epoch_loss[0]:.4f}, Reg Loss={epoch_loss[1]}\n"
            f"Train_Acc={train_score['Overall Acc']:.4f}, Train_Iou={train_score['Mean IoU']:.4f} "
            f"\n -- time: {len_ep // 60}:{len_ep % 60} -- ")
        logger.info(
            f"I will finish in {len_ep * (opts.epochs - cur_epoch) // 60} minutes"
        )

        logger.add_scalar("E-Loss", epoch_loss[0] + epoch_loss[1], cur_epoch)
        # logger.add_scalar("E-Loss-reg", epoch_loss[1], cur_epoch)
        # logger.add_scalar("E-Loss-cls", epoch_loss[0], cur_epoch)

        # =====  Validation  =====
        if (cur_epoch + 1) % opts.val_interval == 0:
            logger.info("validate on val set...")
            val_loss, _ = model.validate(loader=val_loader,
                                         metrics=val_metrics,
                                         ret_samples_ids=None)
            val_score = val_metrics.get_results()

            logger.print("Done validation")
            logger.info(
                f"End of Validation {cur_epoch}/{opts.epochs}, Validation Loss={val_loss}"
            )

            log_val(logger, val_metrics, val_score, val_loss, cur_epoch)

            # keep the metric to print them at the end of training
            results["V-IoU"] = val_score['Class IoU']
            results["V-Acc"] = val_score['Class Acc']

        # =====  Save Model  =====
        if rank == 0:
            if not opts.debug:
                save_ckpt(checkpoint_path, model, cur_epoch)
                logger.info("[!] Checkpoint saved.")

        cur_epoch += 1

    torch.distributed.barrier()

    # ==== TESTING =====
    logger.info("*** Test the model on all seen classes...")
    # make data loader
    test_dst = get_dataset(opts, train=False)
    test_loader = data.DataLoader(test_dst,
                                  batch_size=opts.batch_size_test,
                                  sampler=DistributedSampler(
                                      test_dst,
                                      num_replicas=world_size,
                                      rank=rank),
                                  num_workers=opts.num_workers)

    if rank == 0 and opts.sample_num > 0:
        sample_ids = np.random.choice(len(test_loader),
                                      opts.sample_num,
                                      replace=False)  # sample idxs for visual.
        logger.info(f"The samples id are {sample_ids}")
    else:
        sample_ids = None

    val_loss, ret_samples = model.validate(loader=test_loader,
                                           metrics=val_metrics,
                                           ret_samples_ids=sample_ids)
    val_score = val_metrics.get_results()
    conf_matrixes = val_metrics.get_conf_matrixes()
    logger.print("Done test on all")
    logger.info(f"*** End of Test on all, Total Loss={val_loss}")

    logger.info(val_metrics.to_str(val_score))
    log_samples(logger, ret_samples, denorm, label2color, 0)

    logger.add_figure("Test_Confusion_Matrix_Recall",
                      conf_matrixes['Confusion Matrix'])
    logger.add_figure("Test_Confusion_Matrix_Precision",
                      conf_matrixes["Confusion Matrix Pred"])
    results["T-IoU"] = val_score['Class IoU']
    results["T-Acc"] = val_score['Class Acc']
    results["T-Prec"] = val_score['Class Prec']
    logger.add_results(results)
    logger.add_scalar("T_Overall_Acc", val_score['Overall Acc'])
    logger.add_scalar("T_MeanIoU", val_score['Mean IoU'])
    logger.add_scalar("T_MeanAcc", val_score['Mean Acc'])
    ret = val_score['Mean IoU']

    logger.close()
    return ret
示例#17
0
class trainer():
    def __init__(self, model, optimizer, scheduler, device, cfg):
        self.scheduler = scheduler
        self.model = model
        self.cfg = cfg
        self.optimizer = optimizer
        self.device = device
        self.loss_function = MultiLosses(device=device)

        # Setup dataloader
        self.train_dst, self.val_dst = get_dataset(self.cfg)
        self.train_loader = data.DataLoader(self.train_dst,
                                            batch_size=self.cfg.batch_size,
                                            shuffle=True,
                                            num_workers=8,
                                            pin_memory=True)
        self.val_loader = data.DataLoader(self.val_dst,
                                          batch_size=self.cfg.val_batch_size,
                                          shuffle=True,
                                          num_workers=8,
                                          pin_memory=True)
        print("Dataset: %s, Train set: %d, Val set: %d" %
              (self.cfg.dataset, len(self.train_dst), len(self.val_dst)))

        # visom setup
        vis = Visualizer(port=self.cfg.vis_port,
                         env=self.cfg.vis_env) if self.cfg.enable_vis else None
        if vis is not None:  # display options
            vis.vis_table("Options", vars(self.cfg))
        self.vis = vis
        self.vis_sample_id = np.random.randint(
            0, len(self.val_loader), self.cfg.vis_num_samples, np.int32
        ) if self.cfg.enable_vis else None  # sample idxs for visualization

        # metric
        self.metrics = StreamSegMetrics(self.cfg.num_classes)

    def save_ckpt(self, path, cur_itrs, best_score):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": self.model.module.state_dict(),
                "optimizer_state": self.optimizer.state_dict(),
                "scheduler_state": self.scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    def load_model(self, ckpt=None):
        cur_itrs = 0
        best_score = 0
        print(ckpt)
        if ckpt is not None and os.path.isfile(ckpt):
            # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
            checkpoint = torch.load(ckpt, map_location=torch.device('cuda:0'))
            old_state = checkpoint["model_state"]
            new_state = state_transform(old_state)
            self.model.load_state_dict(new_state)
            self.model = nn.DataParallel(self.model)
            self.model.to(self.device)
            if self.cfg.continue_training:
                self.optimizer.load_state_dict(checkpoint["optimizer_state"])
                self.scheduler.load_state_dict(checkpoint["scheduler_state"])
                cur_itrs = checkpoint["cur_itrs"]
                best_score = checkpoint['best_score']
                print("Training state restored from %s" % ckpt)
            print("Model restored from %s" % ckpt)
            del checkpoint  # free memory
        else:
            print("[!] Retrain")
            self.model = nn.DataParallel(self.model)
            self.model.to(self.device)
        return cur_itrs, best_score

    def train(self, loss_name):
        cur_epochs = 0

        preLoss = []
        loss_record = []
        last_5loss = {}
        lossChange = 0
        # load checkpoints from saved path
        cur_itrs, best_score = self.load_model()

        print("using loss : %s" % loss_name)
        os.makedirs("checkpoints/%s" % loss_name, exist_ok=True)
        while cur_epochs < self.cfg.total_epochs:
            # =====  Train  =====
            self.model.train()
            cur_epochs += 1
            for (images, labels, dist_maps) in self.train_loader:
                cur_itrs += 1

                images = images.to(self.device, dtype=torch.float32)
                labels = labels.to(self.device, dtype=torch.long)
                dist_maps = dist_maps.to(self.device, dtype=torch.float32)
                my_labels = {'label': labels, 'dist_map': dist_maps}
                self.optimizer.zero_grad()
                outputs = self.model(images)

                lossi = []
                maxgrad = 0
                idx = 0
                if 'v' in loss_name:
                    for lit, ll in enumerate(['Bound', 'dice']):
                        if lit not in last_5loss:
                            last_5loss[lit] = np.zeros(5)
                        self.loss_function.bulid_loss(ll)
                        loss = self.loss_function.loss(outputs, my_labels)

                        if 'v2' in loss_name:
                            if len(preLoss) != 0:
                                lossChange = (loss -
                                              preLoss[lit]) / preLoss[lit]

                            if (len(preLoss) == 0 or
                                (lossChange > 0 and lossChange > maxgrad) or
                                (lossChange < 0 and maxgrad < 0
                                 and lossChange < maxgrad)):  # 更新要选择的loss
                                maxgrad = lossChange
                                idx = ll
                        elif 'v1' in loss_name:
                            if loss > maxgrad:
                                maxgrad = loss
                                idx = ll
                        elif 'v3' in loss_name or 'v5' in loss_name:
                            if len(preLoss) != 0:
                                lossChange = (loss -
                                              preLoss[lit]) / preLoss[lit]
                                if 'v5' in loss_name:
                                    if lossChange < 0:  #反向的改变应当降低权重
                                        lossChange /= 1.5
                            if (len(preLoss) == 0
                                    or abs(lossChange) > maxgrad):
                                maxgrad = abs(lossChange)
                                idx = ll
                        elif 'v4' in loss_name:
                            preLoss = np.mean(last_5loss[lit])
                            lossChange = (loss - preLoss) / preLoss
                            if (preLoss == 0 or abs(lossChange) > maxgrad):
                                maxgrad = abs(lossChange)
                                idx = ll
                        lossi.append(loss.item())
                        last_5loss[lit][cur_itrs % 5] = loss.item()
                    preLoss = lossi
                    if cur_itrs % self.cfg.print_interval == 0:
                        loss_record.append([lossi, idx])
                    self.loss_function.bulid_loss(str(idx))
                else:
                    self.loss_function.bulid_loss(loss_name)

                loss = self.loss_function.loss(outputs, my_labels)
                loss.backward()
                self.optimizer.step()

                np_loss = loss.item()
                if self.vis is not None:
                    self.vis.vis_scalar('Loss', cur_itrs, np_loss)

                if 'v' in loss_name:
                    #print("\riter: {} : criteria: {:.2f}, using loss {} | L2: {:.2f}, ce: {:.2f}, focal: {:.2f}, dice: {:.2f}".format(
                    #    cur_itrs, maxgrad, idx, lossi[0], lossi[1], lossi[2], lossi[3]), end='', flush=True)
                    print(
                        "\riter: {} : criteria: {:.2f}, using loss {} | boundary: {:.2f}, dice: {:.2f}"
                        .format(cur_itrs, maxgrad, idx, lossi[0], lossi[1]),
                        end='',
                        flush=True)
                else:
                    print('\riter: {} | loss: {:.7f}'.format(
                        cur_itrs, np_loss),
                          end='',
                          flush=True)
            # save ckpt after every epoch
            self.save_ckpt(
                'checkpoints/%s/latest_%s_%s_epoch%02d.pth' %
                (loss_name, self.cfg.model, self.cfg.dataset, cur_epochs),
                cur_itrs, best_score)
            print("save ckpt every epoch! cur_epoch: %d, best_score_%f" %
                  (cur_epochs, best_score))
            if self.cfg.eval_every_epoch:
                print("validation...", flush=True)
                self.model.eval()
                val_score, ret_samples = self.validate()
                print(self.metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    self.save_ckpt(
                        'checkpoints/%s/best_%s_%s.pth' %
                        (loss_name, self.cfg.model, self.cfg.dataset),
                        cur_itrs, best_score)
                    print("save ckpt best! cur_epoch: %d, best_score_%f" %
                          (cur_epochs, best_score))

                if self.vis is not None:  # visualize validation score and samples
                    self.vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                        val_score['Overall Acc'])
                    self.vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                        val_score['Mean IoU'])
                    self.vis.vis_table("[Val] Class IoU",
                                       val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = img.astype(np.uint8)
                        target_rgb = np.zeros_like(img)
                        lbl_rgb = np.zeros_like(img)
                        for i in range(3):
                            target_rgb[i][target > 0] = 255
                            lbl_rgb[i][lbl > 0] = 255

                        target_rgb = target_rgb.astype(np.uint8)
                        lbl_rgb = lbl_rgb.astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target_rgb, lbl_rgb),
                            axis=2)  # concat along width
                        self.vis.vis_image('Sample %d' % k, concat_img)
                else:
                    print(
                        "iter_%d : overall acc: %.4f, miou: %.4f, class_iou: [0: %.4f, 1: %.4f]"
                        % (cur_itrs, val_score['Overall Acc'],
                           val_score['Mean IoU'], val_score['Class IoU'][0],
                           val_score['Class IoU'][1]),
                        flush=True)
                self.model.train()
            self.scheduler.step()  # 每个轮次,学习率都有可能不一样

        loss_record = np.array(loss_record)
        scio.savemat("./lossRecord-{}.mat".format(loss_name),
                     mdict={'data': loss_record})

        print(loss_record.shape, "save loss record", loss_name)

    def validate(self, ckpt=None, loss_name='focal'):
        """Do validation and return specified samples"""
        opts = self.cfg
        model = self.model
        loader = self.val_loader
        device = self.device
        metrics = self.metrics
        ret_samples_ids = self.vis_sample_id
        if ckpt == None:
            print("if want val exist ckpt, please assign a checkpoint path!")
        else:
            self.load_model(ckpt)
        metrics.reset()
        model.eval()
        ret_samples = []
        if opts.save_val_results:
            if not os.path.exists('results/{}' % loss_name):
                os.mkdir('results/{}' % loss_name)
            img_id = 0

        with torch.no_grad():
            for i, (images, labels, _) in tqdm(enumerate(loader)):

                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.long)

                outputs = model(images)
                preds = outputs.detach().max(dim=1)[1].cpu().numpy()
                targets = labels.detach().cpu().numpy()

                metrics.update(targets, preds)
                if ret_samples_ids is not None and i in ret_samples_ids:  # get vis samples
                    ret_samples.append((images[0].detach().cpu().numpy(),
                                        targets[0], preds[0]))

                if opts.save_val_results:
                    for img_id in range(len(images)):
                        image = images[img_id].detach().cpu().numpy()
                        target = targets[img_id]
                        pred = preds[img_id]

                        image = (image).transpose(1, 2, 0).astype(np.uint8)
                        target = target.astype(np.uint8)
                        target[target > 0] = 255
                        pred = pred.astype(np.uint8)
                        mask = np.zeros_like(image)
                        mask[pred > 0] = 255

                        cv2.imwrite(
                            'results/%s/%d_%d_image.png' %
                            (loss_name, i, img_id), image)
                        cv2.imwrite(
                            'results/%s/%d_%d_target.png' %
                            (loss_name, i, img_id), target)
                        cv2.imwrite(
                            'results/%s/%d_%d_pred.png' %
                            (loss_name, i, img_id), mask)

                        mask_img = cv2.addWeighted(image, 1, mask, 0.5, 0)
                        cv2.imwrite(
                            'results/%s/%d_%d_overlay.png' %
                            (loss_name, i, img_id), mask_img)

            score = metrics.get_results()
        print(metrics.to_str(score))
        return score, ret_samples
for m in model.backbone.modules():
    if isinstance(m, nn.BatchNorm2d):
        m.momentum = 0.01

model = nn.DataParallel(model)
model.to(device)

# optimizer
optimizer = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
scheduler = utils.PolyLR(optimizer, num_epochs, power=0.9)

# define loss
criterion = torch.nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

# train-eval loop
metrics = StreamSegMetrics(19)
train_loss_list = []
train_iou_list = []
val_loss_list = []
val_iou_list = []
best_metric = -1
best_metric_epoch = -1

print("Starting training: ")
print("Train images: ",len(train_dataset.images))
print("Val images: ",len(val_dataset.images))

start_epoch = 0

if opts.resume is not None:
  state = torch.load(opts.resume)
示例#19
0
def main():
    opts = parser.parse_args()

    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
    else:
        scheduler = None
        print("please assign a scheduler!")



    utils.mkdir('checkpoints')
    mytrainer = trainer(model, optimizer, scheduler, device, cfg=opts)
    # ==========   Train Loop   ==========#
    #loss_list = ['bound_dice', 'v3_bound_dice']
    #loss_list = ['v5_bound_dice', 'v4_bound_dice']
    loss_list = ['focal']
    if opts.test_only:
        loss_i = 'v3'
        ckpt = os.path.join("checkpoints", loss_i, "latest_deeplabv3plus_mobilenet_coco_epoch01.pth")
        mytrainer.validate(ckpt, loss_i)
    else:
        for loss_i in loss_list:
            mytrainer.train(loss_i)