def real_channel():
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map['deeplabv3plus_mobilenet'](num_classes=cfg.n_classes,
                                                 output_stride=cfg.stride)

    # if opts.separable_conv and 'plus' in opts.model:
    #     network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=cfg.momentum)

    # Set up metrics
    # metrics = StreamSegMetrics(opts.num_classes)
    l_r = cfg.l_r
    weight_decay = cfg.weight_decay
    lr_policy = cfg.lr_policy
    total_itrs = cfg.total_iter
    step_size = cfg.step_size
    loss_type = cfg.real_loss
    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {
            'params': model.backbone.parameters(),
            'lr': 0.1 * l_r
        },
        {
            'params': model.classifier.parameters(),
            'lr': l_r
        },
    ],
                                lr=l_r,
                                momentum=0.9,
                                weight_decay=weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, total_itrs, power=0.9)
    elif lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=step_size,
                                                    gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255)

    return criterion, optimizer, scheduler, model
    def __init__(self, opts):

        self.denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])

        model_map = {
            'deeplabv3_resnet50': network.deeplabv3_resnet50,
            'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
            'deeplabv3_resnet101': network.deeplabv3_resnet101,
            'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
            'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
            'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
        }

        self.opts = opts
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')

        self.model = model_map[opts['model']](
            num_classes=opts['n_classes'], output_stride=opts['output_stride'])

        if opts['separable_conv'] == 'True' and 'plus' in opts['model']:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(self.model.backbone, momentum=0.01)

        checkpoint = torch.load(opts['checkpoint'],
                                map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint['model_state'])
        self.model = nn.DataParallel(self.model)
        self.model.to(self.device)
        self.model.eval()

        if not os.path.exists(opts['output']):
            os.makedirs(opts['output'])
        if not os.path.exists(opts['score']):
            os.makedirs(opts['score'])

        # create a color pallette, selecting a color for each class
        self.palette = torch.tensor([2**25 - 1, 2**15 - 1, 2**21 - 1])
        self.colors = torch.as_tensor([i for i in range(opts['n_classes'])
                                       ])[:, None] * self.palette
        self.colors = (self.colors % 255).numpy().astype("uint8")
