示例#1
0
                                      shuffle=True,
                                      drop_last=True)

        valloader = data.DataLoader(v_loader,
                                    batch_size=cfg["training"]["batch_size"],
                                    num_workers=cfg["training"]["n_workers"])

        # Setup Model
        model = get_model(cfg, t_loader.n_classes).to(device)
        model = torch.nn.DataParallel(model,
                                      device_ids=range(
                                          torch.cuda.device_count()))
        # import pdb; pdb.set_trace()

        # Setup optimizer
        optimizer_cls = get_optimizer(cfg)
        optimizer_params = {
            k: v
            for k, v in cfg["training"]["optimizer"].items() if k != "name"
        }
        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
        logger.info("Using optimizer {}".format(optimizer))

        # Setup scheduler
        scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

        # Setup loss
        loss_fn = get_loss_function(cfg)
        logger.info("Using loss {}".format(loss_fn))

        # ================== TRAINING ==================
示例#2
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    #    data_loader = get_loader(cfg['data']['dataset'])
    #    data_path = cfg['data']['path']
    #
    #    t_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['train_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    #        augmentations=data_aug)
    #
    #    v_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['val_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
    #
    #    n_classes = t_loader.n_classes
    #    trainloader = data.DataLoader(t_loader,
    #                                  batch_size=cfg['training']['batch_size'],
    #                                  num_workers=cfg['training']['n_workers'],
    #                                  shuffle=True)
    #
    #    valloader = data.DataLoader(v_loader,
    #                                batch_size=cfg['training']['batch_size'],
    #                                num_workers=cfg['training']['n_workers'])

    paths = {
        'masks': './satellitedata/patchark_train/gt/',
        'images': './satellitedata/patchark_train/rgb',
        'nirs': './satellitedata/patchark_train/nir',
        'swirs': './satellitedata/patchark_train/swir',
        'vhs': './satellitedata/patchark_train/vh',
        'vvs': './satellitedata/patchark_train/vv',
        'redes': './satellitedata/patchark_train/rede',
        'ndvis': './satellitedata/patchark_train/ndvi',
    }

    valpaths = {
        'masks': './satellitedata/patchark_val/gt/',
        'images': './satellitedata/patchark_val/rgb',
        'nirs': './satellitedata/patchark_val/nir',
        'swirs': './satellitedata/patchark_val/swir',
        'vhs': './satellitedata/patchark_val/vh',
        'vvs': './satellitedata/patchark_val/vv',
        'redes': './satellitedata/patchark_val/rede',
        'ndvis': './satellitedata/patchark_val/ndvi',
    }

    n_classes = 3
    train_img_paths = [
        pth for pth in os.listdir(paths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    val_img_paths = [
        pth for pth in os.listdir(valpaths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')

    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = './satellitedata/'
        print('train_data_path: {}'.format(train_data_path))
        dataset_path, train_dir = os.path.split(train_data_path)
        print('dataset_path: {}'.format(dataset_path) +
              ',  train_dir: {}'.format(train_dir))
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = TrainDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'],
                                shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    k = 0
    nbackground = 0
    ncorn = 0
    #ncotton = 0
    #nrice = 0
    nsoybean = 0

    for indata in trainloader:
        k += 1
        gt = indata['seg_label'].data.cpu().numpy()
        nbackground += (gt == 0).sum()
        ncorn += (gt == 1).sum()
        #ncotton += (gt == 2).sum()
        #nrice += (gt == 3).sum()
        nsoybean += (gt == 2).sum()

    print('k = {}'.format(k))
    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    #print('ncotton: {}'.format(ncotton))
    #print('nrice: {}'.format(nrice))
    print('nsoybean: {}'.format(nsoybean))

    wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0] / total_wgts
    wgt_corn = wgts[1] / total_wgts
    #wgt_cotton = wgts[2]/total_wgts
    #wgt_rice = wgts[3]/total_wgts
    wgt_soybean = wgts[2] / total_wgts
    weights = torch.autograd.Variable(
        torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    #weights = torch.autograd.Variable(torch.cuda.FloatTensor([1.0, 1.0, 1.0]))

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for inputdata in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = inputdata['img_data']
            labels = inputdata['seg_label']
            #print('images.size: {}'.format(images.size()))
            #print('labels.size: {}'.format(labels.size()))
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            #print('outputs.size: {}'.format(outputs[1].size()))
            #print('labels.size: {}'.format(labels.size()))

            loss = loss_fn(input=outputs[1], target=labels, weight=weights)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for inputdata in valloader:
                        images_val = inputdata['img_data']
                        labels_val = inputdata['seg_label']
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
示例#3
0
def train(cfg, logger):
    
    # Setup seeds   ME: take these out for random samples
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("DEVICE: ",device)

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    
    if torch.cuda.is_available():
        data_path = cfg['data']['server_path']
    else:
        data_path = cfg['data']['path']
    
    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)
    
    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    number_of_images_training = t_loader.number_of_images
    
    # Setup Hierarchy
    
    if torch.cuda.is_available():
        if cfg['data']['dataset'] == "vistas":
            if cfg['data']['viking']:
                root = create_tree_from_textfile("/users/brm512/scratch/experiments/meetshah-semseg/mapillary_tree.txt")
            else:
                root = create_tree_from_textfile("/home/userfs/b/brm512/experiments/meetshah-semseg/mapillary_tree.txt")
        elif cfg['data']['dataset'] == "faces":
            if cfg['data']['viking']:
                root = create_tree_from_textfile("/users/brm512/scratch/experiments/meetshah-semseg/faces_tree.txt")
            else:
                root = create_tree_from_textfile("/home/userfs/b/brm512/experiments/meetshah-semseg/faces_tree.txt")
    else:
        if cfg['data']['dataset'] == "vistas":
            root = create_tree_from_textfile("/home/brm512/Pytorch/meetshah-semseg/mapillary_tree.txt")
        elif cfg['data']['dataset'] == "faces":
            root = create_tree_from_textfile("/home/brm512/Pytorch/meetshah-semseg/faces_tree.txt")

    add_channels(root,0)
    add_levels(root,find_depth(root))
    
    class_lookup = [0,10,7,8,9,1,6,4,5,2,3]  # correcting for tree channel and data integer class correspondence  # HELEN
    #class_lookup = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,48,51,45,46,47,49,50,52,53,54,55,56,57,58,59,60,61,62,63,64,65] # VISTAS
    update_channels(root, class_lookup)

    # Setup models for Hierarchical and Standard training. Note we use Tree synonymously with hierarchy

    model_nontree = get_model(cfg['model'], n_classes).to(device)
    model_tree = get_model(cfg['model'], n_classes).to(device)
    model_nontree = torch.nn.DataParallel(model_nontree, device_ids=range(torch.cuda.device_count()))
    model_tree = torch.nn.DataParallel(model_tree, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls_nontree = get_optimizer(cfg)
    optimizer_params_nontree = {k:v for k, v in cfg['training']['optimizer'].items() if k != 'name'}
    optimizer_nontree = optimizer_cls_nontree(model_nontree.parameters(), **optimizer_params_nontree)
    logger.info("Using non tree optimizer {}".format(optimizer_nontree))

    optimizer_cls_tree = get_optimizer(cfg)
    optimizer_params_tree = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}
    optimizer_tree = optimizer_cls_tree(model_tree.parameters(), **optimizer_params_tree)
    logger.info("Using non tree optimizer {}".format(optimizer_tree))
    
    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    loss_meter_nontree = averageMeter()
    if cfg['training']['use_hierarchy']:
        loss_meter_level0_nontree = averageMeter()
        loss_meter_level1_nontree = averageMeter()
        loss_meter_level2_nontree = averageMeter()
        loss_meter_level3_nontree = averageMeter()
        
    loss_meter_tree = averageMeter()
    if cfg['training']['use_hierarchy']:
        loss_meter_level0_tree = averageMeter()
        loss_meter_level1_tree = averageMeter()
        loss_meter_level2_tree = averageMeter()
        loss_meter_level3_tree = averageMeter()
        
        
    time_meter = averageMeter()
    epoch = 0
    i = 0
    flag = True
    number_epoch_iters = number_of_images_training / cfg['training']['batch_size']
    
# TRAINING
    start_training_time = time.time()
    
    while i < cfg['training']['train_iters'] and flag and epoch < cfg['training']['epochs']:
       
        epoch_start_time = time.time()
        epoch = epoch + 1
        for (images, labels) in trainloader:
            i = i + 1
            start_ts = time.time()
        
            model_nontree.train()
            model_tree.train()
            
            images = images.to(device)
            labels = labels.to(device)

            optimizer_nontree.zero_grad()
            optimizer_tree.zero_grad()
            
            outputs_nontree = model_nontree(images)
            outputs_tree = model_tree(images)

            #nontree loss calculation
            if cfg['training']['use_tree_loss']:
                loss_nontree = loss_fn(input=outputs_nontree, target=labels, root=root, use_hierarchy = cfg['training']['use_hierarchy'])
                level_losses_nontree = loss_nontree[1]
                mainloss_nontree = loss_fn(input=outputs_nontree, target=labels, root=root, use_hierarchy = False)[0]
            else:
                loss_nontree = loss_fn(input=outputs_nontree, target=labels)
                mainloss_nontree = loss_nontree
            
            #tree loss calculation
            if cfg['training']['use_tree_loss']:
                loss_tree = loss_fn(input=outputs_tree, target=labels, root=root, use_hierarchy = cfg['training']['use_hierarchy'])
                level_losses_tree = loss_tree[1]
                mainloss_tree = loss_tree[0]
            else:
                loss_tree = loss_fn(input=outputs_tree, target=labels)
                mainloss_tree = loss_tree
            
            loss_meter_nontree.update(mainloss_nontree.item())
            if cfg['training']['use_hierarchy'] and not cfg['training']['phased']:
                loss_meter_level0_nontree.update(level_losses_nontree[0])
                loss_meter_level1_nontree.update(level_losses_nontree[1])
                loss_meter_level2_nontree.update(level_losses_nontree[2])
                loss_meter_level3_nontree.update(level_losses_nontree[3])
                
            loss_meter_tree.update(mainloss_tree.item())
            if cfg['training']['use_hierarchy'] and not cfg['training']['phased']:
                loss_meter_level0_tree.update(level_losses_tree[0])
                loss_meter_level1_tree.update(level_losses_tree[1])
                loss_meter_level2_tree.update(level_losses_tree[2])
                loss_meter_level3_tree.update(level_losses_tree[3])

            # optimise nontree and tree
            mainloss_nontree.backward()
            mainloss_tree.backward()
            
            optimizer_nontree.step()
            optimizer_tree.step()

            time_meter.update(time.time() - start_ts)
            
            # For printing/logging stats
            if (i) % cfg['training']['print_interval'] == 0:
                fmt_str = "Epoch [{:d}/{:d}] Iter [{:d}/{:d}] IterNonTreeLoss: {:.4f}  IterTreeLoss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(epoch,cfg['training']['epochs'], i % int(number_epoch_iters),
                                           int(number_epoch_iters), mainloss_nontree.item(), 
                                           mainloss_tree.item(),
                                           time_meter.avg / cfg['training']['batch_size'])
    

                print(print_str)
                logger.info(print_str)
                time_meter.reset()
                
# VALIDATION AFTER EVERY EPOCH
            if (i) % cfg['training']['val_interval'] == 0 or (i) % number_epoch_iters == 0 or (i) == cfg['training']['train_iters']:
                validate(cfg, model_nontree, model_tree, loss_fn, device, root)
                # reset meters after validation
                loss_meter_nontree.reset()
                if cfg['training']['use_hierarchy']:
                    loss_meter_level0_nontree.reset()
                    loss_meter_level1_nontree.reset()
                    loss_meter_level2_nontree.reset()
                    loss_meter_level3_nontree.reset()

                loss_meter_tree.reset()     
                if cfg['training']['use_hierarchy']:
                    loss_meter_level0_tree.reset()
                    loss_meter_level1_tree.reset()
                    loss_meter_level2_tree.reset()
                    loss_meter_level3_tree.reset()
            
            # For de-bugging
            if (i) == cfg['training']['train_iters']:
                flag = False
                break
            
        print("EPOCH TIME (MIN): ", epoch, (time.time() - epoch_start_time)/60.0)
        logger.info("Epoch %d took %.4f minutes" % (int(epoch) , (time.time() - epoch_start_time)/60.0))
           
    print("TRAINING TIME: ",(time.time() - start_training_time)/3600.0)
示例#4
0
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
示例#5
0
def train(cfg, writer, logger):

    # Setup seeds
    # torch.manual_seed(cfg.get("seed", 1337))
    # torch.cuda.manual_seed(cfg.get("seed", 1337))
    # np.random.seed(cfg.get("seed", 1337))
    # random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])

            if not args.load_weight_only:
                model = DataParallel_withLoss(model, loss_fn)
                model.load_state_dict(checkpoint["model_state"])
                if not args.not_load_optimizer:
                    optimizer.load_state_dict(checkpoint["optimizer_state"])

                # !!!
                # checkpoint["scheduler_state"]['last_epoch'] = -1
                # scheduler.load_state_dict(checkpoint["scheduler_state"])
                # start_iter = checkpoint["epoch"]
                start_iter = 0
                # import ipdb
                # ipdb.set_trace()
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
            else:
                pretrained_dict = convert_state_dict(checkpoint["model_state"])
                model_dict = model.state_dict()
                # 1. filter out unnecessary keys
                pretrained_dict = {
                    k: v
                    for k, v in pretrained_dict.items() if k in model_dict
                }
                # 2. overwrite entries in the existing state dict
                model_dict.update(pretrained_dict)
                # 3. load the new state dict
                model.load_state_dict(model_dict)
                model = DataParallel_withLoss(model, loss_fn)
                # import ipdb
                # ipdb.set_trace()
                # start_iter = -1
                logger.info(
                    "Loaded checkpoint '{}' (iter unknown, from pretrained icnet model)"
                    .format(cfg["training"]["resume"]))

        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, inst_labels) in trainloader:

            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            inst_labels = inst_labels.to(device)
            optimizer.zero_grad()

            loss, _, aux_info = model(labels,
                                      inst_labels,
                                      images,
                                      return_aux_info=True)
            loss = loss.sum()
            loss_sem = aux_info[0].sum()
            loss_inst = aux_info[1].sum()

            # loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f} (Sem:{:.4f}/Inst:{:.4f})  LR:{:.5f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    loss_sem.item(),
                    loss_inst.item(),
                    scheduler.get_lr()[0],
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()

                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                inst_labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        inst_labels_val = inst_labels_val.to(device)
                        # outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)
                        val_loss, (outputs, outputs_inst) = model(
                            labels_val, inst_labels_val, images_val)
                        val_loss = val_loss.sum()

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) % cfg["training"]["save_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_{:05d}_model.pkl".format(cfg["model"]["arch"],
                                                    cfg["data"]["dataset"],
                                                    i + 1),
                )
                torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
            i += 1
