Beispiel #1
0
def test(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(
        cfg, writer, logger
    )  #source_train\ target_train\ source_valid\ target_valid + _loader

    model = CustomModel(cfg, writer, logger)
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    path = cfg['test']['path']
    checkpoint = torch.load(path)
    model.adaptive_load_nets(model.BaseNet,
                             checkpoint['DeepLab']['model_state'])

    validation(
                model, logger, writer, datasets, device, running_metrics_val, val_loss_meter, loss_fn,\
                source_val_loss_meter, source_running_metrics_val, iters = model.iter
                )
Beispiel #2
0
def test(data_loader, model_fe, model_cls, thres):

    # setup average meter
    ACC = averageMeter()
    CPB = averageMeter()

    # set evaluation mode
    model_fe.eval()
    model_cls.eval()

    for (step, value) in enumerate(data_loader):

        image = value[0].cuda()
        target = value[1].cuda(async=True)

        # forward
        _, imfeat = model_fe(x=image, feat=True)
        output, nout = model_cls(x=imfeat, gate=None, pred=True, thres=thres)

        # get predictions
        max_z = torch.max(output, dim=1)[0]
        preds = torch.eq(output, max_z.view(-1, 1))

        # get boolean of correct prediction (correct: 1, incorrect: 0) and measure accuracy
        iscorrect = torch.gather(preds, 1, target.view(-1, 1)).flatten().float().cpu().data.numpy()
        ACC.update(np.mean(iscorrect), image.size(0))

        # compute PB & CPB
        Y_v = torch.sum(preds.float(), dim=1).data.cpu().numpy()
        Y_v[Y_v == 1] = 0
        pb = (num_class - Y_v) / num_class
        cpb = pb * iscorrect
        CPB.update(np.mean(cpb), image.size(0))

    return ACC.avg, CPB.avg
Beispiel #3
0
def train(data_loader, model_fe, model_cls, model_pivot, opt_main, epoch,
          criterion):

    # setup average meters
    batch_time = averageMeter()
    data_time = averageMeter()
    nlosses = averageMeter()
    stslosses = averageMeter()
    losses = averageMeter()
    acc = averageMeter()

    # setting training mode
    model_fe.train()
    model_cls.train()
    model_pivot.train()

    end = time.time()
    for (step, value) in enumerate(data_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        image = value[0].cuda()
        target = value[1].cuda(async=True)

        # forward
        _, imfeat = model_fe(x=image, feat=True)
        gate = model_pivot(torch.ones([image.size(0), n_nodes]))
        gate[:, 0] = 1
        output, nout, sfmx_base = model_cls(x=imfeat, gate=gate)

        # compute node-conditional consistency loss at each node for each sample
        nloss = []
        for idx in range(image.size(0)):
            for n_id, n_l in node_labels[value[1].numpy()[idx]]:
                nloss.append(
                    criterion(nout[n_id][idx, :].view(1, -1),
                              torch.tensor([n_l]).cuda()))

        nloss = torch.mean(torch.stack(nloss))
        nlosses.update(nloss.item(), image.size(0))

        # compute stochastic tree ssampling loss
        gt_z = torch.gather(output, 1, target.view(-1, 1))
        stsloss = torch.mean(
            -gt_z + torch.log(torch.clamp(sfmx_base.view(-1, 1), 1e-17, 1e17)))
        stslosses.update(stsloss.item(), image.size(0))

        loss = nloss + stsloss * lmbda
        losses.update(loss.item(), image.size(0))

        # measure accuracy
        max_z = torch.max(output, dim=1)[0]
        preds = torch.eq(output, max_z.view(-1, 1))
        iscorrect = torch.gather(preds, 1, target.view(-1,
                                                       1)).flatten().float()
        acc.update(torch.mean(iscorrect).item(), image.size(0))

        # back propagation
        opt_main.zero_grad()
        loss.backward()
        opt_main.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (step + 1) % 10 == 0:
            curr_lr_main = opt_main.param_groups[0]['lr']
            print_str = 'Epoch [{0}/{1}]\t' \
                'Step: [{2}/{3}]\t' \
                'LR: [{4}]\t' \
                'Time {batch_time.avg:.3f}\t' \
                'Data {data_time.avg:.3f}\t' \
                'Loss {loss.avg:.4f}\t' \
                'Acc {acc.avg:.3f}'.format(
                    epoch + 1, cfg['training']['epoch'], step + 1, n_step, curr_lr_main, batch_time=batch_time,
                    data_time=data_time, loss=losses, acc=acc
                )

            print(print_str)
            logger.info(print_str)

    if (epoch + 1) % cfg['training']['print_interval'] == 0:
        curr_lr_main = opt_main.param_groups[0]['lr']
        print_str = 'Epoch: [{0}/{1}]\t' \
            'LR: [{2}]\t' \
            'Time {batch_time.avg:.3f}\t' \
            'Data {data_time.avg:.3f}\t' \
            'Loss {loss.avg:.4f}\t' \
            'Acc {acc.avg:.3f}'.format(
                epoch + 1, cfg['training']['epoch'], curr_lr_main, batch_time=batch_time,
                data_time=data_time, loss=losses, acc=acc
            )

        print(print_str)
        logger.info(print_str)
        writer.add_scalar('train/lr', curr_lr_main, epoch + 1)
        writer.add_scalar('train/nloss', nlosses.avg, epoch + 1)
        writer.add_scalar('train/stsloss', stslosses.avg, epoch + 1)
        writer.add_scalar('train/loss', losses.avg, epoch + 1)
        writer.add_scalar('train/acc', acc.avg, epoch + 1)
Beispiel #4
0
def train(cfg):
    
    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        #img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        #img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=cfg['training']['batch_size'], 
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
 
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            print("=====>",
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            print("=====>","No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()
            
            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()


                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())


                print("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k,':',v)

                for k, v in class_iou.items():
                    print('{}: {}'.format(k, v))

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join('./checkpoint',
                                             "{}_{}_best_model.pkl".format(
                                                 cfg['model']['arch'],
                                                 cfg['data']['dataset']))
                    print("saving···")
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Beispiel #5
0
def train(cfg, logger):

    # Setup Seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup Device
    device = torch.device("cuda:{}".format(cfg["training"]["gpu_idx"])
                          if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
    )

    n_classes = t_loader.n_classes
    n_val = len(v_loader.files['val'])

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes, n_val)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=[cfg["training"]["gpu_idx"]])

    # Setup Optimizer, lr_scheduler and Loss Function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    # Resume Trained Model
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # Start Training
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    start_iter = 0
    best_dice = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, img_name) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # print train loss
            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            # validation
            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                img_name_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred, i_val)
                        val_loss_meter.update(val_loss.item())

                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                # print val metrics
                score, class_dice = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))

                for k, v in class_dice.items():
                    logger.info("{}: {}".format(k, v))

                val_loss_meter.reset()
                running_metrics_val.reset()

                # save model
                if score["Dice : \t"] >= best_dice:
                    best_dice = score["Dice : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_dice": best_dice,
                    }
                    save_path = os.path.join(
                        cfg["training"]["model_dir"],
                        "{}_{}.pkl".format(cfg["model"]["arch"],
                                           cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
    def train_epoch(self):
        if self.epoch % self.val_epoch == 0 or self.epoch == 1:
            self.validate()

        self.model.train()
        train_metrics = runningScore(self.n_classes)
        train_loss_meter = averageMeter()

        self.optim.zero_grad()

        for rgb, ir, target in tqdm.tqdm(
                self.train_loader, total=len(self.train_loader),
                desc=f'Train epoch={self.epoch}', ncols=80, leave=False):

            self.iter += 1
            assert self.model.training

            rgb, ir, target = rgb.to(self.device), ir.to(self.device), target.to(self.device)
            score = self.model(rgb, ir)
            # score = self.model(rgb)

            weight = self.train_loader.dataset.class_weight
            if weight:
                weight = torch.Tensor(weight).to(self.device)

            loss = CrossEntropyLoss(score, target, weight=weight, ignore_index=-1, reduction='mean')

            loss_data = loss.data.item()
            train_loss_meter.update(loss_data)

            if np.isnan(loss_data):
                raise ValueError('loss is nan while training')

            # loss.backward(retain_graph=True)
            loss.backward()

            self.optim.step()
            self.optim.zero_grad()

            if isinstance(score, (tuple, list)):
                lbl_pred = score[0].data.max(1)[1].cpu().numpy()
            else:
                lbl_pred = score.data.max(1)[1].cpu().numpy()
            lbl_true = target.data.cpu().numpy()
            train_metrics.update(lbl_true, lbl_pred)

        acc, acc_cls, mean_iou, fwavacc, _ = train_metrics.get_scores()
        metrics = [acc, acc_cls, mean_iou, fwavacc]

        with open(osp.join(self.out, 'log.csv'), 'a') as f:
            elapsed_time = (
                datetime.datetime.now(pytz.timezone('UTC')) -
                self.timestamp_start).total_seconds()
            log = [self.epoch] + [train_loss_meter.avg] + \
                metrics + [''] * 5 + [elapsed_time]
            log = map(str, log)
            f.write(','.join(log) + '\n')

        if self.scheduler:
            self.scheduler.step()
        if self.epoch % self.val_epoch == 0 or self.epoch == 1:
            lr = self.optim.param_groups[0]['lr']
            print(f'\nCurrent base learning rate of epoch {self.epoch}: {lr:.7f}')

        train_loss_meter.reset()
        train_metrics.reset()
    def validate(self):

        visualizations = []
        val_metrics = runningScore(self.n_classes)
        val_loss_meter = averageMeter()

        with torch.no_grad():
            self.model.eval()
            for rgb, ir, target in tqdm.tqdm(
                    self.val_loader, total=len(self.val_loader),
                    desc=f'Valid epoch={self.epoch}', ncols=80, leave=False):

                rgb, ir, target = rgb.to(self.device), ir.to(self.device), target.to(self.device)

                score = self.model(rgb, ir)
                # score = self.model(rgb)

                weight = self.val_loader.dataset.class_weight
                if weight:
                    weight = torch.Tensor(weight).to(self.device)

                loss = CrossEntropyLoss(score, target, weight=weight, reduction='mean', ignore_index=-1)
                loss_data = loss.data.item()
                if np.isnan(loss_data):
                    raise ValueError('loss is nan while validating')

                val_loss_meter.update(loss_data)

                rgbs = rgb.data.cpu()
                irs = ir.data.cpu()

                if isinstance(score, (tuple, list)):
                    lbl_pred = score[0].data.max(1)[1].cpu().numpy()
                else:
                    lbl_pred = score.data.max(1)[1].cpu().numpy()
                lbl_true = target.data.cpu()

                for rgb, ir, lt, lp in zip(rgbs, irs, lbl_true, lbl_pred):
                    rgb, ir, lt = self.val_loader.dataset.untransform(rgb, ir, lt)
                    val_metrics.update(lt, lp)
                    if len(visualizations) < 9:
                        viz = visualize_segmentation(
                            lbl_pred=lp, lbl_true=lt, img=rgb, ir=ir,
                            n_classes=self.n_classes, dataloader=self.train_loader)
                        visualizations.append(viz)

        acc, acc_cls, mean_iou, fwavacc, cls_iu = val_metrics.get_scores()
        metrics = [acc, acc_cls, mean_iou, fwavacc]

        print(f'\nEpoch: {self.epoch}', f'loss: {val_loss_meter.avg}, mIoU: {mean_iou}')

        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'epoch{:0>5d}.jpg'.format(self.epoch))
        scipy.misc.imsave(out_file, get_tile_image(visualizations))

        with open(osp.join(self.out, 'log.csv'), 'a') as f:
            elapsed_time = (
                datetime.datetime.now(pytz.timezone('UTC')) -
                self.timestamp_start).total_seconds()
            log = [self.epoch] + [''] * 5 + \
                  [val_loss_meter.avg] + metrics + [elapsed_time]
            log = map(str, log)
            f.write(','.join(log) + '\n')

        mean_iu = metrics[2]
        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save({
            'epoch': self.epoch,
            'arch': self.model.__class__.__name__,
            'optim_state_dict': self.optim.state_dict(),
            'model_state_dict': self.model.state_dict(),
            'best_mean_iu': self.best_mean_iu,
        }, osp.join(self.out, 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
                        osp.join(self.out, 'model_best.pth.tar'))

        val_loss_meter.reset()
        val_metrics.reset()

        class_name = self.val_loader.dataset.class_names
        if class_name is not None:
            for index, value in enumerate(cls_iu.values()):
                offset = 20 - len(class_name[index])
                print(class_name[index] + ' ' * offset + f'{value * 100:>.2f}')
        else:
            print("\nyou don't specify class_names, use number instead")
            for key, value in cls_iu.items():
                print(key, f'{value * 100:>.2f}')
Beispiel #8
0
def train(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device(
        "cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(cfg, writer, logger)

    use_pseudo_label = False
    model = CustomModel(cfg, writer, logger, use_pseudo_label, modal_num=3)

    # Setup Metrics
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    flag_train = True

    epoches = cfg['training']['epoches']

    source_train_loader = datasets.source_train_loader
    target_train_loader = datasets.target_train_loader

    logger.info('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(
        source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(
        target_train_loader.batch_size))
    print('target train batchsize is {}'.format(
        target_train_loader.batch_size))

    val_loader = None
    if cfg.get('valset') == 'gta5':
        val_loader = datasets.source_valid_loader
        logger.info('valset is gta5')
        print('valset is gta5')
    else:
        val_loader = datasets.target_valid_loader
        logger.info('valset is cityscapes')
        print('valset is cityscapes')
    logger.info('val batchsize is {}'.format(val_loader.batch_size))
    print('val batchsize is {}'.format(val_loader.batch_size))

    # load category anchors
    """
    objective_vectors = torch.load('category_anchors')
    model.objective_vectors = objective_vectors['objective_vectors']
    model.objective_vectors_num = objective_vectors['objective_num']
    """

    # begin training
    model.iter = 0
    for epoch in range(epoches):
        if not flag_train:
            break
        if model.iter > cfg['training']['train_iters']:
            break

        if use_pseudo_label:
            # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA
            score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores(
            )
            print('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))

            logger.info('clus_IoU: {}'.format(score_cl["Mean IoU : \t"]))
            logger.info('clus_Recall: {}'.format(
                model.metrics.calc_mean_Clu_recall()))
            logger.info(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 1])
            logger.info('clus_Acc: {}'.format(
                np.mean(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 1])))
            logger.info(model.metrics.classes_recall_clu[:, 0] /
                        model.metrics.classes_recall_clu[:, 2])

            score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores(
            )
            logger.info('thr_IoU: {}'.format(score_cl["Mean IoU : \t"]))
            logger.info('thr_Recall: {}'.format(
                model.metrics.calc_mean_Thr_recall()))
            logger.info(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 1])
            logger.info('thr_Acc: {}'.format(
                np.mean(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 1])))
            logger.info(model.metrics.classes_recall_thr[:, 0] /
                        model.metrics.classes_recall_thr[:, 2])
        model.metrics.reset()

        for (target_image, target_label,
             target_img_name) in datasets.target_train_loader:
            model.iter += 1
            i = model.iter
            if i > cfg['training']['train_iters']:
                break
            source_batchsize = cfg['data']['source']['batch_size']
            # load source data
            images, labels, source_img_name = datasets.source_train_loader.next(
            )
            start_ts = time.time()
            images = images.to(device)
            labels = labels.to(device)
            # load target data
            target_image = target_image.to(device)
            target_label = target_label.to(device)
            #model.scheduler_step()
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn') == True:
                model.freeze_bn_apply()
            model.optimizer_zerograd()
            # Switch on modals
            source_modal_ids = []
            for _img_name in source_img_name:
                if 'gtav2cityscapes' in _img_name:
                    source_modal_ids.append(0)
                elif 'gtav2cityfoggy' in _img_name:
                    source_modal_ids.append(1)
                elif 'gtav2cityrain' in _img_name:
                    source_modal_ids.append(2)
                else:
                    assert False, "[ERROR] unknown image source, neither gtav2cityscapes, gtav2cityfoggy!"

            target_modal_ids = []
            for _img_name in target_img_name:
                if 'Cityscapes_foggy' in _img_name:
                    target_modal_ids.append(1)
                elif 'Cityscapes_rain' in _img_name:
                    target_modal_ids.append(2)
                else:
                    target_modal_ids.append(0)

            loss, loss_cls_L2, loss_pseudo = model.step(
                images, labels, source_modal_ids, target_image, target_label,
                target_modal_ids, use_pseudo_label)
            # scheduler step
            model.scheduler_step()
            if loss_cls_L2 > 10:
                logger.info('loss_cls_l2 abnormal!!')

            time_meter.update(time.time() - start_ts)
            if (i + 1) % cfg['training']['print_interval'] == 0:
                unchanged_cls_num = 0
                if use_pseudo_label:
                    fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss: {:.4f}  Loss_L2: {:.4f}  Loss_pseudo: {:.4f}  Time/Image: {:.4f}"
                else:
                    fmt_str = "Epoches [{:d}/{:d}] Iter [{:d}/{:d}]  Loss_GTA: {:.4f}  Loss_adv: {:.4f}  Loss_D: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    epoch + 1, epoches, i + 1, cfg['training']['train_iters'],
                    loss.item(), loss_cls_L2.item(), loss_pseudo.item(),
                    time_meter.avg / cfg['data']['source']['batch_size'])

                print(print_str)
                logger.info(print_str)
                logger.info(
                    'unchanged number of objective class vector: {}'.format(
                        unchanged_cls_num))
                if use_pseudo_label:
                    loss_names = [
                        'train_loss', 'train_L2Loss', 'train_pseudoLoss'
                    ]
                else:
                    loss_names = [
                        'train_loss_GTA', 'train_loss_adv', 'train_loss_D'
                    ]
                writer.add_scalar('loss/{}'.format(loss_names[0]), loss.item(),
                                  i + 1)
                writer.add_scalar('loss/{}'.format(loss_names[1]),
                                  loss_cls_L2.item(), i + 1)
                writer.add_scalar('loss/{}'.format(loss_names[2]),
                                  loss_pseudo.item(), i + 1)
                time_meter.reset()

                if use_pseudo_label:
                    score_cl, _ = model.metrics.running_metrics_val_clusters.get_scores(
                    )
                    logger.info('clus_IoU: {}'.format(
                        score_cl["Mean IoU : \t"]))
                    logger.info('clus_Recall: {}'.format(
                        model.metrics.calc_mean_Clu_recall()))
                    logger.info('clus_Acc: {}'.format(
                        np.mean(model.metrics.classes_recall_clu[:, 0] /
                                model.metrics.classes_recall_clu[:, 2])))

                    score_cl, _ = model.metrics.running_metrics_val_threshold.get_scores(
                    )
                    logger.info('thr_IoU: {}'.format(
                        score_cl["Mean IoU : \t"]))
                    logger.info('thr_Recall: {}'.format(
                        model.metrics.calc_mean_Thr_recall()))
                    logger.info('thr_Acc: {}'.format(
                        np.mean(model.metrics.classes_recall_thr[:, 0] /
                                model.metrics.classes_recall_thr[:, 2])))

            # evaluation
            if (i + 1) % cfg['training']['val_interval'] == 0 or \
                (i + 1) == cfg['training']['train_iters']:
                validation(
                    model, logger, writer, datasets, device, running_metrics_val, val_loss_meter, loss_fn,\
                    source_val_loss_meter, source_running_metrics_val, iters = model.iter
                    )
                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))
            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
    target = target.view(-1)
    loss = F.cross_entropy(input,
                           target,
                           weight=weight,
                           size_average=size_average,
                           ignore_index=250)
    return loss