Пример #3
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=True,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet,
        'doubleattention_resnet50': network.doubleattention_resnet50,
        'doubleattention_resnet101': network.doubleattention_resnet101,
        'head_resnet50': network.head_resnet50,
        'head_resnet101': network.head_resnet101
    }

    model = model_map[opts.model](num_classes=opts.num_classes,
                                  output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {
            'params': model.backbone.parameters(),
            'lr': 0.1 * opts.lr
        },
        {
            'params': model.classifier.parameters(),
            'lr': opts.lr
        },
    ],
                                lr=opts.lr,
                                momentum=0.9,
                                weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)

    # Set up criterion
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
        coss_manifode = utils.ManifondLoss(alpha=1).to(device)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    # ==========   Train Loop   ==========#
    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs,
                             labels) + coss_manifode(outputs, labels) * 0.01
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        target = train_dst.decode_target(target).transpose(
                            2, 0, 1).astype(np.uint8)
                        lbl = train_dst.decode_target(lbl).transpose(
                            2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                return
Пример #4
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
        decode_fn = VOCSegmentation.decode_target
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
        decode_fn = Cityscapes.decode_target

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup dataloader
    image_files = []
    if os.path.isdir(opts.input):
        for ext in ['png', 'jpeg', 'jpg', 'JPEG']:
            files = glob(os.path.join(opts.input, '**/*.%s'%(ext)), recursive=True)
            if len(files)>0:
                image_files.extend(files)
    elif os.path.isfile(opts.input):
        image_files.append(opts.input)
    
    # Set up model (all models are 'constructed at network.modeling)
    model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        print("Resume model from %s" % opts.ckpt)
        del checkpoint
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images

    if opts.crop_val:
        transform = T.Compose([
                T.Resize(opts.crop_size),
                T.CenterCrop(opts.crop_size),
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
    else:
        transform = T.Compose([
                T.ToTensor(),
                T.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
            ])
    if opts.save_val_results_to is not None:
        os.makedirs(opts.save_val_results_to, exist_ok=True)
    with torch.no_grad():
        model = model.eval()
        for img_path in tqdm(image_files):
            ext = os.path.basename(img_path).split('.')[-1]
            img_name = os.path.basename(img_path)[:-len(ext)-1]
            img = Image.open(img_path).convert('RGB')
            img = transform(img).unsqueeze(0) # To tensor of NCHW
            img = img.to(device)
            
            pred = model(img).max(1)[1].cpu().numpy()[0] # HW
            colorized_preds = decode_fn(pred).astype('uint8')
            colorized_preds = Image.fromarray(colorized_preds)
            if opts.save_val_results_to:
                colorized_preds.save(os.path.join(opts.save_val_results_to, img_name+'.png'))
Пример #5
0
def main():
    opts = get_argparser().parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)
    img_h = 512
    img_w = 512

    torch.cuda.empty_cache()
    train_data, test_data = get_data_dcm(img_h=img_h, img_w=img_w, iscrop=True)
    train_loader = data.DataLoader(train_data,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=0)
    val_loader = data.DataLoader(test_data,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=0)

    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    if opts.model != 'unet':
        opts.num_classes = 3
        model = model_map[opts.model](num_classes=opts.num_classes,
                                      output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(model.backbone, momentum=0.01)

        # Set up optimizer
        optimizer = torch.optim.SGD(params=[
            {
                'params': model.backbone.parameters(),
                'lr': 0.1 * opts.lr
            },
            {
                'params': model.classifier.parameters(),
                'lr': opts.lr
            },
        ],
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)
    else:
        opts.num_classes = 3
        model = UNet(n_channels=3, n_classes=3, bilinear=True)

        # Set up optimizer
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)

    scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)

    criterion_bce = nn.BCELoss(reduction='mean')
    criterion_dice = MulticlassDiceLoss()

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0

    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    if opts.test_only:
        # model.load_state_dict()
        model.eval()
        dice_bl, dice_sb, dice_st, acc = validate2(
            model=model,
            loader=val_loader,
            device=device,
            itrs=cur_itrs,
            lr=scheduler.get_lr()[-1],
            criterion_dice=criterion_dice)
        # save_ckpt("./checkpoints/CT_" + opts.model + "_" + str(round(dice, 3)) + "__" + str(cur_itrs) + ".pkl")
        print("dice值:", dice_bl)
        return
    best_dice_bl = 0
    best_dice_sb = 0
    best_dice_st = 0
    best_dice_avg = 0
    interval_loss = 0
    train_iter = iter(train_loader)

    txt_path = './train_info.txt'
    # txtUtils.clearTxt(txt_path)

    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        try:
            images, labels = train_iter.__next__()
        except:
            train_iter = iter(train_loader)
            images, labels = train_iter.__next__()

        cur_itrs += 1
        # print(images.size())
        # print(labels.size())
        images = images.to(device, dtype=torch.float32)
        # labels = labels.to(device, dtype=torch.long)
        labels = labels.to(device, dtype=torch.float32)
        # print(images.size())

        outputs = model(images)
        outputs_ = torch.sigmoid(outputs)

        loss = criterion_bce(outputs_, labels) + criterion_dice(
            outputs_, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        np_loss = loss.item()
        interval_loss += np_loss

        if (cur_itrs) % 50 == 0:
            interval_loss = interval_loss / 50
            cur_epochs = int(cur_itrs / train_loader.dataset.__len__())
            print("Epoch %d, Itrs %d/%d, Loss=%f" %
                  (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))

            content = ("Epoch {}, Itrs {}/{}, Loss={}").format(
                cur_epochs, cur_itrs, opts.total_itrs, interval_loss)
            txtUtils.writeInfoToTxt(file_path=txt_path,
                                    content=content,
                                    is_add_time=True)
            interval_loss = 0.0
        # opts.val_interval=5
        if (cur_itrs) % 500 == 0:
            print("validation... lr:", scheduler.get_lr())
            content = ("validation... lr:{}").format(scheduler.get_lr())
            txtUtils.writeInfoToTxt(file_path=txt_path,
                                    content=content,
                                    is_add_time=True)
            # print(outputs)
            dice_bl, dice_sb, dice_st, acc = validate2(
                model=model,
                loader=val_loader,
                device=device,
                itrs=cur_itrs,
                lr=scheduler.get_lr()[-1],
                criterion_dice=criterion_dice)
            dice_avg = (dice_bl + dice_sb + dice_st) / 3

            content = (
                "dice_bl:{}, dice_sb:{}, dice_st:{}, acc:{}, dice_avg:{}"
            ).format(dice_bl, dice_sb, dice_st, acc, dice_avg)
            txtUtils.writeInfoToTxt(file_path=txt_path,
                                    content=content,
                                    is_add_time=True)

            if best_dice_avg < dice_avg:
                best_dice_avg = dice_avg
                save_ckpt("./checkpoints/" + opts.model + "_dice_avg_" +
                          str(round(best_dice_avg, 3)) + "_dice_bl_" +
                          str(round(dice_bl, 3)) + "_dice_sb_" +
                          str(round(dice_sb, 3)) + "_dice_st_" +
                          str(round(dice_st, 3)) + "__" + str(cur_itrs) +
                          ".pkl")
                print("best avg dice:", best_dice_avg)
                content = ("best avg dice: {}").format(best_dice_avg)
                txtUtils.writeInfoToTxt(file_path=txt_path,
                                        content=content,
                                        is_add_time=True)

            if best_dice_bl < dice_bl:
                best_dice_bl = dice_bl
                content = ("best bl dice: {}").format(best_dice_bl)
                txtUtils.writeInfoToTxt(file_path=txt_path,
                                        content=content,
                                        is_add_time=True)
            if best_dice_sb < dice_sb:
                best_dice_sb = dice_sb
                content = ("best sb dice: {}").format(best_dice_sb)
                txtUtils.writeInfoToTxt(file_path=txt_path,
                                        content=content,
                                        is_add_time=True)
            if best_dice_st < dice_st:
                best_dice_st = dice_st
                content = ("best st dice: {}").format(best_dice_st)
                txtUtils.writeInfoToTxt(file_path=txt_path,
                                        content=content,
                                        is_add_time=True)

        scheduler.step()

        if cur_itrs >= opts.total_itrs:
            return
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    # Set up metrics
    # metrics = StreamSegMetrics(opts.num_classes)
    metrics = StreamSegMetrics(21)
    # Set up optimizer
    # criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
    elif opts.loss_type == 'logit':
        criterion = nn.BCELoss(reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None:
        print("Error --ckpt, can't read model")
        return

    _, val_dst, test_dst = get_dataset(opts)
    val_loader = data.DataLoader(
        val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    test_loader = data.DataLoader(
        test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
    vis_sample_id = np.random.randint(0, len(test_loader), opts.vis_num_samples,
                                      np.int32) if opts.enable_vis else None  # sample idxs for visualization

    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori images
    # ==========   Test Loop   ==========#

    if opts.test_only:
        print("Dataset: %s,  Val set: %d, Test set: %d" %
              (opts.dataset, len(val_dst), len(test_dst)))

        metrics = StreamSegMetrics(21)
        print("val")

        test_score, ret_samples = test_single(opts=opts,
                                              loader=test_loader, device=device, metrics=metrics,
                                              ret_samples_ids=vis_sample_id)
        print("test")
        test_score, ret_samples = test_multiple(
            opts=opts, loader=test_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
        print(metrics.to_str(test_score))
        return

    # ==========   Train Loop   ==========#
    utils.mkdir('checkpoints/multiple_model2')
    for class_num in range(opts.start_class, opts.num_classes):
        # ==========   Dataset   ==========#
        train_dst, val_dst, test_dst = get_dataset_multiple(opts, class_num)
        train_loader = data.DataLoader(train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2)
        val_loader = data.DataLoader(val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        test_loader = data.DataLoader(test_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
        print("Dataset: %s Class %d, Train set: %d, Val set: %d, Test set: %d" % (
            opts.dataset, class_num, len(train_dst), len(val_dst), len(test_dst)))

        # ==========   Model   ==========#
        model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(model.backbone, momentum=0.01)

        # ==========   Params and learning rate   ==========#
        params_list = [
            {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
            {'params': model.classifier.parameters(), 'lr': 0.1 * opts.lr}  # opts.lr
        ]
        if 'SA' in opts.model:
            params_list.append({'params': model.attention.parameters(), 'lr': 0.1 * opts.lr})
        optimizer = torch.optim.Adam(params=params_list, lr=opts.lr, weight_decay=opts.weight_decay)

        if opts.lr_policy == 'poly':
            scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
        elif opts.lr_policy == 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

        model = nn.DataParallel(model)
        model.to(device)

        best_score = 0.0
        cur_itrs = 0
        cur_epochs = 0

        interval_loss = 0
        while True:  # cur_itrs < opts.total_itrs:
            # =====  Train  =====
            model.train()

            cur_epochs += 1
            for (images, labels) in train_loader:
                cur_itrs += 1

                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.long)
                # labels=(labels==class_num).float()
                optimizer.zero_grad()
                outputs = model(images)

                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss
                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                if (cur_itrs) % 10 == 0:
                    interval_loss = interval_loss / 10
                    print("Epoch %d, Itrs %d/%d, Loss=%f" %
                          (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                    interval_loss = 0.0

                if (cur_itrs) % opts.val_interval == 0:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("validation...")
                    model.eval()
                    val_score, ret_samples = validate(
                        opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
                        ret_samples_ids=vis_sample_id, class_num=class_num)
                    print(metrics.to_str(val_score))

                    if val_score['Mean IoU'] > best_score:  # save best model
                        best_score = val_score['Mean IoU']
                        save_ckpt('checkpoints/multiple_model2/best_%s_%s_class%d_os%d.pth' %
                                  (opts.model, opts.dataset, class_num, opts.output_stride))

                    if vis is not None:  # visualize validation score and samples
                        vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
                        vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
                        vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                        for k, (img, target, lbl) in enumerate(ret_samples):
                            img = (denorm(img) * 255).astype(np.uint8)
                            target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
                            concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along width
                            vis.vis_image('Sample %d' % k, concat_img)
                    model.train()

                scheduler.step()

                if cur_itrs >= opts.total_itrs:
                    save_ckpt('checkpoints/multiple_model2/latest_%s_%s_class%d_os%d.pth' %
                              (opts.model, opts.dataset, class_num, opts.output_stride,))
                    print("Saving..")
                    break
            if cur_itrs >= opts.total_itrs:
                cur_itrs = 0
                break

        print("Model of class %d is trained and saved " % (class_num))
Пример #7
0
def main(opts):
    # Set up model
    model_map = {
        'v3_resnet50': network.deeplabv3_resnet50,
        'v3plus_resnet50': network.deeplabv3plus_resnet50,
        'v3_resnet101': network.deeplabv3_resnet101,
        'v3plus_resnet101': network.deeplabv3plus_resnet101,
        'v3_mobilenet': network.deeplabv3_mobilenet,
        'v3plus_mobilenet': network.deeplabv3plus_mobilenet
    }
    
    best_score = 0.0
    epoch      = 0
    
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        checkpoint['teacher_opts']['save_val_results'] = opts.save_val_results
        checkpoint['teacher_opts']['ckpt'] = opts.ckpt
        opts = utils.Bunch(checkpoint['teacher_opts'])
    
    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride, opts=opts)
    teacher = None
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    macs, params = utils.count_flops(model, opts)
    if (opts.count_flops):
        return
    utils.create_result(opts, macs, params)
    
    # Set up optimizer and criterion
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    scheduler = utils.PolyLR(optimizer, opts.total_epochs * len(train_loader), power=0.9)
    criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
    
    
    
    # Load from checkpoint
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        
        
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        scheduler.load_state_dict(checkpoint["scheduler_state"])
        epoch = checkpoint.get("epoch", 0)
        best_score = checkpoint.get('best_score', 0.0)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory 
    else:
        model = nn.DataParallel(model)
        model.to(device)
        
    if opts.save_val_results:
        score = validate(model)
        print(metrics.to_str(score)) 
        return
    
    if opts.mode == "student":
        checkpoint = torch.load(opts.teacher_ckpt, map_location=torch.device('cpu'))
        checkpoint['teacher_opts']['at_type'] = opts.at_type
        
        teacher_opts = utils.Bunch(checkpoint['teacher_opts'])
        
        teacher = model_map[teacher_opts.model](num_classes=opts.num_classes, output_stride=teacher_opts.output_stride, opts=teacher_opts)
        teacher.load_state_dict(checkpoint["model_state"])
        teacher = nn.DataParallel(teacher)
        teacher.to(device)
        for param in teacher.parameters():
            param.requires_grad = False
    
    # =====  Train  =====
    
    for epoch in tqdm(range(epoch, opts.total_epochs)):
        
        if opts.mode == "teacher":
            train_teacher(model, optimizer, criterion, scheduler)
        else:
            train_student(model, teacher, optimizer, criterion, scheduler)
        
        score = validate(model)
        print(metrics.to_str(score))
        utils.save_result(score, opts)
        
        if score['Mean IoU'] > best_score or (opts.max_epochs != opts.total_epochs and epoch+1 == opts.total_epochs):
            best_score = score['Mean IoU']
            utils.save_ckpt(opts.data_root, opts, model, optimizer, scheduler, best_score, epoch+1) 
Пример #8
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19

    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset=='voc' and not opts.crop_val:
        opts.val_batch_size = 1
    

    pipe = create_dali_pipeline(batch_size=opts.batch_size, num_threads=8,
                                device_id=0, data_dir="/home/ubuntu/cityscapes")
    pipe.build()
    train_loader = DALIGenericIterator(pipe, output_map=['image', 'label'], last_batch_policy=LastBatchPolicy.PARTIAL)

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)
    
    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    #torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy=='poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy=='step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)

    # Set up criterion
    #criterion = utils.get_loss(opts.loss_type)
    if opts.loss_type == 'focal_loss':
        criterion = utils.FocalLoss(ignore_index=255, size_average=True)
    elif opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

    def save_ckpt(path):
        """ save current model
        """
        torch.save({
            "cur_itrs": cur_itrs,
            "model_state": model.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_score": best_score,
        }, path)
        print("Model saved as %s" % path)
    
    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    #==========   Train Loop   ==========#
    interval_loss = 0

    class_conv = [255, 255, 255, 255, 255,
                  255, 255, 255, 0, 1,
                  255, 255, 2, 3, 4,
                  255, 255, 255, 5, 255,
                  6, 7, 8, 9, 10,
                  11, 12, 13, 14, 15,
                  255, 255, 16, 17, 18]

    while True: #cur_itrs < opts.total_itrs:
        # =====  Train  =====
        model.train()
        #model = model.half()
        cur_epochs += 1
        while True:
            train_iter = iter(train_loader)
            try:
                nvtx.range_push("Batch " + str(cur_itrs))

                nvtx.range_push("Data loading")
                data = next(train_iter)
                cur_itrs += 1

                images = data[0]['image'].to(dtype=torch.float32)
                labels = data[0]['label'][:, :, :, 0].to(dtype=torch.long)
                labels = torch.zeros(data[0]['label'][:, :, :, 0].shape).to(device, dtype=torch.long)
                nvtx.range_pop()

                nvtx.range_push("Forward pass")
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                nvtx.range_pop()

                nvtx.range_push("Backward pass")
                loss.backward()
                optimizer.step()
                nvtx.range_pop()

                np_loss = loss.detach().cpu().numpy()
                interval_loss += np_loss

                nvtx.range_pop()

                if cur_itrs == 10:
                    break

                if vis is not None:
                    vis.vis_scalar('Loss', cur_itrs, np_loss)

                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                interval_loss = 0.0

                scheduler.step()  

                if cur_itrs >=  opts.total_itrs:
                    return
            except StopIteration:
                break

        break
Пример #9
0
def main():
    opts = get_argparser().parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    os.makedirs(opts.logit_dir, exist_ok=True)

    # Setup dataloader
    if not opts.crop_val:
        opts.val_batch_size = 1

    val_dst = get_dataset(opts)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=4)

    print("Dataset: voc, Val set: %d" % (len(val_dst)))

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes,
                                  output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Restore
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        # https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        assert "no checkpoint"

    #==========   Eval   ==========#
    model.eval()
    val_score = validate(opts=opts,
                         model=model,
                         loader=val_loader,
                         device=device,
                         metrics=metrics)
    print(metrics.to_str(val_score))

    print("\n\n----------- crf -------------")
    crf_score = crf_inference(opts, val_dst, metrics)
    print(metrics.to_str(crf_score))

    os.system(f"rm -rf {opts.logit_dir}")
Пример #10
0
            #eva = GetMultiThresholdEva0(res.copy(), mask.copy(), 0.5)
            Evares[i][mk] = eva
            if (mk % 100 == 0):
                print(
                    checkpoint_path,
                    mk,
                    'mIoU={:.3f}, mdice={:.3f}, IoU0={:.3f}, dice0={:.3f}, IoU1={:.3f}, dice1={:.3f}'
                    .format(eva[0], eva[1], eva[2], eva[3], eva[4], eva[5]),
                    flush=True)
            mk += 1

    Evares = np.array(Evares)
    print("end!!!!")
    scio.savemat(os.path.join(foldName, 'DNW_' + lossname + '.mat'),
                 mdict={'data': Evares})


#loss_name = ['v4', 'focal', 'v1', 'dice', 'ce', 'L2', 'mix', 'Bound', 'v5_4loss', "v5_5loss"]
#loss_name = ['dice', 'Bound']
loss_name = ['L2', 'v3']
model = network.deeplabv3plus_mobilenet(num_classes=cfg.num_classes,
                                        output_stride=cfg.output_stride)
#network.convert_to_separable_conv(model.classifier)
utils.set_bn_momentum(model.backbone, momentum=0.01)
for lossi in loss_name:
    print("====================test " + lossi + " ======================")
    ckptlittledir = os.path.join(checkpointdir, lossi)
    getDNWmat(model, ckptlittledir, lossi)

print("all loss has been evaluated!")
def main():
    global args, best_prec1
    global cur_itrs
    args = parser.parse_args()
    print(args.mode)

    # STEP1: model
    if args.mode=='baseline_train':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=200)
    elif args.mode=='pretrain':
        model = deeplab_network.deeplabv3_resnet50(num_classes=args.num_classes, output_stride=args.output_stride, pretrained_backbone=False)
        set_bn_momentum(model.backbone, momentum=0.01)
    elif args.mode=='finetune':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=3)
        # load the pretrained model
        if args.pretrained_model:
            if os.path.isfile(args.pretrained_model):
                print("=> loading pretrained model '{}'".format(args.pretrained_model))
                checkpoint = torch.load(args.pretrained_model)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded pretrained model '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    if torch.cuda.is_available:
        model = model.cuda()
    
    # STEP2: criterion and optimizer
    if args.mode in ['baseline_train', 'finetune']:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        # train_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 
    elif args.mode=='pretrain':
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*args.lr},
        {'params': model.classifier.parameters(), 'lr': args.lr},
    ], lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = PolyLR(optimizer, args.total_itrs, power=0.9)

    # STEP3: loss/prec record
    if args.mode in ['baseline_train', 'finetune']:
        train_losses = []
        train_top1s = []
        train_top5s = []

        test_losses = []
        test_top1s = []
        test_top5s = []
    elif args.mode == 'pretrain':
        train_losses = []
        test_losses = []

    # STEP4: optionlly resume from a checkpoint
    if args.resume:
        print('resume')
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.mode in ['baseline_train', 'finetune']:
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                train_top1s = list(load_data['train_top1s'])
                train_top5s = list(load_data['train_top5s'])
                test_losses = list(load_data['test_losses'])
                test_top1s = list(load_data['test_top1s'])
                test_top5s = list(load_data['test_top5s'])
            elif args.mode=='pretrain':
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                cur_itrs = checkpoint['cur_itrs']
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                # test_losses = list(load_data['test_losses'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # STEP5: train!
    if args.mode in ['baseline_train', 'finetune']:
        # data
        from utils import TinyImageNet_data_loader
        print('color_distortion:', color_distortion)
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size,color_distortion=args.color_distortion)
        
        # if evaluate the model
        if args.evaluate:
            print('evaluate this model on validation dataset')
            validate(val_loader, model, criterion, args.print_freq)
            return
        
        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            # train for one epoch
            model.train()
            loss, top1, top5 = train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
            train_losses.append(loss)
            train_top1s.append(top1)
            train_top5s.append(top5)

            # evaluate on validation set
            model.eval()
            loss, prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)
            test_losses.append(loss)
            test_top1s.append(prec1)
            test_top5s.append(prec5)

            # remember the best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses,train_top1s=train_top1s,train_top5s=train_top5s, test_losses=test_losses,test_top1s=test_top1s, test_top5s=test_top5s)
           # np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.epochs - epoch - 1)/60,'minutes')
            print()
    elif args.mode=='pretrain':
        #data
        from utils import TinyImageNet_data_loader
        # args.dataset = 'tiny-imagenet-200'
        args.batch_size = 16
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size, col=True)
        
        # if evaluate the model, show some results
        if args.evaluate:
            print('evaluate this model on validation dataset')
            visulization(val_loader, model, args.start_epoch)
            return

        # for epoch in range(args.start_epoch, args.epochs):
        epoch = 0
        while True:
            if cur_itrs >=  args.total_itrs:
                return
            # adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            model.train()
            # train for one epoch
            # loss, _, _ = train(train_loader, model, criterion, optimizer, epoch, args.print_freq, colorization=True,scheduler=scheduler)
            # train_losses.append(loss)
            

            # model.eval()
            # # evaluate on validation set
            # loss, _, _ = validate(val_loader, model, criterion, args.print_freq, colorization=True)
            # test_losses.append(loss)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':scheduler.state_dict(),
                "cur_itrs": cur_itrs
            }, True, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            # scheduler.step()
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.total_itrs - cur_itrs - 1)/60,'minutes')
            print()
            epoch += 1