示例#6
0
def validate(cfg, args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.out_dir != "":
        if not os.path.exists(args.out_dir):
            os.mkdir(args.out_dir)
        if not os.path.exists(args.out_dir+'hmaps_bg'):
            os.mkdir(args.out_dir+'hmaps_bg')
        if not os.path.exists(args.out_dir+'hmaps_fg'):
            os.mkdir(args.out_dir+'hmaps_fg')
        if not os.path.exists(args.out_dir+'pred'):
            os.mkdir(args.out_dir+'pred')
        if not os.path.exists(args.out_dir+'gt'):
            os.mkdir(args.out_dir+'gt')
        if not os.path.exists(args.out_dir+'qry_images'):
            os.mkdir(args.out_dir+'qry_images')
        if not os.path.exists(args.out_dir+'sprt_images'):
            os.mkdir(args.out_dir+'sprt_images')
        if not os.path.exists(args.out_dir+'sprt_gt'):
            os.mkdir(args.out_dir+'sprt_gt')

    if args.fold != -1:
        cfg['data']['fold'] = args.fold

    fold = cfg['data']['fold']

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    loader = data_loader(
        data_path,
        split=cfg['data']['val_split'],
        is_transform=True,
        img_size=[cfg['data']['img_rows'],
                  cfg['data']['img_cols']],
        n_classes=cfg['data']['n_classes'],
        fold=cfg['data']['fold'],
        binary=args.binary,
        k_shot=cfg['data']['k_shot']
    )

    n_classes = loader.n_classes

    valloader = data.DataLoader(loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=0)
    if args.binary:
        running_metrics = runningScore(2)
        fp_list = {}
        tp_list = {}
        fn_list = {}
    else:
        running_metrics = runningScore(n_classes+1) #+1 indicate the novel class thats added each time

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    state = convert_state_dict(torch.load(args.model_path)["model_state"])
    model.load_state_dict(state)
    model.to(device)
    model.freeze_all_except_classifiers()

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}
    model.save_original_weights()

    alpha = 0.14139
    for i, (sprt_images, sprt_labels, qry_images, qry_labels,
            original_sprt_images, original_qry_images, cls_ind) in enumerate(valloader):

        cls_ind = int(cls_ind)
        print('Starting iteration ', i)
        start_time = timeit.default_timer()
        if args.out_dir != "":
            save_images(original_sprt_images, sprt_labels,
                        original_qry_images, i, args.out_dir)

        for si in range(len(sprt_images)):
            sprt_images[si] = sprt_images[si].to(device)
            sprt_labels[si] = sprt_labels[si].to(device)
        qry_images = qry_images.to(device)

        # 1- Extract embedding and add the imprinted weights
        if args.iterations_imp > 0:
            model.iterative_imprinting(sprt_images, qry_images, sprt_labels,
                                       alpha=alpha, itr=args.iterations_imp)
        else:
            model.imprint(sprt_images, sprt_labels, alpha=alpha, random=args.rand)

        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
        scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])
        loss_fn = get_loss_function(cfg)
        print('Finetuning')
        for j in range(cfg['training']['train_iters']):
            for b in range(len(sprt_images)):
                torch.cuda.empty_cache()
                scheduler.step()
                model.train()
                optimizer.zero_grad()

                outputs = model(sprt_images[b])
                loss = loss_fn(input=outputs, target=sprt_labels[b])
                loss.backward()
                optimizer.step()

                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}"
                print_str = fmt_str.format(j,
                                       cfg['training']['train_iters'],
                                       loss.item())
                print(print_str)


        # 2- Infer on the query image
        model.eval()
        with torch.no_grad():
            outputs = model(qry_images)
            pred = outputs.data.max(1)[1].cpu().numpy()

        # Reverse the last imprinting (Few shot setting only not Continual Learning setup yet)
        model.reverse_imprinting()

        gt = qry_labels.numpy()
        if args.binary:
            gt,pred = post_process(gt, pred)

        if args.binary:
            if args.binary == 1:
                tp, fp, fn = running_metrics.update_binary_oslsm(gt, pred)

                if cls_ind in fp_list.keys():
                    fp_list[cls_ind] += fp
                else:
                    fp_list[cls_ind] = fp

                if cls_ind in tp_list.keys():
                    tp_list[cls_ind] += tp
                else:
                    tp_list[cls_ind] = tp

                if cls_ind in fn_list.keys():
                    fn_list[cls_ind] += fn
                else:
                    fn_list[cls_ind] = fn
            else:
                running_metrics.update(gt, pred)
        else:
            running_metrics.update(gt, pred)

        if args.out_dir != "":
            if args.binary:
                save_vis(outputs, pred, gt, i, args.out_dir, fg_class=1)
            else:
                save_vis(outputs, pred, gt, i, args.out_dir)

    if args.binary:
        if args.binary == 1:
            iou_list = [tp_list[ic]/float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1)) \
                         for ic in tp_list.keys()]
            print("Binary Mean IoU ", np.mean(iou_list))
        else:
            score, class_iou = running_metrics.get_scores()
            for k, v in score.items():
                print(k, v)
    else:
        score, class_iou = running_metrics.get_scores()
        for k, v in score.items():
            print(k, v)
        val_nclasses = model.n_classes + 1
        for i in range(val_nclasses):
            print(i, class_iou[i])