# optimier
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
#optimizer = Adam(model.parameters(), lr = 0.01)

# Setup Metrics
running_metrics_val = runningScore(n_classes)
val_loss_meter = averageMeter()

num_epochs = 30
step = 0
epoch = 0
if load_model_file is not None:
    step = start_step
    epoch = start_epoch

score_list = []
while epoch <= num_epochs:
    epoch += 1
    print("Starting epoch %s" % epoch)
    for (images, labels) in trainloader:
        step += 1
        start_ts = time.time()
Beispiel #10
0
def train(cycle_num,
          dirs,
          path_to_net,
          plotter,
          batch_size=12,
          test_split=0.3,
          random_state=666,
          epochs=100,
          learning_rate=0.0001,
          momentum=0.9,
          num_folds=5,
          num_slices=155,
          n_classes=4):
    """
    Applies training on the network
        Args: 
            cycle_num (int): number of cycle in n-fold (num_folds) cross validation
            dirs (string): path to dataset subject directories 
            path_to_net (string): path to directory where to save network
            plotter (callable): visdom plotter
            batch_size - default (int): batch size
            test_split - default (float): percentage of test split 
            random_state - default (int): seed for k-fold cross validation
            epochs - default (int): number of epochs
            learning_rate - default (float): learning rate 
            momentum - default (float): momentum
            num_folds - default (int): number of folds in cross validation
            num_slices - default (int): number of slices per volume
            n_classes - default (int): number of classes (regions)
    """
    print('Setting started', flush=True)

    # Creating data indices
    # arange len of list of subject dirs
    indices = np.arange(len(glob.glob(dirs + '*')))
    test_indices, trainset_indices = get_test_indices(indices, test_split)
    # kfold index generator
    for cv_num, (train_indices, val_indices) in enumerate(
            get_train_cv_indices(trainset_indices, num_folds, random_state)):
        # splitted the 5-fold CV in 5 jobs
        if cv_num != int(cycle_num):
            continue

        net = U_Net()
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        num_GPU = torch.cuda.device_count()
        if num_GPU > 1:
            print('Let us use {} GPUs!'.format(num_GPU), flush=True)
            net = nn.DataParallel(net)
        net.to(device)
        criterion = nn.CrossEntropyLoss()
        if cycle_num % 2 == 0:
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum)
        else:
            optimizer = optim.Adam(net.parameters(), lr=learning_rate)

        scheduler = ReduceLROnPlateau(optimizer, threshold=1e-6, patience=0)

        print('cv cycle number: ', cycle_num, flush=True)
        start = time.time()
        print('Start Train and Val loading', flush=True)

        MRIDataset_train = dataset.MRIDataset(dirs, train_indices)

        MRIDataset_val = dataset.MRIDataset(dirs, val_indices)

        datalengths = {
            'train': len(MRIDataset_train),
            'val': len(MRIDataset_val)
        }
        dataloaders = {
            'train': get_dataloader(MRIDataset_train, batch_size, num_GPU),
            'val': get_dataloader(MRIDataset_val, batch_size, num_GPU)
        }
        print('Train and Val loading took: ', time.time() - start, flush=True)
        # make loss and acc history for train and val separatly
        # Setup Metrics
        running_metrics_val = runningScore(n_classes)
        running_metrics_train = runningScore(n_classes)
        val_loss_meter = averageMeter()
        train_loss_meter = averageMeter()
        itr = 0
        iou_best = 0.
        for epoch in tqdm(range(epochs), desc='Epochs'):
            print('Epoch: ', epoch + 1, flush=True)
            phase = 'train'
            print('Phase: ', phase, flush=True)
            start = time.time()
            # Set model to training mode
            net.train()
            # Iterate over data.
            for i, data in tqdm(enumerate(dataloaders[phase]),
                                desc='Data Iteration ' + phase):
                if (i + 1) % 100 == 0:
                    print('Number of Iteration [{}/{}]'.format(
                        i + 1, int(datalengths[phase] / batch_size)),
                          flush=True)
                # get the inputs
                inputs = data['mri_data'].to(device)
                GT = data['seg'].to(device)
                subject_slice_path = data['subject_slice_path']
                # Clear all accumulated gradients
                optimizer.zero_grad()
                # Predict classes using inputs from the train set
                SR = net(inputs)
                # Compute the loss based on the predictions and
                # actual segmentation
                loss = criterion(SR, GT)
                # Backpropagate the loss
                loss.backward()
                # Adjust parameters according to the computed
                # gradients
                # -- weight update
                optimizer.step()
                # Trake and plot metrics and loss, and save network
                predictions = SR.data.max(1)[1].cpu().numpy()
                GT_cpu = GT.data.cpu().numpy()
                running_metrics_train.update(GT_cpu, predictions)
                train_loss_meter.update(loss.item(), n=1)
                if (i + 1) % 100 == 0:
                    itr += 1
                    score, class_iou = running_metrics_train.get_scores()
                    for k, v in score.items():
                        plotter.plot(k, 'itr', phase, k, itr, v)
                    for k, v in class_iou.items():
                        print('Class {} IoU: {}'.format(k, v), flush=True)
                        plotter.plot(
                            str(k) + ' Class IoU', 'itr', phase,
                            str(k) + ' Class IoU', itr, v)
                    print('Loss Train', train_loss_meter.avg, flush=True)
                    plotter.plot('Loss', 'itr', phase, 'Loss Train', itr,
                                 train_loss_meter.avg)
            print('Phase {} took {} s for whole {}set!'.format(
                phase,
                time.time() - start, phase),
                  flush=True)

            # Validation Phase
            phase = 'val'
            print('Phase: ', phase, flush=True)
            start = time.time()
            # Set model to evaluation mode
            net.eval()
            start = time.time()
            with torch.no_grad():
                # Iterate over data.
                for i, data in tqdm(enumerate(dataloaders[phase]),
                                    desc='Data Iteration ' + phase):
                    if (i + 1) % 100 == 0:
                        print('Number of Iteration [{}/{}]'.format(
                            i + 1, int(datalengths[phase] / batch_size)),
                              flush=True)
                    # get the inputs
                    inputs = data['mri_data'].to(device)
                    GT = data['seg'].to(device)
                    subject_slice_path = data['subject_slice_path']
                    # Clear all accumulated gradients
                    optimizer.zero_grad()
                    # Predict classes using inputs from the train set
                    SR = net(inputs)
                    # Compute the loss based on the predictions and
                    # actual segmentation
                    loss = criterion(SR, GT)
                    # Trake and plot metrics and loss
                    predictions = SR.data.max(1)[1].cpu().numpy()
                    GT_cpu = GT.data.cpu().numpy()
                    running_metrics_val.update(GT_cpu, predictions)
                    val_loss_meter.update(loss.item(), n=1)
                    if (i + 1) % 100 == 0:
                        itr += 1
                        score, class_iou = running_metrics_val.get_scores()
                        for k, v in score.items():
                            plotter.plot(k, 'itr', phase, k, itr, v)
                        for k, v in class_iou.items():
                            print('Class {} IoU: {}'.format(k, v), flush=True)
                            plotter.plot(
                                str(k) + ' Class IoU', 'itr', phase,
                                str(k) + ' Class IoU', itr, v)
                        print('Loss Val', val_loss_meter.avg, flush=True)
                        plotter.plot('Loss ', 'itr', phase, 'Loss Val', itr,
                                     val_loss_meter.avg)
                if (epoch + 1) % 10 == 0:
                    if score['Mean IoU'] > iou_best:
                        save_net(path_to_net, batch_size, epoch, cycle_num,
                                 train_indices, val_indices, test_indices, net,
                                 optimizer)
                        iou_best = score['Mean IoU']
                    save_output(epoch, path_to_net, subject_slice_path,
                                SR.data.cpu().numpy(), GT_cpu)
                print('Phase {} took {} s for whole {}set!'.format(
                    phase,
                    time.time() - start, phase),
                      flush=True)
            # Call the learning rate adjustment function after every epoch
            scheduler.step(val_loss_meter.avg)
    # save network after training
    save_net(path_to_net,
             batch_size,
             epochs,
             cycle_num,
             train_indices,
             val_indices,
             test_indices,
             net,
             optimizer,
             iter_num=None)