Пример #12
0
def main(criterion):
    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))

    # Set up model
    pretrained_backbone = False if "ACE2P" in opts.model else True
    model = network.model_map[opts.model](
        num_classes=opts.num_classes,
        output_stride=opts.output_stride,
        pretrained_backbone=pretrained_backbone,
        use_abn=opts.use_abn)
    if opts.use_schp:
        schp_model = network.model_map[opts.model](
            num_classes=opts.num_classes,
            output_stride=opts.output_stride,
            pretrained_backbone=pretrained_backbone,
            use_abn=opts.use_abn)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    model_params = [
        {
            'params': model.backbone.parameters(),
            'lr': 0.01 * opts.lr
        },
        {
            'params': model.classifier.parameters(),
            'lr': opts.lr
        },
    ]
    optimizer = create_optimizer(opts, model_params=model_params)
    # optimizer = torch.optim.SGD(params=[
    #     {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
    #     {'params': model.classifier.parameters(), 'lr': opts.lr},
    # ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_epochs": cur_epochs,
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    cycle_n = 0

    if opts.use_schp and opts.schp_ckpt is not None and os.path.isfile(
            opts.schp_ckpt):
        # TODO: there is a problem with this part.
        checkpoint = torch.load(opts.schp_ckpt,
                                map_location=torch.device('cpu'))
        schp_model.load_state_dict(checkpoint["model_state"])
        print("SCHP Model restored from %s" % opts.schp_ckpt)

    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.use_schp:
            schp_model = nn.DataParallel(schp_model)
            schp_model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_epochs = checkpoint[
                "cur_epochs"] - 1  # to start from the last epoch for schp
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)
        if opts.use_schp:
            schp_model = nn.DataParallel(schp_model)
            schp_model.to(device)

    # ==========   Train Loop   ==========#
    if opts.test_only:
        model.eval()
        val_score = validate(opts=opts,
                             model=model,
                             loader=val_loader,
                             device=device,
                             metrics=metrics)
        print(metrics.to_str(val_score))
        return
    interval_loss = 0
    while True:  # cur_itrs < opts.total_itrs:
        # =====  Train  =====
        criterion.start_log()
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1

            # images = images.to(device, dtype=torch.float32)
            # labels = labels.to(device, dtype=torch.long)
            images, labels = get_input(images, labels, opts, device, cur_itrs)
            if opts.use_mixup:
                images, main_images = images
            else:
                main_images = None
            images = images[:, [2, 1, 0]]  # for backbone
            optimizer.zero_grad()
            outputs = model(images)

            if opts.use_schp:
                # Online Self Correction Cycle with Label Refinement
                soft_labels = []
                if cycle_n >= 1:
                    with torch.no_grad():
                        if opts.use_mixup:
                            soft_preds = [
                                schp_model(main_images[0]),
                                schp_model(main_images[1])
                            ]
                            soft_edges = [None, None]
                        else:
                            soft_preds = schp_model(images)
                            soft_edges = None
                        if 'ACE2P' in opts.model:
                            soft_edges = soft_preds[1][-1]
                            soft_preds = soft_preds[0][-1]
                            # soft_parsing = []
                            # soft_edge = []
                            # for soft_pred in soft_preds:
                            #     soft_parsing.append(soft_pred[0][-1])
                            #     soft_edge.append(soft_pred[1][-1])
                            # soft_preds = torch.cat(soft_parsing, dim=0)
                            # soft_edges = torch.cat(soft_edge, dim=0)
                else:
                    if opts.use_mixup:
                        soft_preds = [None, None]
                        soft_edges = [None, None]
                    else:
                        soft_preds = None
                        soft_edges = None
                soft_labels.append(soft_preds)
                soft_labels.append(soft_edges)
                labels = [labels, soft_labels]

            # loss = criterion(outputs, labels)
            loss = calc_loss(criterion, outputs, labels, opts, cycle_n)
            loss.backward()
            optimizer.step()

            criterion.batch_step(len(images))
            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss
            sub_loss_text = ''
            for sub_loss, sub_prop in zip(criterion.losses, criterion.loss):
                if sub_prop['weight'] > 0:
                    sub_loss_text += f", {sub_prop['type']}: {sub_loss.item():.4f}"
            print(
                f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={np_loss:.4f}{sub_loss_text}",
                end='')

            if (cur_itrs) % 10 == 0:
                interval_loss = interval_loss / 10
                print(
                    f"\rEpoch {cur_epochs}, Itrs {cur_itrs}/{opts.total_itrs}, Loss={interval_loss:.4f} {criterion.display_loss().replace('][',', ')}"
                )
                interval_loss = 0.0
                torch.cuda.empty_cache()

            if (cur_itrs) % opts.save_interval == 0 and (
                    cur_itrs) % opts.val_interval != 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))

            if (cur_itrs) % opts.val_interval == 0:
                save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
                          (opts.model, opts.dataset, opts.output_stride))
                print("validation...")
                model.eval()
                val_score = validate(opts=opts,
                                     model=model,
                                     loader=val_loader,
                                     device=device,
                                     metrics=metrics)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))
                    # save_ckpt('/content/drive/MyDrive/best_%s_%s_os%d.pth' %
                    #           (opts.model, opts.dataset, opts.output_stride))
                model.train()
            scheduler.step()

            if cur_itrs >= opts.total_itrs:
                criterion.end_log(len(train_loader))
                return

        # Self Correction Cycle with Model Aggregation
        if opts.use_schp:
            if (cur_epochs + 1) >= opts.schp_start and (
                    cur_epochs + 1 - opts.schp_start) % opts.cycle_epochs == 0:
                print(f'\nSelf-correction cycle number {cycle_n}')

                schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
                cycle_n += 1
                schp.bn_re_estimate(train_loader, schp_model)
                schp.save_schp_checkpoint(
                    {
                        'state_dict': schp_model.state_dict(),
                        'cycle_n': cycle_n,
                    },
                    False,
                    "checkpoints",
                    filename=
                    f'schp_{opts.model}_{opts.dataset}_cycle{cycle_n}_checkpoint.pth'
                )
                # schp.save_schp_checkpoint({
                #     'state_dict': schp_model.state_dict(),
                #     'cycle_n': cycle_n,
                # }, False, '/content/drive/MyDrive/', filename=f'schp_{opts.model}_{opts.dataset}_checkpoint.pth')
        torch.cuda.empty_cache()
        criterion.end_log(len(train_loader))