示例#7
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device('cuda')

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get(
            'augmentations', {
                'brightness': 63. / 255.,
                'saturation': 0.5,
                'contrast': 0.8,
                'hflip': 0.5,
                'rotate': 10,
                'rscalecropsquare': 713,
            })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes, args).to(device)
    model.apply(weights_init)
    print('sleep for 5 seconds')
    time.sleep(5)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=(0, 1))
    print(model.device_ids)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    if 'multi_step' in cfg['training']['loss']['name']:
        my_loss_fn = loss_fn(
            scale_weight=cfg['training']['loss']['scale_weight'],
            n_inp=2,
            weight=None,
            reduction='sum',
            bkargs=args)
    else:
        my_loss_fn = loss_fn(weight=None, reduction='sum', bkargs=args)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = my_loss_fn(myinput=outputs, target=labels)

            loss.backward()
            optimizer.step()

            # gpu_profile(frame=sys._getframe(), event='line', arg=None)

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = my_loss_fn(myinput=outputs,
                                              target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
示例#8
0
文件: train.py 项目: HMellor/4YP_code
def train(cfg, writer, logger_old, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if isinstance(cfg['training']['loss']['superpixels'], int):
        use_superpixels = True
        cfg['data']['train_split'] = 'train_super'
        cfg['data']['val_split'] = 'val_super'
        setup_superpixels(cfg['training']['loss']['superpixels'])
    elif cfg['training']['loss']['superpixels'] is not None:
        raise Exception(
            "cfg['training']['loss']['superpixels'] is of the wrong type")
    else:
        use_superpixels = False

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           superpixels=cfg['training']['loss']['superpixels'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        superpixels=cfg['training']['loss']['superpixels'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    running_metrics_train = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger_old.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger_old.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger_old.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger_old.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger_old.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    train_loss_meter = averageMeter()
    time_meter = averageMeter()

    train_len = t_loader.train_len
    val_static = 0
    best_iou = -100.0
    i = start_iter
    j = 0
    flag = True

    # Prepare logging
    xp_name = cfg['model']['arch'] + '_' + \
        cfg['training']['loss']['name'] + '_' + args.name
    xp = logger.Experiment(xp_name,
                           use_visdom=True,
                           visdom_opts={
                               'server': 'http://localhost',
                               'port': 8098
                           },
                           time_indexing=False,
                           xlabel='Epoch')
    # log the hyperparameters of the experiment
    xp.log_config(flatten(cfg))
    # create parent metric for training metrics (easier interface)
    xp.ParentWrapper(tag='train',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    xp.ParentWrapper(tag='val',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    best_loss = xp.BestMetric(tag='val-best', name='loss', mode='min')
    best_acc = xp.BestMetric(tag='val-best', name='acc')
    best_acccls = xp.BestMetric(tag='val-best', name='acccls')
    best_fwavacc = xp.BestMetric(tag='val-best', name='fwavacc')
    best_meaniu = xp.BestMetric(tag='val-best', name='meaniu')

    xp.plotter.set_win_opts(name="loss", opts={'title': 'Loss'})
    xp.plotter.set_win_opts(name="acc", opts={'title': 'Micro-Average'})
    xp.plotter.set_win_opts(name="acccls", opts={'title': 'Macro-Average'})
    xp.plotter.set_win_opts(name="fwavacc", opts={'title': 'FreqW Accuracy'})
    xp.plotter.set_win_opts(name="meaniu", opts={'title': 'Mean IoU'})

    it_per_step = cfg['training']['acc_batch_size']
    eff_batch_size = cfg['training']['batch_size'] * it_per_step
    while i <= train_len * (cfg['training']['epochs']) and flag:
        for (images, labels, labels_s, masks) in trainloader:
            i += 1
            j += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            labels_s = labels_s.to(device)
            masks = masks.to(device)

            outputs = model(images)
            if use_superpixels:
                outputs_s, labels_s, sizes = convert_to_superpixels(
                    outputs, labels_s, masks)
                loss = loss_fn(input=outputs_s, target=labels_s, size=sizes)
                outputs = convert_to_pixels(outputs_s, outputs, masks)
            else:
                loss = loss_fn(input=outputs, target=labels)

            # accumulate train metrics during train
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics_train.update(gt, pred)
            train_loss_meter.update(loss.item())

            if args.evaluate:
                decoded = t_loader.decode_segmap(np.squeeze(pred, axis=0))
                misc.imsave("./{}.png".format(i), decoded)
                image_save = np.transpose(
                    np.squeeze(images.data.cpu().numpy(), axis=0), (1, 2, 0))
                misc.imsave("./{}.jpg".format(i), image_save)

            # accumulate gradients based on the accumulation batch size
            if i % it_per_step == 1 or it_per_step == 1:
                optimizer.zero_grad()

            grad_rescaling = torch.tensor(1. / it_per_step).type_as(loss)
            loss.backward(grad_rescaling)
            if (i + 1) % it_per_step == 1 or it_per_step == 1:
                optimizer.step()
                optimizer.zero_grad()

            time_meter.update(time.time() - start_ts)
            # training logs
            if (j + 1) % (cfg['training']['print_interval'] *
                          it_per_step) == 0:
                fmt_str = "Epoch [{}/{}] Iter [{}/{:d}] Loss: {:.4f}  Time/Image: {:.4f}"
                total_iter = int(train_len / eff_batch_size)
                total_epoch = int(cfg['training']['epochs'])
                current_epoch = ceil((i + 1) / train_len)
                current_iter = int((j + 1) / it_per_step)
                print_str = fmt_str.format(
                    current_epoch, total_epoch, current_iter, total_iter,
                    loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger_old.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()
            # end of epoch evaluation
            if (i + 1) % train_len == 0 or \
               (i + 1) == train_len * (cfg['training']['epochs']):
                optimizer.step()
                optimizer.zero_grad()
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, labels_val_s,
                                masks_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        labels_val_s = labels_val_s.to(device)
                        masks_val = masks_val.to(device)

                        outputs = model(images_val)
                        if use_superpixels:
                            outputs_s, labels_val_s, sizes_val = convert_to_superpixels(
                                outputs, labels_val_s, masks_val)
                            val_loss = loss_fn(input=outputs_s,
                                               target=labels_val_s,
                                               size=sizes_val)
                            outputs = convert_to_pixels(
                                outputs_s, outputs, masks_val)
                        else:
                            val_loss = loss_fn(input=outputs,
                                               target=labels_val)
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg,
                                  i + 1)
                logger_old.info("Epoch %d Val Loss: %.4f" % (int(
                    (i + 1) / train_len), val_loss_meter.avg))
                logger_old.info("Epoch %d Train Loss: %.4f" % (int(
                    (i + 1) / train_len), train_loss_meter.avg))

                score, class_iou = running_metrics_train.get_scores()
                print("Training metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/cls_{}'.format(k), v,
                                      i + 1)

                xp.Parent_Train.update(loss=train_loss_meter.avg,
                                       acc=score['Overall Acc: \t'],
                                       acccls=score['Mean Acc : \t'],
                                       fwavacc=score['FreqW Acc : \t'],
                                       meaniu=score['Mean IoU : \t'])

                score, class_iou = running_metrics_val.get_scores()
                print("Validation metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                xp.Parent_Val.update(loss=val_loss_meter.avg,
                                     acc=score['Overall Acc: \t'],
                                     acccls=score['Mean Acc : \t'],
                                     fwavacc=score['FreqW Acc : \t'],
                                     meaniu=score['Mean IoU : \t'])

                xp.Parent_Val.log_and_reset()
                xp.Parent_Train.log_and_reset()
                best_loss.update(xp.loss_val).log()
                best_acc.update(xp.acc_val).log()
                best_acccls.update(xp.acccls_val).log()
                best_fwavacc.update(xp.fwavacc_val).log()
                best_meaniu.update(xp.meaniu_val).log()

                visdir = os.path.join('runs', cfg['training']['loss']['name'],
                                      args.name, 'plots.json')
                xp.to_json(visdir)

                val_loss_meter.reset()
                train_loss_meter.reset()
                running_metrics_val.reset()
                running_metrics_train.reset()
                j = 0

                if score["Mean IoU : \t"] >= best_iou:
                    val_static = 0
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)
                else:
                    val_static += 1

            if (i + 1) == train_len * (
                    cfg['training']['epochs']) or val_static == 10:
                flag = False
                break
    return best_iou
def train(cfg, writer, logger):
    # Setup dataset split before setting up the seed for random
    data_split_info = init_data_split(cfg['data']['path'], cfg['data'].get('split_ratio', 0), cfg['data'].get('compound', False))  # fly jenelia dataset

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Cross Entropy Weight
    if cfg['training']['loss']['name'] != 'regression_l1':
        weight = prep_class_val_weights(cfg['training']['cross_entropy_ratio'])
    else:
        weight = None
    log('Using loss : {}'.format(cfg['training']['loss']['name']))

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None) # if no augmentation => default None
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']
    patch_size = [cfg['training']['patch_size'], cfg['training']['patch_size']]

    t_loader = data_loader(
        data_path,
        split=cfg['data']['train_split'],
        augmentations=data_aug,
        data_split_info=data_split_info,
        patch_size=patch_size,
        allow_empty_patch = cfg['training'].get('allow_empty_patch', False),
        n_classes=cfg['training'].get('n_classes', 2))

    # v_loader = data_loader(
    #     data_path,
    #     split=cfg['data']['val_split'],
    #     data_split_info=data_split_info,
    #     patch_size=patch_size,
    #     n_classe=cfg['training'].get('n_classes', 1))

    n_classes = t_loader.n_classes
    log('n_classes is: {}'.format(n_classes))
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=False)

    print('trainloader len: ', len(trainloader))
    # Setup Metrics
    running_metrics_val = runningScore(n_classes) # a confusion matrix is created


    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)


    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    softmax_function = nn.Softmax(dim=1)

    # model_count = 0
    min_loss = None
    start_iter = 0
    if cfg['training']['resume'] is not None:
        log('resume saved model')
        if os.path.isfile(cfg['training']['resume']):
            display(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            min_loss = checkpoint["min_loss"]
            display(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            display("No checkpoint found at '{}'".format(cfg['training']['resume']))
            log('no saved model found')

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    if cfg['training']['loss']['name'] == 'dice':
        loss_fn = dice_loss()


    i_train_iter = start_iter

    display('Training from {}th iteration\n'.format(i_train_iter))
    while i_train_iter < cfg['training']['train_iters']:
        i_batch_idx = 0
        train_iter_start_time = time.time()
        averageLoss = 0

        # if cfg['training']['loss']['name'] == 'dice':
        #     loss = dice_loss()

        # training
        for (images, labels) in trainloader:
            start_ts = time.time()
            scheduler.step()
            model.train()

            images = images.to(device)
            labels = labels.to(device)

            # images = images.cuda()
            # labels = labels.cuda()

            soft_loss = -1
            mediate_average_loss = -1

            optimizer.zero_grad()
            outputs = model(images)

            if cfg['training']['loss']['name'] == 'dice':
                loss = loss_fn(outputs, labels)
                # print('loss match: ', loss, loss.item())
                averageLoss += loss.item()
            #
            else:
                hard_loss = loss_fn(input=outputs, target=labels, weight=weight,
                                size_average=cfg['training']['loss']['size_average'])

                loss = hard_loss

                averageLoss += loss
            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)
            print_per_batch_check = True if cfg['training']['print_interval_per_batch'] else i_batch_idx+1 == len(trainloader)
            if (i_train_iter + 1) % cfg['training']['print_interval'] == 0 and print_per_batch_check:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i_train_iter + 1,
                                           cfg['training']['train_iters'],
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                display(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i_train_iter + 1)
                time_meter.reset()
            i_batch_idx += 1
        time_for_one_iteration = time.time() - train_iter_start_time

        display('EntireTime for {}th training iteration: {}  EntireTime/Image: {}'.format(i_train_iter+1, time_converter(time_for_one_iteration),
                                                                                          time_converter(time_for_one_iteration/(len(trainloader)*cfg['training']['batch_size']))))
        averageLoss /= (len(trainloader)*cfg['training']['batch_size'])
        # validation
        validation_check = (i_train_iter + 1) % cfg['training']['val_interval'] == 0 or \
                           (i_train_iter + 1) == cfg['training']['train_iters']
        if not validation_check:
            print('no validation check')
        else:

            '''
            This IF-CHECK is used to update the best model
            '''
            log('Validation: average loss for current iteration is: {}'.format(averageLoss))
            if min_loss is None:
                min_loss = averageLoss

            if averageLoss <= min_loss:
                min_loss = averageLoss
                state = {
                    "epoch": i_train_iter + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "min_loss": min_loss
                }

                # if cfg['training']['cp_save_path'] is None:
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_model_best.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                # else:
                #     save_path = os.path.join(cfg['training']['cp_save_path'],  writer.file_writer.get_logdir(),
                #                              "{}_{}_model_best.pkl".format(
                #                                  cfg['model']['arch'],
                #                                  cfg['data']['dataset']))
                print('save_path is: ' + save_path)

                torch.save(state, save_path)

            # model_count += 1

        i_train_iter += 1
示例#10
0
def train(cfg, writer, logger, run_id):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.backends.cudnn.benchmark = True

    os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    # model = get_model(cfg['model'], n_classes).to(device)
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))

    # model=apex.parallel.convert_syncbn_model(model)
    model = model.to(device)

    # a=range(torch.cuda.device_count())
    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model, device_ids=[0, 1])
    # model = encoding.parallel.DataParallelModel(model, device_ids=[0, 1])

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    loss_fn = get_loss_function(cfg)
    # loss_fn== encoding.parallel.DataParallelCriterion(loss_fn, device_ids=[0, 1])
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    train_data_len = t_loader.__len__()
    batch_size = cfg['training']['batch_size']
    epoch = cfg['training']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)

    val_rlt_f1 = []
    val_rlt_OA = []
    best_f1_till_now = 0
    best_OA_till_now = 0

    while i <= train_iter and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            # optimizer.backward(loss)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            ### add by Sprit
            time_meter_val.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, train_iter, loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == train_iter:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        # val_loss_meter.update(val_loss.item())

                # writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                # val_loss_meter.reset()
                running_metrics_val.reset()

                ### add by Sprit
                avg_f1 = score["Mean F1 : \t"]
                OA = score["Overall Acc: \t"]
                val_rlt_f1.append(avg_f1)
                val_rlt_OA.append(score["Overall Acc: \t"])

                if avg_f1 >= best_f1_till_now:
                    best_f1_till_now = avg_f1
                    correspond_OA = score["Overall Acc: \t"]
                    best_f1_epoch_till_now = i + 1
                print("\nBest F1 till now = ", best_f1_till_now)
                print("Correspond OA= ", correspond_OA)
                print("Best F1 Iter till now= ", best_f1_epoch_till_now)

                if OA >= best_OA_till_now:
                    best_OA_till_now = OA
                    correspond_f1 = score["Mean F1 : \t"]
                    # correspond_acc=score["Overall Acc: \t"]
                    best_OA_epoch_till_now = i + 1

                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_OA": best_OA_till_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

                print("Best OA till now = ", best_OA_till_now)
                print("Correspond F1= ", correspond_f1)
                # print("Correspond OA= ",correspond_acc)
                print("Best OA Iter till now= ", best_OA_epoch_till_now)

                ### add by Sprit
                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - i)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain training time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain training time : Training completed.\n"
                print(train_time)

                # if OA >= best_OA_till_now:
                #     best_iou = score["Mean IoU : \t"]
                #     state = {
                #         "epoch": i + 1,
                #         "model_state": model.state_dict(),
                #         "optimizer_state": optimizer.state_dict(),
                #         "scheduler_state": scheduler.state_dict(),
                #         "best_iou": best_iou,
                #     }
                #     save_path = os.path.join(writer.file_writer.get_logdir(),
                #                              "{}_{}_best_model.pkl".format(
                #                                  cfg['model']['arch'],
                #                                  cfg['data']['dataset']))
                #     torch.save(state, save_path)

            if (i + 1) == train_iter:
                flag = False
                break
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_f1,
                  cfg['training']['val_interval'])
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_OA,
                  cfg['training']['val_interval'])