Beispiel #11
0
def train():

    setup_seeds(1337)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    data_loader = NYUv2Loader
    data_path = '/scratch_ssd/rkwitt/NYUv2/'

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split='training',
                           img_size=(224, 320))

    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(224, 320))

    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=2,
                                  num_workers=16,
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=2,
                                num_workers=16,
                                shuffle=False)

    model = unet(num_classes=n_classes, is_deconv=True,
                 pretrained=True).to(device)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    optimizer = SGD(model.parameters(),
                    lr=1e-3,
                    weight_decay=0.0005,
                    momentum=0.99)

    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    val_loss_meter = averageMeter()
    running_metrics_val = runningScore(n_classes)

    for epoch_i in range(100):

        scheduler.step()

        lrs = []
        for param_group in optimizer.param_groups:
            lrs.append(param_group['lr'])

        print("=>Training epoch {} [lr={}]".format(epoch_i, lrs))

        # Training
        for i_trn, (images, labels) in tqdm(enumerate(trainloader)):

            images = images.to(device)
            labels = labels.to(device)

            model.train()
            optimizer.zero_grad()
            outputs = model(images)
            loss = cross_entropy2d(input=outputs, target=labels)
            loss.backward()
            optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):

                images_val = images_val.to(device)
                labels_val = labels_val.to(device)

                outputs = model(images_val)
                val_loss = cross_entropy2d(input=outputs, target=labels_val)

                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()

                running_metrics_val.update(gt, pred)
                val_loss_meter.update(val_loss.item())

            print('Validation loss (avg): ', val_loss_meter.avg)
            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
            for k, v in class_iou.items():
                print(k, v_loader.cls_idx_to_name[k], v)

            val_loss_meter.reset()
            running_metrics_val.reset()

        torch.save(model.state_dict(),
                   os.path.join('/tmp', 'unet_epoch_{}.pkl'.format(epoch_i)))