Пример #13
0
def main():
    opts = get_argparser().parse_args()
    if opts.dataset.lower() == 'voc':
        opts.num_classes = 21
        ignore_index = 255
    elif opts.dataset.lower() == 'cityscapes':
        opts.num_classes = 19
        ignore_index = 255
    elif opts.dataset.lower() == 'ade20k':
        opts.num_classes = 150
        ignore_index = -1
    elif opts.dataset.lower() == 'lvis':
        opts.num_classes = 1284
        ignore_index = -1
    elif opts.dataset.lower() == 'coco':
        opts.num_classes = 182
        ignore_index = 255
    if (opts.reduce_dim == False):
        opts.num_channels = opts.num_classes
    if (opts.test_only == False):
        writer = SummaryWriter('summary/' + opts.vis_env)
    # Setup visualization
    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Setup dataloader
    if opts.dataset == 'voc' and not opts.crop_val:
        opts.val_batch_size = 1

    train_dst, val_dst = get_dataset(opts)
    train_loader = data.DataLoader(train_dst,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=2)
    val_loader = data.DataLoader(val_dst,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=2)
    print("Dataset: %s, Train set: %d, Val set: %d" %
          (opts.dataset, len(train_dst), len(val_dst)))
    epoch_interval = int(len(train_dst) / opts.batch_size)
    if (epoch_interval > 5000):
        opts.val_interval = 5000
    else:
        opts.val_interval = epoch_interval
    print("Evaluation after %d iterations" % (opts.val_interval))

    # Set up model
    model_map = {
        #'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        #'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        #'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }
    if (opts.reduce_dim):
        num_classes_input = [opts.num_channels, opts.num_classes]
    else:
        num_classes_input = [opts.num_classes]
    model = model_map[opts.model](num_classes=num_classes_input,
                                  output_stride=opts.output_stride,
                                  reduce_dim=opts.reduce_dim)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)
    if opts.reduce_dim:
        emb_layer = ['embedding.weight']
        params_classifier = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] not in emb_layer,
                           model.classifier.named_parameters()))))
        params_embedding = list(
            map(
                lambda x: x[1],
                list(
                    filter(lambda kv: kv[0] in emb_layer,
                           model.classifier.named_parameters()))))
        if opts.freeze_backbone:
            for param in model.backbone.parameters():
                param.requires_grad = False
            optimizer = torch.optim.SGD(
                params=[
                    #@{'params': model.backbone.parameters(),'lr':0.1*opts.lr},
                    {
                        'params': params_classifier,
                        'lr': opts.lr
                    },
                    {
                        'params': params_embedding,
                        'lr': opts.lr,
                        'momentum': 0.95
                    },
                ],
                lr=opts.lr,
                momentum=0.9,
                weight_decay=opts.weight_decay)
        else:
            optimizer = torch.optim.SGD(params=[
                {
                    'params': model.backbone.parameters(),
                    'lr': 0.1 * opts.lr
                },
                {
                    'params': params_classifier,
                    'lr': opts.lr
                },
                {
                    'params': params_embedding,
                    'lr': opts.lr
                },
            ],
                                        lr=opts.lr,
                                        momentum=0.9,
                                        weight_decay=opts.weight_decay)
    # Set up optimizer
    else:
        optimizer = torch.optim.SGD(params=[
            {
                'params': model.backbone.parameters(),
                'lr': 0.1 * opts.lr
            },
            {
                'params': model.classifier.parameters(),
                'lr': opts.lr
            },
        ],
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)

    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=opts.step_size,
                                                    gamma=0.1)
    elif opts.lr_policy == 'multi_poly':
        scheduler = utils.MultiPolyLR(optimizer,
                                      opts.total_itrs,
                                      power=[0.9, 0.9, 0.95])

    # Set up criterion
    if (opts.reduce_dim):
        opts.loss_type = 'nn_cross_entropy'
    else:
        opts.loss_type = 'cross_entropy'

    if opts.loss_type == 'cross_entropy':
        criterion = nn.CrossEntropyLoss(ignore_index=ignore_index,
                                        reduction='mean')
    elif opts.loss_type == 'nn_cross_entropy':
        criterion = utils.NNCrossEntropy(ignore_index=ignore_index,
                                         reduction='mean',
                                         num_neighbours=opts.num_neighbours,
                                         temp=opts.temp,
                                         dataset=opts.dataset)

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir(opts.checkpoint_dir)
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0
    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        increase_iters = True
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("scheduler state dict :", scheduler.state_dict())
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    vis_sample_id = np.random.randint(
        0, len(val_loader), opts.vis_num_samples,
        np.int32) if opts.enable_vis else None  # sample idxs for visualization
    denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224,
                                    0.225])  # denormalization for ori images

    if opts.test_only:
        model.eval()
        val_score, ret_samples = validate(opts=opts,
                                          model=model,
                                          loader=val_loader,
                                          device=device,
                                          metrics=metrics,
                                          ret_samples_ids=vis_sample_id)
        print(metrics.to_str(val_score))
        return

    interval_loss = 0

    writer.add_text('lr', str(opts.lr))
    writer.add_text('batch_size', str(opts.batch_size))
    writer.add_text('reduce_dim', str(opts.reduce_dim))
    writer.add_text('checkpoint_dir', opts.checkpoint_dir)
    writer.add_text('dataset', opts.dataset)
    writer.add_text('num_channels', str(opts.num_channels))
    writer.add_text('num_neighbours', str(opts.num_neighbours))
    writer.add_text('loss_type', opts.loss_type)
    writer.add_text('lr_policy', opts.lr_policy)
    writer.add_text('temp', str(opts.temp))
    writer.add_text('crop_size', str(opts.crop_size))
    writer.add_text('model', opts.model)
    accumulation_steps = 1
    writer.add_text('accumulation_steps', str(accumulation_steps))
    j = 0
    updateflag = False
    while True:
        # =====  Train  =====
        model.train()
        cur_epochs += 1
        for (images, labels) in train_loader:
            cur_itrs += 1
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)
            if (opts.dataset == 'ade20k' or opts.dataset == 'lvis'):
                labels = labels - 1

            optimizer.zero_grad()
            if (opts.reduce_dim):
                outputs, class_emb = model(images)
                loss = criterion(outputs, labels, class_emb)
            else:
                outputs = model(images)
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            model.zero_grad()
            j = j + 1
            np_loss = loss.detach().cpu().numpy()
            interval_loss += np_loss

            if vis is not None:
                vis.vis_scalar('Loss', cur_itrs, np_loss)
                vis.vis_scalar('LR', cur_itrs,
                               scheduler.state_dict()['_last_lr'][0])
            torch.cuda.empty_cache()
            del images, labels, outputs, loss
            if (opts.reduce_dim):
                del class_emb
            gc.collect()
            if (cur_itrs) % 50 == 0:
                interval_loss = interval_loss / 50
                print("Epoch %d, Itrs %d/%d, Loss=%f" %
                      (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
                writer.add_scalar('Loss', interval_loss, cur_itrs)
                writer.add_scalar('lr',
                                  scheduler.state_dict()['_last_lr'][0],
                                  cur_itrs)
            if cur_itrs % opts.val_interval == 0:
                save_ckpt(opts.checkpoint_dir + '/latest_%d.pth' % (cur_itrs))
            if cur_itrs % opts.val_interval == 0:
                print("validation...")
                model.eval()
                val_score, ret_samples = validate(
                    opts=opts,
                    model=model,
                    loader=val_loader,
                    device=device,
                    metrics=metrics,
                    ret_samples_ids=vis_sample_id)
                print(metrics.to_str(val_score))
                if val_score['Mean IoU'] > best_score:  # save best model
                    best_score = val_score['Mean IoU']
                    save_ckpt(opts.checkpoint_dir + '/best_%s_%s_os%d.pth' %
                              (opts.model, opts.dataset, opts.output_stride))

                writer.add_scalar('[Val] Overall Acc',
                                  val_score['Overall Acc'], cur_itrs)
                writer.add_scalar('[Val] Mean IoU', val_score['Mean IoU'],
                                  cur_itrs)
                writer.add_scalar('[Val] Mean Acc', val_score['Mean Acc'],
                                  cur_itrs)
                writer.add_scalar('[Val] Freq Acc', val_score['FreqW Acc'],
                                  cur_itrs)

                if vis is not None:  # visualize validation score and samples
                    vis.vis_scalar("[Val] Overall Acc", cur_itrs,
                                   val_score['Overall Acc'])
                    vis.vis_scalar("[Val] Mean IoU", cur_itrs,
                                   val_score['Mean IoU'])
                    vis.vis_table("[Val] Class IoU", val_score['Class IoU'])

                    for k, (img, target, lbl) in enumerate(ret_samples):
                        img = (denorm(img) * 255).astype(np.uint8)
                        if (opts.dataset.lower() == 'coco'):
                            target = numpy.asarray(
                                train_dst._colorize_mask(target).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                            lbl = numpy.asarray(
                                train_dst._colorize_mask(lbl).convert(
                                    'RGB')).transpose(2, 0, 1).astype(np.uint8)
                        else:
                            target = train_dst.decode_target(target).transpose(
                                2, 0, 1).astype(np.uint8)
                            lbl = train_dst.decode_target(lbl).transpose(
                                2, 0, 1).astype(np.uint8)
                        concat_img = np.concatenate(
                            (img, target, lbl), axis=2)  # concat along width
                        vis.vis_image('Sample %d' % k, concat_img)
                model.train()
            scheduler.step()
            if cur_itrs >= opts.total_itrs:
                return
    writer.close()
Пример #14
0
def main():
    opts = get_argparser().parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)
    img_h = 512
    img_w = 512

    torch.cuda.empty_cache()
    train_data, test_data = get_data_dcm(img_h=img_h, img_w=img_w, iscrop=True)
    train_loader = data.DataLoader(train_data,
                                   batch_size=opts.batch_size,
                                   shuffle=True,
                                   num_workers=0)
    val_loader = data.DataLoader(test_data,
                                 batch_size=opts.val_batch_size,
                                 shuffle=False,
                                 num_workers=0)

    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }
    if base_model != 'unet':
        opts.num_classes = 3
        model = model_map[opts.model](num_classes=opts.num_classes,
                                      output_stride=opts.output_stride)
        if opts.separable_conv and 'plus' in opts.model:
            network.convert_to_separable_conv(model.classifier)
        utils.set_bn_momentum(model.backbone, momentum=0.01)

        # Set up optimizer
        optimizer = torch.optim.SGD(params=[
            {
                'params': model.backbone.parameters(),
                'lr': 0.1 * opts.lr
            },
            {
                'params': model.classifier.parameters(),
                'lr': opts.lr
            },
        ],
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)
    else:
        opts.num_classes = 3
        model = UNet(n_channels=3, n_classes=3, bilinear=True)

        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=opts.lr,
                                    momentum=0.9,
                                    weight_decay=opts.weight_decay)

    scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)

    criterion_bce = nn.BCELoss(reduction='mean')
    criterion_dice = MulticlassDiceLoss()

    def save_ckpt(path):
        """ save current model
        """
        torch.save(
            {
                "cur_itrs": cur_itrs,
                "model_state": model.module.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "scheduler_state": scheduler.state_dict(),
                "best_score": best_score,
            }, path)
        print("Model saved as %s" % path)

    utils.mkdir('checkpoints')
    # Restore
    best_score = 0.0
    cur_itrs = 0
    cur_epochs = 0

    if opts.ckpt is not None and os.path.isfile(opts.ckpt):
        checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        model = nn.DataParallel(model)
        model.to(device)
        if opts.continue_training:
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            cur_itrs = checkpoint["cur_itrs"]
            best_score = checkpoint['best_score']
            print("Training state restored from %s" % opts.ckpt)
        print("Model restored from %s" % opts.ckpt)
        del checkpoint  # free memory
    else:
        print("[!] Retrain")
        model = nn.DataParallel(model)
        model.to(device)

    if opts.test_only:
        # model.load_state_dict()
        model.eval()
        dice_bl, dice_sb, dice_st, acc = validate2(
            model=model,
            loader=val_loader,
            device=device,
            itrs=cur_itrs,
            lr=scheduler.get_lr()[-1],
            criterion_dice=criterion_dice)
        # save_ckpt("./checkpoints/CT_" + opts.model + "_" + str(round(dice, 3)) + "__" + str(cur_itrs) + ".pkl")
        print("dice值:", dice_bl)
        return