def train(cfg, writer, logger):

    # Setup seeds for reproducing
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'], cfg['task'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        img_norm=cfg['data']['img_norm'],
        # version = cfg['data']['version'],
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_norm=cfg['data']['img_norm'],
        # version=cfg['data']['version'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    if cfg['task'] == "seg":
        n_classes = t_loader.n_classes
        running_metrics_val = runningScoreSeg(n_classes)
    elif cfg['task'] == "depth":
        n_classes = 0
        running_metrics_val = runningScoreDepth()
    else:
        raise NotImplementedError('Task {} not implemented'.format(
            cfg['task']))

    # Setup Model
    model = get_model(cfg['model'], cfg['task'], n_classes).to(device)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            # checkpoint = torch.load(cfg['training']['resume'], map_location=lambda storage, loc: storage)  # load model trained on gpu on cpu
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    best_rel = 100.0
    # i = start_iter
    i = 0
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        print(len(trainloader))
        for (images, labels, img_path) in trainloader:
            start_ts = time.time()  # return current time stamp
            scheduler.step()
            model.train()  # set model to training mode
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()  #clear earlier gradients
            outputs = model(images)
            if cfg['model']['arch'] == "dispnet" and cfg['task'] == "depth":
                outputs = 1 / outputs

            loss = loss_fn(input=outputs, target=labels)  # compute loss
            loss.backward()  # backpropagation loss
            optimizer.step()  # optimizer parameter update

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.val / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or (
                    i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                img_path_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(
                            images_val
                        )  # [batch_size, n_classes, height, width]
                        if cfg['model']['arch'] == "dispnet" and cfg[
                                'task'] == "depth":
                            outputs = 1 / outputs

                        val_loss = loss_fn(input=outputs, target=labels_val
                                           )  # mean pixelwise loss in a batch

                        if cfg['task'] == "seg":
                            pred = outputs.data.max(1)[1].cpu().numpy(
                            )  # [batch_size, height, width]
                            gt = labels_val.data.cpu().numpy(
                            )  # [batch_size, height, width]
                        elif cfg['task'] == "depth":
                            pred = outputs.squeeze(1).data.cpu().numpy()
                            gt = labels_val.data.squeeze(1).cpu().numpy()
                        else:
                            raise NotImplementedError(
                                'Task {} not implemented'.format(cfg['task']))

                        running_metrics_val.update(gt=gt, pred=pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d val_loss: %.4f" %
                            (i + 1, val_loss_meter.avg))
                print("Iter %d val_loss: %.4f" % (i + 1, val_loss_meter.avg))

                # output scores
                if cfg['task'] == "seg":
                    score, class_iou = running_metrics_val.get_scores()
                    for k, v in score.items():
                        print(k, v)
                        sys.stdout.flush()
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                    for k, v in class_iou.items():
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                          i + 1)

                elif cfg['task'] == "depth":
                    val_result = running_metrics_val.get_scores()
                    for k, v in val_result.items():
                        print(k, v)
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                else:
                    raise NotImplementedError('Task {} not implemented'.format(
                        cfg['task']))

                val_loss_meter.reset()
                running_metrics_val.reset()

                save_model = False
                if cfg['task'] == "seg":
                    if score["Mean IoU : \t"] >= best_iou:
                        best_iou = score["Mean IoU : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_iou": best_iou,
                        }

                if cfg['task'] == "depth":
                    if val_result["abs rel : \t"] <= best_rel:
                        best_rel = val_result["abs rel : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_rel": best_rel,
                        }

                if save_model:
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
            i += 1