Beispiel #12
0
def CAC(cfg, writer, logger):
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))
    ## create dataset
    default_gpu = cfg['model']['default_gpu']
    device = torch.device("cuda:{}".format(default_gpu) if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(cfg, writer, logger)  #source_train\ target_train\ source_valid\ target_valid + _loader

    model = CustomModel(cfg, writer, logger)

    # Setup Metrics
    running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    source_running_metrics_val = runningScore(cfg['data']['target']['n_class'])
    val_loss_meter = averageMeter()
    source_val_loss_meter = averageMeter()
    time_meter = averageMeter()
    loss_fn = get_loss_function(cfg)
    flag_train = True

    epoches = cfg['training']['epoches']

    source_train_loader = datasets.source_train_loader
    target_train_loader = datasets.target_train_loader
    logger.info('source train batchsize is {}'.format(source_train_loader.args.get('batch_size')))
    print('source train batchsize is {}'.format(source_train_loader.args.get('batch_size')))
    logger.info('target train batchsize is {}'.format(target_train_loader.batch_size))
    print('target train batchsize is {}'.format(target_train_loader.batch_size))

    val_loader = None
    if cfg.get('valset') == 'gta5':
        val_loader = datasets.source_valid_loader
        logger.info('valset is gta5')
        print('valset is gta5')
    else:
        val_loader = datasets.target_valid_loader
        logger.info('valset is cityscapes')
        print('valset is cityscapes')
    logger.info('val batchsize is {}'.format(val_loader.batch_size))
    print('val batchsize is {}'.format(val_loader.batch_size))

    # load category anchors
    # objective_vectors = torch.load('category_anchors')
    # model.objective_vectors = objective_vectors['objective_vectors']
    # model.objective_vectors_num = objective_vectors['objective_num']
    class_features = Class_Features(numbers=19)

    # begin training
    model.iter = 0
    for epoch in range(epoches):
        if not flag_train:
            break
        if model.iter > cfg['training']['train_iters']:
            break

        # monitoring the accuracy and recall of CAG-based PLA and probability-based PLA

        for (target_image, target_label, target_img_name) in datasets.target_train_loader:
            model.iter += 1
            i = model.iter
            if i > cfg['training']['train_iters']:
                break
            source_batchsize = cfg['data']['source']['batch_size']
            images, labels, source_img_name = datasets.source_train_loader.next()
            start_ts = time.time()

            images = images.to(device)
            labels = labels.to(device)
            target_image = target_image.to(device)
            target_label = target_label.to(device)
            model.scheduler_step()
            model.train(logger=logger)
            if cfg['training'].get('freeze_bn') == True:
                model.freeze_bn_apply()
            model.optimizer_zerograd()
            if model.PredNet.training:
                model.PredNet.eval()
            with torch.no_grad():
                _, _, feat_cls, output = model.PredNet_Forward(images)
                batch, w, h = labels.size()
                newlabels = labels.reshape([batch, 1, w, h]).float()
                newlabels = F.interpolate(newlabels, size=feat_cls.size()[2:], mode='nearest')
                vectors, ids = class_features.calculate_mean_vector(feat_cls, output, newlabels, model)
                for t in range(len(ids)):
                    model.update_objective_SingleVector(ids[t], vectors[t].detach().cpu().numpy(), 'mean')

            time_meter.update(time.time() - start_ts)
            if model.iter % 20 == 0:
                print("Iter [{:d}] Time {:.4f}".format(model.iter, time_meter.avg))

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
    save_path = os.path.join(writer.file_writer.get_logdir(),
                                "anchors_on_{}_from_{}".format(
                                    cfg['data']['source']['name'],
                                    cfg['model']['arch'],))
    torch.save(model.objective_vectors, save_path)
Beispiel #13
0
def train(opt, logger):
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    ## create dataset
    device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
    datasets = create_dataset(opt, logger)

    if opt.model_name == 'deeplabv2':
        model = adaptation_modelv2.CustomModel(opt, logger)

    # Setup Metrics
    running_metrics_val = runningScore(opt.n_class)
    time_meter = averageMeter()

    # load category anchors
    if opt.stage == 'stage1':
        objective_vectors = torch.load(
            os.path.join(
                os.path.dirname(opt.resume_path),
                'prototypes_on_{}_from_{}'.format(opt.tgt_dataset,
                                                  opt.model_name)))
        model.objective_vectors = torch.Tensor(objective_vectors).to(0)

    # begin training
    save_path = os.path.join(
        opt.logdir,
        "from_{}_to_{}_on_{}_current_model.pkl".format(opt.src_dataset,
                                                       opt.tgt_dataset,
                                                       opt.model_name))
    model.iter = 0
    start_epoch = 0
    for epoch in range(start_epoch, opt.epochs):
        for data_i in datasets.target_train_loader:
            target_image = data_i['img'].to(device)
            target_imageS = data_i['img_strong'].to(device)
            target_params = data_i['params']
            target_image_full = data_i['img_full'].to(device)
            target_weak_params = data_i['weak_params']

            target_lp = data_i['lp'].to(
                device) if 'lp' in data_i.keys() else None
            target_lpsoft = data_i['lpsoft'].to(
                device) if 'lpsoft' in data_i.keys() else None
            source_data = datasets.source_train_loader.next()

            model.iter += 1
            i = model.iter
            images = source_data['img'].to(device)
            labels = source_data['label'].to(device)
            source_imageS = source_data['img_strong'].to(device)
            source_params = source_data['params']

            start_ts = time.time()

            model.train(logger=logger)
            if opt.freeze_bn:
                model.freeze_bn_apply()
            model.optimizer_zerograd()

            if opt.stage == 'warm_up':
                loss_GTA, loss_G, loss_D = model.step_adv(
                    images, labels, target_image, source_imageS, source_params)
            elif opt.stage == 'stage1':
                loss, loss_CTS, loss_consist = model.step(
                    images, labels, target_image, target_imageS, target_params,
                    target_lp, target_lpsoft, target_image_full,
                    target_weak_params)
            else:
                loss_GTA, loss = model.step_distillation(
                    images, labels, target_image, target_imageS, target_params,
                    target_lp)

            time_meter.update(time.time() - start_ts)

            #print(i)
            if (i + 1) % opt.print_interval == 0:
                if opt.stage == 'warm_up':
                    fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}]  loss_GTA: {:.4f}  loss_G: {:.4f}  loss_D: {:.4f} Time/Image: {:.4f}"
                    print_str = fmt_str.format(epoch + 1, opt.epochs, i + 1,
                                               opt.train_iters, loss_GTA,
                                               loss_G, loss_D,
                                               time_meter.avg / opt.bs)
                elif opt.stage == 'stage1':
                    fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}]  loss: {:.4f}  loss_CTS: {:.4f}  loss_consist: {:.4f} Time/Image: {:.4f}"
                    print_str = fmt_str.format(epoch + 1, opt.epochs, i + 1,
                                               opt.train_iters, loss, loss_CTS,
                                               loss_consist,
                                               time_meter.avg / opt.bs)
                else:
                    fmt_str = "Epochs [{:d}/{:d}] Iter [{:d}/{:d}]  loss_GTA: {:.4f}  loss: {:.4f} Time/Image: {:.4f}"
                    print_str = fmt_str.format(epoch + 1, opt.epochs, i + 1,
                                               opt.train_iters, loss_GTA, loss,
                                               time_meter.avg / opt.bs)
                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            # evaluation
            if (i + 1) % opt.val_interval == 0:
                validation(model,
                           logger,
                           datasets,
                           device,
                           running_metrics_val,
                           iters=model.iter,
                           opt=opt)
                torch.cuda.empty_cache()
                logger.info('Best iou until now is {}'.format(model.best_iou))

            model.scheduler_step()
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device(cfg['device'])

    # Setup Metrics
    seg_scores = SegmentationScore()
    depth_scores = DepthEstimateScore()

    augmentations = Compose([
        RandomRotate(cfg['training']['argumentation']['random_rotate']),
        RandomCrop(cfg['training']['img_size']),
        RandomHorizonFlip(cfg['training']['argumentation']['random_hflip']),
    ])

    traindata = cityscapesLoader(cfg['data']['path'],
                                 img_size=cfg['training']['img_size'],
                                 split=cfg['data']['train_split'],
                                 is_transform=True,
                                 augmentations=augmentations)

    valdata = cityscapesLoader(cfg['data']['path'],
                               img_size=cfg['training']['img_size'],
                               split=cfg['data']['val_split'],
                               is_transform=True)

    trainloader = data.DataLoader(traindata,
                                  batch_size=cfg['training']['batch_size'])
    valloader = data.DataLoader(valdata,
                                batch_size=cfg['training']['batch_size'])

    # Setup Model
    model = networks[cfg['arch']](**cfg['model'])

    model.to(device)
    loss_fn = Loss(**cfg['training']['loss']).to(device)
    # Setup optimizer, lr_scheduler and loss function
    optimizer = optim.SGD(model.parameters(), **cfg['training']['optimizer'])
    # TODO
    # optimizer_loss = optim.SGD(loss_fn.parameters(), **cfg['training']['optimizer_loss'])

    scheduler = PolynomialLR(optimizer,
                             max_iter=cfg['training']['train_iters'],
                             **cfg['training']['schedule'])
    # TODO
    # scheduler_loss = ConstantLR(optimizer_loss)

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_miou = -100.0
    best_abs_rel = float('inf')
    i = start_iter
    flag = True
    optimizer.zero_grad()
    while i <= cfg["training"]["train_iters"] and flag:
        for sample in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            # TODO
            # scheduler_loss.step()
            model.train()
            images = sample['image'].to(device)
            labels = sample['label'].to(device)
            depths = sample['depth'].to(device)

            # TODO
            # optimizer_loss.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, depths,
                           labels) / cfg['training']['accu_steps']
            loss.backward()
            if i % cfg['training']['accu_steps'] == 0:
                optimizer.step()
                optimizer.zero_grad()
            # TODO
            # optimizer_loss.step()

            time_meter.update(time.time() - start_ts)

            if (i) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i,
                    cfg["training"]["train_iters"],
                    loss.item() * cfg['training']['accu_steps'],
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                logger.info(print_str)
                writer.add_scalar("loss/train_loss",
                                  loss.item() * cfg['training']['accu_steps'],
                                  i)
                writer.add_scalar('param/delta1', loss_fn.delta1, i)
                writer.add_scalar('param/delta2', loss_fn.delta2, i)
                writer.add_scalar('param/learning-rate',
                                  scheduler.get_lr()[0], i)
                time_meter.reset()

            if i % cfg["training"]["val_interval"] == 0 or i == cfg[
                    "training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, sample in tqdm(enumerate(valloader)):
                        images_val = sample['image'].to(device)
                        labels_val = sample['label'].to(device)
                        depths_val = sample['depth'].to(device)
                        outputs = model(images_val)
                        val_loss = loss_fn(outputs, depths_val, labels_val)

                        depth_scores.update(depths_val.cpu().numpy(),
                                            outputs[-1][0].data.cpu().numpy())
                        seg_scores.update(
                            labels_val.cpu().numpy(),
                            outputs[-1][1].data.max(1)[1].cpu().numpy())

                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i)
                logger.info("Iter %d Loss: %.4f" % (i, val_loss_meter.avg))

                seg_score, class_iou = seg_scores.get_scores()
                for k, v in seg_score.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("seg_val_metrics/{}".format(k), v, i)

                for k, v in class_iou.items():
                    # logger.info("{}: {}".format(k, v))
                    writer.add_scalar("seg_val_metrics/cls_{}".format(k), v, i)

                depth_score = depth_scores.get_scores()
                for k, v in depth_score.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("depth_val_metrics/{}".format(k), v, i)

                val_loss_meter.reset()
                seg_scores.reset()
                depth_scores.reset()

                # if seg_score["Mean IoU : \t"] >= best_miou and depth_score['abs_rel'] <= best_abs_rel:
                if seg_score["Mean IoU : \t"] >= best_miou:
                    best_iou = seg_score["Mean IoU : \t"]
                    best_abs_rel = depth_score['abs_rel']
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                        "best_abs_rel": best_abs_rel
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pth".format(cfg['arch'],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_{}_model.pth".format(i, cfg['arch'],
                                                cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

            if i == cfg["training"]["train_iters"]:
                flag = False
                break
Beispiel #15
0
# 参数学习列表
# learning_list = list(filter(lambda p: p.requires_grad, model.parameters()))
learning_list = model.parameters()
# 优化器以及学习率设置
optimizer = SGD(learning_list, lr=learning_rate)
# scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(0.5 * end_iter), int(0.75 * end_iter)], gamma=0.1)
# loss_class = nn.CrossEntropyLoss()#reduction='elementwise_mean').to(device)


flag = True
best_acc = 0.0
i = start_iter
# 记录器
loss_meter = averageMeter()
time_meter = averageMeter()
class_acc_meter = averageMeter()

while i <= end_iter and flag:
	for (images, gt) in trainloader:
		i += 1

		start_ts = time.time()	
		model.train()
		images,gt = images.to(device),gt.to(device)
		# gt_class = gt[:,:1].long()
		# 优化器置0
		optimizer.zero_grad()
		out_class = model(images)
		loss = F.cross_entropy(out_class, gt)
Beispiel #16
0
    def validate(self):

        visualizations = []
        val_metrics = runningScore(self.n_classes)
        val_loss_meter = averageMeter()

        with torch.no_grad():
            self.model.eval()
            for data, target in tqdm.tqdm(self.val_loader,
                                          total=len(self.val_loader),
                                          desc=f'Valid epoch={self.epoch}',
                                          ncols=80,
                                          leave=False):

                data, target = data.to(self.device), target.to(self.device)

                score = self.model(data)

                weight = self.val_loader.dataset.class_weight
                if weight:
                    weight = torch.Tensor(weight).to(self.device)

                # target = resize_labels(target, (score.size()[2], score.size()[3]))
                # target = target.to(self.device)
                loss = CrossEntropyLoss(score,
                                        target,
                                        weight=weight,
                                        reduction='mean',
                                        ignore_index=-1)
                loss_data = loss.data.item()
                if np.isnan(loss_data):
                    raise ValueError('loss is nan while validating')

                val_loss_meter.update(loss_data)

                # if not isinstance(score, tuple):
                #     lbl_pred = score.data.max(1)[1].cpu().numpy()
                # else:
                #     lbl_pred = score[-1].data.max(1)[1].cpu().numpy()

                # lbl_pred, lbl_true = get_multiscale_results(score, target, upsample_logits=False)
                imgs = data.data.cpu()
                if isinstance(score, tuple):
                    lbl_pred = score[-1].data.max(1)[1].cpu().numpy()
                else:
                    lbl_pred = score.data.max(1)[1].cpu().numpy()
                lbl_true = target.data.cpu()
                for img, lt, lp in zip(imgs, lbl_true, lbl_pred):
                    img, lt = self.val_loader.dataset.untransform(img, lt)
                    val_metrics.update(lt, lp)
                    # img = Image.fromarray(img).resize((lt.shape[1], lt.shape[0]), Image.BILINEAR)
                    # img = np.array(img)
                    if len(visualizations) < 9:
                        viz = visualize_segmentation(
                            lbl_pred=lp,
                            lbl_true=lt,
                            img=img,
                            n_classes=self.n_classes,
                            dataloader=self.train_loader)
                        visualizations.append(viz)

        acc, acc_cls, mean_iou, fwavacc, _ = val_metrics.get_scores()
        metrics = [acc, acc_cls, mean_iou, fwavacc]

        print(f'\nEpoch: {self.epoch}',
              f'loss: {val_loss_meter.avg}, mIoU: {mean_iou}')

        out = osp.join(self.out, 'visualization_viz')
        if not osp.exists(out):
            os.makedirs(out)
        out_file = osp.join(out, 'epoch{:0>5d}.jpg'.format(self.epoch))
        scipy.misc.imsave(out_file, get_tile_image(visualizations))

        with open(osp.join(self.out, 'log.csv'), 'a') as f:
            elapsed_time = (datetime.datetime.now(pytz.timezone('UTC')) -
                            self.timestamp_start).total_seconds()
            log = [self.epoch] + [''] * 5 + \
                  [val_loss_meter.avg] + metrics + [elapsed_time]
            log = map(str, log)
            f.write(','.join(log) + '\n')

        mean_iu = metrics[2]
        is_best = mean_iu > self.best_mean_iu
        if is_best:
            self.best_mean_iu = mean_iu
        torch.save(
            {
                'epoch': self.epoch,
                'arch': self.model.__class__.__name__,
                'optim_state_dict': self.optim.state_dict(),
                'model_state_dict': self.model.state_dict(),
                'best_mean_iu': self.best_mean_iu,
            }, osp.join(self.out, 'checkpoint.pth.tar'))
        if is_best:
            shutil.copy(osp.join(self.out, 'checkpoint.pth.tar'),
                        osp.join(self.out, 'model_best.pth.tar'))

        val_loss_meter.reset()
        val_metrics.reset()