Пример #15
0
def main():
    opts = parser.parse_args()

    vis = Visualizer(port=opts.vis_port,
                     env=opts.vis_env) if opts.enable_vis else None
    if vis is not None:  # display options
        vis.vis_table("Options", vars(opts))

    os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Device: %s" % device)

    # Setup random seed
    torch.manual_seed(opts.random_seed)
    np.random.seed(opts.random_seed)
    random.seed(opts.random_seed)

    # Set up model
    model_map = {
        'deeplabv3_resnet50': network.deeplabv3_resnet50,
        'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,
        'deeplabv3_resnet101': network.deeplabv3_resnet101,
        'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,
        'deeplabv3_mobilenet': network.deeplabv3_mobilenet,
        'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet
    }

    model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
    if opts.separable_conv and 'plus' in opts.model:
        network.convert_to_separable_conv(model.classifier)
    utils.set_bn_momentum(model.backbone, momentum=0.01)

    # Set up metrics
    metrics = StreamSegMetrics(opts.num_classes)

    # Set up optimizer
    optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
        {'params': model.classifier.parameters(), 'lr': opts.lr},
    ], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
    # torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
    if opts.lr_policy == 'poly':
        scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
    elif opts.lr_policy == 'step':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
    else:
        scheduler = None
        print("please assign a scheduler!")



    utils.mkdir('checkpoints')
    mytrainer = trainer(model, optimizer, scheduler, device, cfg=opts)
    # ==========   Train Loop   ==========#
    #loss_list = ['bound_dice', 'v3_bound_dice']
    #loss_list = ['v5_bound_dice', 'v4_bound_dice']
    loss_list = ['focal']
    if opts.test_only:
        loss_i = 'v3'
        ckpt = os.path.join("checkpoints", loss_i, "latest_deeplabv3plus_mobilenet_coco_epoch01.pth")
        mytrainer.validate(ckpt, loss_i)
    else:
        for loss_i in loss_list:
            mytrainer.train(loss_i)