示例#12
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', RNG_SEED))
    torch.cuda.manual_seed(cfg.get('seed', RNG_SEED))
    np.random.seed(cfg.get('seed', RNG_SEED))
    random.seed(cfg.get('seed', RNG_SEED))

    # Setup device
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device(args.device)

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 10,
                                             'rscalecropsquare': 704,  # 640, # 672, # 704,
                                             })
    elif cfg['data']['dataset'] in ['drive']:
        augmentations = cfg['training'].get('augmentations',
                                            {'brightness': 63. / 255.,
                                             'saturation': 0.5,
                                             'contrast': 0.8,
                                             'hflip': 0.5,
                                             'rotate': 180,
                                             'rscalecropsquare': 576,
                                             })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    else:
        augmentations = cfg['training'].get('augmentations', {'rotate': 10, 'hflip': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes, cfg['data']['void_class'] > 0)

    # Setup Model
    print('trying device {}'.format(device))
    model = get_model(cfg['model'], n_classes, args)  # .to(device)

    if cfg['model']['arch'] not in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'unetresnet50', 'unetresnet50bn',
                                    'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        model.apply(weights_init)
    else:
        init_model(model)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # if cfg['model']['arch'] in ['druresnet50syncedbn']:
    #     print('using synchronized batch normalization')
    #     time.sleep(5)
    #     patch_replication_callback(model)

    model = model.cuda()
    # model = torch.nn.DataParallel(model, device_ids=(3, 2))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}
    if cfg['model']['arch'] in ['unetvgg16', 'unetvgg16gn', 'druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
        optimizer = optimizer_cls([
            {'params': model.module.paramGroup1.parameters(), 'lr': optimizer_params['lr'] / 10},
            {'params': model.module.paramGroup2.parameters()}
        ], **optimizer_params)
    else:
        optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.warning(f"Model parameters in total: {sum([p.numel() for p in model.parameters()])}")
    logger.warning(f"Trainable parameters in total: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    weight = torch.ones(n_classes)
    if cfg['data'].get('void_class'):
        if cfg['data'].get('void_class') >= 0:
            weight[cfg['data'].get('void_class')] = 0.
    weight = weight.to(device)

    logger.info("Set the prediction weights as {}".format(weight))

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            # for param_group in optimizer.param_groups:
            #     print(param_group['lr'])
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            if cfg['model']['arch'] in ['reclast']:
                h0 = torch.ones([images.shape[0], args.hidden_size, images.shape[2], images.shape[3]],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['recmid']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                outputs = model(images, h0)

            elif cfg['model']['arch'] in ['dru', 'sru']:
                W, H = images.shape[2], images.shape[3]
                w = int(np.floor(np.floor(np.floor(W/2)/2)/2)/2)
                h = int(np.floor(np.floor(np.floor(H/2)/2)/2)/2)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32)
                h0.to(device)
                s0 = torch.ones([images.shape[0], n_classes, W, H],
                                dtype=torch.float32)
                s0.to(device)
                outputs = model(images, h0, s0)

            elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                W, H = images.shape[2], images.shape[3]
                w, h = int(W / 2 ** 4), int(H / 2 ** 4)
                if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                    w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                h0 = torch.ones([images.shape[0], args.hidden_size, w, h],
                                dtype=torch.float32, device=device)
                s0 = torch.zeros([images.shape[0], n_classes, W, H],
                                 dtype=torch.float32, device=device)
                outputs = model(images, h0, s0)

            else:
                outputs = model(images)

            loss = loss_fn(input=outputs, target=labels, weight=weight, bkargs=args)
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            # if use_grad_clip(cfg['model']['arch']):  #
            # if cfg['model']['arch'] in ['rcnn', 'rcnn2', 'rcnn3']:  #
            if use_grad_clip(cfg['model']['arch']):
                nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                torch.backends.cudnn.benchmark = False
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
                        if args.benchmark:
                            if i_val > 10:
                                break
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        if cfg['model']['arch'] in ['reclast']:
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, images_val.shape[2], images_val.shape[3]],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['recmid']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            outputs = model(images_val, h0)

                        elif cfg['model']['arch'] in ['dru', 'sru']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w = int(np.floor(np.floor(np.floor(W / 2) / 2) / 2) / 2)
                            h = int(np.floor(np.floor(np.floor(H / 2) / 2) / 2) / 2)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.ones([images_val.shape[0], n_classes, W, H],
                                            dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        elif cfg['model']['arch'] in ['druvgg16', 'druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                            W, H = images_val.shape[2], images_val.shape[3]
                            w, h = int(W / 2**4), int(H / 2**4)
                            if cfg['model']['arch'] in ['druresnet50', 'druresnet50bn', 'druresnet50syncedbn']:
                                w, h = int(W / 2 ** 5), int(H / 2 ** 5)
                            h0 = torch.ones([images_val.shape[0], args.hidden_size, w, h],
                                            dtype=torch.float32)
                            h0.to(device)
                            s0 = torch.zeros([images_val.shape[0], n_classes, W, H],
                                             dtype=torch.float32)
                            s0.to(device)
                            outputs = model(images_val, h0, s0)

                        else:
                            outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val, bkargs=args)

                        if cfg['training']['loss']['name'] in ['multi_step_cross_entropy']:
                            pred = outputs[-1].data.max(1)[1].cpu().numpy()
                        else:
                            pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()
                        logger.debug('pred shape: ', pred.shape, '\t ground-truth shape:',gt.shape)
                        # IPython.embed()
                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())
                    # assert i_val > 0, "Validation dataset is empty for no reason."
                torch.backends.cudnn.benchmark = True
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                # IPython.embed()
                score, class_iou, _ = running_metrics_val.get_scores()
                for k, v in score.items():
                    # print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(),
                                             best_model_path(cfg))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_final_model.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                torch.save(state, save_path)
                break
示例#13
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # # Setup Augmentations
    # augmentations = cfg["training"].get("augmentations", None)
    # data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
        drop_last=True,
    )

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model_orig = get_model(cfg["model"], n_classes).to(device)
    if cfg["training"]["pretrain"] == True:
        # Load a pretrained model
        model_orig.load_pretrained_model(
            model_path="pretrained/pspnet101_cityscapes.caffemodel")
        logger.info("Loaded pretrained model.")
    else:
        # No pretrained model
        logger.info("No pretraining.")

    model = torch.nn.DataParallel(model_orig,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    ### Visualize model training

    # helper function to show an image
    # (used in the `plot_classes_preds` function below)
    def matplotlib_imshow(data, is_image):
        if is_image:  #for images
            data = data / 4 + 0.5  # unnormalize
            npimg = data.numpy()
            plt.imshow(npimg, cmap="gray")
        else:  # for labels
            nplbl = data.numpy()
            plt.imshow(t_loader.decode_segmap(nplbl))

    def plot_classes_preds(data, batch_size, iter, is_image=True):
        fig = plt.figure(figsize=(12, 48))
        for idx in np.arange(batch_size):
            ax = fig.add_subplot(1, batch_size, idx + 1, xticks=[], yticks=[])
            matplotlib_imshow(data[idx], is_image)

            ax.set_title("Iteration Number " + str(iter))

        return fig

    best_iou = -100.0
    #best_val_loss = -100.0
    i = start_iter
    flag = True

    #Check if params trainable
    print('CHECK PARAMETER TRAINING:')
    for name, param in model.named_parameters():
        if param.requires_grad == False:
            print(name, param.data)

    while i <= cfg["training"]["train_iters"] and flag:
        for (images_orig, labels_orig, weights_orig,
             nuc_weights_orig) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()  #convert model into training mode
            images = images_orig.to(device)
            labels = labels_orig.to(device)
            weights = weights_orig.to(device)
            nuc_weights = nuc_weights_orig.to(device)

            optimizer.zero_grad()

            outputs = model(images)

            # Transform output to calculate meaningful loss
            out = outputs[0]

            # Resize output of network to same size as labels
            target_size = (labels.size()[1], labels.size()[2])
            out = torch.nn.functional.interpolate(out,
                                                  size=target_size,
                                                  mode='bicubic')

            # Multiply weights by loss output
            loss = loss_fn(input=out, target=labels)

            loss = torch.mul(loss, weights)  # add contour weights
            loss = torch.mul(loss, nuc_weights)  # add nuclei weights
            loss = loss.mean(
            )  # average over all pixels to obtain scaler for loss

            loss.backward()  # computes gradients over network
            optimizer.step()  #updates parameters

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"][
                    "print_interval"] == 0:  # frequency with which visualize training update
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )
                #Show mini-batches during training

                # #Visualize only DAPI
                # writer.add_figure('Inputs',
                #     plot_classes_preds(images_orig.squeeze(), cfg["training"]["batch_size"], i, True),
                #             global_step=i)

                # writer.add_figure('Targets',
                #     plot_classes_preds(labels_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Take max across classes (of probability maps) and assign class label to visualize semantic map
                #1)
                out_orig = torch.nn.functional.softmax(
                    outputs[0], dim=1).max(1).indices.cpu()
                #out_orig = out_orig.cpu().detach()
                #2)
                #out_orig = torch.argmax(outputs[0],dim=1)
                #3)
                #out_orig = outputs[0].data.max(1)[1].cpu()

                # #Visualize predictions
                # writer.add_figure('Predictions',
                #     plot_classes_preds(out_orig, cfg["training"]["batch_size"], i, False),
                #             global_step=i)

                #Save probability map
                prob_maps_folder = os.path.join(
                    writer.file_writer.get_logdir(), "probability_maps")
                os.makedirs(prob_maps_folder, exist_ok=True)

                #Downsample original images to target size for visualization
                images = torch.nn.functional.interpolate(images,
                                                         size=target_size,
                                                         mode='bicubic')

                out = torch.nn.functional.softmax(out, dim=1)

                contours = (out[:, 1, :, :]).unsqueeze(dim=1)
                nuclei = (out[:, 2, :, :]).unsqueeze(dim=1)
                background = (out[:, 0, :, :]).unsqueeze(dim=1)

                #imageTensor = torch.cat((images, contours, nuclei, background),dim=0)

                # Save images side by side: nrow is how many images per row
                #save_image(make_grid(imageTensor, nrow=2), os.path.join(prob_maps_folder,"Prob_maps_%d.tif" % i))

                # Targets visualization below
                nplbl = labels_orig.numpy()
                targets = []  #each element is RGB target label in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    target_bs = t_loader.decode_segmap(nplbl[bs])
                    target_bs = 255 * target_bs
                    target_bs = target_bs.astype('uint8')
                    target_bs = torch.from_numpy(target_bs)
                    target_bs = target_bs.unsqueeze(dim=0)
                    targets.append(target_bs)  #uint8 labels, shape (N,N,3)

                target = reduce(lambda x, y: torch.cat((x, y), dim=0), targets)
                target = target.permute(0, 3, 1,
                                        2)  # size=(Batch, Channels, N, N)
                target = target.type(torch.FloatTensor)

                save_image(
                    make_grid(target, nrow=cfg["training"]["batch_size"]),
                    os.path.join(prob_maps_folder, "Target_labels_%d.tif" % i))

                # Weights visualization below:
                #wgts = weights_orig.type(torch.FloatTensor)
                #save_image(make_grid(wgts, nrow=2), os.path.join(prob_maps_folder,"Weights_%d.tif" % i))

                # Probability maps visualization below
                t1 = []
                t2 = []
                t3 = []
                t4 = []

                # Normalize individual images in batch
                for bs in np.arange(cfg["training"]["batch_size"]):
                    t1.append((images[bs][0] - images[bs][0].min()) /
                              (images[bs][0].max() - images[bs][0].min()))
                    t2.append(contours[bs])
                    t3.append(nuclei[bs])
                    t4.append(background[bs])

                t1 = [torch.unsqueeze(elem, dim=0)
                      for elem in t1]  #expand dim=0 for images in batch
                # Convert normalized batch to Tensor
                tensor1 = torch.cat((t1), dim=0)
                tensor2 = torch.cat((t2), dim=0)
                tensor3 = torch.cat((t3), dim=0)
                tensor4 = torch.cat((t4), dim=0)

                tTensor = torch.cat((tensor1, tensor2, tensor3, tensor4),
                                    dim=0)
                tTensor = tTensor.unsqueeze(dim=1)

                save_image(make_grid(tTensor,
                                     nrow=cfg["training"]["batch_size"]),
                           os.path.join(prob_maps_folder,
                                        "Prob_maps_%d.tif" % i),
                           normalize=False)

                logger.info(print_str)
                writer.add_scalar(
                    "loss/train_loss", loss.item(),
                    i + 1)  # adds value to history (title, loss, iter index)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1
            ) == cfg["training"][
                    "train_iters"]:  # evaluate model on validation set at these intervals
                model.eval()  # evaluate mode for model
                with torch.no_grad():
                    for i_val, (images_val, labels_val, weights_val,
                                nuc_weights_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        weights_val = weights_val.to(device)
                        nuc_weights_val = nuc_weights_val.to(device)

                        outputs_val = model(images_val)

                        # Resize output of network to same size as labels
                        target_val_size = (labels_val.size()[1],
                                           labels_val.size()[2])
                        outputs_val = torch.nn.functional.interpolate(
                            outputs_val, size=target_val_size, mode='bicubic')

                        # Multiply weights by loss output
                        val_loss = loss_fn(input=outputs_val,
                                           target=labels_val)

                        val_loss = torch.mul(val_loss, weights_val)
                        val_loss = torch.mul(val_loss, nuc_weights_val)
                        val_loss = val_loss.mean(
                        )  # average over all pixels to obtain scaler for loss

                        outputs_val = torch.nn.functional.softmax(outputs_val,
                                                                  dim=1)

                        #Save probability map
                        val_prob_maps_folder = os.path.join(
                            writer.file_writer.get_logdir(),
                            "val_probability_maps")
                        os.makedirs(val_prob_maps_folder, exist_ok=True)

                        #Downsample original images to target size for visualization
                        images_val = torch.nn.functional.interpolate(
                            images_val, size=target_val_size, mode='bicubic')

                        contours_val = (outputs_val[:,
                                                    1, :, :]).unsqueeze(dim=1)
                        nuclei_val = (outputs_val[:, 2, :, :]).unsqueeze(dim=1)
                        background_val = (outputs_val[:, 0, :, :]).unsqueeze(
                            dim=1)

                        # Targets visualization below
                        nplbl_val = labels_val.cpu().numpy()
                        targets_val = [
                        ]  #each element is RGB target label in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            target_bs = v_loader.decode_segmap(nplbl_val[bs])
                            target_bs = 255 * target_bs
                            target_bs = target_bs.astype('uint8')
                            target_bs = torch.from_numpy(target_bs)
                            target_bs = target_bs.unsqueeze(dim=0)
                            targets_val.append(
                                target_bs)  #uint8 labels, shape (N,N,3)

                        target_val = reduce(
                            lambda x, y: torch.cat((x, y), dim=0), targets_val)
                        target_val = target_val.permute(
                            0, 3, 1, 2)  # size=(Batch, Channels, N, N)
                        target_val = target_val.type(torch.FloatTensor)

                        save_image(
                            make_grid(target_val,
                                      nrow=cfg["training"]["batch_size"]),
                            os.path.join(
                                val_prob_maps_folder,
                                "Target_labels_%d_val_%d.tif" % (i, i_val)))

                        # Weights visualization below:
                        #wgts_val = weights_val.type(torch.FloatTensor)
                        #save_image(make_grid(wgts_val, nrow=2), os.path.join(val_prob_maps_folder,"Weights_val_%d.tif" % i_val))

                        # Probability maps visualization below
                        t1_val = []
                        t2_val = []
                        t3_val = []
                        t4_val = []
                        # Normalize individual images in batch
                        for bs in np.arange(cfg["training"]["batch_size"]):
                            t1_val.append(
                                (images_val[bs][0] - images_val[bs][0].min()) /
                                (images_val[bs][0].max() -
                                 images_val[bs][0].min()))
                            t2_val.append(contours_val[bs])
                            t3_val.append(nuclei_val[bs])
                            t4_val.append(background_val[bs])

                        t1_val = [
                            torch.unsqueeze(elem, dim=0) for elem in t1_val
                        ]  #expand dim=0 for images_val in batch
                        # Convert normalized batch to Tensor
                        tensor1_val = torch.cat((t1_val), dim=0)
                        tensor2_val = torch.cat((t2_val), dim=0)
                        tensor3_val = torch.cat((t3_val), dim=0)
                        tensor4_val = torch.cat((t4_val), dim=0)

                        tTensor_val = torch.cat((tensor1_val, tensor2_val,
                                                 tensor3_val, tensor4_val),
                                                dim=0)
                        tTensor_val = tTensor_val.unsqueeze(dim=1)

                        save_image(make_grid(
                            tTensor_val, nrow=cfg["training"]["batch_size"]),
                                   os.path.join(
                                       val_prob_maps_folder,
                                       "Prob_maps_%d_val_%d.tif" % (i, i_val)),
                                   normalize=False)

                        pred = outputs_val.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                ### Save best validation loss model
                # if val_loss_meter.avg >= best_val_loss:
                #     best_val_loss = val_loss_meter.avg
                #     state = {
                #         "epoch": i + 1,
                #         "model_state": model.state_dict(),
                #         "optimizer_state": optimizer.state_dict(),
                #         "scheduler_state": scheduler.state_dict(),
                #         "best_val_loss": best_val_loss,
                #     }
                #     save_path = os.path.join(
                #         writer.file_writer.get_logdir(),
                #         "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                #     )
                #     torch.save(state, save_path)
                ###

                score, class_iou = running_metrics_val.get_scores(
                )  # best model chosen via IoU
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                ### Save best mean IoU model
                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                ###

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
示例#14
0
def train(cfg, writer, logger, start_iter=0, model_only=False, gpu=-1, save_dir=None):

    # Setup seeds and config
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))
    
    # Setup device
    if gpu == -1:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cuda:%d" %gpu if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    if cfg["data"]["dataset"] == "softmax_cityscapes_convention":
        data_aug = get_composed_augmentations_softmax(augmentations)
    else:
        data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )
    v_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    sampler = None
    if "sampling" in cfg["data"]:
        sampler = data.WeightedRandomSampler(
            weights = get_sampling_weights(t_loader, cfg["data"]["sampling"]),
            num_samples = len(t_loader),
            replacement = True
        )
    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        sampler=sampler,
        shuffle=sampler==None,
    )
    valloader = data.DataLoader(
        v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]
    )

    # Setup Metrics
    running_metrics_val = {"seg": runningScoreSeg(n_classes)}
    if "classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier( len(classes) )
    if "bin_classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["bin_classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier(2)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print( 'Parameters:',total_params )

    if gpu == -1:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    else:
        model = torch.nn.DataParallel(model, device_ids=[gpu])
    
    model.apply(weights_init)
    pretrained_path='weights/hardnet_petite_base.pth'
    weights = torch.load(pretrained_path)
    model.module.base.load_state_dict(weights)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    print("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])
    loss_dict = get_loss_function(cfg, device)

    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"])
            )
            checkpoint = torch.load(cfg["training"]["resume"], map_location=device)
            model.load_state_dict(checkpoint["model_state"], strict=False)
            if not model_only:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
                scheduler.load_state_dict(checkpoint["scheduler_state"])
                start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"]))

    if cfg["training"]["finetune"] is not None:
        if os.path.isfile(cfg["training"]["finetune"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["finetune"])
            )
            checkpoint = torch.load(cfg["training"]["finetune"])
            model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True
    loss_all = 0
    loss_n = 0

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, label_dict, _) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()

            images = images.to(device)
            optimizer.zero_grad()
            output_dict = model(images)

            loss = compute_loss(    # considers key names in loss_dict and output_dict
                loss_dict, images, label_dict, output_dict, device, t_loader
            )
            
            loss.backward()         # backprops sum of loss tensors, frozen components will have no grad_fn
            optimizer.step()
            c_lr = scheduler.get_lr()

            if i%1000 == 0:             # log images, seg ground truths, predictions
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict["seg"].data.cpu().numpy()
                softmax_gt_array = None
                if "softmax" in label_dict:
                    softmax_gt_array = label_dict["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(t_loader, images, gt_array, pred_array, i, name = 'train', softmax_gt = softmax_gt_array)

                if save_dir is not None:
                    image_array = images.data.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(t_loader, image_array, gt_array, pred_array, i, save_dir, name = 'train', softmax_gt = softmax_gt_array)

            time_meter.update(time.time() - start_ts)
            loss_all += loss.item()
            loss_n += 1
            
            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}  lr={:.6f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_all / loss_n,
                    time_meter.avg / cfg["training"]["batch_size"],
                    c_lr[0],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][
                "train_iters"
            ]:
                torch.cuda.empty_cache()
                model.eval() # set batchnorm and dropouts to work in eval mode
                loss_all = 0
                loss_n = 0
                with torch.no_grad(): # Deactivate torch autograd engine, less memusage
                    for i_val, (images_val, label_dict_val, _) in tqdm(enumerate(valloader)):
                        
                        images_val = images_val.to(device)
                        output_dict = model(images_val)
                        
                        val_loss = compute_loss(
                            loss_dict, images_val, label_dict_val, output_dict, device, v_loader
                        )
                        val_loss_meter.update(val_loss.item())

                        for name, metrics in running_metrics_val.items():
                            gt_array = label_dict_val[name].data.cpu().numpy()
                            if name+'_loss' in cfg['training'] and cfg['training'][name+'_loss']['name'] == 'l1':  # for binary classification
                                pred_array = output_dict[name].data.cpu().numpy()
                                pred_array = np.sign(pred_array)
                                pred_array[pred_array == -1] = 0
                                gt_array[gt_array == -1] = 0
                            else:
                                pred_array = output_dict[name].data.max(1)[1].cpu().numpy()

                            metrics.update(gt_array, pred_array)

                softmax_gt_array = None # log validation images
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict_val["seg"].data.cpu().numpy()
                if "softmax" in label_dict_val:
                    softmax_gt_array = label_dict_val["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(v_loader, images_val, gt_array, pred_array, i, 'validation', softmax_gt = softmax_gt_array)
                if save_dir is not None:
                    images_val = images_val.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(v_loader, images_val, gt_array, pred_array, i, save_dir, name='validation', softmax_gt = softmax_gt_array)

                logger.info("Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg))
                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)

                for name, metrics in running_metrics_val.items():
                    
                    overall, classwise = metrics.get_scores()
                    
                    for k, v in overall.items():
                        logger.info("{}_{}: {}".format(name, k, v))
                        writer.add_scalar("val_metrics/{}_{}".format(name, k), v, i + 1)

                        if k == cfg["training"]["save_metric"]:
                            curr_performance = v

                    for metric_name, metric in classwise.items():
                        for k, v in metric.items():
                            logger.info("{}_{}_{}: {}".format(name, metric_name, k, v))
                            writer.add_scalar("val_metrics/{}_{}_{}".format(name, metric_name, k), v, i + 1)

                    metrics.reset()
                
                state = {
                      "epoch": i + 1,
                      "model_state": model.state_dict(),
                      "optimizer_state": optimizer.state_dict(),
                      "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

                if curr_performance >= best_iou:
                    best_iou = curr_performance
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                torch.cuda.empty_cache()

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
示例#15
0
def train(cfg, writer, logger):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])

    t_loader = data_loader(cfg=cfg["data"],
                           mode='train',
                           augmentations=t_data_aug)
    v_loader = data_loader(cfg=cfg["data"],
                           mode='val',
                           augmentations=v_data_aug)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"])

    logger.info("Using training seting {}".format(cfg["training"]))

    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes,
                                       t_loader.unseen_classes)

    model_state = torch.load(
        './runs/deeplabv3p_ade_25unseen/84253/deeplabv3p_ade20k_best_model.pkl'
    )
    running_metrics_val.confusion_matrix = model_state['results']
    score, a_iou = running_metrics_val.get_scores()

    pdb.set_trace()
    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    logger.info("Using loss {}".format(loss_fn))
    model = get_model(cfg["model"], t_loader.n_classes, loss_fn=loss_fn)

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Initialize training param
    start_iter = 0
    best_iou = -100.0

    # Resume from checkpoint
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info("Resuming training from checkpoint '{}'".format(
                cfg["training"]["resume"]))
            model_state = torch.load(cfg["training"]["resume"])["model_state"]
            model.load_state_dict(model_state)
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # Setup Multi-GPU
    if torch.cuda.is_available():
        model = model.cuda()  # DataParallelModel(model).cuda()
        logger.info("Model initialized on GPUs.")

    time_meter = averageMeter()
    i = start_iter

    embd = t_loader.embeddings
    ignr_idx = t_loader.ignore_index
    embds = embd.cuda()
    while i <= cfg["training"]["train_iters"]:
        for (images, labels) in trainloader:
            images = images.cuda()
            labels = labels.cuda()

            i += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            loss_sum = model(images, labels, embds, ignr_idx)
            if loss_sum == 0:  # Ignore samples contain unseen cat
                continue  # To enable non-transductive learning, set transductive=0 in the config

            loss_sum.backward()

            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_sum.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss_sum.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.cuda()
                        labels_val = labels_val.cuda()
                        outputs = model(images_val, labels_val, embds,
                                        ignr_idx)
                        # outputs = gather(outputs, 0, dim=0)

                        running_metrics_val.update(outputs)

                score, a_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print("{}: {}".format(k, v))
                    logger.info("{}: {}".format(k, v))
                    #writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                #for k, v in class_iou.items():
                #    logger.info("{}: {}".format(k, v))
                #    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                if a_iou >= best_iou:
                    best_iou = a_iou
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                        "results": running_metrics_val.confusion_matrix
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                running_metrics_val.reset()
示例#16
0
def train(cfg, writer, logger):
    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = cityscapesLoader
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()

    # get loss_seg meter and also loss_dep meter

    val_loss_meter = averageMeter()
    # loss_seg_meter = averageMeter()
    # loss_dep_meter = averageMeter()
    time_meter = averageMeter()
    acc_result_total = averageMeter()
    acc_result_correct = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, masks, depths) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            depths = depths.to(device)

            # print(images.shape)
            optimizer.zero_grad()
            outputs = model(images).squeeze(1)

            # -----------------------------------------------------------------
            # add depth loss

            # -----------------------------------------------------------------
            # MSE loss
            # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean')

            # -----------------------------------------------------------------
            # Berhu loss; loss_dep = loss
            loss = berhu_loss_function(prediction=outputs, target=depths)
            masks = masks.type(torch.cuda.ByteTensor)
            loss = torch.sum(loss[masks]) / torch.sum(masks)

            # -----------------------------------------------------------------

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}] loss_dep: {:.4f} Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg["training"]["train_iters"], loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, masks_val,
                                depths_val) in enumerate(valloader):
                        images_val = images_val.to(device)

                        # add depth to device
                        depths_val = depths_val.to(device)

                        outputs = model(images_val).squeeze(1)
                        # depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3))

                        # -----------------------------------------------------------------
                        # berhu loss function
                        val_loss = berhu_loss_function(prediction=outputs,
                                                       target=depths_val)
                        masks_val = masks_val.type(torch.cuda.ByteTensor)
                        val_loss = val_loss.type(torch.cuda.ByteTensor)
                        print('val_loss1 is', val_loss)
                        val_loss = torch.sum(
                            val_loss[masks_val]) / torch.sum(masks_val)
                        print('val_loss2 is', val_loss)

                        # -----------------------------------------------------------------
                        # Update

                        val_loss_meter.update(val_loss.item())

                        outputs = outputs.cpu().numpy()
                        depths_val = depths_val.cpu().numpy()
                        masks_val = masks_val.cpu().numpy()

                        # depths_val = depths_val.type(torch.cuda.FloatTensor)
                        # outputs = outputs.type(torch.cuda.FloatTensor)

                        # -----------------------------------------------------------------
                        # Try the following against error:
                        # RuntimeWarning: invalid value encountered in double_scalars: acc = np.diag(hist).sum() / hist.sum()
                        # Similar error: https://github.com/meetshah1995/pytorch-semseg/issues/118

                        acc_1 = outputs / depths_val
                        acc_2 = 1 / acc_1
                        acc_threshold = np.maximum(acc_1, acc_2)

                        acc_result_total.update(np.sum(masks_val))
                        acc_result_correct.update(
                            np.sum(
                                np.logical_and(acc_threshold < 1.25,
                                               masks_val)))

                print("Iter {:d}, val_loss {:.4f}".format(
                    i + 1, val_loss_meter.avg))
                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                acc_result = float(acc_result_correct.sum) / float(
                    acc_result_total.sum)
                print("Iter {:d}, acc_1.25 {:.4f}".format(i + 1, acc_result))
                logger.info("Iter %d acc_1.25: %.4f" % (i + 1, acc_result))

                # -----------------------------------------------------------------

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                acc_result_total.reset()
                acc_result_correct.reset()

                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                    # insert print function to see if the losses are correct

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
示例#17
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if not 'fold' in cfg['data'].keys():
        cfg['data']['fold'] = None

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        augmentations=data_aug,
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=1,
                                num_workers=cfg['training']['n_workers'])

    logger.info("Training on fold {}".format(cfg['data']['fold']))
    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    if args.model_path != "fcn8s_pascal_1_26.pkl":  # Default Value
        state = convert_state_dict(torch.load(args.model_path)["model_state"])
        if cfg['model']['use_scale']:
            model = load_my_state_dict(model, state)
            model.freeze_weights_extractor()
        else:
            model.load_state_dict(state)
            model.freeze_weights_extractor()

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    torch.cuda.synchronize()

    logger.info("Start training from here!!!!!")
    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            #            import matplotlib.pyplot as plt
            #            plt.figure(1);plt.imshow(np.transpose(images[0], (1,2,0)));plt.figure(2); plt.imshow(labels[0]); plt.show()

            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            logger.info("Start validation from here!!!!!")
            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
