Example #1
0
def validate(args):

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path, split=args.split, is_transform=True, img_size=(args.img_rows, args.img_cols))
    n_classes = loader.n_classes
    valloader = data.DataLoader(loader, batch_size=args.batch_size, num_workers=4)
    running_metrics = runningScore(n_classes)

    # Setup Model
    model = get_model(args.model_path[:args.model_path.find('_')], n_classes)
    state = convert_state_dict(torch.load(args.model_path)['model_state'])
    model.load_state_dict(state)
    model.eval()

    for i, (images, labels) in tqdm(enumerate(valloader)):
        model.cuda()
        images = Variable(images.cuda(), volatile=True)
        labels = Variable(labels.cuda(), volatile=True)

        outputs = model(images)
        pred = outputs.data.max(1)[1].cpu().numpy()
        gt = labels.data.cpu().numpy()
        
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Example #2
0
def validate(cfg, model_nontree, model_tree, loss_fn, device, root):

    val_loss_meter_nontree = averageMeter()
    if cfg['training']['use_hierarchy']:
        val_loss_meter_level0_nontree = averageMeter()
        val_loss_meter_level1_nontree = averageMeter()
        val_loss_meter_level2_nontree = averageMeter()
        val_loss_meter_level3_nontree = averageMeter()

    val_loss_meter_tree = averageMeter()
    if cfg['training']['use_hierarchy']:
        val_loss_meter_level0_tree = averageMeter()
        val_loss_meter_level1_tree = averageMeter()
        val_loss_meter_level2_tree = averageMeter()
        val_loss_meter_level3_tree = averageMeter()

    if torch.cuda.is_available():
        data_path = cfg['data']['server_path']
    else:
        data_path = cfg['data']['path']

    data_loader = get_loader(cfg['data']['dataset'])
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['val_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    n_classes = v_loader.n_classes
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val_nontree = runningScore(n_classes)
    running_metrics_val_tree = runningScore(n_classes)

    model_nontree.eval()
    model_tree.eval()
    with torch.no_grad():
        print("validation loop")
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            outputs_nontree = model_nontree(images_val)
            outputs_tree = model_tree(images_val)

            if cfg['training']['use_tree_loss']:
                val_loss_nontree = loss_fn(
                    input=outputs_nontree,
                    target=labels_val,
                    root=root,
                    use_hierarchy=cfg['training']['use_hierarchy'])
            else:
                val_loss_nontree = loss_fn(input=outputs_nontree,
                                           target=labels_val)

            if cfg['training']['use_tree_loss']:
                val_loss_tree = loss_fn(
                    input=outputs_tree,
                    target=labels_val,
                    root=root,
                    use_hierarchy=cfg['training']['use_hierarchy'])
            else:
                val_loss_tree = loss_fn(input=outputs_tree, target=labels_val)

            # Using standard max prob based classification
            pred_nontree = outputs_nontree.data.max(1)[1].cpu().numpy()
            pred_tree = outputs_tree.data.max(1)[1].cpu().numpy()

            gt = labels_val.data.cpu().numpy()
            running_metrics_val_nontree.update(
                gt, pred_nontree)  # updates confusion matrix
            running_metrics_val_tree.update(gt, pred_tree)

            if cfg['training']['use_tree_loss']:
                val_loss_meter_nontree.update(
                    val_loss_nontree[1][0])  # take the 1st level
            else:
                val_loss_meter_nontree.update(val_loss_nontree.item())

            if cfg['training']['use_tree_loss']:
                val_loss_meter_tree.update(val_loss_tree[0].item())
            else:
                val_loss_meter_tree.update(val_loss_tree.item())

            if cfg['training']['use_hierarchy']:
                val_loss_meter_level0_nontree.update(val_loss_nontree[1][0])
                val_loss_meter_level1_nontree.update(val_loss_nontree[1][1])
                val_loss_meter_level2_nontree.update(val_loss_nontree[1][2])
                val_loss_meter_level3_nontree.update(val_loss_nontree[1][3])

            if cfg['training']['use_hierarchy']:
                val_loss_meter_level0_tree.update(val_loss_tree[1][0])
                val_loss_meter_level1_tree.update(val_loss_tree[1][1])
                val_loss_meter_level2_tree.update(val_loss_tree[1][2])
                val_loss_meter_level3_tree.update(val_loss_tree[1][3])

            if i_val == 1:
                break

        score_nontree, class_iou_nontree = running_metrics_val_nontree.get_scores(
        )
        score_tree, class_iou_tree = running_metrics_val_tree.get_scores()

        ### VISUALISE METRICS AND LOSSES HERE

        val_loss_meter_nontree.reset()
        running_metrics_val_nontree.reset()
        val_loss_meter_tree.reset()
        running_metrics_val_tree.reset()
        if cfg['training']['use_hierarchy']:
            val_loss_meter_level0_nontree.reset()
            val_loss_meter_level1_nontree.reset()
            val_loss_meter_level2_nontree.reset()
            val_loss_meter_level3_nontree.reset()

        if cfg['training']['use_hierarchy']:
            val_loss_meter_level0_tree.reset()
            val_loss_meter_level1_tree.reset()
            val_loss_meter_level2_tree.reset()
            val_loss_meter_level3_tree.reset()
trainloader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=TRAIN_BATCH,
                                          shuffle=True,
                                          num_workers=TRAIN_BATCH,
                                          pin_memory=True)

val_dataset = tusimpleLoader('/mnt/data/tejus/train_set/',
                             split="val",
                             augmentations=None)
valloader = torch.utils.data.DataLoader(val_dataset,
                                        batch_size=VAL_BATCH,
                                        shuffle=True,
                                        num_workers=VAL_BATCH,
                                        pin_memory=True)

running_metrics_val = runningScore(2)
best_val_loss = math.inf
val_loss = 0
ctr = 0
best_iou = -100
val_loss_meter = averageMeter()
time_meter = averageMeter()

for EPOCHS in range(50):

    # Training

    net.train()
    running_loss = 0
    for i, data in enumerate(trainloader):
        start_ts = time.time()