示例#18
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
        n_classes=20,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])
    # -----------------------------------------------------------------
    # Setup Metrics (substract one class)
    running_metrics_val = runningScore(n_classes - 1)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()

    # get loss_seg meter and also loss_dep meter

    loss_seg_meter = averageMeter()
    loss_dep_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, masks, depths) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            depths = depths.to(device)

            #print(images.shape)
            optimizer.zero_grad()
            outputs = model(images)
            #print('depths size: ', depths.size())
            #print('output shape: ', outputs.shape)

            loss_seg = loss_fn(input=outputs[:, :-1, :, :], target=labels)

            # -----------------------------------------------------------------
            # add depth loss

            # -----------------------------------------------------------------
            # MSE loss
            # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean')

            # -----------------------------------------------------------------
            # Berhu loss
            loss_dep = berhu_loss_function(prediction=outputs[:, -1, :, :],
                                           target=depths)
            #loss_dep = loss_dep.type(torch.cuda.ByteTensor)
            masks = masks.type(torch.cuda.ByteTensor)
            loss_dep = torch.sum(loss_dep[masks]) / torch.sum(masks)
            print('loss depth', loss_dep)
            loss = loss_dep + loss_seg
            # -----------------------------------------------------------------

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  loss_seg: {:.4f}  loss_dep: {:.4f}  overall loss: {:.4f}   Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg["training"]["train_iters"], loss_seg.item(),
                    loss_dep.item(), loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, masks_val,
                                depths_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        print('images_val shape', images_val.size())
                        # add depth to device
                        depths_val = depths_val.to(device)

                        outputs = model(images_val)
                        #depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3))

                        # -----------------------------------------------------------------
                        # loss function for segmentation
                        print('output shape', outputs.size())
                        val_loss_seg = loss_fn(input=outputs[:, :-1, :, :],
                                               target=labels_val)

                        # -----------------------------------------------------------------
                        # MSE loss
                        # val_loss_dep = F.mse_loss(input=outputs[:, -1, :, :], target=depths_val, reduction='mean')

                        # -----------------------------------------------------------------
                        # berhu loss function
                        val_loss_dep = berhu_loss_function(
                            prediction=outputs[:, -1, :, :], target=depths_val)
                        val_loss_dep = val_loss_dep.type(torch.cuda.ByteTensor)
                        masks_val = masks_val.type(torch.cuda.ByteTensor)
                        val_loss_dep = torch.sum(
                            val_loss_dep[masks_val]) / torch.sum(masks_val)
                        val_loss = loss_dep + loss_seg
                        # -----------------------------------------------------------------

                        prediction = outputs[:, :-1, :, :]
                        prediction = prediction.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        # adapt metrics to seg and dep
                        running_metrics_val.update(gt, prediction)
                        loss_seg_meter.update(val_loss_seg.item())
                        loss_dep_meter.update(val_loss_dep.item())

                        # -----------------------------------------------------------------
                        # get rid of val_loss_meter
                        # val_loss_meter.update(val_loss.item())
                        # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                        # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                        # -----------------------------------------------------------------

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                print("Segmentation loss is {}".format(loss_seg_meter.avg))
                logger.info("Segmentation loss is {}".format(
                    loss_seg_meter.avg))
                #writer.add_scalar("Segmentation loss is {}".format(loss_seg_meter.avg), i + 1)

                print("Depth loss is {}".format(loss_dep_meter.avg))
                logger.info("Depth loss is {}".format(loss_dep_meter.avg))
                #writer.add_scalar("Depth loss is {}".format(loss_dep_meter.avg), i + 1)

                val_loss_meter.reset()
                loss_seg_meter.reset()
                loss_dep_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                    # insert print function to see if the losses are correct

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
示例#19
0
def train(cfg, writer, logger):

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    logger.info("data path: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split='train',
        split_root=cfg.data.split,
        augments=data_aug,
        logger=logger,
        log=cfg.data.log,
        ENL=cfg.data.ENL,
    )

    v_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        split='val',
        log=cfg.data.log,
        split_root=cfg.data.split,
        logger=logger,
        ENL=cfg.data.ENL,
    )

    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg.test.batch_size,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.train.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    # loss_fn = get_loss_function(cfg)
    # logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)
    # data_range /= 350

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    train_loss_meter = averageMeter()
    val_psnr_meter = averageMeter()
    val_ssim_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for clean, noisy, _ in trainloader:
            it += 1

            noisy = noisy.to(device, dtype=torch.float32)
            # noisy /= 350
            mask1, mask2 = rand_pool.generate_mask_pair(noisy)
            noisy_sub1 = rand_pool.generate_subimages(noisy, mask1)
            noisy_sub2 = rand_pool.generate_subimages(noisy, mask2)

            # preparing for the regularization term
            with torch.no_grad():
                noisy_denoised = model(noisy)
            noisy_sub1_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask1)
            noisy_sub2_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask2)
            # print(rand_pool.operation_seed_counter)

            # for ii, param in enumerate(model.parameters()):
            #     if torch.sum(torch.isnan(param.data)):
            #         print(f'{ii}: nan parameters')

            # calculating the loss
            noisy_output = model(noisy_sub1)
            noisy_target = noisy_sub2
            if cfg.train.loss.gamma.const:
                gamma = cfg.train.loss.gamma.base
            else:
                gamma = it / train_iter * cfg.train.loss.gamma.base

            diff = noisy_output - noisy_target
            exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
            loss1 = torch.mean(diff**2)
            loss2 = gamma * torch.mean((diff - exp_diff)**2)
            loss_all = loss1 + loss2

            # loss1 = noisy_output - noisy_target
            # loss2 = torch.exp(noisy_target - noisy_output)
            # loss_all = torch.mean(loss1 + loss2)
            loss_all.backward()

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            # record the loss of the minibatch
            train_loss_meter.update(loss_all)
            train_time_meter.update(time.time() - train_start_time)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)

            if it % 1000 == 0:
                writer.add_histogram('hist/pred', noisy_denoised, it)
                writer.add_histogram('hist/noisy', noisy, it)

                if cfg.data.simulate:
                    writer.add_histogram('hist/clean', clean, it)

            if cfg.data.simulate:
                pass

            # print interval
            if it % cfg.train.print_interval == 0:
                terminal_info = f"Iter [{it:d}/{train_iter:d}]  \
                                train Loss: {train_loss_meter.avg:.4f}  \
                                Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}"

                logger.info(terminal_info)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg, it)

                if cfg.data.simulate:
                    pass

                runing_metrics_train.reset()
                train_time_meter.reset()
                train_loss_meter.reset()

            # val interval
            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()
                with torch.no_grad():
                    for clean, noisy, _ in valloader:
                        # noisy /= 350
                        # clean /= 350
                        noisy = noisy.to(device, dtype=torch.float32)
                        noisy_denoised = model(noisy)

                        if cfg.data.simulate:
                            clean = clean.to(device, dtype=torch.float32)
                            psnr = piq.psnr(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            ssim = piq.ssim(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            val_psnr_meter.update(psnr)
                            val_ssim_meter.update(ssim)

                        val_loss = torch.mean((noisy_denoised - noisy)**2)
                        val_loss_meter.update(val_loss)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}"
                )
                val_loss_meter.reset()
                running_metrics_val.reset()

                if cfg.data.simulate:
                    writer.add_scalars('metrics/val', {
                        'psnr': val_psnr_meter.avg,
                        'ssim': val_ssim_meter.avg
                    }, it)
                    logger.info(
                        f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}'
                    )
                    val_psnr_meter.reset()
                    val_ssim_meter.reset()

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            # save model
            if it % (train_iter / cfg.train.epoch * 10) == 0:
                ep = int(it / (train_iter / cfg.train.epoch))
                state = {
                    "epoch": it,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = osp.join(writer.file_writer.get_logdir(),
                                     f"{ep}.pkl")
                torch.save(state, save_path)
                logger.info(f'saved model state dict at {save_path}')

            train_start_time = time.time()
示例#20
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    import cv2

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataloader_type"])

    data_root = cfg["data"]["data_root"]
    presentation_root = cfg["data"]["presentation_root"]

    t_loader = data_loader(
        data_root=data_root,
        presentation_root=presentation_root,
        is_transform=True,
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(data_root=data_root,
                           presentation_root=presentation_root,
                           is_transform=True,
                           img_size=(cfg["data"]["img_rows"],
                                     cfg["data"]["img_cols"]),
                           augmentations=data_aug,
                           test_mode=True)

    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=False,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"],
                                shuffle=False)

    # Setup Metrics
    # running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes, defaultParams).to(device)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    model.load_pretrained_weights(cfg["training"]["saved_model_path"])

    # train_loss_meter = averageMeter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter

    while i <= cfg["training"]["num_presentations"]:

        #                #
        # TRAINING PHASE #
        #                #
        i += 1
        start_ts = time.time()
        trainloader.dataset.random_select()

        hebb = model.initialZeroHebb().to(device)
        for idx, (images, labels) in enumerate(
                trainloader, 1):  # get a single training presentation

            images = images.to(device)
            labels = labels.to(device)

            if idx <= 5:
                model.eval()
                with torch.no_grad():
                    outputs, hebb = model(images,
                                          labels,
                                          hebb,
                                          device,
                                          test_mode=False)
            else:
                scheduler.step()
                model.train()
                optimizer.zero_grad()
                outputs, hebb = model(images,
                                      labels,
                                      hebb,
                                      device,
                                      test_mode=True)
                loss = loss_fn(input=outputs, target=labels)
                loss.backward()
                optimizer.step()

        time_meter.update(time.time() -
                          start_ts)  # -> time taken per presentation

        if (i + 1) % cfg["training"]["print_interval"] == 0:
            fmt_str = "Pres [{:d}/{:d}]  Loss: {:.4f}  Time/Pres: {:.4f}"
            print_str = fmt_str.format(
                i + 1,
                cfg["training"]["num_presentations"],
                loss.item(),
                time_meter.avg / cfg["training"]["batch_size"],
            )
            print(print_str)
            logger.info(print_str)
            writer.add_scalar("loss/test_loss", loss.item(), i + 1)
            time_meter.reset()

        #            #
        # TEST PHASE #
        #            #
        if ((i + 1) % cfg["training"]["test_interval"] == 0
                or (i + 1) == cfg["training"]["num_presentations"]):

            training_state_dict = model.state_dict(
            )  # saving the training state of the model

            valloader.dataset.random_select()
            hebb = model.initialZeroHebb().to(device)
            for idx, (images_val, labels_val) in enumerate(
                    valloader, 1):  # get a single test presentation

                images_val = images_val.to(device)
                labels_val = labels_val.to(device)

                if idx <= 5:
                    model.eval()
                    with torch.no_grad():
                        outputs, hebb = model(images_val,
                                              labels_val,
                                              hebb,
                                              device,
                                              test_mode=False)
                else:
                    model.train()
                    optimizer.zero_grad()
                    outputs, hebb = model(images_val,
                                          labels_val,
                                          hebb,
                                          device,
                                          test_mode=True)
                    loss = loss_fn(input=outputs, target=labels_val)
                    loss.backward()
                    optimizer.step()

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                    val_loss_meter.update(loss.item())

            model.load_state_dict(
                training_state_dict)  # revert back to training parameters

            writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
            logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

            for k, v in class_iou.items():
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

            val_loss_meter.reset()
            running_metrics_val.reset()

            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_best_model.pkl".format(
                        cfg["model"]["arch"], cfg["data"]["dataloader_type"]),
                )
                torch.save(state, save_path)

        if (i + 1) == cfg["training"]["num_presentations"]:
            break
示例#22
0
def train(cfg, writer, logger, run_id):

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    torch.backends.cudnn.benchmark = True

    # Setup Augmentations
    augmentations = cfg['train'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataloader'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           transform=None,
                           split=cfg['data']['train_split'],
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        transform=None,
        split=cfg['data']['val_split'],
    )
    logger.info(
        f'num of train samples: {len(t_loader)} \nnum of val samples: {len(v_loader)}'
    )

    train_data_len = len(t_loader)
    batch_size = cfg['train']['batch_size']
    epoch = cfg['train']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')
    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['train']['batch_size'],
                                  num_workers=cfg['train']['n_workers'],
                                  shuffle=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['train']['batch_size'],
                                num_workers=cfg['train']['n_workers'])

    # Setup Model
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))
    device = f'cuda:{cuda_idx[0]}'
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=cuda_idx)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['train']['optimizer'].items() if k != 'name'
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    scheduler = get_scheduler(optimizer, cfg['train']['lr_schedule'])
    loss_fn = get_loss_function(cfg)
    # logger.info("Using loss {}".format(loss_fn))

    # set checkpoints
    start_iter = 0
    if cfg['train']['resume'] is not None:
        if os.path.isfile(cfg['train']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['train']['resume']))
            checkpoint = torch.load(cfg['train']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['train']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['train']['resume']))

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = 0
    flag = True

    val_rlt_f1 = []
    val_rlt_OA = []
    best_f1_till_now = 0
    best_OA_till_now = 0
    best_fwIoU_now = 0
    best_fwIoU_iter_till_now = 0

    # train
    it = start_iter
    model.train()
    while it <= train_iter and flag:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1
            start_ts = time.time()
            file_a = file_a.to(device)
            file_b = file_b.to(device)
            label = label.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            outputs = model(file_a, file_b)

            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()
            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            train_time_meter.update(time.time() - start_ts)
            time_meter_val.update(time.time() - start_ts)

            if (it + 1) % cfg['train']['print_interval'] == 0:
                fmt_str = "train:\nIter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    it + 1,
                    train_iter,
                    loss.item(),  #extracts the loss’s value as a Python float.
                    train_time_meter.avg / cfg['train']['batch_size'])
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it + 1)

            if (it + 1) % cfg['train']['val_interval'] == 0 or \
               (it + 1) == train_iter:
                model.eval()  # change behavior like drop out
                with torch.no_grad():  # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val,
                         mask_val) in valloader:
                        file_a_val = file_a_val.to(device)
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max with return the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        gt = label_val.numpy()
                        running_metrics_val.update(gt, pred, mask_val)

                        label_val = label_val.to(device)
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs,
                                           target=label_val,
                                           mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                lr_now = optimizer.param_groups[0]['lr']
                logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it + 1)
                logger.info("Iter %d, val Loss: %.4f" %
                            (it + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                # for k, v in score.items():
                #     logger.info('{}: {}'.format(k, v))
                #     writer.add_scalar('val_metrics/{}'.format(k), v, it+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                      it + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                avg_f1 = score["Mean_F1"]
                OA = score["Overall_Acc"]
                fw_IoU = score["FreqW_IoU"]
                val_rlt_f1.append(avg_f1)
                val_rlt_OA.append(OA)

                if fw_IoU >= best_fwIoU_now and it > 200:
                    best_fwIoU_now = fw_IoU
                    correspond_meanIou = score["Mean_IoU"]
                    best_fwIoU_iter_till_now = it + 1

                    state = {
                        "epoch": it + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_fwIoU": best_fwIoU_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(
                            cfg['model']['arch'], cfg['data']['dataloader']))
                    torch.save(state, save_path)

                    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
                    logger.info("Best fwIoU Iter till now= %d" %
                                (best_fwIoU_iter_till_now))

                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - it)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                print(train_time)

            model.train()
            if (it + 1) == train_iter:
                flag = False
                logger.info("Use the Sar_seg_band3,val_interval: 30")
                break
    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
    logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now))

    state = {
        "epoch": it + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_fwIoU": best_fwIoU_now,
    }
    save_path = os.path.join(
        writer.file_writer.get_logdir(),
        "{}_{}_last_model.pkl".format(cfg['model']['arch'],
                                      cfg['data']['dataloader']))
    torch.save(state, save_path)