Example #4
0
def train(cfg, writer, logger):

    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split=cfg.data.train_split,
        augments=data_aug,
        use_perc=cfg.data.use_perc)

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format=cfg.data.format,
        split=cfg.data.val_split,
    )
    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=10,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    logger.info(
        f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}'
    )
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    if cfg.train.clip:
        logger.info(f'max grad norm: {cfg.train.clip}')

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1
            file_a = file_a.to(device)
            file_b = file_b.to(device)
            label = label.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # grads = []
            # for param in model.parameters():
            #     grads.append(param.grad.view(-1))
            # grads = torch.cat(grads)
            # grads = torch.abs(grads)
            # grads_total_norm = tt.get_params_norm(model.parameters(), norm_type=1)
            # writer.add_scalars('grads/unnormed', {'mean': grads.mean(), 'max':grads.max(), 'total':grads_total_norm}, it)

            # logger.info(f'max grad: {grads.max()}, mean grad: {grads.mean()}')
            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`

            if cfg.train.clip:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         max_norm=cfg.train.clip,
                                         norm_type=1)

            # grads = []
            # for param in model.parameters():
            #     grads.append(param.grad.view(-1))
            # grads = torch.cat(grads)
            # grads = torch.abs(grads)
            # grads_total_norm = tt.get_params_norm(model.parameters(), norm_type=1)
            # writer.add_scalars('grads/normed', {'mean': grads.mean(), 'max':grads.max(), 'total':grads_total_norm}, it)

            optimizer.step()
            scheduler.step()

            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred,
                                        mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(
                    it,
                    train_iter,
                    loss.item(),  #extracts the loss’s value as a Python float.
                    train_time_meter.avg / cfg.train.batch_size,
                    train_cls_0_acc,
                    train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {
                    'cls_0': train_cls_0_acc,
                    'cls_1': train_cls_1_acc
                }, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()  # change behavior like drop out
                with torch.no_grad():  # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val,
                         mask_val) in valloader:
                        file_a_val = file_a_val.to(device)
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred,
                                                   mask_val.numpy())

                        label_val = label_val.to(device)
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs,
                                           target=label_val,
                                           mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}"
                )
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(
                    val_cls_0_acc, val_cls_1_acc))
                micro_OA = score['Overall_Acc']
                miou = score['Mean_IoU']
                logger.info(f'overall acc: {micro_OA}, mean iou: {miou}')
                writer.add_scalars('metrics/val', {
                    'cls_0': val_cls_0_acc,
                    'cls_1': val_cls_1_acc
                }, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc + val_cls_1_acc) / 2
                if val_macro_OA >= best_macro_OA_now and it > 200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now': best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg.model.arch,
                                                      cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" %
                                (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time()

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
        "epoch": it,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_macro_OA_now": best_macro_OA_now,
        'best_macro_OA_iter_now': best_macro_OA_iter_now,
    }
    save_path = os.path.join(
        writer.file_writer.get_logdir(),
        "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
Example #5
0
def train(args):

    logger.auto_set_dir()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           epoch_scale=4,
                           augmentations=data_aug,
                           img_norm=args.img_norm)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           img_norm=args.img_norm)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup Model
    from pytorchgo.model.deeplabv1 import VGG16_LargeFoV
    model = VGG16_LargeFoV(class_num=n_classes,
                           image_size=[args.img_cols, args.img_rows],
                           pretrained=True)
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:
        if os.path.isfile(args.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = 0
    for epoch in tqdm(range(args.n_epoch), total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),
                                        total=len(trainloader),
                                        desc="training epoch {}/{}".format(
                                            epoch, args.n_epoch)):
            cur_iter = i + epoch * len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,
                                          args.l_rate,
                                          cur_iter,
                                          args.n_epoch * len(trainloader),
                                          power=0.9)
            #if i > 10:break

            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)
            #print(np.unique(outputs.data[0].cpu().numpy()))
            loss = CrossEntropyLoss2d_Seg(input=outputs,
                                          target=labels,
                                          class_num=n_classes)

            loss.backward()
            optimizer.step()

            if (i + 1) % 100 == 0:
                logger.info("Epoch [%d/%d] Loss: %.4f, lr: %.7f" %
                            (epoch + 1, args.n_epoch, loss.data[0], cur_lr))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader),
                                                    total=len(valloader),
                                                    desc="validation"):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {
                'epoch': epoch + 1,
                'mIoU': best_iou,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       os.path.join(logger.get_logger_dir(), "best_model.pkl"))
Example #6
0
def train(cfg, writer, logger, run_id):

    # Setup random seeds
    # torch.manual_seed(cfg.get('seed', 137))
    # torch.cuda.manual_seed(cfg.get('seed', 137))
    # np.random.seed(cfg.get('seed', 137))
    # random.seed(cfg.get('seed', 137))

    torch.backends.cudnn.benchmark = True

    # Setup Augmentations
    augmentations = cfg['train'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataloader'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    tile_size = cfg['data']['tile_size']
    t_loader = data_loader(data_path,
                           transform=None,
                           split=cfg['data']['train_split'],
                           tile_size=tile_size,
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        transform=None,
        tile_size=tile_size,
        split=cfg['data']['val_split'],
    )
    logger.info(
        f'num of train samples: {len(t_loader)} \nnum of val samples: {len(v_loader)}'
    )

    train_data_len = len(t_loader)
    batch_size = cfg['train']['batch_size']
    epoch = cfg['train']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')
    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['train']['batch_size'],
                                  num_workers=cfg['train']['n_workers'],
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg['train']['batch_size'],
        num_workers=cfg['train']['n_workers'],
        persistent_workers=True,
    )

    # Setup Model
    model = get_model(cfg['model'], n_classes)
    # print('model:\n', model)
    logger.info("Using Model: {}".format(cfg['model']['arch']))
    device = f'cuda:{cuda_idx[0]}'
    model = model.to(device)
    # model = torch.nn.DataParallel(model, device_ids=cuda_idx)      #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['train']['optimizer'].items() if k != 'name'
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    scheduler = get_scheduler(optimizer, cfg['train']['lr_schedule'])
    loss_fn = get_loss_function(cfg)
    # logger.info("Using loss {}".format(loss_fn))

    # set checkpoints
    start_iter = 0
    if cfg['train']['resume'] is not None:
        if os.path.isfile(cfg['train']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['train']['resume']))
            checkpoint = torch.load(cfg['train']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['train']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['train']['resume']))

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    time_meter_val = averageMeter()

    flag = True

    val_rlt_f1 = []
    val_rlt_OA = []
    best_fwIoU_now = 0
    best_fwIoU_iter_till_now = 0

    # train
    it = start_iter
    num_positive = 0
    model.train()
    while it <= train_iter and flag:
        for (file_a, file_b, label) in trainloader:

            # caculate distribution of samples
            num_positive += label.sum()

            it += 1
            # print('iteration', it)
            start_ts = time.time()
            file = torch.cat((file_a, file_b), dim=1).to(device)
            # file_a = file_a.to(device)
            # file_b = file_b.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            outputs = model(file)

            loss = loss_fn(input=outputs, target=label)
            loss.backward()
            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            train_time_meter.update(time.time() - start_ts)
            time_meter_val.update(time.time() - start_ts)

            if (it + 1) % cfg['train']['print_interval'] == 0:
                fmt_str = "train:\nIter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    it + 1,
                    train_iter,
                    loss.item(),  #extracts the loss’s value as a Python float.
                    train_time_meter.avg / cfg['train']['batch_size'])
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it + 1)
                writer.add_scalar(
                    'num_positve',
                    num_positive / (cfg['train']['batch_size'] *
                                    cfg['train']['print_interval']), it + 1)
                num_positive = 0

            if (it + 1) % cfg['train']['val_interval'] == 0 or \
               (it + 1) == train_iter:
                model.eval()  # change behavior like drop out
                with torch.no_grad():  # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val) in valloader:
                        file_val = torch.cat((file_a_val, file_b_val),
                                             dim=1).to(device)
                        # file_a_val = file_a_val.to(device)
                        # file_b_val = file_b_val.to(device)

                        outputs = model(file_val)
                        label_val = label_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val)
                        val_loss_meter.update(val_loss.item())

                        # tensor.max with return the maximum value and its indices
                        label_val = label_val.cpu().numpy()
                        pred = outputs.max(dim=1)[1].cpu().numpy()
                        running_metrics_val.update(label_val, pred)

                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)
                writer.add_scalars('loss/val_loss', {
                    'train': loss.item(),
                    'val': val_loss_meter.avg
                }, it + 1)
                logger.info("Iter %d, val Loss: %.4f" %
                            (it + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                # for k, v in score.items():
                #     logger.info('{}: {}'.format(k, v))
                #     writer.add_scalar('val_metrics/{}'.format(k), v, it+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                      it + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                avg_f1 = score["Mean_F1"]
                OA = score["Overall_Acc"]
                fw_IoU = score["FreqW_IoU"]
                # val_rlt_f1.append(avg_f1)
                # val_rlt_OA.append(OA)

                if fw_IoU >= best_fwIoU_now and it > 200:
                    best_fwIoU_now = fw_IoU
                    correspond_meanIou = score["Mean_IoU"]
                    best_fwIoU_iter_till_now = it + 1

                    state = {
                        "epoch": it + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_fwIoU": best_fwIoU_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(
                            cfg['model']['arch'], cfg['data']['dataloader']))
                    torch.save(state, save_path)

                    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
                    logger.info("Best fwIoU Iter till now= %d" %
                                (best_fwIoU_iter_till_now))

                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - it)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            if (it + 1) == train_iter:
                flag = False
                logger.info("Use the Sar_seg_band3,val_interval: 30")
                break

    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
    logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now))

    state = {
        "epoch": it + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_fwIoU": best_fwIoU_now,
    }
    save_path = os.path.join(
        writer.file_writer.get_logdir(),
        "{}_{}_last_model.pkl".format(cfg['model']['arch'],
                                      cfg['data']['dataloader']))
    torch.save(state, save_path)
Example #7
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device('cuda')

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get(
            'augmentations', {
                'brightness': 63. / 255.,
                'saturation': 0.5,
                'contrast': 0.8,
                'hflip': 0.5,
                'rotate': 10,
                'rscalecropsquare': 713,
            })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes, args).to(device)
    model.apply(weights_init)
    print('sleep for 5 seconds')
    time.sleep(5)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=(0, 1))
    print(model.device_ids)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    if 'multi_step' in cfg['training']['loss']['name']:
        my_loss_fn = loss_fn(
            scale_weight=cfg['training']['loss']['scale_weight'],
            n_inp=2,
            weight=None,
            reduction='sum',
            bkargs=args)
    else:
        my_loss_fn = loss_fn(weight=None, reduction='sum', bkargs=args)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = my_loss_fn(myinput=outputs, target=labels)

            loss.backward()
            optimizer.step()

            # gpu_profile(frame=sys._getframe(), event='line', arg=None)

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = my_loss_fn(myinput=outputs,
                                              target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Example #8
0
def validate(cfg, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        # split = "test", # if test set is used
        is_transform=True,
        img_size=(1024, 2048),
    )

    n_images = len(loader.files[cfg["data"]["val_split"]])
    # print("N images", len(loader.files[cfg["data"]["val_split"]]))  # or "test" instead of cfg["data"]["val_split"]

    n_classes = loader.n_classes
    valloader = data.DataLoader(loader, batch_size=1, num_workers=1)
    running_metrics = runningScore(n_classes)

    # Setup Model
    model = FASSDNet(19).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)

    model.eval()
    model.to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print('Parameters: ', total_params)

    torch.backends.cudnn.benchmark = True

    for i, (images, labels, fname) in enumerate(valloader):
        start_time = timeit.default_timer()

        images = images.to(device)

        if i == 0:
            with torch.no_grad():
                outputs = model(images)

        torch.cuda.synchronize()
        start_time = time.perf_counter()

        with torch.no_grad():
            outputs = model(images)

        torch.cuda.synchronize()
        elapsed_time = time.perf_counter() - start_time

        if args.save_image:
            pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
            save_rgb = True

            decoded = loader.decode_segmap_id(pred)
            dir = "./out_predID/"
            if not os.path.exists(dir):
                os.mkdir(dir)
            misc.imsave(dir + fname[0], decoded)

            if save_rgb:
                decoded = loader.decode_segmap(pred)
                img_input = np.squeeze(images.cpu().numpy(), axis=0)
                img_input = img_input.transpose(1, 2, 0)
                blend = img_input * 0.2 + decoded * 0.8
                fname_new = fname[0]
                fname_new = fname_new[:-4]
                fname_new1 = fname_new + '.jpg'
                fname_new2 = fname_new + '.png'  # For Color labels

                dir = "./out_rgb/"
                if not os.path.exists(dir):
                    os.mkdir(dir)
                misc.imsave(dir + fname_new1, blend)

                # Save labels with color
                dir2 = "./out_color/"
                if not os.path.exists(dir2):
                    os.mkdir(dir2)
                misc.imsave(dir2 + fname_new2, decoded)

        pred = outputs.data.max(1)[1].cpu().numpy()

        gt = labels.numpy()
        s = np.sum(gt == pred) / (1024 * 2048)

        running_metrics.update(gt, pred)
        print("iteration {}/{}".format(i, n_images), end='\r')

    score, class_iou = running_metrics.get_scores()
    # print("score", score)

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Example #9
0
def test(cfg, logger, run_id):
    # augmemtations
    augments = cfg.test.augments
    data_aug = get_composed_augmentations(augments)

    # dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_loader = data_loader(root=cfg.data.path,
                              data_format=cfg.data.format,
                              augments=cfg.test.augments,
                              split=cfg.test.dataset)
    os.mkdir(osp.join(run_id, cfg.test.dataset))

    logger.info(f'data path: {cfg.data.path}')
    logger.info(f'num of {cfg.test.dataset} set samples: {len(data_loader)}')

    loader = data.DataLoader(data_loader,
                             batch_size=cfg.test.batch_size,
                             num_workers=cfg.test.n_workers,
                             shuffle=False,
                             persistent_workers=True,
                             drop_last=False)

    # model
    model = get_model(cfg.model, n_classes=2)
    logger.info(f'using model: {cfg.model.arch}')
    device = f'cuda:{cfg.gpu[0]}'
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)

    # load model params
    if osp.isfile(cfg.test.pth):
        logger.info("Loading model from checkpoint '{}'".format(cfg.test.pth))

        # load model state
        checkpoint = torch.load(cfg.test.pth)
        model.load_state_dict(checkpoint["model_state"])
        # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
        # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]
    else:
        raise FileNotFoundError(f'{cfg.test.pth} file not found')

    # Setup Metrics
    running_metrics_val = runningScore(2)
    running_metrics_train = runningScore(2)
    metrics = runningScore(2)

    # test
    model.eval()
    img_cnt = 0
    with torch.no_grad():
        for (file_a, file_b, label, mask) in loader:
            file_a = file_a.to(device)
            file_b = file_b.to(device)
            label = label.numpy()
            mask = mask.numpy()

            outputs = model(file_a, file_b)
            pred = outputs.max(1)[1].cpu().numpy()
            confusion_matrix_now = metrics.update(label, pred, mask)

            for idx, cm in enumerate(confusion_matrix_now):
                cm *= 100
                pred_filename = osp.join(
                    run_id, cfg.test.dataset,
                    f'{img_cnt}_{cm[0, 0]:.2f}_{cm[1, 1]:.2f}_pred.png')
                gt_filename = osp.join(
                    run_id, cfg.test.dataset,
                    f'{img_cnt}_{cm[0, 0]:.2f}_{cm[1, 1]:.2f}_gt.png')
                img_cnt += 1

                if cv2.imwrite(pred_filename, (pred[idx, :, :] * 255).astype(
                        np.uint8)) and cv2.imwrite(
                            gt_filename,
                            (label[idx, :, :] * 255).astype(np.uint8)):
                    logger.info(f'writed {pred_filename}')
                else:
                    logger.info(f'fail to writed {pred_filename}')

        score, _ = metrics.get_scores()
        # score_train,_ = running_metrics_train.get_scores()
        # score_val,_ = running_metrics_val.get_scores()
        acc = score['Acc']
        # acc_train = score_train['Acc']
        # acc_val = score_val['Acc']
        logger.info(f'acc : {acc}\tOA:{acc.mean()}')
        micro_OA = score['Overall_Acc']
        miou = score['Mean_IoU']
        logger.info(f'overall acc: {micro_OA}, mean iou: {miou}')
Example #10
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', RNG_SEED))
    torch.cuda.manual_seed(cfg.get('seed', RNG_SEED))
    np.random.seed(cfg.get('seed', RNG_SEED))
    random.seed(cfg.get('seed', RNG_SEED))

    # Setup device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(args.device)

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 10,
                                             'rscalecropsquare': 704,  # 640, # 672, # 704,
                                             })
    elif cfg['data']['dataset'] in ['drive']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 180,
                                             'rscalecropsquare': 576,
                                             })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    else:
        augmentations = cfg['training'].get('augmentations', {'rotate': 10, 'hflip': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes, cfg['data']['void_class'] > 0)

    # Setup Model
    print('trying device {}'.format(device))
    model = get_model(cfg['model'], n_classes, args)  # .to(device)

    if cfg['model']['arch'] not in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'unetresnet50', 'unetresnet50bn',
                                    'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        model.apply(weights_init)
    else:
        init_model(model)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # if cfg['model']['arch'] in ['druresnet50syncedbn']:
    #     print('using synchronized batch normalization')
    #     time.sleep(5)
    #     patch_replication_callback(model)

    model = model.cuda()
    # model = torch.nn.DataParallel(model, device_ids=(3, 2))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}
    if cfg['model']['arch'] in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        optimizer = optimizer_cls([
            {'params': model.module.paramGroup1.parameters(), 'lr': optimizer_params['lr'] / 10},
            {'params': model.module.paramGroup2.parameters()}
        ], **optimizer_params)
    else:
        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.warning(f"Model parameters in total: {sum([p.numel() for p in model.parameters()])}")
    logger.warning(f"Trainable parameters in total: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    weight = torch.ones(n_classes)
    if cfg['data'].get('void_class'):
        if cfg['data'].get('void_class') >= 0:
            weight[cfg['data'].get('void_class')] = 0.
    weight = weight.to(device)

    logger.info("Set the prediction weights as {}".format(weight))

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            # for param_group in optimizer.param_groups:
            #     print(param_group['lr'])
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            if cfg['model']['arch'] in ['reclast']:
                h0 = torch.ones([images.shape[0], args.hidden_size, images.shape[2], images.shape[3]],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['recmid']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['dru', 'sru']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                s0 = torch.ones([images.shape[0], n_classes, W, H],
                                dtype=torch.float32)
                s0.to(device)
                outputs = model(images, h0, s0)

            elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                W, H = images.shape[2], images.shape[3]
                w, h = int(W / 2 ** 4), int(H / 2 ** 4)
                if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                    w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32, device=device)
                s0 = torch.zeros([images.shape[0], n_classes, W, H],
                                 dtype=torch.float32, device=device)
                outputs = model(images, h0, s0)

            else:
                outputs = model(images)

            loss = loss_fn(input=outputs, target=labels, weight=weight, bkargs=args)
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            # if use_grad_clip(cfg['model']['arch']):  #
            # if cfg['model']['arch'] in ['rcnn', 'rcnn2', 'rcnn3']:  #
            if use_grad_clip(cfg['model']['arch']):
                nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                torch.backends.cudnn.benchmark = False
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                        if args.benchmark:
                            if i_val > 10:
                                break
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        if cfg['model']['arch'] in ['reclast']:
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, images_val.shape[2], images_val.shape[3]],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['recmid']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['dru', 'sru']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.ones([images_val.shape[0], n_classes, W, H],
                                            dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w, h = int(W / 2**4), int(H / 2**4)
                            if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                                w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.zeros([images_val.shape[0], n_classes, W, H],
                                             dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        else:
                            outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val, bkargs=args)

                        if cfg['training']['loss']['name'] in ['multi_step_cross_entropy']:
                            pred = outputs[-1].data.max(1)[1].cpu().numpy()
                        else:
                            pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()
                        logger.debug('pred shape: ', pred.shape, '\t ground-truth shape:',gt.shape)
                        # IPython.embed()
                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())
                    # assert i_val > 0, "Validation dataset is empty for no reason."
                torch.backends.cudnn.benchmark = True
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                # IPython.embed()
                score, class_iou, _ = running_metrics_val.get_scores()
                for k, v in score.items():
                    # print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(),
                                             best_model_path(cfg))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_final_model.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                torch.save(state, save_path)
                break
def train(args):

    logger.auto_set_dir()
    from pytorchgo.utils.pytorch_utils import set_gpu
    set_gpu(args.gpu)


    # Setup Dataloader
    from pytorchgo.augmentation.segmentation import SubtractMeans, PIL2NP, RGB2BGR,PIL_Scale, Value255to0, ToLabel
    from torchvision.transforms import Compose, Normalize, ToTensor
    img_transform = Compose([  # notice the order!!!
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])

    label_transform = Compose([
        PIL_Scale(train_img_shape, Image.NEAREST),
        PIL2NP(),
        Value255to0(),
        ToLabel()

    ])

    val_img_transform = Compose([
        PIL_Scale(train_img_shape, Image.BILINEAR),
        PIL2NP(),
        RGB2BGR(),
        SubtractMeans(),
        ToTensor(),
    ])
    val_label_transform = Compose([PIL_Scale(train_img_shape, Image.NEAREST),
                                   PIL2NP(),
                                   ToLabel(),
                                   # notice here, training, validation size difference, this is very tricky.
                                   ])

    from pytorchgo.dataloader.pascal_voc_loader import pascalVOCLoader as common_voc_loader
    train_loader = common_voc_loader( split="train_aug", epoch_scale=1, img_transform=img_transform, label_transform=label_transform)

    n_classes = train_loader.n_classes
    trainloader = data.DataLoader(train_loader, batch_size=args.batch_size, num_workers=8, shuffle=True)

    validation_loader = common_voc_loader(split='val',  img_transform=val_img_transform, label_transform=val_label_transform)
    valloader = data.DataLoader(validation_loader, batch_size=args.batch_size, num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)


    # Setup Model
    from pytorchgo.model.deeplabv1 import VGG16_LargeFoV
    from pytorchgo.model.deeplab_resnet import Res_Deeplab

    model = Res_Deeplab(NoLabels=n_classes, pretrained=True, output_all=False)

    from pytorchgo.utils.pytorch_utils import model_summary,optimizer_summary
    model_summary(model)




    def get_validation_miou(model):
        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader), total=len(valloader), desc="validation"):
            if i_val > 5 and is_debug==1: break
            if i_val > 200 and is_debug==2:break

            #img_large = torch.Tensor(np.zeros((1, 3, 513, 513)))
            #img_large[:, :, :images_val.shape[2], :images_val.shape[3]] = images_val

            output = model(Variable(images_val, volatile=True).cuda())
            output = output
            pred = output.data.max(1)[1].cpu().numpy()
            #pred = output[:, :images_val.shape[2], :images_val.shape[3]]

            gt = labels_val.numpy()

            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            logger.info("{}: {}".format(k, v))
        running_metrics.reset()
        return score['Mean IoU : \t']


    model.cuda()
    
    # Check if model has custom optimizer / loss
    if hasattr(model, 'optimizer'):
        logger.warn("don't have customzed optimizer, use default setting!")
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.optimizer_params(args.l_rate), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

    optimizer_summary(optimizer)
    if args.resume is not None:                                         
        if os.path.isfile(args.resume):
            logger.info("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            logger.info("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            logger.info("No checkpoint found at '{}'".format(args.resume))

    best_iou = 0
    logger.info('start!!')
    for epoch in tqdm(range(args.n_epoch),total=args.n_epoch):
        model.train()
        for i, (images, labels) in tqdm(enumerate(trainloader),total=len(trainloader), desc="training epoch {}/{}".format(epoch, args.n_epoch)):
            if i > 10 and is_debug==1: break

            if i> 200 and is_debug==2:break

            cur_iter = i + epoch*len(trainloader)
            cur_lr = adjust_learning_rate(optimizer,args.l_rate,cur_iter,args.n_epoch*len(trainloader),power=0.9)


            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images) # use fusion score
            loss = CrossEntropyLoss2d_Seg(input=outputs, target=labels, class_num=n_classes)

            #for i in range(len(outputs) - 1):
            #for i in range(1):
            #    loss = loss + CrossEntropyLoss2d_Seg(input=outputs[i], target=labels, class_num=n_classes)

            loss.backward()
            optimizer.step()


            if (i+1) % 100 == 0:
                logger.info("Epoch [%d/%d] Loss: %.4f, lr: %.7f, best mIoU: %.7f" % (epoch+1, args.n_epoch, loss.data[0], cur_lr, best_iou))


        cur_miou = get_validation_miou(model)
        if cur_miou >= best_iou:
            best_iou = cur_miou
            state = {'epoch': epoch+1,
                     'mIoU': best_iou,
                     'model_state': model.state_dict(),
                     'optimizer_state' : optimizer.state_dict(),}
            torch.save(state, os.path.join(logger.get_logger_dir(), "best_model.pth"))
Example #12
0
def train(args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)

    t_loader = data_loader(
        data_path,
        is_transform=True,
        img_size=(args.img_rows, args.img_cols),
        augmentations=data_aug,
        img_norm=args.img_norm,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split="val",
        img_size=(args.img_rows, args.img_cols),
        img_norm=args.img_norm,
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(
            X=torch.zeros((1, )).cpu(),
            Y=torch.zeros((1)).cpu(),
            opts=dict(
                xlabel="minibatches",
                ylabel="Loss",
                title="Training Loss",
                legend=["Loss"],
            ),
        )

    # Setup Model
    model = get_model(args.arch, n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Check if model has custom optimizer / loss
    if hasattr(model.module, "optimizer"):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    if hasattr(model.module, "loss"):
        print("Using custom loss")
        loss_fn = model.module.loss
    else:
        loss_fn = cross_entropy2d

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint["epoch"]))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in range(args.n_epoch):
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(
                    X=torch.ones((1, 1)).cpu() * i,
                    Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                    win=loss_window,
                    update="append",
                )

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        running_metrics.reset()

        if score["Mean IoU : \t"] >= best_iou:
            best_iou = score["Mean IoU : \t"]
            state = {
                "epoch": epoch + 1,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
            }
            torch.save(state,
                       "{}_{}_best_model.pkl".format(args.arch, args.dataset))
Example #13
0
def validate(args):
    model_file_name = os.path.split(args.model_path)[1]
    model_name = model_file_name[:model_file_name.find('_')]

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    loader = data_loader(data_path,
                         split=args.split,
                         is_transform=True,
                         img_size=(args.img_rows, args.img_cols),
                         img_norm=args.img_norm)
    n_classes = loader.n_classes
    valloader = data.DataLoader(loader,
                                batch_size=args.batch_size,
                                num_workers=4)
    running_metrics = runningScore(n_classes)

    # Setup Model
    model = get_model(model_name, n_classes, version=args.dataset)
    state = convert_state_dict(torch.load(args.model_path)['model_state'])
    model.load_state_dict(state)
    model.eval()
    model.cuda()

    for i, (images, labels) in enumerate(valloader):
        start_time = timeit.default_timer()

        images = Variable(images.cuda(), volatile=True)
        #labels = Variable(labels.cuda(), volatile=True)

        if args.eval_flip:
            outputs = model(images)

            # Flip images in numpy (not support in tensor)
            outputs = outputs.data.cpu().numpy()
            flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1])
            flipped_images = Variable(
                torch.from_numpy(flipped_images).float().cuda(), volatile=True)
            outputs_flipped = model(flipped_images)
            outputs_flipped = outputs_flipped.data.cpu().numpy()
            outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0

            pred = np.argmax(outputs, axis=1)
        else:
            outputs = model(images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        #gt = labels.data.cpu().numpy()
        gt = labels.numpy()

        if args.measure_time:
            elapsed_time = timeit.default_timer() - start_time
            print('Inference time (iter {0:5d}): {1:3.5f} fps'.format(
                i + 1, pred.shape[0] / elapsed_time))
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Example #14
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # # Setup Augmentations
    # augmentations = cfg["training"].get("augmentations", None)
    # data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model_orig = get_model(cfg["model"], n_classes).to(device)
    if cfg["training"]["pretrain"] == True:
        # Load a pretrained model
        model_orig.load_pretrained_model(
            model_path="pretrained/pspnet101_cityscapes.caffemodel")
        logger.info("Loaded pretrained model.")
    else:
        # No pretrained model
        logger.info("No pretraining.")

    model = torch.nn.DataParallel(model_orig,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    ### Visualize model training

    # helper function to show an image
    # (used in the `plot_classes_preds` function below)
    def matplotlib_imshow(data, is_image):
        if is_image:  #for images
            data = data / 4 + 0.5  # unnormalize
            npimg = data.numpy()
            plt.imshow(npimg, cmap="gray")
        else:  # for labels
            nplbl = data.numpy()
            plt.imshow(t_loader.decode_segmap(nplbl))

    def plot_classes_preds(data, batch_size, iter, is_image=True):
        fig = plt.figure(figsize=(12, 48))
        for idx in np.arange(batch_size):
            ax = fig.add_subplot(1, batch_size, idx + 1, xticks=[], yticks=[])
            matplotlib_imshow(data[idx], is_image)

            ax.set_title("Iteration Number " + str(iter))

        return fig

    best_iou = -100.0
    #best_val_loss = -100.0
    i = start_iter
    flag = True

    #Check if params trainable
    print('CHECK PARAMETER TRAINING:')
    for name, param in model.named_parameters():
        if param.requires_grad == False:
            print(name, param.data)

    while i <= cfg["training"]["train_iters"] and flag:
        for (images_orig, labels_orig, weights_orig,
             nuc_weights_orig) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()  #convert model into training mode
            images = images_orig.to(device)
            labels = labels_orig.to(device)
            weights = weights_orig.to(device)
            nuc_weights = nuc_weights_orig.to(device)

            optimizer.zero_grad()

            outputs = model(images)

            # Transform output to calculate meaningful loss
            out = outputs[0]

            # Resize output of network to same size as labels
            target_size = (labels.size()[1], labels.size()[2])
            out = torch.nn.functional.interpolate(out,
                                                  size=target_size,
                                                  mode='bicubic')

            # Multiply weights by loss output
            loss = loss_fn(input=out, target=labels)

            loss = torch.mul(loss, weights)  # add contour weights
            loss = torch.mul(loss, nuc_weights)  # add nuclei weights
            loss = loss.mean(
            )  # average over all pixels to obtain scaler for loss

            loss.backward()  # computes gradients over network
            optimizer.step()  #updates parameters

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"][
                    "print_interval"] == 0:  # frequency with which visualize training update
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )
                #Show mini-batches during training

                # #Visualize only DAPI
                # writer.add_figure('Inputs',
                #     plot_classes_preds(images_orig.squeeze(), cfg["training"]["batch_size"], i, True),
                #             global_step=i)

                # writer.add_figure('Targets',
                #     plot_classes_preds(labels_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Take max across classes (of probability maps) and assign class label to visualize semantic map
                #1)
                out_orig = torch.nn.functional.softmax(
                    outputs[0], dim=1).max(1).indices.cpu()
                #out_orig = out_orig.cpu().detach()
                #2)
                #out_orig = torch.argmax(outputs[0],dim=1)
                #3)
                #out_orig = outputs[0].data.max(1)[1].cpu()

                # #Visualize predictions
                # writer.add_figure('Predictions',
                #     plot_classes_preds(out_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Save probability map
                prob_maps_folder = os.path.join(
                    writer.file_writer.get_logdir(), "probability_maps")
                os.makedirs(prob_maps_folder, exist_ok=True)

                #Downsample original images to target size for visualization
                images = torch.nn.functional.interpolate(images,
                                                         size=target_size,
                                                         mode='bicubic')

                out = torch.nn.functional.softmax(out, dim=1)

                contours = (out[:, 1, :, :]).unsqueeze(dim=1)
                nuclei = (out[:, 2, :, :]).unsqueeze(dim=1)
                background = (out[:, 0, :, :]).unsqueeze(dim=1)

                #imageTensor = torch.cat((images, contours, nuclei, background),dim=0)

                # Save images side by side: nrow is how many images per row
                #save_image(make_grid(imageTensor, nrow=2), os.path.join(prob_maps_folder,"Prob_maps_%d.tif" % i))

                # Targets visualization below
                nplbl = labels_orig.numpy()
                targets = []  #each element is RGB target label in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    target_bs = t_loader.decode_segmap(nplbl[bs])
                    target_bs = 255 * target_bs
                    target_bs = target_bs.astype('uint8')
                    target_bs = torch.from_numpy(target_bs)
                    target_bs = target_bs.unsqueeze(dim=0)
                    targets.append(target_bs)  #uint8 labels, shape (N,N,3)

                target = reduce(lambda x, y: torch.cat((x, y), dim=0), targets)
                target = target.permute(0, 3, 1,
                                        2)  # size=(Batch, Channels, N, N)
                target = target.type(torch.FloatTensor)

                save_image(
                    make_grid(target, nrow=cfg["training"]["batch_size"]),
                    os.path.join(prob_maps_folder, "Target_labels_%d.tif" % i))

                # Weights visualization below:
                #wgts = weights_orig.type(torch.FloatTensor)
                #save_image(make_grid(wgts, nrow=2), os.path.join(prob_maps_folder,"Weights_%d.tif" % i))

                # Probability maps visualization below
                t1 = []
                t2 = []
                t3 = []
                t4 = []

                # Normalize individual images in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    t1.append((images[bs][0] - images[bs][0].min()) /
                              (images[bs][0].max() - images[bs][0].min()))
                    t2.append(contours[bs])
                    t3.append(nuclei[bs])
                    t4.append(background[bs])

                t1 = [torch.unsqueeze(elem, dim=0)
                      for elem in t1]  #expand dim=0 for images in batch
                # Convert normalized batch to Tensor
                tensor1 = torch.cat((t1), dim=0)
                tensor2 = torch.cat((t2), dim=0)
                tensor3 = torch.cat((t3), dim=0)
                tensor4 = torch.cat((t4), dim=0)

                tTensor = torch.cat((tensor1, tensor2, tensor3, tensor4),
                                    dim=0)
                tTensor = tTensor.unsqueeze(dim=1)

                save_image(make_grid(tTensor,
                                     nrow=cfg["training"]["batch_size"]),
                           os.path.join(prob_maps_folder,
                                        "Prob_maps_%d.tif" % i),
                           normalize=False)

                logger.info(print_str)
                writer.add_scalar(
                    "loss/train_loss", loss.item(),
                    i + 1)  # adds value to history (title, loss, iter index)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1
            ) == cfg["training"][
                    "train_iters"]:  # evaluate model on validation set at these intervals
                model.eval()  # evaluate mode for model
                with torch.no_grad():
                    for i_val, (images_val, labels_val, weights_val,
                                nuc_weights_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        weights_val = weights_val.to(device)
                        nuc_weights_val = nuc_weights_val.to(device)

                        outputs_val = model(images_val)

                        # Resize output of network to same size as labels
                        target_val_size = (labels_val.size()[1],
                                           labels_val.size()[2])
                        outputs_val = torch.nn.functional.interpolate(
                            outputs_val, size=target_val_size, mode='bicubic')

                        # Multiply weights by loss output
                        val_loss = loss_fn(input=outputs_val,
                                           target=labels_val)

                        val_loss = torch.mul(val_loss, weights_val)
                        val_loss = torch.mul(val_loss, nuc_weights_val)
                        val_loss = val_loss.mean(
                        )  # average over all pixels to obtain scaler for loss

                        outputs_val = torch.nn.functional.softmax(outputs_val,
                                                                  dim=1)

                        #Save probability map
                        val_prob_maps_folder = os.path.join(
                            writer.file_writer.get_logdir(),
                            "val_probability_maps")
                        os.makedirs(val_prob_maps_folder, exist_ok=True)

                        #Downsample original images to target size for visualization
                        images_val = torch.nn.functional.interpolate(
                            images_val, size=target_val_size, mode='bicubic')

                        contours_val = (outputs_val[:,
                                                    1, :, :]).unsqueeze(dim=1)
                        nuclei_val = (outputs_val[:, 2, :, :]).unsqueeze(dim=1)
                        background_val = (outputs_val[:, 0, :, :]).unsqueeze(
                            dim=1)

                        # Targets visualization below
                        nplbl_val = labels_val.cpu().numpy()
                        targets_val = [
                        ]  #each element is RGB target label in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            target_bs = v_loader.decode_segmap(nplbl_val[bs])
                            target_bs = 255 * target_bs
                            target_bs = target_bs.astype('uint8')
                            target_bs = torch.from_numpy(target_bs)
                            target_bs = target_bs.unsqueeze(dim=0)
                            targets_val.append(
                                target_bs)  #uint8 labels, shape (N,N,3)

                        target_val = reduce(
                            lambda x, y: torch.cat((x, y), dim=0), targets_val)
                        target_val = target_val.permute(
                            0, 3, 1, 2)  # size=(Batch, Channels, N, N)
                        target_val = target_val.type(torch.FloatTensor)

                        save_image(
                            make_grid(target_val,
                                      nrow=cfg["training"]["batch_size"]),
                            os.path.join(
                                val_prob_maps_folder,
                                "Target_labels_%d_val_%d.tif" % (i, i_val)))

                        # Weights visualization below:
                        #wgts_val = weights_val.type(torch.FloatTensor)
                        #save_image(make_grid(wgts_val, nrow=2), os.path.join(val_prob_maps_folder,"Weights_val_%d.tif" % i_val))

                        # Probability maps visualization below
                        t1_val = []
                        t2_val = []
                        t3_val = []
                        t4_val = []
                        # Normalize individual images in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            t1_val.append(
                                (images_val[bs][0] - images_val[bs][0].min()) /
                                (images_val[bs][0].max() -
                                 images_val[bs][0].min()))
                            t2_val.append(contours_val[bs])
                            t3_val.append(nuclei_val[bs])
                            t4_val.append(background_val[bs])

                        t1_val = [
                            torch.unsqueeze(elem, dim=0) for elem in t1_val
                        ]  #expand dim=0 for images_val in batch
                        # Convert normalized batch to Tensor
                        tensor1_val = torch.cat((t1_val), dim=0)
                        tensor2_val = torch.cat((t2_val), dim=0)
                        tensor3_val = torch.cat((t3_val), dim=0)
                        tensor4_val = torch.cat((t4_val), dim=0)

                        tTensor_val = torch.cat((tensor1_val, tensor2_val,
                                                 tensor3_val, tensor4_val),
                                                dim=0)
                        tTensor_val = tTensor_val.unsqueeze(dim=1)

                        save_image(make_grid(
                            tTensor_val, nrow=cfg["training"]["batch_size"]),
                                   os.path.join(
                                       val_prob_maps_folder,
                                       "Prob_maps_%d_val_%d.tif" % (i, i_val)),
                                   normalize=False)

                        pred = outputs_val.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                ### Save best validation loss model
                # if val_loss_meter.avg >= best_val_loss:
                #     best_val_loss = val_loss_meter.avg
                #     state = {
                #         "epoch": i + 1,
                #         "model_state": model.state_dict(),
                #         "optimizer_state": optimizer.state_dict(),
                #         "scheduler_state": scheduler.state_dict(),
                #         "best_val_loss": best_val_loss,
                #     }
                #     save_path = os.path.join(
                #         writer.file_writer.get_logdir(),
                #         "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                #     )
                #     torch.save(state, save_path)
                ###

                score, class_iou = running_metrics_val.get_scores(
                )  # best model chosen via IoU
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                ### Save best mean IoU model
                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                ###

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Example #15
0
        print("==> Loading Half-training model...'{}'".format(march))
        checkpoints = torch.load(resume_root)
        start_epoch = checkpoints['epoch']
        state = checkpoints['model_state']
        model.load_state_dict(state)
        optimizer.load_state_dict(checkpoints['optimizer_state'])
        print("==> Checkpoint '{}' Loaded, start from epoch {}".format(
            march, checkpoints["epoch"]))

if hasattr(model.module, 'loss'):
    print('Using custom loss')
    loss_fn = model.module.loss
else:
    loss_fn = cross_entropy2d
# Setup Metrics
running_metrics = runningScore(traindata.n_classes)

best_iou = -100.0
for epoch in range(start_epoch, mn_epoch):
    if (epoch % 100 == 0):
        ml_rate = ml_rate * 0.1
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=ml_rate,
                                    momentum=0.9,
                                    weight_decay=5e-4)
    storeLoss = []
    model.train()
    # print(len(trainloader))
    for i, (images, labels) in enumerate(trainloader):
        images = Variable(images.cuda())
        labels = Variable(labels.cuda())
def train(args):

    # Setup Augmentations
    # data_aug= Compose([RandomRotate(10), RandomHorizontallyFlip()])
    data_aug = Compose([RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           augmentations=data_aug,
                           img_norm=args.img_norm)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           img_norm=args.img_norm)
    # t_loader = data_loader(data_path, is_transform=True, img_size=None, augmentations=data_aug, img_norm=args.img_norm)
    # v_loader = data_loader(data_path, is_transform=True, split='val', img_size=None, img_norm=args.img_norm)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics_first_head = runningScore(n_classes)
    running_metrics_second_head = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           'min',
                                                           patience=20)

    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = cross_entropy2d

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    best_iou_first_head = -100.0
    best_iou_second_head = -100.0
    class_weights = torch.ones(n_classes).cuda()
    class_weights[
        -1] *= 5.0  # Distinguishing the border is the most important task
    print("Class weights:", class_weights)
    for epoch in range(args.n_epoch):
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            output_first_head, output_second_head = model(images)

            loss_row = loss_fn(input=output_first_head,
                               target=labels[:, 0, :, :],
                               weight=class_weights)
            loss_col = loss_fn(input=output_second_head,
                               target=labels[:, 1, :, :],
                               weight=class_weights)
            loss = loss_row + loss_col

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(X=torch.ones((1, 1)).cpu() * i,
                         Y=torch.Tensor([loss.item()]).unsqueeze(0).cpu(),
                         win=loss_window,
                         update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.item()))

        model.eval()
        avg_loss_val = 0.0
        num_iter = 0
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            with torch.no_grad():
                images_val = Variable(images_val.cuda())
                labels_val = Variable(labels_val.cuda())

                # outputs = model(images_val)
                output_first_head, output_second_head = model(images_val)
                pred_first_head = output_first_head.data.max(
                    1)[1].cpu().numpy()
                pred_second_head = output_second_head.data.max(
                    1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()
                gt_first_head = gt[:, 0, :, :]
                gt_second_head = gt[:, 1, :, :]
                running_metrics_first_head.update(gt_first_head,
                                                  pred_first_head)
                running_metrics_second_head.update(gt_second_head,
                                                   pred_second_head)

                # Compute loss on the validation set
                loss_row = loss_fn(input=output_first_head,
                                   target=labels_val[:, 0, :, :],
                                   weight=class_weights)
                loss_col = loss_fn(input=output_second_head,
                                   target=labels_val[:, 1, :, :],
                                   weight=class_weights)
                loss_val = loss_row + loss_col

                avg_loss_val += loss_val.item()
                num_iter += 1

        # Update the learning rate
        avg_loss_val = avg_loss_val / num_iter
        print("Average validation loss: %.4f" % (avg_loss_val))
        scheduler.step(avg_loss_val)

        score_first_head, class_iou_first_head = running_metrics_first_head.get_scores(
        )
        score_second_head, class_iou_second_head = running_metrics_second_head.get_scores(
        )
        print("First head:")
        for k, v in score_first_head.items():
            print(k, v)
        print("Second head:")
        for k, v in score_second_head.items():
            print(k, v)
        running_metrics_first_head.reset()
        running_metrics_second_head.reset()

        if score_first_head[
                'Mean IoU : \t'] >= best_iou_first_head and score_second_head[
                    'Mean IoU : \t'] >= best_iou_second_head:
            best_iou_first_head = score_first_head['Mean IoU : \t']
            best_iou_second_head = score_second_head['Mean IoU : \t']
            print(
                "Saving best model with %.5f first head mean IoU and %.5f second head mean IoU"
                % (best_iou_first_head, best_iou_second_head))
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       "{}_{}_best_model.pkl".format(args.arch, args.dataset))
Example #17
0
def train(param_file):
    with open(param_file) as json_params:
        params = json.load(json_params)
    exp_identifier = '|'.join('{}={}'.format(key, val)
                              for (key, val) in params.items())
    params['exp_id'] = exp_identifier

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(params['dataset'])
    data_path = get_data_path(params['dataset'])
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=['train'],
                           img_size=(params['img_rows'], params['img_cols']),
                           augmentations=data_aug)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=['val'],
                           img_size=(params['img_rows'], params['img_cols']))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=params['batch_size'],
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=params['batch_size'],
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    writer = SummaryWriter(log_dir='runs/{}_{}'.format(
        params['exp_id'],
        datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    # Setup Model
    model = get_model(params['arch'], n_classes, params['tasks'])

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    print(params)
    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=params['lr'],
                                    momentum=0.9)

    loss_fn = {}
    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn['S'] = cross_entropy2d
        loss_fn['I'] = l1_loss_instance
    """
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    """

    tasks = []
    if 'S' in params['tasks']:
        tasks.append('S')
    if 'I' in params['tasks']:
        tasks.append('I')

    best_iou = -100.0
    best_loss = 1e8
    n_iter = 0
    for epoch in range(100):
        model.train()
        if (epoch + 1) % 30 == 0:
            # Every 10 epoch, half the LR
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.5
            print('Half the learning rate{}'.format(n_iter))
        target = {}
        for i, (images, labels, instances, imname) in enumerate(trainloader):
            n_iter += 1
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            instances = Variable(instances.cuda())

            target['S'] = labels
            target['I'] = instances

            optimizer.zero_grad()
            outputs = model(images)

            for task_id, task in enumerate(tasks):
                if task_id > 0:
                    loss = loss + loss_fn[task](input=outputs[task_id],
                                                target=target[task])
                else:
                    loss = loss_fn[task](input=outputs[0], target=target[task])

            if loss is None:
                print('WARN: image with no instance {}'.format(imname))
                continue

            loss.backward()
            optimizer.step()
            writer.add_scalar('training_loss', loss.data[0], n_iter)
        model.eval()
        tot_loss = 0.0
        summed = 0.0
        target_val = {}
        val_losses = {}
        for task in tasks:
            val_losses[task] = 0.0
        for i_val, (images_val, labels_val, instances_val,
                    imname_val) in enumerate(valloader):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)
            instances_val = Variable(instances_val.cuda(), volatile=True)

            target_val['S'] = labels_val
            target_val['I'] = instances_val

            outputs = model(images_val)

            for task_id, task in enumerate(tasks):
                if task_id > 0:
                    ll = loss_fn[task](input=outputs[task_id],
                                       target=target_val[task])
                    val_losses[task] += ll.data[0]
                else:
                    ll = loss_fn[task](input=outputs[0],
                                       target=target_val[task])
                    val_losses[task] += ll.data[0]
                if 'S' in task:
                    pred_cpu = outputs[task_id].data.max(1)[1].cpu().numpy()
                    gt_cpu = labels_val.data.cpu().numpy()
                    running_metrics.update(gt_cpu, pred_cpu)
            summed += 1
            #running_metrics.update_instance(instances_gt, outputs.data.cpu().numpy())
        for task in tasks:
            writer.add_scalar('validation_loss_{}'.format(task),
                              val_losses[task] / summed, n_iter)
        writer.add_scalar('validation_loss_total',
                          sum(val_losses.values()) / summed, n_iter)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            writer.add_scalar('score_{}'.format(k), v, n_iter)
        running_metrics.reset()
        """
Example #18
0
def train(args):
    do_finetuning = True
    data_parallel = False # whether split a batch on multiple GPUs to accelerate
    if data_parallel:
        print('Using data parallel.')
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # use GPU1 if train on one GPU
    # Setup Augmentations
    data_aug= Compose([RandomSized(args.img_rows),
                        RandomHorizontallyFlip(),
                        RandomSizedCrop(args.img_rows)])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols), augmentations=data_aug, img_norm=False)
    v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(args.img_rows, args.img_cols), img_norm=False)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True, drop_last = True)
    valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=8, drop_last = True)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()
        # window for training loss
        loss_window = vis.line(X=torch.ones((1)),
                           Y=torch.zeros((1)),
                           opts=dict(xlabel='epoch',
                                     ylabel='Loss',
                                     title='Training Loss',
                                     legend=['Loss'],
                                     width = 400,
                                     height = 400))
        # window for example training image
        image_train_window = vis.images(torch.zeros((3, 3, args.img_rows, args.img_cols)),
                           opts=dict(nrow = 3,
                                     caption = 'input-prediction-groundtruth',
                                     title = 'Training example image'))
        # window for train and validation accuracy
        acc_window = vis.line(X=torch.ones((1,2)),
                           Y=torch.zeros((1,2)),
                           opts=dict(xlabel='epoch',
                                     ylabel='mean IoU',
                                     title='Mean IoU',
                                     legend=['train','validation'],
                                     width = 400,
                                     height = 400))

        # window for example validation image
        image_val_window = vis.images(torch.zeros((3, 3, args.img_rows, args.img_cols)),
                           opts=dict(nrow = 3,
                                     caption = 'input-prediction-groundtruth',
                                     title = 'Validation example image'))
    # Setup Model
    model_name = 'pspnet'
    model = get_model(model_name, n_classes, version = args.dataset+'_res50')
    #model = get_model(model_name, n_classes, version = args.dataset+'_res101')
    if do_finetuning:
        # pspnet pretrained on ade20k
        pretrained_model_path = '/home/interns/xuan/models/pspnet_50_ade20k.pth'
        # pspnet pretrained on pascal VOC
        #pretrained_model_path = '/home/interns/xuan/models/pspnet_101_pascalvoc.pth'
        pretrained_state = convert_state_dict(torch.load(pretrained_model_path)['model_state']) # remove 'module' in keys
        # Load parameters except for last classification layer to fine tuning
        print('Setting up for fine tuning')
        # 1. filter out unnecessary keys
        pretrained_state = {k: v for k, v in pretrained_state.items() if k not in ['classification.weight', 'classification.bias',
                                                                                    'aux_cls.weight', 'aux_cls.bias']}
        # 2. overwrite entries in the existing state dict
        model_state_dict = model.state_dict()
        model_state_dict.update(pretrained_state)
        # 3. load the new state dict
        model.load_state_dict(model_state_dict)

    # load checkpoint to continue training
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint)
            print("Loaded checkpoint '{}'"
                  .format(args.resume))
        else:
            print("No checkpoint found at '{}'".format(args.resume))
    '''
    # freeze all parameters except for final classification if doing fine tuning
    if do_finetuning:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.classification.parameters():
            param.requires_grad = True
        for param in model.cbr_final.parameters():
            param.requires_grad = True
    '''

    # Set up optimizer
    opt_dict = {'name': 'SGD', 'learning_rate': args.l_rate, 'momentum': 0.9, 'weight_decay': 1e-3}
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), opt_dict['learning_rate'], opt_dict['momentum'], opt_dict['weight_decay'])
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda = lambda epoch: 0.9**epoch)

    # train on multiple GPU
    if data_parallel:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # move parameters to GPU
    model = model.to(device)

    best_iou = -100.0
    statistics = {}
    best_model_stat = {}
    # print params
    print('optimizer', opt_dict)
    print('batch size', args.batch_size)
    since = time() # start time
    # for every epoch.train then validate. Keep the best model in validation
    for epoch in range(1, args.n_epoch + 1):
        print('=>Epoch %d / %d' % (epoch, args.n_epoch))
        # -------- train --------
        model.train()
        # Freeze BatchNorm2d layers because we have small batch size
        #print('Freeze BatchNorm2d layers')
        #model.apply(freeze_batchnorm2d)
        print('  =>Training')
        loss_epoch = 0. # average loss in an epoch
        #scheduler.step()
        for i, (images, labels) in tqdm(enumerate(trainloader), total = len(trainloader)):
            images = images.to(device)
            labels = labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            outputs = model(images)
            # if use aux loss, loss_fn = multi_scale_cross_entropy2d
            # if ignore aux loss, loss_fn = cross_entropy2d
            loss = multi_scale_cross_entropy2d(input=outputs, target=labels, device = device)
            loss.backward()
            optimizer.step()
            # update average loss
            loss_epoch += loss.item()
            # update train accuracy (mIoU)
            pred = outputs[0].data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics.update(gt, pred)

        loss_epoch /= len(trainloader)
        print('Average training loss: %f' % loss_epoch)
        # draw train loss every epoch
        if args.visdom:
            vis.line(
                X=torch.Tensor([epoch]),
                Y=torch.Tensor([loss_epoch]).unsqueeze(0),
                win=loss_window,
                update='append')
        # get train accuracy for this epoch
        scores_train, class_iou_train = running_metrics.get_scores()
        running_metrics.reset()
        print('Training mean IoU: %f' % scores_train['Mean IoU'])

        # -------- validate --------
        model.eval()
        print('  =>Validation')
        with torch.no_grad():
            for i_val, (images_val, labels_val) in tqdm(enumerate(valloader), total = len(valloader)):
                images_val = images_val.to(device)
                labels_val = labels_val.to(device)
                outputs = model(images_val)
                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()
                running_metrics.update(gt, pred)

        scores_val, class_iou_val = running_metrics.get_scores()
        running_metrics.reset()
        for k, v in scores_val.items():
            print(k+': %f' % v)

        # --------save best model --------
        if scores_val['Mean IoU'] >= best_iou:
            best_iou = scores_val['Mean IoU']
            best_model = model.state_dict()
            if data_parallel:
                best_model = convert_state_dict(best_model) # remove 'module' in keys to be competible with single GPU
            torch.save(best_model, "{}_{}_best_model.pth".format(model_name, args.dataset))
            print('Best model updated!')
            print(class_iou_val)
            best_model_stat = {'epoch': epoch, 'scores_val': scores_val, 'class_iou_val': class_iou_val}

        # -------- draw --------
        if args.visdom:
            # draw accuracy for training and validation
            vis.line(
                X=torch.Tensor([epoch]),
                Y=torch.Tensor([scores_train['Mean IoU'], scores_val['Mean IoU']]).unsqueeze(0),
                win=acc_window,
                update='append')
            # show example train image
            with torch.no_grad():
                (image_train, label_train) = t_loader[0]
                gt = t_loader.decode_segmap(label_train.numpy())
                image_train = image_train.unsqueeze(0)
                image_train = image_train.to(device)
                label_train = label_train.to(device)
                outputs = model(image_train)
                pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
                decoded = t_loader.decode_segmap(pred)
                vis.images([image_train.data.cpu().squeeze(0), decoded.transpose(2,0,1)*255.0, gt.transpose(2,0,1)*255.0], win = image_train_window)
	        # show example validation image
            with torch.no_grad():
                (image_val, label_val) = v_loader[0]
                gt = v_loader.decode_segmap(label_val.numpy())
                image_val = image_val.unsqueeze(0)
                image_val = image_val.to(device)
                label_val = label_val.to(device)
                outputs = model(image_val)
                pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
                decoded = v_loader.decode_segmap(pred)
                vis.images([image_val.data.cpu().squeeze(0), decoded.transpose(2,0,1)*255.0, gt.transpose(2,0,1)*255.0], win = image_val_window)

        # -------- save training statistics --------
        statistics['epoch %d' % epoch] = {'train_loss': loss_epoch, 'scores_train': scores_train, 'scores_val': scores_val}
        with open('train_statistics.json', 'w') as outfile:
            json.dump({
                    'optimizer': opt_dict,
                    'batch_size': args.batch_size,
                    'data_parallel': data_parallel,
                    'Training hours': (time() - since)/3600.0,
                    'best_model': best_model_stat,
                    'statistics': statistics
                    }, outfile)
Example #19
0
def train(cfg, writer, logger, run_id):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.backends.cudnn.benchmark = True

    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    # model = get_model(cfg['model'], n_classes).to(device)
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))

    # model=apex.parallel.convert_syncbn_model(model)
    model = model.to(device)

    # a=range(torch.cuda.device_count())
    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model, device_ids=[0, 1])
    # model = encoding.parallel.DataParallelModel(model, device_ids=[0, 1])

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    loss_fn = get_loss_function(cfg)
    # loss_fn== encoding.parallel.DataParallelCriterion(loss_fn, device_ids=[0, 1])
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    train_data_len = t_loader.__len__()
    batch_size = cfg['training']['batch_size']
    epoch = cfg['training']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)

    val_rlt_f1 = []
    val_rlt_OA = []
    best_f1_till_now = 0
    best_OA_till_now = 0

    while i <= train_iter and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            # optimizer.backward(loss)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            ### add by Sprit
            time_meter_val.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, train_iter, loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == train_iter:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        # val_loss_meter.update(val_loss.item())

                # writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                # val_loss_meter.reset()
                running_metrics_val.reset()

                ### add by Sprit
                avg_f1 = score["Mean F1 : \t"]
                OA = score["Overall Acc: \t"]
                val_rlt_f1.append(avg_f1)
                val_rlt_OA.append(score["Overall Acc: \t"])

                if avg_f1 >= best_f1_till_now:
                    best_f1_till_now = avg_f1
                    correspond_OA = score["Overall Acc: \t"]
                    best_f1_epoch_till_now = i + 1
                print("\nBest F1 till now = ", best_f1_till_now)
                print("Correspond OA= ", correspond_OA)
                print("Best F1 Iter till now= ", best_f1_epoch_till_now)

                if OA >= best_OA_till_now:
                    best_OA_till_now = OA
                    correspond_f1 = score["Mean F1 : \t"]
                    # correspond_acc=score["Overall Acc: \t"]
                    best_OA_epoch_till_now = i + 1

                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_OA": best_OA_till_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

                print("Best OA till now = ", best_OA_till_now)
                print("Correspond F1= ", correspond_f1)
                # print("Correspond OA= ",correspond_acc)
                print("Best OA Iter till now= ", best_OA_epoch_till_now)

                ### add by Sprit
                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - i)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain training time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain training time : Training completed.\n"
                print(train_time)

                # if OA >= best_OA_till_now:
                #     best_iou = score["Mean IoU : \t"]
                #     state = {
                #         "epoch": i + 1,
                #         "model_state": model.state_dict(),
                #         "optimizer_state": optimizer.state_dict(),
                #         "scheduler_state": scheduler.state_dict(),
                #         "best_iou": best_iou,
                #     }
                #     save_path = os.path.join(writer.file_writer.get_logdir(),
                #                              "{}_{}_best_model.pkl".format(
                #                                  cfg['model']['arch'],
                #                                  cfg['data']['dataset']))
                #     torch.save(state, save_path)

            if (i + 1) == train_iter:
                flag = False
                break
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_f1,
                  cfg['training']['val_interval'])
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_OA,
                  cfg['training']['val_interval'])
Example #20
0
def train(args):

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
    assert len(args.gpus) // args.batch_size == 3, 'Each gpu must have 3'

    # Setup Augmentations
    data_aug = Compose([RandomRotate(10), RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader('semi_cityscapes')
    data_path = get_data_path(args.dataset)
    if args.subsample:
        city_names = '[a-h]*'
    else:
        city_names = '*'
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           augmentations=data_aug,
                           gamma_augmentation=args.gamma,
                           city_names=city_names,
                           real_synthetic=args.real_synthetic)
    # Full val dataset
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols),
                           city_names='*')
    print("Training dataset size: {}".format(len(t_loader)))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Checkpoint
    ckpt_dir = os.path.join(args.ckpt_dir, args.name)
    os.makedirs(ckpt_dir, exist_ok=True)
    tb_path = os.path.join(ckpt_dir, 'tb')
    if os.path.exists(tb_path):
        os.system('rm -r {}'.format(tb_path))
    writer = SummaryWriter(tb_path)
    log_path = os.path.join(ckpt_dir, 'train.log')
    with open(log_path, 'w+') as f:
        args_dict = vars(args)
        for k in sorted(args_dict.keys()):
            f.write('{}: {}\n'.format(k, str(args_dict[k])))

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)

    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # model.cuda()
    model = torch.nn.DataParallel(model.cuda())

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.l_rate,
                                    momentum=0.95,
                                    weight_decay=5e-4)
    scheduler = StepLR(optimizer, step_size=args.lr_step_size, gamma=0.1)

    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = cross_entropy2d

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in range(args.n_epoch):
        epoch_start_time = time.time()
        scheduler.step()
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(X=torch.ones((1, 1)).cpu() * i,
                         Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                         win=loss_window,
                         update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]))

        print("Epoch [{}/{}] done ({} sec)".format(
            epoch + 1, args.n_epoch, int(time.time() - epoch_start_time)))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        running_metrics.reset()

        mean_iou = score['Mean IoU : \t']
        writer.add_scalar('mean IoU', mean_iou, epoch)
        if mean_iou >= best_iou:
            best_iou = mean_iou
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state, "{}/best_model.pkl".format(ckpt_dir))
Example #21
0
def validate(cfg, args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(
        data_path,
        split=cfg['data']['val_split'],
        is_transform=True,
        img_size=(cfg['data']['img_rows'], cfg['data']['img_rows']),
    )

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=8)
    running_metrics = runningScore(n_classes)

    # Setup Model

    model = get_model(cfg['model'], n_classes).to(device)
    # state = convert_state_dict(torch.load(args.model_path)["model_state"])
    state = torch.load(args.model_path)["model_state"]
    model.load_state_dict(state)
    model.eval()
    model.to(device)

    for i, (images, labels) in enumerate(valloader):
        start_time = timeit.default_timer()

        images = images.to(device)

        if args.eval_flip:
            outputs = model(images)

            # Flip images in numpy (not support in tensor)
            outputs = outputs.data.cpu().numpy()
            flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1])
            flipped_images = torch.from_numpy(flipped_images).float().to(
                device)
            outputs_flipped = model(flipped_images)
            outputs_flipped = outputs_flipped.data.cpu().numpy()
            outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0

            pred = np.argmax(outputs, axis=1)
        else:
            outputs = model(images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        gt = labels.numpy()

        if args.measure_time:
            elapsed_time = timeit.default_timer() - start_time
            print("Inference time \
                  (iter {0:5d}): {1:3.5f} fps".format(
                i + 1, pred.shape[0] / elapsed_time))
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Example #22
0
def validate(cfg, args):

    augmentations = {
        'hue': args.hue,
        'contrast': args.contrast,
        'brightness': args.brightness,
        'saturation': args.saturation,
        'gamma': args.gamma
    }
    data_aug = get_composed_augmentations(augmentations)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(data_path,
                         split=cfg['data']['val_split'],
                         is_transform=True,
                         img_size=(cfg['data']['img_rows'],
                                   cfg['data']['img_cols']),
                         augmentations=data_aug)

    n_classes = loader.n_classes
    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=8)
    running_metrics = runningScore(n_classes)

    model = get_model(cfg['model'], n_classes).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    checkpoint = torch.load(cfg['training']['resume'],
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()
    model.to(device)

    print("hue {}".format(args.hue))
    print("contrast {}".format(args.contrast))
    print("brightness {}".format(args.brightness))
    print("saturation {}".format(args.saturation))
    print("gamma {}".format(args.gamma))

    for i, (images, labels) in enumerate(valloader):
        # img = images[0, [2, 1, 0], :, :]
        # vis.images(img)
        images = images.to(device)
        outputs = model(images)
        if cfg['model']['arch'] == 'deeplab':
            interp = nn.Upsample(size=(cfg['data']['img_rows'],
                                       cfg['data']['img_cols']),
                                 mode='bilinear')
            outputs = interp(outputs)
        pred = outputs.data.max(1)[1].cpu().numpy()
        gt = labels.numpy()
        # decoded_crf = loader.decode_segmap(np.array(pred.squeeze(0), dtype=np.uint8))
        # vis.image(decoded_crf.transpose([2, 0, 1]))
        # fg = loader.decode_segmap(np.array(gt.squeeze(0), dtype=np.uint8))
        # vis.image(fg.transpose([2, 0, 1]))
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        if loader.class_names is not None:
            print(loader.class_names[i + 1], class_iou[i])
    print('\n')
Example #23
0
def validate(cfg, args):

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    path_n = cfg["model"]["path_num"]

    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        # img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=v_data_aug,
        path_num=path_n)

    n_classes = v_loader.n_classes
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"])

    running_metrics = runningScore(n_classes)

    # Setup Model
    teacher = get_model(cfg["teacher"], n_classes)
    model = get_model(cfg["model"],
                      n_classes,
                      psp_path=cfg["training"]["resume"],
                      teacher=teacher).to(device)
    state = torch.load(cfg["validating"]["resume"])  #["model_state"]
    model.load_state_dict(state, strict=False)
    model.eval()
    model.to(device)

    with torch.no_grad():
        for i, (val, labels) in enumerate(valloader):

            gt = labels.numpy()
            _val = [ele.to(device) for ele in val]

            torch.cuda.synchronize()
            start_time = timeit.default_timer()
            outputs = model(_val, pos_id=i % path_n)
            torch.cuda.synchronize()
            elapsed_time = timeit.default_timer() - start_time
            pred = outputs.data.max(1)[1].cpu().numpy()
            running_metrics.update(gt, pred)

            if args.measure_time:
                elapsed_time = timeit.default_timer() - start_time
                print("Inference time \
                      (iter {0:5d}): {1:3.5f} fps".format(
                    i + 1, pred.shape[0] / elapsed_time))
            if False:
                decoded = v_loader.decode_segmap(pred[0])
                import cv2
                cv2.namedWindow("Image")
                cv2.imshow("Image", decoded)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Example #24
0
def validate(cfg, args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        is_transform=True,
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=8)
    running_metrics = runningScore(n_classes)

    # Setup Model

    model = get_model(cfg["model"], n_classes).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.eval()
    model.to(device)

    for i, (images, labels) in enumerate(valloader):
        start_time = timeit.default_timer()

        images = images.to(device)

        if args.eval_flip:
            outputs = model(images)

            # Flip images in numpy (not support in tensor)
            outputs = outputs.data.cpu().numpy()
            flipped_images = np.copy(images.data.cpu().numpy()[:, :, :, ::-1])
            flipped_images = torch.from_numpy(flipped_images).float().to(
                device)
            outputs_flipped = model(flipped_images)
            outputs_flipped = outputs_flipped.data.cpu().numpy()
            outputs = (outputs + outputs_flipped[:, :, :, ::-1]) / 2.0

            pred = np.argmax(outputs, axis=1)
        else:
            outputs = model(images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        pred = pred + 1
        gt = labels.numpy()

        # gt_im = Image.fromarray(gt[0, :, :].astype('uint8'), mode='P')
        # gt_im.putpalette(color_map())
        # gt_im.save('output/%d_gt.png' % i)

        # pred_im = Image.fromarray(pred[0, :, :].astype('uint8'), mode='P')
        # pred_im.putpalette(color_map())
        # pred_im.save('output/%d_pred.png' % i)

        # # print(images.min(), images.max(), images.mean())
        # rgb_im = images[0, :, :, :].detach().cpu().numpy()
        # rgb_im = im_inv_trans(rgb_im)
        # rgb_im = Image.fromarray(rgb_im.astype('uint8'))
        # rgb_im.save('output/%d_im.png' % i)

        if args.measure_time:
            elapsed_time = timeit.default_timer() - start_time
            print("Inference time \
                  (iter {0:5d}): {1:3.5f} fps".format(
                i + 1, pred.shape[0] / elapsed_time))
        running_metrics.update(gt, pred)

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes - 1):
        print(i, class_iou[i])
Example #25
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            if not args.load_weight_only:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
                scheduler.load_state_dict(checkpoint["scheduler_state"])
                start_iter = checkpoint["epoch"]
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
            else:
                logger.info("Loaded checkpoint '{}' (iter unknown)".format(
                    cfg["training"]["resume"]))

        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Example #26
0
def train(args):

    # Setup Augmentations
    data_aug = Compose([
        RandomRotate(10),
        RandomHorizontallyFlip(),
        RandomRotate(-15),
        RandomRotate(15),
        RandomRotate(-10),
        RandomCrop(size=256, padding=0),
        CenterCrop(size=256),
        HorizontallyFlip(),
        RandomSized(size=128),
        RandomSized(size=256),
        RandomSizedCrop(size=256),
    ])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols),
                           augmentations=data_aug)
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=False)
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)

    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1, )).cpu(),
                               Y=torch.zeros((1)).cpu(),
                               opts=dict(xlabel='minibatches',
                                         ylabel='Loss',
                                         title='Training Loss',
                                         legend=['Loss']))

    # Setup Model
    print('Arguments : {}'.format(args))
    model = get_model(args.arch, n_classes)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    print("Using {} GPUs...".format(range(torch.cuda.device_count())))

    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    #else:
    #    optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)
    else:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.l_rate,
                                     weight_decay=5e-4)

    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = cross_entropy2d

    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    best_iou = -100.0
    for epoch in range(args.n_epoch):
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            #print("Train images size : {}".format(images.size()))

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(X=torch.ones((1, 1)).cpu() * i,
                         Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                         win=loss_window,
                         update='append')

            if (i + 1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, loss.data[0]))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(state,
                       "{}_{}_best_model.pkl".format(args.arch, args.dataset))
Example #27
0
def validate(cfg, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.out_dir != "":
        if not os.path.exists(args.out_dir):
            os.mkdir(args.out_dir)
        if not os.path.exists(args.out_dir + 'hmaps_bg'):
            os.mkdir(args.out_dir + 'hmaps_bg')
        if not os.path.exists(args.out_dir + 'hmaps_fg'):
            os.mkdir(args.out_dir + 'hmaps_fg')
        if not os.path.exists(args.out_dir + 'pred'):
            os.mkdir(args.out_dir + 'pred')
        if not os.path.exists(args.out_dir + 'gt'):
            os.mkdir(args.out_dir + 'gt')
        if not os.path.exists(args.out_dir + 'qry_images'):
            os.mkdir(args.out_dir + 'qry_images')
        if not os.path.exists(args.out_dir + 'sprt_images'):
            os.mkdir(args.out_dir + 'sprt_images')
        if not os.path.exists(args.out_dir + 'sprt_gt'):
            os.mkdir(args.out_dir + 'sprt_gt')

    if args.fold != -1:
        cfg['data']['fold'] = args.fold

    fold = cfg['data']['fold']

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(
        data_path,
        split=cfg['data']['val_split'],
        is_transform=True,
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        n_classes=cfg['data']['n_classes'],
        fold=cfg['data']['fold'],
        binary=args.binary,
        k_shot=cfg['data']['k_shot'])

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=0)
    if args.binary:
        running_metrics = runningScore(2)
        fp_list = {}
        tp_list = {}
        fn_list = {}
    else:
        running_metrics = runningScore(
            n_classes + 1)  #+1 indicate the novel class thats added each time

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.to(device)
    model.freeze_all_except_classifiers()

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }
    model.save_original_weights()

    alpha = 0.14139
    for i, (sprt_images, sprt_labels, qry_images, qry_labels,
            original_sprt_images, original_qry_images,
            cls_ind) in enumerate(valloader):

        cls_ind = int(cls_ind)
        print('Starting iteration ', i)
        start_time = timeit.default_timer()
        if args.out_dir != "":
            save_images(original_sprt_images, sprt_labels, original_qry_images,
                        i, args.out_dir)

        for si in range(len(sprt_images)):
            sprt_images[si] = sprt_images[si].to(device)
            sprt_labels[si] = sprt_labels[si].to(device)
        qry_images = qry_images.to(device)

        # 1- Extract embedding and add the imprinted weights
        if args.iterations_imp > 0:
            model.iterative_imprinting(sprt_images,
                                       qry_images,
                                       sprt_labels,
                                       alpha=alpha,
                                       itr=args.iterations_imp)
        else:
            model.imprint(sprt_images,
                          sprt_labels,
                          alpha=alpha,
                          random=args.rand)

        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
        scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])
        loss_fn = get_loss_function(cfg)

        model.train()
        print('Finetuning')
        for j in range(cfg['training']['train_iters']):
            scheduler.step()
            for b in range(len(sprt_images)):
                torch.cuda.empty_cache()
                optimizer.zero_grad()

                outputs = model(sprt_images[b])
                loss = loss_fn(input=outputs, target=sprt_labels[b])
                loss.backward()
                optimizer.step()

                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}"
                print_str = fmt_str.format(j, cfg['training']['train_iters'],
                                           loss.item())
                print(print_str)

        # 2- Infer on the query image
        model.eval()
        with torch.no_grad():
            outputs = model(qry_images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        # Reverse the last imprinting (Few shot setting only not Continual Learning setup yet)
        model.reverse_imprinting()

        gt = qry_labels.numpy()
        if args.binary:
            gt, pred = post_process(gt, pred)

        if args.binary:
            if args.binary == 1:
                tp, fp, fn = running_metrics.update_binary_oslsm(gt, pred)

                if cls_ind in fp_list.keys():
                    fp_list[cls_ind] += fp
                else:
                    fp_list[cls_ind] = fp

                if cls_ind in tp_list.keys():
                    tp_list[cls_ind] += tp
                else:
                    tp_list[cls_ind] = tp

                if cls_ind in fn_list.keys():
                    fn_list[cls_ind] += fn
                else:
                    fn_list[cls_ind] = fn
            else:
                running_metrics.update(gt, pred)
        else:
            running_metrics.update(gt, pred)

        if args.out_dir != "":
            if args.binary:
                save_vis(outputs, pred, gt, i, args.out_dir, fg_class=1)
            else:
                save_vis(outputs, pred, gt, i, args.out_dir)

    if args.binary:
        if args.binary == 1:
            iou_list = [tp_list[ic]/float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1)) \
                         for ic in tp_list.keys()]
            print("Binary Mean IoU ", np.mean(iou_list))
        else:
            score, class_iou = running_metrics.get_scores()
            for k, v in score.items():
                print(k, v)
    else:
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        val_nclasses = model.n_classes + 1
        for i in range(val_nclasses):
            print(i, class_iou[i])
def train(cfg, writer, logger):
    # Setup dataset split before setting up the seed for random
    data_split_info = init_data_split(cfg['data']['path'], cfg['data'].get('split_ratio', 0), cfg['data'].get('compound', False))  # fly jenelia dataset

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Cross Entropy Weight
    if cfg['training']['loss']['name'] != 'regression_l1':
        weight = prep_class_val_weights(cfg['training']['cross_entropy_ratio'])
    else:
        weight = None
    log('Using loss : {}'.format(cfg['training']['loss']['name']))

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None) # if no augmentation => default None
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']
    patch_size = [cfg['training']['patch_size'], cfg['training']['patch_size']]

    t_loader = data_loader(
        data_path,
        split=cfg['data']['train_split'],
        augmentations=data_aug,
        data_split_info=data_split_info,
        patch_size=patch_size,
        allow_empty_patch = cfg['training'].get('allow_empty_patch', False),
        n_classes=cfg['training'].get('n_classes', 2))

    # v_loader = data_loader(
    #     data_path,
    #     split=cfg['data']['val_split'],
    #     data_split_info=data_split_info,
    #     patch_size=patch_size,
    #     n_classe=cfg['training'].get('n_classes', 1))

    n_classes = t_loader.n_classes
    log('n_classes is: {}'.format(n_classes))
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=False)

    print('trainloader len: ', len(trainloader))
    # Setup Metrics
    running_metrics_val = runningScore(n_classes) # a confusion matrix is created


    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)


    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    softmax_function = nn.Softmax(dim=1)

    # model_count = 0
    min_loss = None
    start_iter = 0
    if cfg['training']['resume'] is not None:
        log('resume saved model')
        if os.path.isfile(cfg['training']['resume']):
            display(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            min_loss = checkpoint["min_loss"]
            display(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            display("No checkpoint found at '{}'".format(cfg['training']['resume']))
            log('no saved model found')

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    if cfg['training']['loss']['name'] == 'dice':
        loss_fn = dice_loss()


    i_train_iter = start_iter

    display('Training from {}th iteration\n'.format(i_train_iter))
    while i_train_iter < cfg['training']['train_iters']:
        i_batch_idx = 0
        train_iter_start_time = time.time()
        averageLoss = 0

        # if cfg['training']['loss']['name'] == 'dice':
        #     loss = dice_loss()

        # training
        for (images, labels) in trainloader:
            start_ts = time.time()
            scheduler.step()
            model.train()

            images = images.to(device)
            labels = labels.to(device)

            # images = images.cuda()
            # labels = labels.cuda()

            soft_loss = -1
            mediate_average_loss = -1

            optimizer.zero_grad()
            outputs = model(images)

            if cfg['training']['loss']['name'] == 'dice':
                loss = loss_fn(outputs, labels)
                # print('loss match: ', loss, loss.item())
                averageLoss += loss.item()
            #
            else:
                hard_loss = loss_fn(input=outputs, target=labels, weight=weight,
                                size_average=cfg['training']['loss']['size_average'])

                loss = hard_loss

                averageLoss += loss
            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)
            print_per_batch_check = True if cfg['training']['print_interval_per_batch'] else i_batch_idx+1 == len(trainloader)
            if (i_train_iter + 1) % cfg['training']['print_interval'] == 0 and print_per_batch_check:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i_train_iter + 1,
                                           cfg['training']['train_iters'],
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                display(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1)
                time_meter.reset()
            i_batch_idx += 1
        time_for_one_iteration = time.time() - train_iter_start_time

        display('EntireTime for {}th training iteration: {}  EntireTime/Image: {}'.format(i_train_iter+1, time_converter(time_for_one_iteration),
                                                                                          time_converter(time_for_one_iteration/(len(trainloader)*cfg['training']['batch_size']))))
        averageLoss /= (len(trainloader)*cfg['training']['batch_size'])
        # validation
        validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \
                           (i_train_iter + 1) == cfg['training']['train_iters']
        if not validation_check:
            print('no validation check')
        else:

            '''
            This IF-CHECK is used to update the best model
            '''
            log('Validation: average loss for current iteration is: {}'.format(averageLoss))
            if min_loss is None:
                min_loss = averageLoss

            if averageLoss <= min_loss:
                min_loss = averageLoss
                state = {
                    "epoch": i_train_iter + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "min_loss": min_loss
                }

                # if cfg['training']['cp_save_path'] is None:
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_model_best.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                # else:
                #     save_path = os.path.join(cfg['training']['cp_save_path'],  writer.file_writer.get_logdir(),
                #                              "{}_{}_model_best.pkl".format(
                #                                  cfg['model']['arch'],
                #                                  cfg['data']['dataset']))
                print('save_path is: ' + save_path)

                torch.save(state, save_path)

            # model_count += 1

        i_train_iter += 1
Example #29
0
def test(args):
    print("Starting...")

    base = "/home/greg/datasets/rootset"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_file_name = os.path.split(args.model_path)[1]
    model_name = model_file_name[:model_file_name.find("_")]

    data_loader = get_loader(args.dataset)
    loader = data_loader(root=base,
                         mode=args.mode,
                         is_transform=True,
                         img_norm=args.img_norm,
                         test_mode=True)
    n_classes = loader.n_classes

    # Setup image
    img_path = args.image + ".png"
    print("Read Input Image from : {}".format(base + "/images/" + img_path))
    img = io.imread(base + "/images/" + img_path)
    if n_classes == 2:
        gt = io.imread(base + "/segmentation8/" + img_path, 0)
    elif n_classes == 5:
        gt = io.imread(base + "/combined/" + img_path, 0)
    #gt = np.array(gt, dtype=np.int8)
    #if n_classes == 2:
    #    gt[gt == -1] = 1

    running_metrics = runningScore(n_classes)

    #resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")

    orig_size = img.shape[:-1]
    #if model_name in ["pspnet", "icnet", "icnetBN", "root"]:
    # uint8 with RGB mode, resize width and height which are odd numbers
    #img = misc.imresize(img, (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1))
    resized_image = np.array(
        Image.fromarray(img).resize(
            (orig_size[0] // 2 * 2 + 1, orig_size[1] // 2 * 2 + 1)))
    #else:
    #    img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]))

    img = img[:, :, ::-1]
    img = img.astype(np.float64)
    img -= loader.mean
    if args.img_norm:
        img = img.astype(float) / 255.0

    # NHWC -> NCHW
    img = img.transpose(2, 0, 1)
    img = np.expand_dims(img, 0)
    img = torch.from_numpy(img).float()

    # Setup Model
    model_dict = {"arch": model_name}
    model = get_model(model_dict, n_classes, version=args.dataset)
    state = convert_state_dict(
        torch.load(args.model_path, map_location='cpu')["model_state"])
    model.load_state_dict(state)
    model.eval()
    model.to(device)

    images = img.to(device)
    print("Running network...")
    outputs = model(images)

    if args.dcrf:
        unary = outputs.data.cpu().numpy()
        unary = np.squeeze(unary, 0)
        unary = -np.log(unary)
        unary = unary.transpose(2, 1, 0)
        w, h, c = unary.shape
        unary = unary.transpose(2, 0, 1).reshape(loader.n_classes, -1)
        unary = np.ascontiguousarray(unary)

        resized_img = np.ascontiguousarray(resized_img)

        d = dcrf.DenseCRF2D(w, h, loader.n_classes)
        d.setUnaryEnergy(unary)
        d.addPairwiseBilateral(sxy=5, srgb=3, rgbim=resized_img, compat=1)

        q = d.inference(50)
        mask = np.argmax(q, axis=0).reshape(w, h).transpose(1, 0)
        decoded_crf = loader.decode_segmap(np.array(mask, dtype=np.uint8))
        dcrf_path = args.out_path[:-4] + "_drf.png"
        misc.imsave(dcrf_path, decoded_crf)
        print("Dense CRF Processed Mask Saved at: {}".format(dcrf_path))

    pred = np.squeeze(outputs.data.max(1)[1].cpu().numpy(), axis=0)
    if model_name in ["pspnet", "icnet", "icnetBN"]:
        pred = pred.astype(np.float32)
        # float32 with F mode, resize back to orig_size
        pred = misc.imresize(pred, orig_size, "nearest", mode="F")

    running_metrics.update(gt, pred)

    decoded = loader.decode_segmap(pred)
    decoded = decoded.astype(np.float32)
    origimg = io.imread(base + "/images/" + img_path)
    origimg = origimg.astype(np.float32) / 255.0
    gtimg = decode_segmap(gt)

    stack = np.hstack(
        (origimg, gtimg, decoded[:origimg.shape[0], :origimg.shape[1]]))
    print("Done!")
    print("Time taken:", int((time.time() - start) * 10) / 10.0, "seconds")

    if n_classes == 5:
        score, class_iou = running_metrics.get_scores()

        print("\n===IoU===")
        for i in range(1, n_classes):
            classnames = [
                "Background", "Root", "Seed", "Lateral tips", "Primary tips"
            ]
            print("%s:\t%.2f" % (classnames[i], class_iou[i]))

    plt.imshow(stack)
    plt.show()
    tosave = (255 * decoded).astype(np.uint8)
    io.imsave("final/output/" + img_path, tosave)
def validate(cfg, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.out_dir != "":
        if not os.path.exists(args.out_dir):
            os.mkdir(args.out_dir)
        if not os.path.exists(args.out_dir + 'hmaps_bg'):
            os.mkdir(args.out_dir + 'hmaps_bg')
        if not os.path.exists(args.out_dir + 'hmaps_fg'):
            os.mkdir(args.out_dir + 'hmaps_fg')
        if not os.path.exists(args.out_dir + 'pred'):
            os.mkdir(args.out_dir + 'pred')
        if not os.path.exists(args.out_dir + 'gt'):
            os.mkdir(args.out_dir + 'gt')
        if not os.path.exists(args.out_dir + 'qry_images'):
            os.mkdir(args.out_dir + 'qry_images')
        if not os.path.exists(args.out_dir + 'sprt_images'):
            os.mkdir(args.out_dir + 'sprt_images')
        if not os.path.exists(args.out_dir + 'sprt_gt'):
            os.mkdir(args.out_dir + 'sprt_gt')

    if args.fold != -1:
        cfg['data']['fold'] = args.fold

    fold = cfg['data']['fold']

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(
        data_path,
        split=cfg['data']['val_split'],
        is_transform=True,
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        n_classes=cfg['data']['n_classes'],
        fold=cfg['data']['fold'],
        binary=args.binary,
        k_shot=cfg['data']['k_shot'])

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=0)
    if args.binary:
        running_metrics = runningScore(2)
        fp_list = {}
        tp_list = {}
        fn_list = {}
    else:
        running_metrics = runningScore(
            n_classes + 1)  #+1 indicate the novel class thats added each time

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.to(device)

    if not args.cl:
        print('No Continual Learning of Bg Class')
        model.save_original_weights()

    alpha = 0.25821
    for i, (sprt_images, sprt_labels, qry_images, qry_labels,
            original_sprt_images, original_qry_images,
            cls_ind) in enumerate(valloader):
        cls_ind = int(cls_ind)
        print('Starting iteration ', i)
        start_time = timeit.default_timer()
        if args.out_dir != "":
            save_images(original_sprt_images, sprt_labels, original_qry_images,
                        i, args.out_dir)

        for si in range(len(sprt_images)):
            sprt_images[si] = sprt_images[si].to(device)
            sprt_labels[si] = sprt_labels[si].to(device)
        qry_images = qry_images.to(device)

        # 1- Extract embedding and add the imprinted weights
        model.imprint(sprt_images, sprt_labels, alpha=alpha)

        # 2- Infer on the query image
        model.eval()
        with torch.no_grad():
            outputs = model(qry_images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        # Reverse the last imprinting (Few shot setting only not Continual Learning setup yet)
        model.reverse_imprinting(args.cl)

        gt = qry_labels.numpy()
        if args.binary:
            gt, pred = post_process(gt, pred)

        if args.binary:
            if args.binary == 1:
                tp, fp, fn = running_metrics.update_binary_oslsm(gt, pred)

                if cls_ind in fp_list.keys():
                    fp_list[cls_ind] += fp
                else:
                    fp_list[cls_ind] = fp

                if cls_ind in tp_list.keys():
                    tp_list[cls_ind] += tp
                else:
                    tp_list[cls_ind] = tp

                if cls_ind in fn_list.keys():
                    fn_list[cls_ind] += fn
                else:
                    fn_list[cls_ind] = fn

            else:
                running_metrics.update(gt, pred)
        else:
            running_metrics.update(gt, pred)

        if args.out_dir != "":
            if args.binary:
                save_vis(outputs, pred, gt, i, args.out_dir, fg_class=1)
            else:
                save_vis(outputs, pred, gt, i, args.out_dir)

    if args.binary:
        if args.binary == 1:
            iou_list = [tp_list[ic]/float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1)) \
                         for ic in tp_list.keys()]
            print("Binary Mean IoU ", np.mean(iou_list))
        else:
            score, class_iou = running_metrics.get_scores()
            for k, v in score.items():
                print(k, v)
    else:
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        val_nclasses = model.n_classes + 1
        for i in range(val_nclasses):
            print(i, class_iou[i])
Example #31
0
def train(args):

    # Setup Augmentations
    data_aug= Compose([RandomRotate(10),                                        
                       RandomHorizontallyFlip()])

    # Setup Dataloader
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, is_transform=True, img_size=(args.img_rows, args.img_cols), augmentations=data_aug)
    v_loader = data_loader(data_path, is_transform=True, split='val', img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=8, shuffle=True)
    valloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=8)

    # Setup Metrics
    running_metrics = runningScore(n_classes)
        
    # Setup visdom for visualization
    if args.visdom:
        vis = visdom.Visdom()

        loss_window = vis.line(X=torch.zeros((1,)).cpu(),
                           Y=torch.zeros((1)).cpu(),
                           opts=dict(xlabel='minibatches',
                                     ylabel='Loss',
                                     title='Training Loss',
                                     legend=['Loss']))

    # Setup Model
    model = get_model(args.arch, n_classes)
    
    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model.cuda()
    
    # Check if model has custom optimizer / loss
    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.l_rate, momentum=0.99, weight_decay=5e-4)

    if hasattr(model.module, 'loss'):
        print('Using custom loss')
        loss_fn = model.module.loss
    else:
        loss_fn = cross_entropy2d

    if args.resume is not None:                                         
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})"                    
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume)) 

    best_iou = -100.0 
    for epoch in range(args.n_epoch):
        model.train()
        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            if args.visdom:
                vis.line(
                    X=torch.ones((1, 1)).cpu() * i,
                    Y=torch.Tensor([loss.data[0]]).unsqueeze(0).cpu(),
                    win=loss_window,
                    update='append')

            if (i+1) % 20 == 0:
                print("Epoch [%d/%d] Loss: %.4f" % (epoch+1, args.n_epoch, loss.data[0]))

        model.eval()
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = Variable(images_val.cuda(), volatile=True)
            labels_val = Variable(labels_val.cuda(), volatile=True)

            outputs = model(images_val)
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()
            running_metrics.update(gt, pred)

        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        running_metrics.reset()

        if score['Mean IoU : \t'] >= best_iou:
            best_iou = score['Mean IoU : \t']
            state = {'epoch': epoch+1,
                     'model_state': model.state_dict(),
                     'optimizer_state' : optimizer.state_dict(),}
            torch.save(state, "{}_{}_best_model.pkl".format(args.arch, args.dataset))
Example #32
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    #    data_loader = get_loader(cfg['data']['dataset'])
    #    data_path = cfg['data']['path']
    #
    #    t_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['train_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    #        augmentations=data_aug)
    #
    #    v_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['val_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
    #
    #    n_classes = t_loader.n_classes
    #    trainloader = data.DataLoader(t_loader,
    #                                  batch_size=cfg['training']['batch_size'],
    #                                  num_workers=cfg['training']['n_workers'],
    #                                  shuffle=True)
    #
    #    valloader = data.DataLoader(v_loader,
    #                                batch_size=cfg['training']['batch_size'],
    #                                num_workers=cfg['training']['n_workers'])

    paths = {
        'masks': './satellitedata/patchvai_train/gt/',
        'images': './satellitedata/patchvai_train/rgb',
        'nirs': './satellitedata/patchvai_train/nir',
        'swirs': './satellitedata/patchvai_train/swir',
        'vhs': './satellitedata/patchvai_train/vh',
        'vvs': './satellitedata/patchvai_train/vv',
        'redes': './satellitedata/patchvai_train/rede',
        'ndvis': './satellitedata/patchvai_train/ndvi',
    }

    valpaths = {
        'masks': './satellitedata/patchvai_val/gt/',
        'images': './satellitedata/patchvai_val/rgb',
        'nirs': './satellitedata/patchvai_val/nir',
        'swirs': './satellitedata/patchvai_val/swir',
        'vhs': './satellitedata/patchvai_val/vh',
        'vvs': './satellitedata/patchvai_val/vv',
        'redes': './satellitedata/patchvai_val/rede',
        'ndvis': './satellitedata/patchvai_val/ndvi',
    }

    n_classes = 3
    train_img_paths = [
        pth for pth in os.listdir(paths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    val_img_paths = [
        pth for pth in os.listdir(valpaths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')

    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = './satellitedata/'
        print('train_data_path: {}'.format(train_data_path))
        dataset_path, train_dir = os.path.split(train_data_path)
        print('dataset_path: {}'.format(dataset_path) +
              ',  train_dir: {}'.format(train_dir))
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = TrainDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'],
                                shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    k = 0
    nbackground = 0
    ncorn = 0
    #ncotton = 0
    #nrice = 0
    nsoybean = 0

    for indata in trainloader:
        k += 1
        gt = indata['seg_label'].data.cpu().numpy()
        nbackground += (gt == 0).sum()
        ncorn += (gt == 1).sum()
        #ncotton += (gt == 2).sum()
        #nrice += (gt == 3).sum()
        nsoybean += (gt == 2).sum()

    print('k = {}'.format(k))
    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    #print('ncotton: {}'.format(ncotton))
    #print('nrice: {}'.format(nrice))
    print('nsoybean: {}'.format(nsoybean))

    wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0] / total_wgts
    wgt_corn = wgts[1] / total_wgts
    #wgt_cotton = wgts[2]/total_wgts
    #wgt_rice = wgts[3]/total_wgts
    wgt_soybean = wgts[2] / total_wgts
    weights = torch.autograd.Variable(
        torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    #weights = torch.autograd.Variable(torch.cuda.FloatTensor([1.0, 1.0, 1.0]))

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for inputdata in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = inputdata['img_data']
            labels = inputdata['seg_label']
            #print('images.size: {}'.format(images.size()))
            #print('labels.size: {}'.format(labels.size()))
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            #print('outputs.size: {}'.format(outputs[1].size()))
            #print('labels.size: {}'.format(labels.size()))

            loss = loss_fn(input=outputs[1], target=labels, weight=weights)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for inputdata in valloader:
                        images_val = inputdata['img_data']
                        labels_val = inputdata['seg_label']
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break