示例#23
0
def test(cfg, areaname):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    #    data_loader = get_loader(cfg['data']['dataset'])
    #    data_path = cfg['data']['path']
    #
    #    t_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['train_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    #        augmentations=data_aug)
    #
    #    v_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['val_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
    #
    #    n_classes = t_loader.n_classes
    #    trainloader = data.DataLoader(t_loader,
    #                                  batch_size=cfg['training']['batch_size'],
    #                                  num_workers=cfg['training']['n_workers'],
    #                                  shuffle=True)
    #
    #    valloader = data.DataLoader(v_loader,
    #                                batch_size=cfg['training']['batch_size'],
    #                                num_workers=cfg['training']['n_workers'])
    datapath = '/home/chengjjang/Projects/deepres/SatelliteData/{}/'.format(
        areaname)
    paths = {
        'masks': '{}/patch{}_train/gt'.format(datapath, areaname),
        'images': '{}/patch{}_train/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_train/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_train/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_train/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_train/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_train/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_train/ndvi'.format(datapath, areaname),
    }

    valpaths = {
        'masks': '{}/patch{}_train/gt'.format(datapath, areaname),
        'images': '{}/patch{}_train/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_train/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_train/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_train/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_train/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_train/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_train/ndvi'.format(datapath, areaname),
    }

    n_classes = 3
    train_img_paths = [
        pth for pth in os.listdir(paths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    val_img_paths = [
        pth for pth in os.listdir(valpaths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')

    print('valds.im_names: {}'.format(valds.im_names))

    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = '{}/patch{}_train'.format(datapath, areaname)
        dataset_path, train_dir = os.path.split(train_data_path)
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = ValDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'],
                                shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    nbackground = 1116403140
    ncorn = 44080178
    nsoybean = 316698122

    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    print('nsoybean: {}'.format(nsoybean))

    wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0] / total_wgts
    wgt_corn = wgts[1] / total_wgts
    wgt_soybean = wgts[2] / total_wgts
    weights = torch.autograd.Variable(
        torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)

    start_iter = 0
    runpath = '/home/chengjjang/arisia/CropPSPNet/runs/pspnet_crop_{}'.format(
        areaname)
    modelpath = glob.glob('{}/*/*_best_model.pkl'.format(runpath))[0]
    print('modelpath: {}'.format(modelpath))
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0

    respath = '{}_results_train'.format(areaname)
    os.makedirs(respath, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for inputdata in valloader:
            imname_val = inputdata['img_name']
            images_val = inputdata['img_data']
            labels_val = inputdata['seg_label']
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            print('imname_train: {}'.format(imname_val))

            outputs = model(images_val)
            val_loss = loss_fn(input=outputs, target=labels_val)

            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()

            dname = imname_val[0].split('.png')[0]
            np.save('{}/pred'.format(respath) + dname + '.npy', pred)
            np.save('{}/gt'.format(respath) + dname + '.npy', gt)
            np.save('{}/output'.format(respath) + dname + '.npy',
                    outputs.data.cpu().numpy())

            running_metrics_val.update(gt, pred)
            val_loss_meter.update(val_loss.item())

    #writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
    #logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
    print('Test loss: {}'.format(val_loss_meter.avg))

    score, class_iou = running_metrics_val.get_scores()
    for k, v in score.items():
        print('val_metrics, {}: {}'.format(k, v))

    for k, v in class_iou.items():
        print('val_metrics, {}: {}'.format(k, v))

    val_loss_meter.reset()
    running_metrics_val.reset()