Ejemplo n.º 1
0
 def __init__(self,
              indices,
              n_init=100,
              output_dir=None,
              train=True,
              queries_name='queries.txt'):
     self.data_path = os.path.join(DATA_ROOT, f'VOCdevkit/VOC2012')
     self.sbd_path = os.path.join(DATA_ROOT, f'benchmark_RELEASE')
     self.data_aug = get_composed_augmentations(None)
     self.data_loader = get_loader('pascal')
     self.init_dataset = self._get_initial_dataset(train)
     super().__init__(self.get_dataset(indices),
                      n_init=n_init,
                      output_dir=output_dir,
                      queries_name=queries_name)
Ejemplo n.º 2
0
def train(cfg, writer, logger, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device('cuda')

    # Setup Augmentations
    # augmentations = cfg['training'].get('augmentations', None)
    if cfg['data']['dataset'] in ['cityscapes']:
        augmentations = cfg['training'].get(
            'augmentations', {
                'brightness': 63. / 255.,
                'saturation': 0.5,
                'contrast': 0.8,
                'hflip': 0.5,
                'rotate': 10,
                'rscalecropsquare': 713,
            })
        # augmentations = cfg['training'].get('augmentations',
        #                                     {'rotate': 10, 'hflip': 0.5, 'rscalecrop': 512, 'gaussian': 0.5})
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes, args).to(device)
    model.apply(weights_init)
    print('sleep for 5 seconds')
    time.sleep(5)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=(0, 1))
    print(model.device_ids)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))
    if 'multi_step' in cfg['training']['loss']['name']:
        my_loss_fn = loss_fn(
            scale_weight=cfg['training']['loss']['scale_weight'],
            n_inp=2,
            weight=None,
            reduction='sum',
            bkargs=args)
    else:
        my_loss_fn = loss_fn(weight=None, reduction='sum', bkargs=args)

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = my_loss_fn(myinput=outputs, target=labels)

            loss.backward()
            optimizer.step()

            # gpu_profile(frame=sys._getframe(), event='line', arg=None)

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = my_loss_fn(myinput=outputs,
                                              target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Ejemplo n.º 3
0
def train(cfg, writer, logger):
    
    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
#    data_loader = get_loader(cfg['data']['dataset'])
#    data_path = cfg['data']['path']
#
#    t_loader = data_loader(
#        data_path,
#        is_transform=True,
#        split=cfg['data']['train_split'],
#        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
#        augmentations=data_aug)
#
#    v_loader = data_loader(
#        data_path,
#        is_transform=True,
#        split=cfg['data']['val_split'],
#        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
#
#    n_classes = t_loader.n_classes
#    trainloader = data.DataLoader(t_loader,
#                                  batch_size=cfg['training']['batch_size'], 
#                                  num_workers=cfg['training']['n_workers'], 
#                                  shuffle=True)
#
#    valloader = data.DataLoader(v_loader, 
#                                batch_size=cfg['training']['batch_size'], 
#                                num_workers=cfg['training']['n_workers'])

    paths = {
        'masks': './satellitedata/patchohio_train/gt/',
        'images': './satellitedata/patchohio_train/rgb',
        'nirs': './satellitedata/patchohio_train/nir',
        'swirs': './satellitedata/patchohio_train/swir',
        'vhs': './satellitedata/patchohio_train/vh',
        'vvs': './satellitedata/patchohio_train/vv',
        'redes': './satellitedata/patchohio_train/rede',
        'ndvis': './satellitedata/patchohio_train/ndvi',
        }

    valpaths = {
        'masks': './satellitedata/patchohio_val/gt/',
        'images': './satellitedata/patchohio_val/rgb',
        'nirs': './satellitedata/patchohio_val/nir',
        'swirs': './satellitedata/patchohio_val/swir',
        'vhs': './satellitedata/patchohio_val/vh',
        'vvs': './satellitedata/patchohio_val/vv',
        'redes': './satellitedata/patchohio_val/rede',
        'ndvis': './satellitedata/patchohio_val/ndvi',
        }
  
  
    n_classes = 3
    train_img_paths = [pth for pth in os.listdir(paths['images']) if ('_01_' not in pth) and ('_25_' not in pth)]
    val_img_paths = [pth for pth in os.listdir(valpaths['images']) if ('_01_' not in pth) and ('_25_' not in pth)]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')
    
    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = './satellitedata/'
        print('train_data_path: {}'.format(train_data_path))
        dataset_path, train_dir = os.path.split(train_data_path)
        print('dataset_path: {}'.format(dataset_path) + ',  train_dir: {}'.format(train_dir))
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = TrainDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                  batch_size=cfg['training']['batch_size'], 
                                  num_workers=cfg['training']['n_workers'], 
                                  shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    k = 0
    nbackground = 0
    ncorn = 0
    #ncotton = 0
    #nrice = 0
    nsoybean = 0


    for indata in trainloader:
        k += 1
        gt = indata['seg_label'].data.cpu().numpy()
        nbackground += (gt == 0).sum()
        ncorn += (gt == 1).sum()
        #ncotton += (gt == 2).sum()
        #nrice += (gt == 3).sum()
        nsoybean += (gt == 2).sum()

    print('k = {}'.format(k))
    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    #print('ncotton: {}'.format(ncotton))
    #print('nrice: {}'.format(nrice))
    print('nsoybean: {}'.format(nsoybean))
    
    wgts = [1.0, 1.0*nbackground/ncorn, 1.0*nbackground/nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0]/total_wgts
    wgt_corn = wgts[1]/total_wgts
    #wgt_cotton = wgts[2]/total_wgts
    #wgt_rice = wgts[3]/total_wgts
    wgt_soybean = wgts[2]/total_wgts
    weights = torch.autograd.Variable(torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    #weights = torch.autograd.Variable(torch.cuda.FloatTensor([1.0, 1.0, 1.0]))
    

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items() 
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for inputdata in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = inputdata['img_data']
            labels = inputdata['seg_label']
            #print('images.size: {}'.format(images.size()))
            #print('labels.size: {}'.format(labels.size()))
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            #print('outputs.size: {}'.format(outputs[1].size()))
            #print('labels.size: {}'.format(labels.size()))

            loss = loss_fn(input=outputs[1], target=labels, weight=weights)

            loss.backward()
            optimizer.step()
            
            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'], 
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for inputdata in valloader:
                        images_val = inputdata['img_data']
                        labels_val = inputdata['seg_label']
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()


                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(),
                                             "{}_{}_best_model.pkl".format(
                                                 cfg['model']['arch'],
                                                 cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
Ejemplo n.º 4
0
def validate(cfg, args):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    data_aug = None
    if "validation" in cfg:
        augmentations = cfg["validation"].get("augmentations", None)
        if cfg["data"]["dataset"] == "softmax_cityscapes_convention":
            data_aug = get_composed_augmentations_softmax(augmentations)
        else:
            data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"][args.dataset_split],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )
    n_classes = loader.n_classes
    valloader = data.DataLoader(loader, batch_size=1, num_workers=1)
    
    # Setup Metrics
    running_metrics_val = {"seg": runningScoreSeg(n_classes)}
    if "classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier( len(classes) )
    if "bin_classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["bin_classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier(2)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    state = torch.load(args.model_path, map_location="cuda:0")["model_state"]
    state = convert_state_dict(state) # converts from dataParallel module to normal module
    model.load_state_dict(state, strict=False)
    
    if args.bn_fusion:
      model = fuse_bn_recursively(model)
    
    if args.update_bn:
      print("Reset BatchNorm and recalculate mean/var")
      model.apply(reset_batchnorm)
      model.train()
    else:
      model.eval() # set batchnorm and dropouts to work in eval mode
    model.to(device)
    total_time = 0
    
    total_params = sum(p.numel() for p in model.parameters())
    print('Parameters: ', total_params )
    
    #stat(model, (3, 1024, 2048))
    torch.backends.cudnn.benchmark=True

    with open(args.output_csv_path, 'a') as output_csv:

        output_csv.write(create_overall_logs_header(running_metrics_val))

        for i, (images, label_dict, fname) in enumerate(valloader):
            images = images.to(device)
            
            torch.cuda.synchronize()
            start_time = time.perf_counter()
            with torch.no_grad(): # deactivates autograd engine, less mem usage
                output_dict = model(images)
            torch.cuda.synchronize()
            elapsed_time = time.perf_counter() - start_time
            
            if args.save_image:
                save_image(images, output_dict, fname, args.output_path, loader=loader)
            
            image_score = []
            
            for name, metrics in running_metrics_val.items(): # update running metrics and record imagewise metrics
                gt_array = label_dict[name].data.cpu().numpy()
                if name+'_loss' in cfg['training'] and cfg['training'][name+'_loss']['name'] == 'l1': # for binary classification
                    pred_array = output_dict[name].data.cpu().numpy()
                    pred_array = np.sign(pred_array)
                    pred_array[pred_array == -1] = 0
                    gt_array[gt_array == -1] = 0
                else:
                    pred_array = output_dict[name].data.max(1)[1].cpu().numpy()

                if name == "seg" or name == "softmax":
                    image_score.append( "%.3f" %metrics.get_image_score(gt_array, pred_array) )
                else:
                    imagewise_score = softmax(np.squeeze(
                        output_dict[name].data.cpu().numpy()
                    )).round(3)
                    image_score.append( "%.3f" %(imagewise_score[gt_array[0]]) )
                    image_score.append( str(imagewise_score) ) # append raw probability results for non-segmentation task
                    image_score.append( "pred %s label %s" %(np.argmax(imagewise_score), gt_array[0]))
                
                metrics.update(gt_array, pred_array)

            output_csv.write( '%s, %.4f, %s\n' %(fname[0], 1 / elapsed_time, ",".join(image_score)) ) # record imagewise metrics

            if args.measure_time:
                total_time += elapsed_time
                print(
                    "Iter {0:5d}: {1:3.5f} fps {2}".format(
                        i + 1, 1 / elapsed_time, " ".join(image_score)
                    )
                )

    print("Total Frame Rate = %.2f fps" %(i/total_time ))

    if args.update_bn:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
        state2 = {"model_state": model.state_dict()}
        torch.save(state2, 'hardnet_cityscapes_mod.pth')

    with open(args.miou_logs_path, 'a') as main_output_csv: # record overall metrics
        main_output_csv.write( '%s\n' %args.output_csv_path )

        for name, metrics in running_metrics_val.items():
            overall, classwise = metrics.get_scores()
            
            for k, v in overall.items():
                print("{}_{}: {}".format(name, k, v))
                main_output_csv.write("%s,%s,%s\n" %(name, k, v))

            for metric_name, metric in classwise.items():
                for k, v in metric.items():
                    print("{}_{}_{}: {}".format(name, metric_name, k, v))
                    main_output_csv.write( "%s,%s,%s,%s\n" %(name, metric_name, k, v))
            
            confusion_matrix = np.round(metrics.confusion_matrix, 3)
            print("confusion matrix:\n%s" %confusion_matrix)
            main_output_csv.write("%s\n" %(
                "\n".join(str(i) for i in confusion_matrix)
            ))
Ejemplo n.º 5
0
def train(cfg, logger, logdir):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader

    path_n = cfg["model"]["path_num"]

    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(data_path,split=cfg["data"]["train_split"],augmentations=t_data_aug,path_num=path_n)
    v_loader = data_loader(data_path,split=cfg["data"]["val_split"],augmentations=v_data_aug,path_num=path_n)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True  )
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"] )

    logger.info("Using training seting {}".format(cfg["training"]))
    
    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes)

    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    teacher = get_model(cfg["teacher"], t_loader.n_classes)
    model = get_model(cfg["model"],t_loader.n_classes, loss_fn, cfg["training"]["resume"],teacher)
    logger.info("Using loss {}".format(loss_fn))

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Setup Multi-GPU
    model = DataParallelModel(model).cuda()

    #Initialize training param
    cnt_iter = 0
    best_iou = 0.0
    time_meter = averageMeter()

    while cnt_iter <= cfg["training"]["train_iters"]:
        for (f_img, labels) in trainloader:
            cnt_iter += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            outputs = model(f_img,labels,pos_id=cnt_iter%path_n)

            seg_loss = gather(outputs, 0)
            seg_loss = torch.mean(seg_loss)

            seg_loss.backward()
            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (cnt_iter + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                                            cnt_iter + 1,
                                            cfg["training"]["train_iters"],
                                            seg_loss.item(),
                                            time_meter.avg / cfg["training"]["batch_size"], )

                print(print_str)
                logger.info(print_str)
                time_meter.reset()

            if (cnt_iter + 1) % cfg["training"]["val_interval"] == 0 or (cnt_iter + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (f_img_val, labels_val) in tqdm(enumerate(valloader)):
                        
                        outputs = model(f_img_val,pos_id=i_val%path_n)
                        outputs = gather(outputs, 0, dim=0)
                        
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))

                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": cnt_iter + 1,
                        "model_state": clean_state_dict(model.module.state_dict(),'teacher'),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(logdir,
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
Ejemplo n.º 6
0
def train(cfg, writer, logger, args):
    # cfg

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(1024, 2048),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = FASSDNet(n_classes=19, alpha=args.alpha).to(device)

    total_params = sum(p.numel() for p in model.parameters())
    print('Parameters:', total_params)
    model.apply(weights_init)

    # Non-strict ImageNet pre-train
    pretrained_path = 'weights/imagenet_weights.pth'
    checkpoint = torch.load(pretrained_path)
    q = 1
    model_dict = {}
    state_dict = model.state_dict()

    # print('================== Weights orig: ', model.base[1].conv.weight[0][0][0])
    for k, v in checkpoint.items():
        if q == 1:
            # print("===> Key of checkpoint: ", k)
            # print("===> Value of checkpoint: ", v[0][0][0])
            if ('base.' + k in state_dict):
                # print("============> CONTAINS KEY...")
                # print("===> Value of the key: ", state_dict['base.'+k][0][0][0])
                pass

            else:
                # print("============> DOES NOT CONTAIN KEY...")
                pass
            q = 0

        if ('base.' + k in state_dict) and (state_dict['base.' + k].shape
                                            == checkpoint[k].shape):
            model_dict['base.' + k] = v

    state_dict.update(model_dict)  # Updated weights with ImageNet pretraining
    model.load_state_dict(state_dict)
    # print('================== Weights loaded: ', model.base[0].conv.weight[0][0][0])

    # Multi-gpu model
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    print("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    print("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):

            print_str = "Finetuning model from '{}'".format(
                cfg["training"]["finetune"])
            if logger is not None:
                logger.info(print_str)
            print(print_str)

            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]

            print_str = "Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"])
            print(print_str)
            if logger is not None:
                logger.info(print_str)
        else:
            print_str = "No checkpoint found at '{}'".format(
                cfg["training"]["resume"])
            print(print_str)
            if logger is not None:
                logger.info(print_str)

    if cfg["training"]["finetune"] is not None:
        if os.path.isfile(cfg["training"]["finetune"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["finetune"]))
            checkpoint = torch.load(cfg["training"]["finetune"])
            model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True
    loss_all = 0
    loss_n = 0
    sys.stdout.flush()

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, _) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)
            loss.backward()
            optimizer.step()
            c_lr = scheduler.get_lr()

            time_meter.update(time.time() - start_ts)
            loss_all += loss.item()
            loss_n += 1

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}  lr={:.6f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_all / loss_n,
                    time_meter.avg / cfg["training"]["batch_size"],
                    c_lr[0],
                )

                print(print_str)
                if logger is not None:
                    logger.info(print_str)

                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                torch.cuda.empty_cache()
                model.eval()
                loss_all = 0
                loss_n = 0
                with torch.no_grad():
                    # for i_val, (images_val, labels_val, _) in tqdm(enumerate(valloader)):
                    for i_val, (images_val, labels_val,
                                _) in enumerate(valloader):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)

                print_str = "Iter %d Val Loss: %.4f" % (i + 1,
                                                        val_loss_meter.avg)
                if logger is not None:
                    logger.info(print_str)
                print(print_str)

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print_str = "{}: {}".format(k, v)
                    if logger is not None:
                        logger.info(print_str)
                    print(print_str)

                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    print_str = "{}: {}".format(k, v)
                    if logger is not None:
                        logger.info(print_str)
                    print(print_str)

                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"],
                                                  cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

                if score["Mean IoU : \t"] >= best_iou:  # Save best model (mIoU)
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                torch.cuda.empty_cache()

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
            sys.stdout.flush()  # Added
Ejemplo n.º 7
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if not 'fold' in cfg['data'].keys():
        cfg['data']['fold'] = None

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        augmentations=data_aug,
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=[cfg['data']['img_rows'], cfg['data']['img_cols']],
        fold=cfg['data']['fold'],
        n_classes=cfg['data']['n_classes'])

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=1,
                                num_workers=cfg['training']['n_workers'])

    logger.info("Training on fold {}".format(cfg['data']['fold']))
    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)
    if args.model_path != "fcn8s_pascal_1_26.pkl": # Default Value
        state = convert_state_dict(torch.load(args.model_path)["model_state"])
        if cfg['model']['use_scale']:
            model = load_my_state_dict(model, state)
            model.freeze_weights_extractor()
        else:
            model.load_state_dict(state)
            model.freeze_weights_extractor()

    model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))


    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in cfg['training']['optimizer'].items()
                        if k != 'name'}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg['training']['resume'])
            )
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg['training']['resume'], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:
#            import matplotlib.pyplot as plt
#            plt.figure(1);plt.imshow(np.transpose(images[0], (1,2,0)));plt.figure(2); plt.imshow(labels[0]); plt.show()

            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(i + 1,
                                           cfg['training']['train_iters'],
                                           loss.item(),
                                           time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i+1)
                time_meter.reset()

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(writer.file_writer.get_logdir(),
                                         "{}_{}_best_model.pkl".format(
                                             cfg['model']['arch'],
                                             cfg['data']['dataset']))
                torch.save(state, save_path)
                break
def test(cfg, areaname):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    #    data_loader = get_loader(cfg['data']['dataset'])
    #    data_path = cfg['data']['path']
    #
    #    t_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['train_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    #        augmentations=data_aug)
    #
    #    v_loader = data_loader(
    #        data_path,
    #        is_transform=True,
    #        split=cfg['data']['val_split'],
    #        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),)
    #
    #    n_classes = t_loader.n_classes
    #    trainloader = data.DataLoader(t_loader,
    #                                  batch_size=cfg['training']['batch_size'],
    #                                  num_workers=cfg['training']['n_workers'],
    #                                  shuffle=True)
    #
    #    valloader = data.DataLoader(v_loader,
    #                                batch_size=cfg['training']['batch_size'],
    #                                num_workers=cfg['training']['n_workers'])
    datapath = '/home/chengjjang/Projects/deepres/SatelliteData/{}/'.format(
        areaname)
    paths = {
        'masks': '{}/patch{}_train/gt'.format(datapath, areaname),
        'images': '{}/patch{}_train/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_train/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_train/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_train/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_train/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_train/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_train/ndvi'.format(datapath, areaname),
    }

    valpaths = {
        'masks': '{}/patch{}_val/gt'.format(datapath, areaname),
        'images': '{}/patch{}_val/rgb'.format(datapath, areaname),
        'nirs': '{}/patch{}_val/nir'.format(datapath, areaname),
        'swirs': '{}/patch{}_val/swir'.format(datapath, areaname),
        'vhs': '{}/patch{}_val/vh'.format(datapath, areaname),
        'vvs': '{}/patch{}_val/vv'.format(datapath, areaname),
        'redes': '{}/patch{}_val/rede'.format(datapath, areaname),
        'ndvis': '{}/patch{}_val/ndvi'.format(datapath, areaname),
    }

    n_classes = 3
    train_img_paths = [
        pth for pth in os.listdir(paths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    val_img_paths = [
        pth for pth in os.listdir(valpaths['images'])
        if ('_01_' not in pth) and ('_25_' not in pth)
    ]
    ntrain = len(train_img_paths)
    nval = len(val_img_paths)
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    train_idx = [i for i in range(ntrain)]
    val_idx = [i for i in range(nval)]
    trainds = ImageProvider(MultibandImageType, paths, image_suffix='.png')
    valds = ImageProvider(MultibandImageType, valpaths, image_suffix='.png')

    print('valds.im_names: {}'.format(valds.im_names))

    config_path = 'crop_pspnet_config.json'
    with open(config_path, 'r') as f:
        mycfg = json.load(f)
        train_data_path = '{}/patch{}_train'.format(datapath, areaname)
        dataset_path, train_dir = os.path.split(train_data_path)
        mycfg['dataset_path'] = dataset_path
    config = Config(**mycfg)

    config = update_config(config, num_channels=12, nb_epoch=50)
    #dataset_train = TrainDataset(trainds, train_idx, config, transforms=augment_flips_color)
    dataset_train = TrainDataset(trainds, train_idx, config, 1)
    dataset_val = ValDataset(valds, val_idx, config, 1)
    trainloader = data.DataLoader(dataset_train,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(dataset_val,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'],
                                shuffle=False)
    # Setup Metrics
    running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    nbackground = 1116403140
    ncorn = 44080178
    nsoybean = 316698122

    print('nbackgraound: {}'.format(nbackground))
    print('ncorn: {}'.format(ncorn))
    print('nsoybean: {}'.format(nsoybean))

    wgts = [1.0, 1.0 * nbackground / ncorn, 1.0 * nbackground / nsoybean]
    total_wgts = sum(wgts)
    wgt_background = wgts[0] / total_wgts
    wgt_corn = wgts[1] / total_wgts
    wgt_soybean = wgts[2] / total_wgts
    weights = torch.autograd.Variable(
        torch.cuda.FloatTensor([wgt_background, wgt_corn, wgt_soybean]))

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)

    start_iter = 0
    runpath = '/home/chengjjang/arisia/CropPSPNet/runs/pspnet_crop_{}'.format(
        areaname)
    modelpath = glob.glob('{}/*/*_best_model.pkl'.format(runpath))[0]
    print('modelpath: {}'.format(modelpath))
    checkpoint = torch.load(modelpath)
    model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0

    respath = '{}_results_val'.format(areaname)
    os.makedirs(respath, exist_ok=True)

    model.eval()
    with torch.no_grad():
        for inputdata in valloader:
            imname_val = inputdata['img_name']
            images_val = inputdata['img_data']
            labels_val = inputdata['seg_label']
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            print('imname_val: {}'.format(imname_val))

            outputs = model(images_val)
            val_loss = loss_fn(input=outputs, target=labels_val)

            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels_val.data.cpu().numpy()

            dname = imname_val[0].split('.png')[0]
            np.save('{}/pred'.format(respath) + dname + '.npy', pred)
            np.save('{}/gt'.format(respath) + dname + '.npy', gt)
            np.save('{}/output'.format(respath) + dname + '.npy',
                    outputs.data.cpu().numpy())

            running_metrics_val.update(gt, pred)
            val_loss_meter.update(val_loss.item())

    #writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
    #logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
    print('Test loss: {}'.format(val_loss_meter.avg))

    score, class_iou = running_metrics_val.get_scores()
    for k, v in score.items():
        print('val_metrics, {}: {}'.format(k, v))

    for k, v in class_iou.items():
        print('val_metrics, {}: {}'.format(k, v))

    val_loss_meter.reset()
    running_metrics_val.reset()
Ejemplo n.º 9
0
def train(cfg, writer, logger):

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    logger.info("data path: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        norm=cfg.data.norm,
        split='train',
        split_root=cfg.data.split,
        augments=data_aug,
        logger=logger,
        log=cfg.data.log,
        ENL=cfg.data.ENL,
    )

    v_loader = data_loader(
        data_path,
        data_format=cfg.data.format,
        split='val',
        log=cfg.data.log,
        split_root=cfg.data.split,
        logger=logger,
        ENL=cfg.data.ENL,
    )

    train_data_len = len(t_loader)
    logger.info(
        f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}'
    )

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=cfg.train.n_workers,
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(
        v_loader,
        batch_size=cfg.test.batch_size,
        # persis
        num_workers=cfg.train.n_workers,
    )

    # Setup Model
    device = f'cuda:{cfg.train.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=False)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in vars(cfg.train.optimizer).items()
        if k not in ('name', 'wrap')
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer,
               'wrap') and cfg.train.optimizer.wrap == 'lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    # loss_fn = get_loss_function(cfg)
    # logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg.train.resume))

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg.train.resume, checkpoint["epoch"]))

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file
                        or '_last_model' in file):
                    # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(
                        osp.join(resume_src_dir, file),
                        resume_dst_dir,
                    )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)
    # data_range /= 350

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    train_loss_meter = averageMeter()
    val_psnr_meter = averageMeter()
    val_ssim_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time()
    train_val_start_time = time.time()
    model.train()
    while it < train_iter:
        for clean, noisy, _ in trainloader:
            it += 1

            noisy = noisy.to(device, dtype=torch.float32)
            # noisy /= 350
            mask1, mask2 = rand_pool.generate_mask_pair(noisy)
            noisy_sub1 = rand_pool.generate_subimages(noisy, mask1)
            noisy_sub2 = rand_pool.generate_subimages(noisy, mask2)

            # preparing for the regularization term
            with torch.no_grad():
                noisy_denoised = model(noisy)
            noisy_sub1_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask1)
            noisy_sub2_denoised = rand_pool.generate_subimages(
                noisy_denoised, mask2)
            # print(rand_pool.operation_seed_counter)

            # for ii, param in enumerate(model.parameters()):
            #     if torch.sum(torch.isnan(param.data)):
            #         print(f'{ii}: nan parameters')

            # calculating the loss
            noisy_output = model(noisy_sub1)
            noisy_target = noisy_sub2
            if cfg.train.loss.gamma.const:
                gamma = cfg.train.loss.gamma.base
            else:
                gamma = it / train_iter * cfg.train.loss.gamma.base

            diff = noisy_output - noisy_target
            exp_diff = noisy_sub1_denoised - noisy_sub2_denoised
            loss1 = torch.mean(diff**2)
            loss2 = gamma * torch.mean((diff - exp_diff)**2)
            loss_all = loss1 + loss2

            # loss1 = noisy_output - noisy_target
            # loss2 = torch.exp(noisy_target - noisy_output)
            # loss_all = torch.mean(loss1 + loss2)
            loss_all.backward()

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            # record the loss of the minibatch
            train_loss_meter.update(loss_all)
            train_time_meter.update(time.time() - train_start_time)
            writer.add_scalar('lr', optimizer.param_groups[0]['lr'], it)

            if it % 1000 == 0:
                writer.add_histogram('hist/pred', noisy_denoised, it)
                writer.add_histogram('hist/noisy', noisy, it)

                if cfg.data.simulate:
                    writer.add_histogram('hist/clean', clean, it)

            if cfg.data.simulate:
                pass

            # print interval
            if it % cfg.train.print_interval == 0:
                terminal_info = f"Iter [{it:d}/{train_iter:d}]  \
                                train Loss: {train_loss_meter.avg:.4f}  \
                                Time/Image: {train_time_meter.avg / cfg.train.batch_size:.4f}"

                logger.info(terminal_info)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg, it)

                if cfg.data.simulate:
                    pass

                runing_metrics_train.reset()
                train_time_meter.reset()
                train_loss_meter.reset()

            # val interval
            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()
                with torch.no_grad():
                    for clean, noisy, _ in valloader:
                        # noisy /= 350
                        # clean /= 350
                        noisy = noisy.to(device, dtype=torch.float32)
                        noisy_denoised = model(noisy)

                        if cfg.data.simulate:
                            clean = clean.to(device, dtype=torch.float32)
                            psnr = piq.psnr(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            ssim = piq.ssim(clean,
                                            noisy_denoised,
                                            data_range=data_range)
                            val_psnr_meter.update(psnr)
                            val_ssim_meter.update(ssim)

                        val_loss = torch.mean((noisy_denoised - noisy)**2)
                        val_loss_meter.update(val_loss)

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(
                    f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}"
                )
                val_loss_meter.reset()
                running_metrics_val.reset()

                if cfg.data.simulate:
                    writer.add_scalars('metrics/val', {
                        'psnr': val_psnr_meter.avg,
                        'ssim': val_ssim_meter.avg
                    }, it)
                    logger.info(
                        f'psnr: {val_psnr_meter.avg},\tssim: {val_ssim_meter.avg}'
                    )
                    val_psnr_meter.reset()
                    val_ssim_meter.reset()

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter - it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            # save model
            if it % (train_iter / cfg.train.epoch * 10) == 0:
                ep = int(it / (train_iter / cfg.train.epoch))
                state = {
                    "epoch": it,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                }
                save_path = osp.join(writer.file_writer.get_logdir(),
                                     f"{ep}.pkl")
                torch.save(state, save_path)
                logger.info(f'saved model state dict at {save_path}')

            train_start_time = time.time()
Ejemplo n.º 10
0
def get_loaders(name, cfg):

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    data_loader = {
        "airsim": airsimLoader,
        "pascal": pascalVOCLoader,
        "camvid": camvidLoader,
        "ade20k": ADE20KLoader,
        "mit_sceneparsing_benchmark": MITSceneParsingBenchmarkLoader,
        "cityscapes": cityscapesLoader,
        "nyuv2": NYUv2Loader,
        "sunrgbd": SUNRGBDLoader,
        "vistas": mapillaryVistasLoader,
    }[name]
    data_path = cfg["data"]["path"]

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           subsplits=cfg['data']['train_subsplit'],
                           scale_quantity=cfg['data']['train_reduction'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    r_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['recal_split'],
                           subsplits=cfg['data']['recal_subsplit'],
                           scale_quantity=cfg['data']['recal_reduction'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    tv_loader = data_loader(data_path,
                            is_transform=True,
                            split=cfg['data']['train_split'],
                            subsplits=cfg['data']['train_subsplit'],
                            scale_quantity=0.05,
                            img_size=(cfg['data']['img_rows'],
                                      cfg['data']['img_cols']),
                            augmentations=data_aug)

    v_loader = {
        env: data_loader(
            data_path,
            is_transform=True,
            split='val',
            subsplits=[env],
            scale_quantity=cfg['data']['val_reduction'],
            img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        )
        for env in cfg['data']['val_subsplit']
    }

    n_classes = int(t_loader.n_classes)
    valloaders = {
        key: data.DataLoader(
            v_loader[key],
            batch_size=cfg['training']['batch_size'],
            num_workers=cfg['training']['n_workers'],
        )
        for key in v_loader.keys()
    }

    # add training samples to validation sweep
    # valloaders = {**valloaders, 'train': data.DataLoader(tv_loader,
    # batch_size=cfg['training']['batch_size'],
    # num_workers=cfg['training']['n_workers'])}

    return {
        'train':
        data.DataLoader(t_loader,
                        batch_size=cfg['training']['batch_size'],
                        num_workers=cfg['training']['n_workers'],
                        shuffle=True),
        'recal':
        data.DataLoader(r_loader,
                        batch_size=cfg['training']['batch_size'],
                        num_workers=cfg['training']['n_workers'],
                        shuffle=True),
        'val':
        valloaders
    }, n_classes
Ejemplo n.º 11
0
def train(cfg, writer, logger):
    # Setup seeds
    init_seed(11733, en_cudnn=False)

    # Setup Augmentations
    train_augmentations = cfg["training"].get("train_augmentations", None)
    t_data_aug = get_composed_augmentations(train_augmentations)
    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])

    t_loader = data_loader(cfg=cfg["data"],
                           mode='train',
                           augmentations=t_data_aug)
    v_loader = data_loader(cfg=cfg["data"],
                           mode='val',
                           augmentations=v_data_aug)

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg["training"]["batch_size"],
                                  num_workers=cfg["training"]["n_workers"],
                                  shuffle=True,
                                  drop_last=True)
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"])

    logger.info("Using training seting {}".format(cfg["training"]))

    # Setup Metrics
    running_metrics_val = runningScore(t_loader.n_classes,
                                       t_loader.unseen_classes)

    model_state = torch.load(
        './runs/deeplabv3p_ade_25unseen/84253/deeplabv3p_ade20k_best_model.pkl'
    )
    running_metrics_val.confusion_matrix = model_state['results']
    score, a_iou = running_metrics_val.get_scores()

    pdb.set_trace()
    # Setup Model and Loss
    loss_fn = get_loss_function(cfg["training"])
    logger.info("Using loss {}".format(loss_fn))
    model = get_model(cfg["model"], t_loader.n_classes, loss_fn=loss_fn)

    # Setup optimizer
    optimizer = get_optimizer(cfg["training"], model)

    # Initialize training param
    start_iter = 0
    best_iou = -100.0

    # Resume from checkpoint
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info("Resuming training from checkpoint '{}'".format(
                cfg["training"]["resume"]))
            model_state = torch.load(cfg["training"]["resume"])["model_state"]
            model.load_state_dict(model_state)
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # Setup Multi-GPU
    if torch.cuda.is_available():
        model = model.cuda()  # DataParallelModel(model).cuda()
        logger.info("Model initialized on GPUs.")

    time_meter = averageMeter()
    i = start_iter

    embd = t_loader.embeddings
    ignr_idx = t_loader.ignore_index
    embds = embd.cuda()
    while i <= cfg["training"]["train_iters"]:
        for (images, labels) in trainloader:
            images = images.cuda()
            labels = labels.cuda()

            i += 1
            model.train()
            optimizer.zero_grad()

            start_ts = time.time()
            loss_sum = model(images, labels, embds, ignr_idx)
            if loss_sum == 0:  # Ignore samples contain unseen cat
                continue  # To enable non-transductive learning, set transductive=0 in the config

            loss_sum.backward()

            time_meter.update(time.time() - start_ts)

            optimizer.step()

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_sum.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss_sum.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.cuda()
                        labels_val = labels_val.cuda()
                        outputs = model(images_val, labels_val, embds,
                                        ignr_idx)
                        # outputs = gather(outputs, 0, dim=0)

                        running_metrics_val.update(outputs)

                score, a_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print("{}: {}".format(k, v))
                    logger.info("{}: {}".format(k, v))
                    #writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                #for k, v in class_iou.items():
                #    logger.info("{}: {}".format(k, v))
                #    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                if a_iou >= best_iou:
                    best_iou = a_iou
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                        "results": running_metrics_val.confusion_matrix
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                running_metrics_val.reset()
Ejemplo n.º 12
0
def test(cfg, logger, run_id):
    # Setup Augmentations
    augmentations = cfg.test.augments
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    data_loader = data_loader(
        data_path, 
        data_format=cfg.data.format, 
        norm = cfg.data.norm,
        split=cfg.test.dataset,
        split_root = cfg.data.split,
        log = cfg.data.log,
        augments=data_aug,
        logger=logger,
        ENL = cfg.data.ENL,
        )
    run_id = osp.join(run_id, cfg.test.dataset)
    # os.mkdir(run_id)
    
    logger.info("data path: {}".format(data_path))
    logger.info(f'num of {cfg.test.dataset} set samples: {len(data_loader)}')

    loader = data.DataLoader(data_loader,
                            batch_size=cfg.test.batch_size, 
                            num_workers=cfg.test.n_workers, 
                            shuffle=False,
                            persistent_workers=True,
                            drop_last=False,
                            )

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f'using model: {cfg.model.arch}')
    
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)

    # load model params
    if osp.isfile(cfg.test.pth):
        logger.info("Loading model from checkpoint '{}'".format(cfg.test.pth))

        # load model state
        checkpoint = torch.load(cfg.test.pth)
        model.load_state_dict(checkpoint["model_state"])
    else:
        raise FileNotFoundError(f'{cfg.test.pth} file not found')

    # Setup Metrics
    running_metrics_val = runningScore(2)
    running_metrics_train = runningScore(2)
    metrics = runningScore(2)
    test_psnr_meter = averageMeter()
    test_ssim_meter = averageMeter()
    img_cnt = 0
    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)

    # test
    model.eval()
    with torch.no_grad():
        for clean, noisy, files_path in loader:
             
            noisy = noisy.to(device, dtype=torch.float32)
            noisy_denoised = model(noisy)

            psnr = []
            ssim = []
            if cfg.data.simulate:
                clean = clean.to(device, dtype=torch.float32)
                for ii in range(9):
                    psnr.append(piq.psnr(noisy_denoised[:, ii, :, :], clean[:, ii, :, :], data_range=data_range).cpu())
                    ssim.append(piq.ssim(noisy_denoised[:, ii, :, :], clean[:, ii, :, :], data_range=data_range).cpu())
                    print(f'{ii}: PSNR: {psnr[ii]}\n\tSSIM: {ssim[ii]}')

                print('\n')
                test_psnr_meter.update(np.array(psnr).mean(), n=clean.shape[0])
                test_ssim_meter.update(np.array(ssim).mean(), n=clean.shape[0])

        if cfg.data.simulate:    
            logger.info(f'overall psnr: {test_psnr_meter.avg}, ssim: {test_ssim_meter.avg}')

        logger.info(f'\ndone')
Ejemplo n.º 13
0
def validate(cfg, model_nontree, model_tree, loss_fn, device, root):

    val_loss_meter_nontree = averageMeter()
    if cfg['training']['use_hierarchy']:
        val_loss_meter_level0_nontree = averageMeter()
        val_loss_meter_level1_nontree = averageMeter()
        val_loss_meter_level2_nontree = averageMeter()
        val_loss_meter_level3_nontree = averageMeter()

    val_loss_meter_tree = averageMeter()
    if cfg['training']['use_hierarchy']:
        val_loss_meter_level0_tree = averageMeter()
        val_loss_meter_level1_tree = averageMeter()
        val_loss_meter_level2_tree = averageMeter()
        val_loss_meter_level3_tree = averageMeter()

    if torch.cuda.is_available():
        data_path = cfg['data']['server_path']
    else:
        data_path = cfg['data']['path']

    data_loader = get_loader(cfg['data']['dataset'])
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['val_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    n_classes = v_loader.n_classes
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val_nontree = runningScore(n_classes)
    running_metrics_val_tree = runningScore(n_classes)

    model_nontree.eval()
    model_tree.eval()
    with torch.no_grad():
        print("validation loop")
        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            outputs_nontree = model_nontree(images_val)
            outputs_tree = model_tree(images_val)

            if cfg['training']['use_tree_loss']:
                val_loss_nontree = loss_fn(
                    input=outputs_nontree,
                    target=labels_val,
                    root=root,
                    use_hierarchy=cfg['training']['use_hierarchy'])
            else:
                val_loss_nontree = loss_fn(input=outputs_nontree,
                                           target=labels_val)

            if cfg['training']['use_tree_loss']:
                val_loss_tree = loss_fn(
                    input=outputs_tree,
                    target=labels_val,
                    root=root,
                    use_hierarchy=cfg['training']['use_hierarchy'])
            else:
                val_loss_tree = loss_fn(input=outputs_tree, target=labels_val)

            # Using standard max prob based classification
            pred_nontree = outputs_nontree.data.max(1)[1].cpu().numpy()
            pred_tree = outputs_tree.data.max(1)[1].cpu().numpy()

            gt = labels_val.data.cpu().numpy()
            running_metrics_val_nontree.update(
                gt, pred_nontree)  # updates confusion matrix
            running_metrics_val_tree.update(gt, pred_tree)

            if cfg['training']['use_tree_loss']:
                val_loss_meter_nontree.update(
                    val_loss_nontree[1][0])  # take the 1st level
            else:
                val_loss_meter_nontree.update(val_loss_nontree.item())

            if cfg['training']['use_tree_loss']:
                val_loss_meter_tree.update(val_loss_tree[0].item())
            else:
                val_loss_meter_tree.update(val_loss_tree.item())

            if cfg['training']['use_hierarchy']:
                val_loss_meter_level0_nontree.update(val_loss_nontree[1][0])
                val_loss_meter_level1_nontree.update(val_loss_nontree[1][1])
                val_loss_meter_level2_nontree.update(val_loss_nontree[1][2])
                val_loss_meter_level3_nontree.update(val_loss_nontree[1][3])

            if cfg['training']['use_hierarchy']:
                val_loss_meter_level0_tree.update(val_loss_tree[1][0])
                val_loss_meter_level1_tree.update(val_loss_tree[1][1])
                val_loss_meter_level2_tree.update(val_loss_tree[1][2])
                val_loss_meter_level3_tree.update(val_loss_tree[1][3])

            if i_val == 1:
                break

        score_nontree, class_iou_nontree = running_metrics_val_nontree.get_scores(
        )
        score_tree, class_iou_tree = running_metrics_val_tree.get_scores()

        ### VISUALISE METRICS AND LOSSES HERE

        val_loss_meter_nontree.reset()
        running_metrics_val_nontree.reset()
        val_loss_meter_tree.reset()
        running_metrics_val_tree.reset()
        if cfg['training']['use_hierarchy']:
            val_loss_meter_level0_nontree.reset()
            val_loss_meter_level1_nontree.reset()
            val_loss_meter_level2_nontree.reset()
            val_loss_meter_level3_nontree.reset()

        if cfg['training']['use_hierarchy']:
            val_loss_meter_level0_tree.reset()
            val_loss_meter_level1_tree.reset()
            val_loss_meter_level2_tree.reset()
            val_loss_meter_level3_tree.reset()
Ejemplo n.º 14
0
if __name__ == "__main__":
    import matplotlib.pyplot as plt
    crop_size = 713
    fill = 250
    # augmentations = Compose([Scale(2048), RandomRotate(10), RandomHorizontallyFlip(0.5), RandomGaussianBlur])
    # augmentations = Compose([RandomScaleCrop(crop_size, fill), RandomRotate(10), RandomHorizontallyFlip(0.5)])
    augmentations = {
        'brightness': 63. / 255.,
        'saturation': 0.5,
        'contrast': 0.8,
        'hflip': 0.5,
        'rotate': 10,
        'rscalecropsquare': 713,
    }
    data_aug = get_composed_augmentations(augmentations)
    local_path = '/cvlabdata2/home/user/data/cityscapes/'
    # local_path = '/cvlabdata1/cvlab/dataset_cityscapes/'
    # local_path_half = "/cvlabdata2/cvlab/dataset_cityscapes_downsampled/"
    dst = cityscapesLoader(local_path,
                           img_size='same',
                           is_transform=True,
                           img_norm=False,
                           augmentations=data_aug)
    bs = 16
    trainloader = data.DataLoader(dst, batch_size=bs, num_workers=0)
    for i, data_samples in enumerate(trainloader):
        imgs, labels = data_samples
        # import pdb
        # pdb.set_trace()
        imgs = imgs.numpy()[:, ::-1, :, :]
Ejemplo n.º 15
0
def validate(cfg, args):

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    path_n = cfg["model"]["path_num"]

    val_augmentations = cfg["validating"].get("val_augmentations", None)
    v_data_aug = get_composed_augmentations(val_augmentations)

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
        # img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=v_data_aug,
        path_num=path_n)

    n_classes = v_loader.n_classes
    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["validating"]["batch_size"],
                                num_workers=cfg["validating"]["n_workers"])

    running_metrics = runningScore(n_classes)

    # Setup Model
    teacher = get_model(cfg["teacher"], n_classes)
    model = get_model(cfg["model"],
                      n_classes,
                      psp_path=cfg["training"]["resume"],
                      teacher=teacher).to(device)
    state = torch.load(cfg["validating"]["resume"])  #["model_state"]
    model.load_state_dict(state, strict=False)
    model.eval()
    model.to(device)

    with torch.no_grad():
        for i, (val, labels) in enumerate(valloader):

            gt = labels.numpy()
            _val = [ele.to(device) for ele in val]

            torch.cuda.synchronize()
            start_time = timeit.default_timer()
            outputs = model(_val, pos_id=i % path_n)
            torch.cuda.synchronize()
            elapsed_time = timeit.default_timer() - start_time
            pred = outputs.data.max(1)[1].cpu().numpy()
            running_metrics.update(gt, pred)

            if args.measure_time:
                elapsed_time = timeit.default_timer() - start_time
                print("Inference time \
                      (iter {0:5d}): {1:3.5f} fps".format(
                    i + 1, pred.shape[0] / elapsed_time))
            if False:
                decoded = v_loader.decode_segmap(pred[0])
                import cv2
                cv2.namedWindow("Image")
                cv2.imshow("Image", decoded)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

    score, class_iou = running_metrics.get_scores()

    for k, v in score.items():
        print(k, v)

    for i in range(n_classes):
        print(i, class_iou[i])
Ejemplo n.º 16
0
def train(cfg, writer, logger, run_id):

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    torch.backends.cudnn.benchmark = True

    # Setup Augmentations
    augmentations = cfg['train'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataloader'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           transform=None,
                           split=cfg['data']['train_split'],
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        transform=None,
        split=cfg['data']['val_split'],
    )
    logger.info(
        f'num of train samples: {len(t_loader)} \nnum of val samples: {len(v_loader)}'
    )

    train_data_len = len(t_loader)
    batch_size = cfg['train']['batch_size']
    epoch = cfg['train']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')
    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['train']['batch_size'],
                                  num_workers=cfg['train']['n_workers'],
                                  shuffle=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['train']['batch_size'],
                                num_workers=cfg['train']['n_workers'])

    # Setup Model
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))
    device = f'cuda:{cuda_idx[0]}'
    model = model.to(device)
    model = torch.nn.DataParallel(model, device_ids=cuda_idx)  #自动多卡运行,这个好用

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['train']['optimizer'].items() if k != 'name'
    }
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    scheduler = get_scheduler(optimizer, cfg['train']['lr_schedule'])
    loss_fn = get_loss_function(cfg)
    # logger.info("Using loss {}".format(loss_fn))

    # set checkpoints
    start_iter = 0
    if cfg['train']['resume'] is not None:
        if os.path.isfile(cfg['train']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['train']['resume']))
            checkpoint = torch.load(cfg['train']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['train']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['train']['resume']))

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = 0
    flag = True

    val_rlt_f1 = []
    val_rlt_OA = []
    best_f1_till_now = 0
    best_OA_till_now = 0
    best_fwIoU_now = 0
    best_fwIoU_iter_till_now = 0

    # train
    it = start_iter
    model.train()
    while it <= train_iter and flag:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1
            start_ts = time.time()
            file_a = file_a.to(device)
            file_b = file_b.to(device)
            label = label.to(device)
            mask = mask.to(device)

            optimizer.zero_grad()
            outputs = model(file_a, file_b)

            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()
            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()

            train_time_meter.update(time.time() - start_ts)
            time_meter_val.update(time.time() - start_ts)

            if (it + 1) % cfg['train']['print_interval'] == 0:
                fmt_str = "train:\nIter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    it + 1,
                    train_iter,
                    loss.item(),  #extracts the loss’s value as a Python float.
                    train_time_meter.avg / cfg['train']['batch_size'])
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it + 1)

            if (it + 1) % cfg['train']['val_interval'] == 0 or \
               (it + 1) == train_iter:
                model.eval()  # change behavior like drop out
                with torch.no_grad():  # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val,
                         mask_val) in valloader:
                        file_a_val = file_a_val.to(device)
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max with return the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        gt = label_val.numpy()
                        running_metrics_val.update(gt, pred, mask_val)

                        label_val = label_val.to(device)
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs,
                                           target=label_val,
                                           mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                lr_now = optimizer.param_groups[0]['lr']
                logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it + 1)
                logger.info("Iter %d, val Loss: %.4f" %
                            (it + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                # for k, v in score.items():
                #     logger.info('{}: {}'.format(k, v))
                #     writer.add_scalar('val_metrics/{}'.format(k), v, it+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                      it + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                avg_f1 = score["Mean_F1"]
                OA = score["Overall_Acc"]
                fw_IoU = score["FreqW_IoU"]
                val_rlt_f1.append(avg_f1)
                val_rlt_OA.append(OA)

                if fw_IoU >= best_fwIoU_now and it > 200:
                    best_fwIoU_now = fw_IoU
                    correspond_meanIou = score["Mean_IoU"]
                    best_fwIoU_iter_till_now = it + 1

                    state = {
                        "epoch": it + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_fwIoU": best_fwIoU_now,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(
                            cfg['model']['arch'], cfg['data']['dataloader']))
                    torch.save(state, save_path)

                    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
                    logger.info("Best fwIoU Iter till now= %d" %
                                (best_fwIoU_iter_till_now))

                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - it)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                print(train_time)

            model.train()
            if (it + 1) == train_iter:
                flag = False
                logger.info("Use the Sar_seg_band3,val_interval: 30")
                break
    logger.info("best_fwIoU_now =  %.8f" % (best_fwIoU_now))
    logger.info("Best fwIoU Iter till now= %d" % (best_fwIoU_iter_till_now))

    state = {
        "epoch": it + 1,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "best_fwIoU": best_fwIoU_now,
    }
    save_path = os.path.join(
        writer.file_writer.get_logdir(),
        "{}_{}_last_model.pkl".format(cfg['model']['arch'],
                                      cfg['data']['dataloader']))
    torch.save(state, save_path)
Ejemplo n.º 17
0
def test(cfg, logger, run_id):
    # Setup Augmentations
    augmentations = cfg.test.augments
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path
    data_loader = data_loader(
        data_path, 
        data_format=cfg.data.format, 
        norm = cfg.data.norm,
        split=cfg.test.dataset,
        split_root = cfg.data.split,
        log = cfg.data.log,
        augments=data_aug,
        logger=logger,
        ENL = cfg.data.ENL,
        )
    run_id = osp.join(run_id, cfg.test.dataset)
    os.mkdir(run_id)
    
    logger.info("data path: {}".format(data_path))
    logger.info(f'num of {cfg.test.dataset} set samples: {len(data_loader)}')

    loader = data.DataLoader(data_loader,
                            batch_size=cfg.test.batch_size, 
                            num_workers=cfg.test.n_workers, 
                            shuffle=False,
                            persistent_workers=True,
                            drop_last=False,
                            )

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model).to(device)
    input_size = (cfg.model.in_channels, 512, 512)
    logger.info(f'using model: {cfg.model.arch}')
    
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)

    # load model params
    if osp.isfile(cfg.test.pth):
        logger.info("Loading model from checkpoint '{}'".format(cfg.test.pth))

        # load model state
        checkpoint = torch.load(cfg.test.pth)
        model.load_state_dict(checkpoint["model_state"])
    else:
        raise FileNotFoundError(f'{cfg.test.pth} file not found')

    # Setup Metrics
    running_metrics_val = runningScore(2)
    running_metrics_train = runningScore(2)
    metrics = runningScore(2)
    test_psnr_meter = averageMeter()
    test_ssim_meter = averageMeter()
    img_cnt = 0
    data_range = 255
    if cfg.data.log:
        data_range = np.log(data_range)

    # test
    model.eval()
    with torch.no_grad():
        for clean, noisy, files_path in loader:
             
            noisy = noisy.to(device, dtype=torch.float32)
            noisy_denoised = model(noisy)

            psnr = []
            ssim = []
            if cfg.data.simulate:
                clean = clean.to(device, dtype=torch.float32)
                for ii in range(clean.shape[0]):
                    psnr.append(piq.psnr(noisy_denoised[ii, ...], clean[ii, ...], data_range=data_range).cpu())
                    ssim.append(piq.ssim(noisy_denoised[ii, ...], clean[ii, ...], data_range=data_range).cpu())

                test_psnr_meter.update(np.array(psnr).mean(), n=clean.shape[0])
                test_ssim_meter.update(np.array(ssim).mean(), n=clean.shape[0])

            noisy = data_loader.Hoekman_recover_to_C3(noisy)
            clean = data_loader.Hoekman_recover_to_C3(clean)
            noisy_denoised = data_loader.Hoekman_recover_to_C3(noisy_denoised)
                
            # save images
            for ii in range(clean.shape[0]):

                file_path = files_path[ii][29:]
                file_path = file_path.replace(r'/', '_')
                file_ori = noisy[ii, ...]
                file_clean = clean[ii, ...]
                file_denoise = noisy_denoised[ii, ...]
                print('clean')
                pauli_clean = (psr.rgb_by_c3(file_clean, 'sinclair', is_print=True)*255).astype(np.uint8)
                print('noisy')
                pauli_ori = (psr.rgb_by_c3(file_ori, 'sinclair', is_print=True)*255).astype(np.uint8)
                print('denoise')
                pauli_denoise = (psr.rgb_by_c3(file_denoise, 'sinclair', is_print=True)*255).astype(np.uint8)

                path_ori = osp.join(run_id, file_path)
                path_denoise = osp.join(run_id, file_path)
                path_clean = osp.join(run_id, file_path)
                if cfg.data.simulate:
                    metric_str = f'_{psnr[ii].item():.3f}_{ssim[ii].item():.3f}'
                    path_ori += metric_str
                    path_denoise += metric_str
                    path_clean += metric_str

                path_ori += '-ori.png'
                path_denoise += '-denoise.png'
                path_clean += '-clean.png'

                cv2.imwrite(path_ori, pauli_ori)
                cv2.imwrite(path_denoise, pauli_denoise)
                cv2.imwrite(path_clean, pauli_clean)

        if cfg.data.simulate:    
            logger.info(f'overall psnr: {test_psnr_meter.avg}, ssim: {test_ssim_meter.avg}')

        logger.info(f'\ndone')
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataloader_type"])

    data_root = cfg["data"]["data_root"]
    presentation_root = cfg["data"]["presentation_root"]

    t_loader = data_loader(
        data_root=data_root,
        presentation_root=presentation_root,
        is_transform=True,
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(data_root=data_root,
                           presentation_root=presentation_root,
                           is_transform=True,
                           img_size=(cfg["data"]["img_rows"],
                                     cfg["data"]["img_cols"]),
                           augmentations=data_aug,
                           test_mode=True)

    n_classes = t_loader.n_classes

    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=False,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"],
                                shuffle=False)

    # Setup Metrics
    # running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes, defaultParams).to(device)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    model.load_pretrained_weights(cfg["training"]["saved_model_path"])

    # train_loss_meter = averageMeter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter

    while i <= cfg["training"]["num_presentations"]:

        #                #
        # TRAINING PHASE #
        #                #
        i += 1
        start_ts = time.time()
        trainloader.dataset.random_select()

        hebb = model.initialZeroHebb().to(device)
        for idx, (images, labels) in enumerate(
                trainloader, 1):  # get a single training presentation

            images = images.to(device)
            labels = labels.to(device)

            if idx <= 5:
                model.eval()
                with torch.no_grad():
                    outputs, hebb = model(images,
                                          labels,
                                          hebb,
                                          device,
                                          test_mode=False)
            else:
                scheduler.step()
                model.train()
                optimizer.zero_grad()
                outputs, hebb = model(images,
                                      labels,
                                      hebb,
                                      device,
                                      test_mode=True)
                loss = loss_fn(input=outputs, target=labels)
                loss.backward()
                optimizer.step()

        time_meter.update(time.time() -
                          start_ts)  # -> time taken per presentation

        if (i + 1) % cfg["training"]["print_interval"] == 0:
            fmt_str = "Pres [{:d}/{:d}]  Loss: {:.4f}  Time/Pres: {:.4f}"
            print_str = fmt_str.format(
                i + 1,
                cfg["training"]["num_presentations"],
                loss.item(),
                time_meter.avg / cfg["training"]["batch_size"],
            )
            print(print_str)
            logger.info(print_str)
            writer.add_scalar("loss/test_loss", loss.item(), i + 1)
            time_meter.reset()

        #            #
        # TEST PHASE #
        #            #
        if ((i + 1) % cfg["training"]["test_interval"] == 0
                or (i + 1) == cfg["training"]["num_presentations"]):

            training_state_dict = model.state_dict(
            )  # saving the training state of the model

            valloader.dataset.random_select()
            hebb = model.initialZeroHebb().to(device)
            for idx, (images_val, labels_val) in enumerate(
                    valloader, 1):  # get a single test presentation

                images_val = images_val.to(device)
                labels_val = labels_val.to(device)

                if idx <= 5:
                    model.eval()
                    with torch.no_grad():
                        outputs, hebb = model(images_val,
                                              labels_val,
                                              hebb,
                                              device,
                                              test_mode=False)
                else:
                    model.train()
                    optimizer.zero_grad()
                    outputs, hebb = model(images_val,
                                          labels_val,
                                          hebb,
                                          device,
                                          test_mode=True)
                    loss = loss_fn(input=outputs, target=labels_val)
                    loss.backward()
                    optimizer.step()

                    pred = outputs.data.max(1)[1].cpu().numpy()
                    gt = labels_val.data.cpu().numpy()

                    running_metrics_val.update(gt, pred)
                    val_loss_meter.update(loss.item())

            model.load_state_dict(
                training_state_dict)  # revert back to training parameters

            writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
            logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

            score, class_iou = running_metrics_val.get_scores()
            for k, v in score.items():
                print(k, v)
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

            for k, v in class_iou.items():
                logger.info("{}: {}".format(k, v))
                writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

            val_loss_meter.reset()
            running_metrics_val.reset()

            if score["Mean IoU : \t"] >= best_iou:
                best_iou = score["Mean IoU : \t"]
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_best_model.pkl".format(
                        cfg["model"]["arch"], cfg["data"]["dataloader_type"]),
                )
                torch.save(state, save_path)

        if (i + 1) == cfg["training"]["num_presentations"]:
            break
Ejemplo n.º 19
0
def train(cfg, writer, logger):

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 1860))
    torch.cuda.manual_seed(cfg.get('seed', 1860))
    np.random.seed(cfg.get('seed', 1860))
    random.seed(cfg.get('seed', 1860))

    # Setup device
    if cfg["device"]["use_gpu"]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if not torch.cuda.is_available():
            logger.warning("CUDA not available, using CPU instead!")
    else:
        device = torch.device("cpu")

    # Setup augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)
    if "rcrop" in augmentations.keys():
        data_aug_val = get_composed_augmentations(
            {"rcrop": augmentations["rcrop"]})

    # Setup dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']
    if 'depth_scaling' not in cfg['data'].keys():
        cfg['data']['depth_scaling'] = None
    if 'max_depth' not in cfg['data'].keys():
        logger.warning(
            "Key d_max not found in configuration file! Using default value")
        cfg['data']['max_depth'] = 256
    if 'min_depth' not in cfg['data'].keys():
        logger.warning(
            "Key d_min not found in configuration file! Using default value")
        cfg['data']['min_depth'] = 1
    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug,
                           depth_scaling=cfg['data']['depth_scaling'],
                           n_bins=cfg['data']['depth_bins'],
                           max_depth=cfg['data']['max_depth'],
                           min_depth=cfg['data']['min_depth'])

    v_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['val_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug_val,
                           depth_scaling=cfg['data']['depth_scaling'],
                           n_bins=cfg['data']['depth_bins'],
                           max_depth=cfg['data']['max_depth'],
                           min_depth=cfg['data']['min_depth'])

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['validation']['batch_size'],
                                num_workers=cfg['validation']['n_workers'],
                                shuffle=True,
                                drop_last=True)

    # Check selected tasks
    if sum(cfg["data"]["tasks"].values()) > 1:
        logger.info("Running multi-task training with config: {}".format(
            cfg["data"]["tasks"]))

    # Get output dimension of the network's final layer
    n_classes_d_cls = None
    if cfg["data"]["tasks"]["d_cls"]:
        n_classes_d_cls = t_loader.n_classes_d_cls

    # Setup metrics for validation
    if cfg["data"]["tasks"]["d_cls"]:
        running_metrics_val_d_cls = runningScore(n_classes_d_cls)
    if cfg["data"]["tasks"]["d_reg"]:
        running_metrics_val_d_reg = running_side_score()

    # Setup model
    model = get_model(cfg['model'],
                      cfg["data"]["tasks"],
                      n_classes_d_cls=n_classes_d_cls).to(device)
    # model = d_regResNet().to(device)

    # Setup multi-GPU support
    n_gpus = torch.cuda.device_count()
    if n_gpus > 1:
        logger.info("Running multi-gpu training on {} GPUs".format(n_gpus))
        model = torch.nn.DataParallel(model, device_ids=range(n_gpus))

    # Setup multi-task loss
    task_weights = {}
    update_weights = True if \
        cfg["training"]["task_weight_policy"] == 'update' else False
    for task, weight in cfg["training"]["task_weight_init"].items():
        task_weights[task] = torch.tensor(weight).float()
        task_weights[task] = task_weights[task].to(device)
        task_weights[task] = task_weights[task].requires_grad_(update_weights)
    logger.info("Task weights were initialized with {}".format(
        cfg["training"]["task_weight_init"]))

    # Setup optimizer and lr_scheduler
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    objective_params = list(model.parameters()) + list(task_weights.values())
    optimizer = optimizer_cls(objective_params, **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])
    logger.info("Using learning-rate scheduler {}".format(scheduler))

    # Setup task-specific loss functions
    # logger.debug("setting loss functions")
    loss_fns = {}
    for task, selected in cfg["data"]["tasks"].items():
        if selected:
            logger.info("Task " + task + " was selected for training.")
            loss_fn = get_loss_function(cfg, task)
            logger.info("Using loss function {} for task {}".format(
                loss_fn, task))
            loss_fns[task] = loss_fn

    # Load weights from old checkpoint if set
    # logger.debug("checking for resume checkpoint")
    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            logger.info("Loading file...")
            checkpoint = torch.load(cfg['training']['resume'],
                                    map_location="cpu")
            logger.info("Loading model...")
            model.load_state_dict(checkpoint["model_state"])
            model.to("cpu")
            model.to(device)
            logger.info("Restoring task weights...")
            task_weights = checkpoint["task_weights"]
            for task, state in task_weights.items():
                # task_weights[task] = state.to(device)
                task_weights[task] = torch.tensor(state.data).float()
                task_weights[task] = task_weights[task].to(device)
                task_weights[task] = task_weights[task].requires_grad_(
                    update_weights)
            logger.info("Loading scheduler...")
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            #            scheduler.to("cpu")
            start_iter = checkpoint["iteration"]

            # Add loaded parameters to optimizer
            # NOTE task_weights will not update otherwise!
            logger.info("Loading optimizer...")
            optimizer_cls = get_optimizer(cfg)
            objective_params = list(model.parameters()) + \
                list(task_weights.values())
            optimizer = optimizer_cls(objective_params, **optimizer_params)
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            # for state in optimizer.state.values():
            #     for k, v in state.items():
            #         if torch.is_tensor(v):
            #             state[k] = v.to(device)

            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["iteration"]))
        else:
            logger.error(
                "No checkpoint found at '{}'. Re-initializing params!".format(
                    cfg['training']['resume']))

    # Initialize meters for various metrics
    # logger.debug("initializing metrics")
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # Setup other utility variables
    i = start_iter
    flag = True
    timer_training_start = time.time()

    logger.info("Starting training phase...")

    logger.debug("model device cuda?")
    logger.debug(next(model.parameters()).is_cuda)
    logger.debug("d_reg weight device:")
    logger.debug(task_weights["d_reg"].device)
    logger.debug("cls weight device:")
    logger.debug(task_weights["d_cls"].device)

    while i <= cfg['training']['train_iters'] and flag:
        for (images, labels) in trainloader:

            start_ts = time.time()
            scheduler.step()
            model.train()

            # Forward pass
            # logger.debug("sending images to device")
            images = images.to(device)
            optimizer.zero_grad()
            # logger.debug("forward pass")
            outputs = model(images)

            # Clip predicted depth to min/max
            # logger.debug("clamping outputs")
            if cfg["data"]["tasks"]["d_reg"]:
                if cfg["data"]["depth_scaling"] is not None:
                    if cfg["data"]["depth_scaling"] == "clip":
                        logger.warning("Using deprecated clip function!")
                        outputs["d_reg"] = torch.clamp(
                            outputs["d_reg"], 0, cfg["data"]["max_depth"])

            # Calculate single-task losses
            # logger.debug("calculate loss")
            st_loss = {}
            for task, loss_fn in loss_fns.items():
                labels[task] = labels[task].to(device)
                st_loss[task] = loss_fn(input=outputs[task],
                                        target=labels[task])

            # Calculate multi-task loss
            # logger.debug("calculate mt loss")
            mt_loss = 0
            if len(st_loss) > 1:
                for task, loss in st_loss.items():
                    s = task_weights[task]  # s := log(sigma^2)
                    r = s * 0.5  # regularization term
                    if task in ["d_cls"]:
                        w = torch.exp(-s)  # weighting (class.)
                    elif task in ["d_reg"]:
                        w = 0.5 * torch.exp(-s)  # weighting (regr.)
                    else:
                        raise ValueError("Weighting not implemented!")
                    mt_loss += loss * w + r
            else:
                mt_loss = list(st_loss.values())[0]

            # Backward pass
            # logger.debug("backward pass")
            mt_loss.backward()
            # logger.debug("update weights")
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # Output current training status
            # logger.debug("write log")
            if i == 0 or (i + 1) % cfg['training']['print_interval'] == 0:
                pad = str(len(str(cfg['training']['train_iters'])))
                print_str = ("Training Iteration: [{:>" + pad + "d}/{:d}]" +
                             "  Loss: {:>14.4f}" +
                             "  Time/Image: {:>7.4f}").format(
                                 i + 1, cfg['training']['train_iters'],
                                 mt_loss.item(), time_meter.avg /
                                 cfg['training']['batch_size'])
                logger.info(print_str)

                # Add training status to summaries
                writer.add_scalar('learning_rate',
                                  scheduler.get_lr()[0], i + 1)
                writer.add_scalar('batch_size', cfg['training']['batch_size'],
                                  i + 1)
                writer.add_scalar('loss/train_loss', mt_loss.item(), i + 1)
                for task, loss in st_loss.items():
                    writer.add_scalar("loss/single_task/" + task, loss, i + 1)
                for task, weight in task_weights.items():
                    writer.add_scalar("task_weights/" + task, weight, i + 1)
                time_meter.reset()

                # Add latest input image to summaries
                train_input = images[0].cpu().numpy()[::-1, :, :]
                writer.add_image("training/input", train_input, i + 1)

                # Add d_cls predictions and gt for latest sample to summaries
                if cfg["data"]["tasks"]["d_cls"]:
                    train_pred = outputs["d_cls"].detach().cpu().numpy().max(
                        0)[1].astype(np.uint8)
                    # train_pred = np.array(outputs["d_cls"][0].data.max(0)[1],
                    #                       dtype=np.uint8)
                    train_pred = t_loader.decode_segmap(train_pred)
                    train_pred = torch.tensor(np.rollaxis(train_pred, 2, 0))
                    writer.add_image("training/d_cls/prediction", train_pred,
                                     i + 1)

                    train_gt = t_loader.decode_segmap(
                        labels["d_cls"][0].data.cpu().numpy())
                    train_gt = torch.tensor(np.rollaxis(train_gt, 2, 0))
                    writer.add_image("training/d_cls/label", train_gt, i + 1)

                # Add d_reg predictions and gt for latest sample to summaries
                if cfg["data"]["tasks"]["d_reg"]:
                    train_pred = outputs["d_reg"][0]
                    train_pred = np.array(train_pred.data.cpu().numpy())
                    train_pred = t_loader.visualize_depths(
                        t_loader.restore_metric_depths(train_pred))
                    writer.add_image("training/d_reg/prediction", train_pred,
                                     i + 1)

                    train_gt = labels["d_reg"][0].data.cpu().numpy()
                    train_gt = t_loader.visualize_depths(
                        t_loader.restore_metric_depths(train_gt))
                    if len(train_gt.shape) < 3:
                        train_gt = np.expand_dims(train_gt, axis=0)
                    writer.add_image("training/d_reg/label", train_gt, i + 1)

            # Run mid-training validation
            if (i + 1) % cfg['training']['val_interval'] == 0:
                # or (i + 1) == cfg['training']['train_iters']:

                # Output current status
                # logger.debug("Training phase took " + str(timedelta(seconds=time.time() - timer_training_start)))
                timer_validation_start = time.time()
                logger.info("Validating model at training iteration" +
                            " {}...".format(i + 1))

                # Evaluate validation set
                model.eval()
                with torch.no_grad():
                    i_val = 0
                    pbar = tqdm(total=len(valloader), unit="batch")
                    for (images_val, labels_val) in valloader:

                        # Forward pass
                        images_val = images_val.to(device)
                        outputs_val = model(images_val)

                        # Clip predicted depth to min/max
                        if cfg["data"]["tasks"]["d_reg"]:
                            if cfg["data"]["depth_scaling"] is None:
                                logger.warning(
                                    "Using deprecated clip function!")
                                outputs_val["d_reg"] = torch.clamp(
                                    outputs_val["d_reg"], 0,
                                    cfg["data"]["max_depth"])
                            else:
                                outputs_val["d_reg"] = torch.clamp(
                                    outputs_val["d_reg"], 0, 1)

                        # Calculate single-task losses
                        st_loss_val = {}
                        for task, loss_fn in loss_fns.items():
                            labels_val[task] = labels_val[task].to(device)
                            st_loss_val[task] = loss_fn(
                                input=outputs_val[task],
                                target=labels_val[task])

                        # Calculate multi-task loss
                        mt_loss_val = 0
                        if len(st_loss) > 1:
                            for task, loss_val in st_loss_val.items():
                                s = task_weights[task]
                                r = s * 0.5
                                if task in ["d_cls"]:
                                    w = torch.exp(-s)
                                elif task in ["d_reg"]:
                                    w = 0.5 * torch.exp(-s)
                                else:
                                    raise ValueError(
                                        "Weighting not implemented!")
                                mt_loss_val += loss_val * w + r
                        else:
                            mt_loss_val = list(st_loss.values())[0]

                        # Accumulate metrics for summaries
                        val_loss_meter.update(mt_loss_val.item())

                        if cfg["data"]["tasks"]["d_cls"]:
                            running_metrics_val_d_cls.update(
                                labels_val["d_cls"].data.cpu().numpy(),
                                outputs_val["d_cls"].data.cpu().numpy().argmax(
                                    1))

                        if cfg["data"]["tasks"]["d_reg"]:
                            running_metrics_val_d_reg.update(
                                v_loader.restore_metric_depths(
                                    outputs_val["d_reg"].data.cpu().numpy()),
                                v_loader.restore_metric_depths(
                                    labels_val["d_reg"].data.cpu().numpy()))

                        # Update progressbar
                        i_val += 1
                        pbar.update()

                        # Stop validation early if max_iter key is set
                        if "max_iter" in cfg["validation"].keys() and \
                                i_val >= cfg["validation"]["max_iter"]:
                            logger.warning("Stopped validation early " +
                                           "because max_iter was reached")
                            break

                # Add sample input images from latest batch to summaries
                num_img_samples_val = min(len(images_val), NUM_IMG_SAMPLES)
                for cur_s in range(0, num_img_samples_val):
                    val_input = images_val[cur_s].cpu().numpy()[::-1, :, :]
                    writer.add_image(
                        "validation_sample_" + str(cur_s + 1) + "/input",
                        val_input, i + 1)

                    # Add predictions/ground-truth for d_cls to summaries
                    if cfg["data"]["tasks"]["d_cls"]:
                        val_pred = outputs_val["d_cls"][cur_s].data.max(0)[1]
                        val_pred = np.array(val_pred, dtype=np.uint8)
                        val_pred = t_loader.decode_segmap(val_pred)
                        val_pred = torch.tensor(np.rollaxis(val_pred, 2, 0))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/prediction_d_cls", val_pred, i + 1)
                        val_gt = t_loader.decode_segmap(
                            labels_val["d_cls"][cur_s].data.cpu().numpy())
                        val_gt = torch.tensor(np.rollaxis(val_gt, 2, 0))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/label_d_cls", val_gt, i + 1)

                # Add predictions/ground-truth for d_reg to summaries
                    if cfg["data"]["tasks"]["d_reg"]:
                        val_pred = outputs_val["d_reg"][cur_s].cpu().numpy()
                        val_pred = v_loader.visualize_depths(
                            v_loader.restore_metric_depths(val_pred))
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/prediction_d_reg", val_pred, i + 1)

                        val_gt = labels_val["d_reg"][cur_s].data.cpu().numpy()
                        val_gt = v_loader.visualize_depths(
                            v_loader.restore_metric_depths(val_gt))
                        if len(val_gt.shape) < 3:
                            val_gt = np.expand_dims(val_gt, axis=0)
                        writer.add_image(
                            "validation_sample_" + str(cur_s + 1) +
                            "/label_d_reg", val_gt, i + 1)

                # Add evaluation metrics for d_cls predictions to summaries
                if cfg["data"]["tasks"]["d_cls"]:
                    score, class_iou = running_metrics_val_d_cls.get_scores()
                    for k, v in score.items():
                        writer.add_scalar(
                            'validation/d_cls_metrics/{}'.format(k[:-3]), v,
                            i + 1)
                        for k, v in class_iou.items():
                            writer.add_scalar(
                                'validation/d_cls_metrics/class_{}'.format(k),
                                v, i + 1)
                    running_metrics_val_d_cls.reset()

                # Add evaluation metrics for d_reg predictions to summaries
                if cfg["data"]["tasks"]["d_reg"]:
                    writer.add_scalar('validation/d_reg_metrics/rel',
                                      running_metrics_val_d_reg.rel, i + 1)
                    running_metrics_val_d_reg.reset()

                # Add validation loss to summaries
                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)

                # Output current status
                logger.info(
                    ("Validation Loss at Iteration {}: " + "{:>14.4f}").format(
                        i + 1, val_loss_meter.avg))
                val_loss_meter.reset()
                # logger.debug("Validation phase took {}".format(timedelta(seconds=time.time() - timer_validation_start)))
                timer_training_start = time.time()

                # Close progressbar
                pbar.close()

            # Save checkpoint
            if (i + 1) % cfg['training']['checkpoint_interval'] == 0 or \
               (i + 1) == cfg['training']['train_iters'] or \
               i == 0:
                state = {
                    "iteration": i + 1,
                    "model_state": model.state_dict(),
                    "task_weights": task_weights,
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict()
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint_iter_".format(cfg['model']['arch'],
                                                    cfg['data']['dataset']) +
                    str(i + 1) + ".pkl")
                torch.save(state, save_path)
                logger.info("Saved checkpoint at iteration {} to: {}".format(
                    i + 1, save_path))

            # Stop training if current iteration == max iterations
            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break

            i += 1
Ejemplo n.º 20
0
def train(cfg, writer, logger):
    
    # Setup random seeds to a determinated value for reproduction
    # seed = 1337
    # torch.manual_seed(seed)
    # torch.cuda.manual_seed(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    # np.random.default_rng(seed)

    # Setup Augmentations
    augmentations = cfg.train.augment
    logger.info(f'using augments: {augmentations}')
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg.data.dataloader)
    data_path = cfg.data.path

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.train_split,
        norm = cfg.data.norm,
        augments=data_aug
        )

    v_loader = data_loader(
        data_path,
        # transform=None,
        # time_shuffle = cfg.data.time_shuffle,
        # to_tensor=False,
        data_format = cfg.data.format,
        split=cfg.data.val_split,
        )
    train_data_len = len(t_loader)
    logger.info(f'num of train samples: {train_data_len} \nnum of val samples: {len(v_loader)}')

    batch_size = cfg.train.batch_size
    epoch = cfg.train.epoch
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)
    logger.info(f'total train iter: {train_iter}')

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size, 
                                  num_workers=cfg.train.n_workers, 
                                  shuffle=True,
                                  persistent_workers=True,
                                  drop_last=True)

    valloader = data.DataLoader(v_loader, 
                                batch_size=10, 
                                # persis
                                num_workers=cfg.train.n_workers,)

    # Setup Model
    device = f'cuda:{cfg.gpu[0]}'
    model = get_model(cfg.model, 2).to(device)
    input_size = (cfg.model.input_nbr, 512, 512)
    logger.info(f"Using Model: {cfg.model.arch}")
    # logger.info(f'model summary: {summary(model, input_size=(input_size, input_size), is_complex=True)}')
    model = torch.nn.DataParallel(model, device_ids=cfg.gpu)      #自动多卡运行,这个好用
    
    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k:v for k, v in vars(cfg.train.optimizer).items()
                        if k not in ('name', 'wrap')}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))
    if hasattr(cfg.train.optimizer, 'warp') and cfg.train.optimizer.wrap=='lars':
        optimizer = LARS(optimizer=optimizer)
        logger.info(f'warp optimizer with {cfg.train.optimizer.wrap}')
    scheduler = get_scheduler(optimizer, cfg.train.lr)
    loss_fn = get_loss_function(cfg)
    logger.info(f"Using loss ,{str(cfg.train.loss)}")

    # load checkpoints
    val_cls_1_acc = 0
    best_cls_1_acc_now = 0
    best_cls_1_acc_iter_now = 0
    val_macro_OA = 0
    best_macro_OA_now = 0
    best_macro_OA_iter_now = 0
    start_iter = 0
    if cfg.train.resume is not None:
        if os.path.isfile(cfg.train.resume):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg.train.resume)
            )

            # load model state
            checkpoint = torch.load(cfg.train.resume)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # best_cls_1_acc_now = checkpoint["best_cls_1_acc_now"]
            # best_cls_1_acc_iter_now = checkpoint["best_cls_1_acc_iter_now"]

            start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg.train.resume, checkpoint["epoch"]
                )
            )

            # copy tensorboard files
            resume_src_dir = osp.split(cfg.train.resume)[0]
            # shutil.copytree(resume_src_dir, writer.get_logdir())
            for file in os.listdir(resume_src_dir):
                if not ('.log' in file or '.yml' in file or '_last_model' in file):
                # if 'events.out.tfevents' in file:
                    resume_dst_dir = writer.get_logdir()
                    fu.copy(osp.join(resume_src_dir, file), resume_dst_dir, )

        else:
            logger.info("No checkpoint found at '{}'".format(cfg.train.resume))

    # Setup Metrics
    running_metrics_val = runningScore(2)
    runing_metrics_train = runningScore(2)
    val_loss_meter = averageMeter()
    train_time_meter = averageMeter()

    # train
    it = start_iter
    train_start_time = time.time() 
    train_val_start_time = time.time()
    model.train()   
    while it < train_iter:
        for (file_a, file_b, label, mask) in trainloader:
            it += 1           
            file_a = file_a.to(device)            
            file_b = file_b.to(device)            
            label = label.to(device)            
            mask = mask.to(device)

            optimizer.zero_grad()
            # print(f'dtype: {file_a.dtype}')
            outputs = model(file_a, file_b)
            loss = loss_fn(input=outputs, target=label, mask=mask)
            loss.backward()

            # print('conv11: ', model.conv11.weight.grad, model.conv11.weight.grad.shape)
            # print('conv21: ', model.conv21.weight.grad, model.conv21.weight.grad.shape)
            # print('conv31: ', model.conv31.weight.grad, model.conv31.weight.grad.shape)

            # In PyTorch 1.1.0 and later, you should call `optimizer.step()` before `lr_scheduler.step()`
            optimizer.step()
            scheduler.step()
            
            # record the acc of the minibatch
            pred = outputs.max(1)[1].cpu().numpy()
            runing_metrics_train.update(label.cpu().numpy(), pred, mask.cpu().numpy())

            train_time_meter.update(time.time() - train_start_time)

            if it % cfg.train.print_interval == 0:
                # acc of the samples between print_interval
                score, _ = runing_metrics_train.get_scores()
                train_cls_0_acc, train_cls_1_acc = score['Acc']
                fmt_str = "Iter [{:d}/{:d}]  train Loss: {:.4f}  Time/Image: {:.4f},\n0:{:.4f}\n1:{:.4f}"
                print_str = fmt_str.format(it,
                                           train_iter,
                                           loss.item(),      #extracts the loss’s value as a Python float.
                                           train_time_meter.avg / cfg.train.batch_size,train_cls_0_acc, train_cls_1_acc)
                runing_metrics_train.reset()
                train_time_meter.reset()
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), it)
                writer.add_scalars('metrics/train', {'cls_0':train_cls_0_acc, 'cls_1':train_cls_1_acc}, it)
                # writer.add_scalar('train_metrics/acc/cls_0', train_cls_0_acc, it)
                # writer.add_scalar('train_metrics/acc/cls_1', train_cls_1_acc, it)

            if it % cfg.train.val_interval == 0 or \
               it == train_iter:
                val_start_time = time.time()
                model.eval()            # change behavior like drop out
                with torch.no_grad():   # disable autograd, save memory usage
                    for (file_a_val, file_b_val, label_val, mask_val) in valloader:      
                        file_a_val = file_a_val.to(device)            
                        file_b_val = file_b_val.to(device)

                        outputs = model(file_a_val, file_b_val)
                        # tensor.max() returns the maximum value and its indices
                        pred = outputs.max(1)[1].cpu().numpy()
                        running_metrics_val.update(label_val.numpy(), pred, mask_val.numpy())
            
                        label_val = label_val.to(device)            
                        mask_val = mask_val.to(device)
                        val_loss = loss_fn(input=outputs, target=label_val, mask=mask_val)
                        val_loss_meter.update(val_loss.item())

                score, _ = running_metrics_val.get_scores()
                val_cls_0_acc, val_cls_1_acc = score['Acc']

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, it)
                logger.info(f"Iter [{it}/{train_iter}], val Loss: {val_loss_meter.avg:.4f} Time/Image: {(time.time()-val_start_time)/len(v_loader):.4f}\n0: {val_cls_0_acc:.4f}\n1:{val_cls_1_acc:.4f}")
                # lr_now = optimizer.param_groups[0]['lr']
                # logger.info(f'lr: {lr_now}')
                # writer.add_scalar('lr', lr_now, it+1)

                logger.info('0: {:.4f}\n1:{:.4f}'.format(val_cls_0_acc, val_cls_1_acc))
                writer.add_scalars('metrics/val', {'cls_0':val_cls_0_acc, 'cls_1':val_cls_1_acc}, it)
                # writer.add_scalar('val_metrics/acc/cls_0', val_cls_0_acc, it)
                # writer.add_scalar('val_metrics/acc/cls_1', val_cls_1_acc, it)

                val_loss_meter.reset()
                running_metrics_val.reset()

                # OA=score["Overall_Acc"]
                val_macro_OA = (val_cls_0_acc+val_cls_1_acc)/2
                if val_macro_OA >= best_macro_OA_now and it>200:
                    best_macro_OA_now = val_macro_OA
                    best_macro_OA_iter_now = it
                    state = {
                        "epoch": it,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_macro_OA_now": best_macro_OA_now,
                        'best_macro_OA_iter_now':best_macro_OA_iter_now,
                    }
                    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_best_model.pkl".format(cfg.model.arch,cfg.data.dataloader))
                    torch.save(state, save_path)

                    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
                    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

                train_val_time = time.time() - train_val_start_time
                remain_time = train_val_time * (train_iter-it) / it
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain train time = %d hours %d minutes %d seconds \n" % (h, m, s)
                else:
                    train_time = "Remain train time : train completed.\n"
                logger.info(train_time)
                model.train()

            train_start_time = time.time() 

    logger.info("best OA now =  %.8f" % (best_macro_OA_now))
    logger.info("best OA iter now= %d" % (best_macro_OA_iter_now))

    state = {
            "epoch": it,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_macro_OA_now": best_macro_OA_now,
            'best_macro_OA_iter_now':best_macro_OA_iter_now,
            }
    save_path = os.path.join(writer.file_writer.get_logdir(), "{}_{}_last_model.pkl".format(cfg.model.arch, cfg.data.dataloader))
    torch.save(state, save_path)
Ejemplo n.º 21
0
def eval(cfg, writer, logger, logdir):

    # Setup seeds
    #torch.manual_seed(cfg.get("seed", 1337))
    #torch.cuda.manual_seed(cfg.get("seed", 1337))
    #np.random.seed(cfg.get("seed", 1337))
    #random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataloader_type"])

    data_root = cfg["data"]["data_root"]
    presentation_root = cfg["data"]["presentation_root"]

    v_loader = data_loader(data_root=data_root,
                           presentation_root=presentation_root,
                           is_transform=True,
                           img_size=(cfg["data"]["img_rows"],
                                     cfg["data"]["img_cols"]),
                           augmentations=data_aug,
                           test_mode=True)

    n_classes = v_loader.n_classes

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"],
                                shuffle=False)

    # Setup Metrics
    # running_metrics_train = runningScore(n_classes)
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes, defaultParams).to(device)

    #model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    # train_loss_meter = averageMeter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = 0
    pres_results = [
    ]  # a final list of all <image, label, output> of all presentations
    img_list = []

    while i < cfg["training"]["num_presentations"]:

        #                 #
        #  TESTING PHASE  #
        #                 #
        i += 1

        training_state_dict = model.state_dict()
        hebb = model.initialZeroHebb().to(device)
        valloader.dataset.random_select()
        start_ts = time.time()

        for idx, (images_val, labels_val) in enumerate(
                valloader, 1):  # get a single test presentation

            img = torchvision.utils.make_grid(images_val).numpy()
            img = np.transpose(img, (1, 2, 0))
            img = img[:, :, ::-1]
            img_list.append(img)
            pres_results.append(decode_segmap(labels_val.numpy()))
            images_val = images_val.to(device)
            labels_val = labels_val.to(device)

            if idx <= 5:
                model.eval()
                with torch.no_grad():
                    outputs, hebb = model(images_val,
                                          labels_val,
                                          hebb,
                                          device,
                                          test_mode=False)
            else:
                model.train()
                optimizer.zero_grad()
                outputs, hebb = model(images_val,
                                      labels_val,
                                      hebb,
                                      device,
                                      test_mode=True)
                loss = loss_fn(input=outputs, target=labels_val)
                loss.backward()
                optimizer.step()

                pred = outputs.data.max(1)[1].cpu().numpy()
                gt = labels_val.data.cpu().numpy()

                running_metrics_val.update(gt, pred)
                val_loss_meter.update(loss.item())

                # Turning the image, label, and output into plottable formats
                '''img = torchvision.utils.make_grid(images_val.cpu()).numpy()
                img = np.transpose(img, (1, 2, 0))
                img = img[:, :, ::-1]
                print("img.shape",img.shape)
                print("gt.shape and type",gt.shape, gt.dtype)
                print("pred.shape and type",pred.shape, pred.dtype)'''

                cla, cnt = np.unique(pred, return_counts=True)
                print("Unique classes predicted = {}, counts = {}".format(
                    cla, cnt))
                #pres_results.append(img)
                #pres_results.append(decode_segmap(gt))
                pres_results.append(decode_segmap(pred))

        time_meter.update(time.time() -
                          start_ts)  # -> time taken per presentation
        model.load_state_dict(
            training_state_dict)  # revert back to training parameters

        # Display presentations stats
        fmt_str = "Pres [{:d}/{:d}]  Loss: {:.4f}  Time/Pres: {:.4f}"
        print_str = fmt_str.format(
            i + 1,
            cfg["training"]["num_presentations"],
            loss.item(),
            time_meter.avg / cfg["training"]["batch_size"],
        )
        print(print_str)
        logger.info(print_str)
        writer.add_scalar("loss/test_loss", loss.item(), i + 1)
        time_meter.reset()

        writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
        logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

        # Display presentation metrics
        score, class_iou = running_metrics_val.get_scores()
        for k, v in score.items():
            print(k, v)
            logger.info("{}: {}".format(k, v))
            writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

        #for k, v in class_iou.items():
        #    logger.info("{}: {}".format(k, v))
        #    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

        val_loss_meter.reset()
        running_metrics_val.reset()

    # save presentations to a png image file
    save_presentations(pres_results=pres_results,
                       num_pres=cfg["training"]["num_presentations"],
                       num_col=7,
                       logdir=logdir,
                       name="pre_results.png")
    save_presentations(pres_results=img_list,
                       num_pres=cfg["training"]["num_presentations"],
                       num_col=6,
                       logdir=logdir,
                       name="img_list.png")
Ejemplo n.º 22
0
        data_loader = get_loader(cfg["data"]["dataset"])
        data_path = cfg["data"]["path"]

        # Load communication label (note that some datasets do not provide this)
        if 'commun_label' in cfg["data"]:
            if_commun_label = cfg["data"]['commun_label']
        else:
            if_commun_label = 'None'

        # dataloaders
        t_loader = data_loader(data_path,
                               is_transform=True,
                               split=cfg["data"]["train_split"],
                               img_size=(cfg["data"]["img_rows"],
                                         cfg["data"]["img_cols"]),
                               augmentations=get_composed_augmentations(
                                   cfg["training"].get("augmentations", None)),
                               target_view=cfg["data"]["target_view"],
                               commun_label=if_commun_label)

        v_loader = data_loader(data_path,
                               is_transform=True,
                               split=cfg["data"]["val_split"],
                               img_size=(cfg["data"]["img_rows"],
                                         cfg["data"]["img_cols"]),
                               target_view=cfg["data"]["target_view"],
                               commun_label=if_commun_label)

        trainloader = data.DataLoader(t_loader,
                                      batch_size=cfg["training"]["batch_size"],
                                      num_workers=cfg["training"]["n_workers"],
                                      shuffle=True,
Ejemplo n.º 23
0
def test(args, cfg):
    ''' use the trained model to test '''

    # Setup random seeds
    torch.manual_seed(cfg.get('seed', 137))
    torch.cuda.manual_seed(cfg.get('seed', 137))
    np.random.seed(cfg.get('seed', 137))
    random.seed(cfg.get('seed', 137))

    # setup augmentations
    augs = cfg['train'].get('augmentations', None)
    data_aug = get_composed_augmentations(augs)

    # setup dataloader
    data_loader = get_loader(cfg['data']['dataloader'])
    data_path = cfg['data']['path']
    print('using dataset:', data_path, ', dataloader:',
          cfg['data']['dataloader'])
    if cfg['test']['train_set']:
        test_train_loader = data_loader(root=data_path,
                                        transfrom=None,
                                        split='test_train',
                                        augmentations=data_aug)
    if cfg['test']['val_set']:
        test_val_loader = data_loader(root=data_path,
                                      transform=None,
                                      split='test_val',
                                      augmentations=data_aug)

    # setup model
    model = get_model(cfg['model'])
    print('using model:', cfg['model']['arch'])
    device = f'cuda:{cuda_idx[0]}'
    model = model.to(device)
    # don't need run on multiple gpus
    # model = nn.DataParallel(model, device_ids=cuda_idx)

    # load model
    pth_path = cfg['test']['pth']
    if osp.isfile(pth_path):
        print('load model from checkpoint', pth_path)
        check_point = torch.load(pth_path)
        model.load_state_dict(check_point['model_state'])
    else:
        raise FileNotFoundError('can not find the specified .pth file')

    # setup metrics
    inc_metrics = runningScore()
    current_metrics = runningScore()

    # test
    tile_size = cfg['data']['tile_size']
    if not isinstance(tile_size, (tuple, list)):
        tile_size = (tile_size, tile_size)
    if cfg['test']['train_set']:
        test_loader = test_train_loader
    else:
        test_loader = test_val_loader
    tiles_per_image = test_loader.tiles_per_image
    model.eval()
    with torch.no_grad():
        for file_a, file_b, label, mask, label_path in test_loader:
            regid_pred = np.zeros_like(label)
            final_pred = regid_pred.zeros_like(label)

            # tile-wise change detection
            for tile_idx in range(tiles_per_image):
                tile_coords = lbm.get_corrds_from_slice_idx(
                    (512, 512), tile_size, tile_idx)
                tile_a = file_a[tile_coords[0]:tile_coords + tile_size[0],
                                tile_coords[1]:tile_coords[1] +
                                tile_size[1], :]
                tile_b = file_b[tile_coords[0]:tile_coords + tile_size[0],
                                tile_coords[1]:tile_coords[1] +
                                tile_size[1], :]
                tile = torch.cat((tile_a, tile_b), dim=0)
                tile_outputs = model(tile)
                tile_pred = tile_outputs.max(dim=0)[1]
                regid_pred[tile_coords[0]:tile_coords[0] + tile_size[0],
                           tile_coords[1]:tile_coords[1] +
                           tile_size[1]] = tile_pred

            # use file a to make superpixel segmentation
            segs = slic(file_a,
                        n_segments=1024,
                        compactness=0.5,
                        min_size_factor=0.5,
                        enforce_connectivity=True,
                        convert2lab=False)
            for spix_idx in range(segs.max() + 1):
                spix_region = segs == spix_idx
                final_pred[spix_region] = regid_pred[spix_region].sum(
                ) > spix_region.sum() / 2

            # evaluate and save
            inc_metrics.update(label, final_pred, mask)
            current_metrics.update(label, final_pred, mask)
            score, cls_iou = current_metrics.get_scores()
            for k, v in cls_iou.items():
                print('{}: {}'.format(k, v))
            current_metrics.reset()
            save_path = label_path[:-4] + '_pred_.png'
            cv2.imwrite(save_path, final_pred)

        # ultimate evaluate
        score, cls_iou = inc_metrics.get_scores()
        for k, v in score.items():
            print('{}: {}'.format(k, v))
        for k, v in cls_iou.items():
            print('{}: {}'.format(k, v))
Ejemplo n.º 24
0
def train(cfg, writer, logger, start_iter=0, model_only=False, gpu=-1, save_dir=None):

    # Setup seeds and config
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))
    
    # Setup device
    if gpu == -1:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cuda:%d" %gpu if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    if cfg["data"]["dataset"] == "softmax_cityscapes_convention":
        data_aug = get_composed_augmentations_softmax(augmentations)
    else:
        data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )
    v_loader = data_loader(
        data_path,
        config = cfg["data"],
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    sampler = None
    if "sampling" in cfg["data"]:
        sampler = data.WeightedRandomSampler(
            weights = get_sampling_weights(t_loader, cfg["data"]["sampling"]),
            num_samples = len(t_loader),
            replacement = True
        )
    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        sampler=sampler,
        shuffle=sampler==None,
    )
    valloader = data.DataLoader(
        v_loader, batch_size=cfg["training"]["batch_size"], num_workers=cfg["training"]["n_workers"]
    )

    # Setup Metrics
    running_metrics_val = {"seg": runningScoreSeg(n_classes)}
    if "classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier( len(classes) )
    if "bin_classifiers" in cfg["data"]:
        for name, classes in cfg["data"]["bin_classifiers"].items():
            running_metrics_val[name] = runningScoreClassifier(2)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)
    
    total_params = sum(p.numel() for p in model.parameters())
    print( 'Parameters:',total_params )

    if gpu == -1:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    else:
        model = torch.nn.DataParallel(model, device_ids=[gpu])
    
    model.apply(weights_init)
    pretrained_path='weights/hardnet_petite_base.pth'
    weights = torch.load(pretrained_path)
    model.module.base.load_state_dict(weights)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg["training"]["optimizer"].items() if k != "name"}

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    print("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])
    loss_dict = get_loss_function(cfg, device)

    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["resume"])
            )
            checkpoint = torch.load(cfg["training"]["resume"], map_location=device)
            model.load_state_dict(checkpoint["model_state"], strict=False)
            if not model_only:
                optimizer.load_state_dict(checkpoint["optimizer_state"])
                scheduler.load_state_dict(checkpoint["scheduler_state"])
                start_iter = checkpoint["epoch"]
            logger.info(
                "Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]
                )
            )
        else:
            logger.info("No checkpoint found at '{}'".format(cfg["training"]["resume"]))

    if cfg["training"]["finetune"] is not None:
        if os.path.isfile(cfg["training"]["finetune"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(cfg["training"]["finetune"])
            )
            checkpoint = torch.load(cfg["training"]["finetune"])
            model.load_state_dict(checkpoint["model_state"])

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True
    loss_all = 0
    loss_n = 0

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, label_dict, _) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()

            images = images.to(device)
            optimizer.zero_grad()
            output_dict = model(images)

            loss = compute_loss(    # considers key names in loss_dict and output_dict
                loss_dict, images, label_dict, output_dict, device, t_loader
            )
            
            loss.backward()         # backprops sum of loss tensors, frozen components will have no grad_fn
            optimizer.step()
            c_lr = scheduler.get_lr()

            if i%1000 == 0:             # log images, seg ground truths, predictions
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict["seg"].data.cpu().numpy()
                softmax_gt_array = None
                if "softmax" in label_dict:
                    softmax_gt_array = label_dict["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(t_loader, images, gt_array, pred_array, i, name = 'train', softmax_gt = softmax_gt_array)

                if save_dir is not None:
                    image_array = images.data.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(t_loader, image_array, gt_array, pred_array, i, save_dir, name = 'train', softmax_gt = softmax_gt_array)

            time_meter.update(time.time() - start_ts)
            loss_all += loss.item()
            loss_n += 1
            
            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}  lr={:.6f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss_all / loss_n,
                    time_meter.avg / cfg["training"]["batch_size"],
                    c_lr[0],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"][
                "train_iters"
            ]:
                torch.cuda.empty_cache()
                model.eval() # set batchnorm and dropouts to work in eval mode
                loss_all = 0
                loss_n = 0
                with torch.no_grad(): # Deactivate torch autograd engine, less memusage
                    for i_val, (images_val, label_dict_val, _) in tqdm(enumerate(valloader)):
                        
                        images_val = images_val.to(device)
                        output_dict = model(images_val)
                        
                        val_loss = compute_loss(
                            loss_dict, images_val, label_dict_val, output_dict, device, v_loader
                        )
                        val_loss_meter.update(val_loss.item())

                        for name, metrics in running_metrics_val.items():
                            gt_array = label_dict_val[name].data.cpu().numpy()
                            if name+'_loss' in cfg['training'] and cfg['training'][name+'_loss']['name'] == 'l1':  # for binary classification
                                pred_array = output_dict[name].data.cpu().numpy()
                                pred_array = np.sign(pred_array)
                                pred_array[pred_array == -1] = 0
                                gt_array[gt_array == -1] = 0
                            else:
                                pred_array = output_dict[name].data.max(1)[1].cpu().numpy()

                            metrics.update(gt_array, pred_array)

                softmax_gt_array = None # log validation images
                pred_array = output_dict["seg"].data.max(1)[1].cpu().numpy()
                gt_array = label_dict_val["seg"].data.cpu().numpy()
                if "softmax" in label_dict_val:
                    softmax_gt_array = label_dict_val["softmax"].data.max(1)[1].cpu().numpy()
                write_images_to_board(v_loader, images_val, gt_array, pred_array, i, 'validation', softmax_gt = softmax_gt_array)
                if save_dir is not None:
                    images_val = images_val.cpu().numpy().transpose(0, 2, 3, 1)
                    write_images_to_dir(v_loader, images_val, gt_array, pred_array, i, save_dir, name='validation', softmax_gt = softmax_gt_array)

                logger.info("Iter %d Val Loss: %.4f" % (i + 1, val_loss_meter.avg))
                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)

                for name, metrics in running_metrics_val.items():
                    
                    overall, classwise = metrics.get_scores()
                    
                    for k, v in overall.items():
                        logger.info("{}_{}: {}".format(name, k, v))
                        writer.add_scalar("val_metrics/{}_{}".format(name, k), v, i + 1)

                        if k == cfg["training"]["save_metric"]:
                            curr_performance = v

                    for metric_name, metric in classwise.items():
                        for k, v in metric.items():
                            logger.info("{}_{}_{}: {}".format(name, metric_name, k, v))
                            writer.add_scalar("val_metrics/{}_{}_{}".format(name, metric_name, k), v, i + 1)

                    metrics.reset()
                
                state = {
                      "epoch": i + 1,
                      "model_state": model.state_dict(),
                      "optimizer_state": optimizer.state_dict(),
                      "scheduler_state": scheduler.state_dict(),
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_checkpoint.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                )
                torch.save(state, save_path)

                if curr_performance >= best_iou:
                    best_iou = curr_performance
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"], cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)
                torch.cuda.empty_cache()

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Ejemplo n.º 25
0
def train(cfg, writer, logger_old, args):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    if isinstance(cfg['training']['loss']['superpixels'], int):
        use_superpixels = True
        cfg['data']['train_split'] = 'train_super'
        cfg['data']['val_split'] = 'val_super'
        setup_superpixels(cfg['training']['loss']['superpixels'])
    elif cfg['training']['loss']['superpixels'] is not None:
        raise Exception(
            "cfg['training']['loss']['superpixels'] is of the wrong type")
    else:
        use_superpixels = False

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           superpixels=cfg['training']['loss']['superpixels'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        superpixels=cfg['training']['loss']['superpixels'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)
    running_metrics_train = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg['model'], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger_old.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger_old.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger_old.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger_old.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger_old.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    train_loss_meter = averageMeter()
    time_meter = averageMeter()

    train_len = t_loader.train_len
    val_static = 0
    best_iou = -100.0
    i = start_iter
    j = 0
    flag = True

    # Prepare logging
    xp_name = cfg['model']['arch'] + '_' + \
        cfg['training']['loss']['name'] + '_' + args.name
    xp = logger.Experiment(xp_name,
                           use_visdom=True,
                           visdom_opts={
                               'server': 'http://localhost',
                               'port': 8098
                           },
                           time_indexing=False,
                           xlabel='Epoch')
    # log the hyperparameters of the experiment
    xp.log_config(flatten(cfg))
    # create parent metric for training metrics (easier interface)
    xp.ParentWrapper(tag='train',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    xp.ParentWrapper(tag='val',
                     name='parent',
                     children=(xp.AvgMetric(name="loss"),
                               xp.AvgMetric(name='acc'),
                               xp.AvgMetric(name='acccls'),
                               xp.AvgMetric(name='fwavacc'),
                               xp.AvgMetric(name='meaniu')))
    best_loss = xp.BestMetric(tag='val-best', name='loss', mode='min')
    best_acc = xp.BestMetric(tag='val-best', name='acc')
    best_acccls = xp.BestMetric(tag='val-best', name='acccls')
    best_fwavacc = xp.BestMetric(tag='val-best', name='fwavacc')
    best_meaniu = xp.BestMetric(tag='val-best', name='meaniu')

    xp.plotter.set_win_opts(name="loss", opts={'title': 'Loss'})
    xp.plotter.set_win_opts(name="acc", opts={'title': 'Micro-Average'})
    xp.plotter.set_win_opts(name="acccls", opts={'title': 'Macro-Average'})
    xp.plotter.set_win_opts(name="fwavacc", opts={'title': 'FreqW Accuracy'})
    xp.plotter.set_win_opts(name="meaniu", opts={'title': 'Mean IoU'})

    it_per_step = cfg['training']['acc_batch_size']
    eff_batch_size = cfg['training']['batch_size'] * it_per_step
    while i <= train_len * (cfg['training']['epochs']) and flag:
        for (images, labels, labels_s, masks) in trainloader:
            i += 1
            j += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            labels_s = labels_s.to(device)
            masks = masks.to(device)

            outputs = model(images)
            if use_superpixels:
                outputs_s, labels_s, sizes = convert_to_superpixels(
                    outputs, labels_s, masks)
                loss = loss_fn(input=outputs_s, target=labels_s, size=sizes)
                outputs = convert_to_pixels(outputs_s, outputs, masks)
            else:
                loss = loss_fn(input=outputs, target=labels)

            # accumulate train metrics during train
            pred = outputs.data.max(1)[1].cpu().numpy()
            gt = labels.data.cpu().numpy()
            running_metrics_train.update(gt, pred)
            train_loss_meter.update(loss.item())

            if args.evaluate:
                decoded = t_loader.decode_segmap(np.squeeze(pred, axis=0))
                misc.imsave("./{}.png".format(i), decoded)
                image_save = np.transpose(
                    np.squeeze(images.data.cpu().numpy(), axis=0), (1, 2, 0))
                misc.imsave("./{}.jpg".format(i), image_save)

            # accumulate gradients based on the accumulation batch size
            if i % it_per_step == 1 or it_per_step == 1:
                optimizer.zero_grad()

            grad_rescaling = torch.tensor(1. / it_per_step).type_as(loss)
            loss.backward(grad_rescaling)
            if (i + 1) % it_per_step == 1 or it_per_step == 1:
                optimizer.step()
                optimizer.zero_grad()

            time_meter.update(time.time() - start_ts)
            # training logs
            if (j + 1) % (cfg['training']['print_interval'] *
                          it_per_step) == 0:
                fmt_str = "Epoch [{}/{}] Iter [{}/{:d}] Loss: {:.4f}  Time/Image: {:.4f}"
                total_iter = int(train_len / eff_batch_size)
                total_epoch = int(cfg['training']['epochs'])
                current_epoch = ceil((i + 1) / train_len)
                current_iter = int((j + 1) / it_per_step)
                print_str = fmt_str.format(
                    current_epoch, total_epoch, current_iter, total_iter,
                    loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger_old.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()
            # end of epoch evaluation
            if (i + 1) % train_len == 0 or \
               (i + 1) == train_len * (cfg['training']['epochs']):
                optimizer.step()
                optimizer.zero_grad()
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, labels_val_s,
                                masks_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        labels_val_s = labels_val_s.to(device)
                        masks_val = masks_val.to(device)

                        outputs = model(images_val)
                        if use_superpixels:
                            outputs_s, labels_val_s, sizes_val = convert_to_superpixels(
                                outputs, labels_val_s, masks_val)
                            val_loss = loss_fn(input=outputs_s,
                                               target=labels_val_s,
                                               size=sizes_val)
                            outputs = convert_to_pixels(
                                outputs_s, outputs, masks_val)
                        else:
                            val_loss = loss_fn(input=outputs,
                                               target=labels_val)
                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                writer.add_scalar('loss/train_loss', train_loss_meter.avg,
                                  i + 1)
                logger_old.info("Epoch %d Val Loss: %.4f" % (int(
                    (i + 1) / train_len), val_loss_meter.avg))
                logger_old.info("Epoch %d Train Loss: %.4f" % (int(
                    (i + 1) / train_len), train_loss_meter.avg))

                score, class_iou = running_metrics_train.get_scores()
                print("Training metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('train_metrics/cls_{}'.format(k), v,
                                      i + 1)

                xp.Parent_Train.update(loss=train_loss_meter.avg,
                                       acc=score['Overall Acc: \t'],
                                       acccls=score['Mean Acc : \t'],
                                       fwavacc=score['FreqW Acc : \t'],
                                       meaniu=score['Mean IoU : \t'])

                score, class_iou = running_metrics_val.get_scores()
                print("Validation metrics:")
                for k, v in score.items():
                    print(k, v)
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger_old.info('{}: {}'.format(k, v))
                    writer.add_scalar('val_metrics/cls_{}'.format(k), v, i + 1)

                xp.Parent_Val.update(loss=val_loss_meter.avg,
                                     acc=score['Overall Acc: \t'],
                                     acccls=score['Mean Acc : \t'],
                                     fwavacc=score['FreqW Acc : \t'],
                                     meaniu=score['Mean IoU : \t'])

                xp.Parent_Val.log_and_reset()
                xp.Parent_Train.log_and_reset()
                best_loss.update(xp.loss_val).log()
                best_acc.update(xp.acc_val).log()
                best_acccls.update(xp.acccls_val).log()
                best_fwavacc.update(xp.fwavacc_val).log()
                best_meaniu.update(xp.meaniu_val).log()

                visdir = os.path.join('runs', cfg['training']['loss']['name'],
                                      args.name, 'plots.json')
                xp.to_json(visdir)

                val_loss_meter.reset()
                train_loss_meter.reset()
                running_metrics_val.reset()
                running_metrics_train.reset()
                j = 0

                if score["Mean IoU : \t"] >= best_iou:
                    val_static = 0
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)
                else:
                    val_static += 1

            if (i + 1) == train_len * (
                    cfg['training']['epochs']) or val_static == 10:
                flag = False
                break
    return best_iou
Ejemplo n.º 26
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
        n_classes=20,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])
    # -----------------------------------------------------------------
    # Setup Metrics (substract one class)
    running_metrics_val = runningScore(n_classes - 1)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()

    # get loss_seg meter and also loss_dep meter

    loss_seg_meter = averageMeter()
    loss_dep_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, masks, depths) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            depths = depths.to(device)

            #print(images.shape)
            optimizer.zero_grad()
            outputs = model(images)
            #print('depths size: ', depths.size())
            #print('output shape: ', outputs.shape)

            loss_seg = loss_fn(input=outputs[:, :-1, :, :], target=labels)

            # -----------------------------------------------------------------
            # add depth loss

            # -----------------------------------------------------------------
            # MSE loss
            # loss_dep = F.mse_loss(input=outputs[:, -1,:,:], target=depths, reduction='mean')

            # -----------------------------------------------------------------
            # Berhu loss
            loss_dep = berhu_loss_function(prediction=outputs[:, -1, :, :],
                                           target=depths)
            #loss_dep = loss_dep.type(torch.cuda.ByteTensor)
            masks = masks.type(torch.cuda.ByteTensor)
            loss_dep = torch.sum(loss_dep[masks]) / torch.sum(masks)
            print('loss depth', loss_dep)
            loss = loss_dep + loss_seg
            # -----------------------------------------------------------------

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  loss_seg: {:.4f}  loss_dep: {:.4f}  overall loss: {:.4f}   Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg["training"]["train_iters"], loss_seg.item(),
                    loss_dep.item(), loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val, masks_val,
                                depths_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        print('images_val shape', images_val.size())
                        # add depth to device
                        depths_val = depths_val.to(device)

                        outputs = model(images_val)
                        #depths_val = depths_val.data.resize_(depths_val.size(0), outputs.size(2), outputs.size(3))

                        # -----------------------------------------------------------------
                        # loss function for segmentation
                        print('output shape', outputs.size())
                        val_loss_seg = loss_fn(input=outputs[:, :-1, :, :],
                                               target=labels_val)

                        # -----------------------------------------------------------------
                        # MSE loss
                        # val_loss_dep = F.mse_loss(input=outputs[:, -1, :, :], target=depths_val, reduction='mean')

                        # -----------------------------------------------------------------
                        # berhu loss function
                        val_loss_dep = berhu_loss_function(
                            prediction=outputs[:, -1, :, :], target=depths_val)
                        val_loss_dep = val_loss_dep.type(torch.cuda.ByteTensor)
                        masks_val = masks_val.type(torch.cuda.ByteTensor)
                        val_loss_dep = torch.sum(
                            val_loss_dep[masks_val]) / torch.sum(masks_val)
                        val_loss = loss_dep + loss_seg
                        # -----------------------------------------------------------------

                        prediction = outputs[:, :-1, :, :]
                        prediction = prediction.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        # adapt metrics to seg and dep
                        running_metrics_val.update(gt, prediction)
                        loss_seg_meter.update(val_loss_seg.item())
                        loss_dep_meter.update(val_loss_dep.item())

                        # -----------------------------------------------------------------
                        # get rid of val_loss_meter
                        # val_loss_meter.update(val_loss.item())
                        # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                        # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))
                        # -----------------------------------------------------------------

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                print("Segmentation loss is {}".format(loss_seg_meter.avg))
                logger.info("Segmentation loss is {}".format(
                    loss_seg_meter.avg))
                #writer.add_scalar("Segmentation loss is {}".format(loss_seg_meter.avg), i + 1)

                print("Depth loss is {}".format(loss_dep_meter.avg))
                logger.info("Depth loss is {}".format(loss_dep_meter.avg))
                #writer.add_scalar("Depth loss is {}".format(loss_dep_meter.avg), i + 1)

                val_loss_meter.reset()
                loss_seg_meter.reset()
                loss_dep_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

                    # insert print function to see if the losses are correct

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
Ejemplo n.º 27
0
def train(cfg, writer, logger, run_id):

    # Setup seeds
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False

    torch.backends.cudnn.benchmark = True

    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'])
    data_path = cfg['data']['path']

    logger.info("Using dataset: {}".format(data_path))

    t_loader = data_loader(data_path,
                           is_transform=True,
                           split=cfg['data']['train_split'],
                           img_size=(cfg['data']['img_rows'],
                                     cfg['data']['img_cols']),
                           augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    # model = get_model(cfg['model'], n_classes).to(device)
    model = get_model(cfg['model'], n_classes)
    logger.info("Using Model: {}".format(cfg['model']['arch']))

    # model=apex.parallel.convert_syncbn_model(model)
    model = model.to(device)

    # a=range(torch.cuda.device_count())
    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # model = torch.nn.DataParallel(model, device_ids=[0,1])
    # model = encoding.parallel.DataParallelModel(model, device_ids=[0, 1])

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    # optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.0)

    loss_fn = get_loss_function(cfg)
    # loss_fn== encoding.parallel.DataParallelCriterion(loss_fn, device_ids=[0, 1])
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()
    time_meter_val = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    train_data_len = t_loader.__len__()
    batch_size = cfg['training']['batch_size']
    epoch = cfg['training']['train_epoch']
    train_iter = int(np.ceil(train_data_len / batch_size) * epoch)

    val_rlt_f1 = []
    val_rlt_IoU = []
    best_f1_till_now = 0
    best_IoU_till_now = 0

    while i <= train_iter and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            # optimizer.backward(loss)

            optimizer.step()

            time_meter.update(time.time() - start_ts)

            ### add by Sprit
            time_meter_val.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, train_iter, loss.item(),
                    time_meter.avg / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or \
               (i + 1) == train_iter:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        # val_loss_meter.update(val_loss.item())

                # writer.add_scalar('loss/val_loss', val_loss_meter.avg, i+1)
                # logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()

                for k, v in score.items():
                    print(k, v)
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/{}'.format(k), v, i+1)

                for k, v in class_iou.items():
                    logger.info('{}: {}'.format(k, v))
                    # writer.add_scalar('val_metrics/cls_{}'.format(k), v, i+1)

                # val_loss_meter.reset()
                running_metrics_val.reset()

                ### add by Sprit
                avg_f1 = score["Mean F1 : \t"]
                avg_IoU = score["Mean IoU : \t"]
                val_rlt_f1.append(avg_f1)
                val_rlt_IoU.append(score["Mean IoU : \t"])

                if avg_f1 >= best_f1_till_now:
                    best_f1_till_now = avg_f1
                    correspond_iou = score["Mean IoU : \t"]
                    best_epoch_till_now = i + 1
                print("\nBest F1 till now = ", best_f1_till_now)
                print("Correspond IoU= ", correspond_iou)
                print("Best F1 Iter till now= ", best_epoch_till_now)

                if avg_IoU >= best_IoU_till_now:
                    best_IoU_till_now = avg_IoU
                    correspond_f1 = score["Mean F1 : \t"]
                    correspond_acc = score["Overall Acc: \t"]
                    best_epoch_till_now = i + 1
                print("Best IoU till now = ", best_IoU_till_now)
                print("Correspond F1= ", correspond_f1)
                print("Correspond OA= ", correspond_acc)
                print("Best IoU Iter till now= ", best_epoch_till_now)

                ### add by Sprit
                iter_time = time_meter_val.avg
                time_meter_val.reset()
                remain_time = iter_time * (train_iter - i)
                m, s = divmod(remain_time, 60)
                h, m = divmod(m, 60)
                if s != 0:
                    train_time = "Remain training time = %d hours %d minutes %d seconds \n" % (
                        h, m, s)
                else:
                    train_time = "Remain training time : Training completed.\n"
                print(train_time)

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == train_iter:
                flag = False
                break
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_f1,
                  cfg['training']['val_interval'])
    my_pt.csv_out(run_id, data_path, cfg['model']['arch'], epoch, val_rlt_IoU,
                  cfg['training']['val_interval'])
Ejemplo n.º 28
0
def train(cfg, writer, logger):

    # Setup seeds
    # torch.manual_seed(cfg.get("seed", 1337))
    # torch.cuda.manual_seed(cfg.get("seed", 1337))
    # np.random.seed(cfg.get("seed", 1337))
    # random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])

            if not args.load_weight_only:
                model = DataParallel_withLoss(model, loss_fn)
                model.load_state_dict(checkpoint["model_state"])
                if not args.not_load_optimizer:
                    optimizer.load_state_dict(checkpoint["optimizer_state"])

                # !!!
                # checkpoint["scheduler_state"]['last_epoch'] = -1
                # scheduler.load_state_dict(checkpoint["scheduler_state"])
                # start_iter = checkpoint["epoch"]
                start_iter = 0
                # import ipdb
                # ipdb.set_trace()
                logger.info("Loaded checkpoint '{}' (iter {})".format(
                    cfg["training"]["resume"], checkpoint["epoch"]))
            else:
                pretrained_dict = convert_state_dict(checkpoint["model_state"])
                model_dict = model.state_dict()
                # 1. filter out unnecessary keys
                pretrained_dict = {
                    k: v
                    for k, v in pretrained_dict.items() if k in model_dict
                }
                # 2. overwrite entries in the existing state dict
                model_dict.update(pretrained_dict)
                # 3. load the new state dict
                model.load_state_dict(model_dict)
                model = DataParallel_withLoss(model, loss_fn)
                # import ipdb
                # ipdb.set_trace()
                # start_iter = -1
                logger.info(
                    "Loaded checkpoint '{}' (iter unknown, from pretrained icnet model)"
                    .format(cfg["training"]["resume"]))

        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels, inst_labels) in trainloader:

            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)
            inst_labels = inst_labels.to(device)
            optimizer.zero_grad()

            loss, _, aux_info = model(labels,
                                      inst_labels,
                                      images,
                                      return_aux_info=True)
            loss = loss.sum()
            loss_sem = aux_info[0].sum()
            loss_inst = aux_info[1].sum()

            # loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f} (Sem:{:.4f}/Inst:{:.4f})  LR:{:.5f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    loss_sem.item(),
                    loss_inst.item(),
                    scheduler.get_lr()[0],
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                # print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:

                model.eval()

                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                inst_labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)
                        inst_labels_val = inst_labels_val.to(device)
                        # outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)
                        val_loss, (outputs, outputs_inst) = model(
                            labels_val, inst_labels_val, images_val)
                        val_loss = val_loss.sum()

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) % cfg["training"]["save_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                state = {
                    "epoch": i + 1,
                    "model_state": model.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "scheduler_state": scheduler.state_dict(),
                    "best_iou": best_iou,
                }
                save_path = os.path.join(
                    writer.file_writer.get_logdir(),
                    "{}_{}_{:05d}_model.pkl".format(cfg["model"]["arch"],
                                                    cfg["data"]["dataset"],
                                                    i + 1),
                )
                torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
            i += 1
Ejemplo n.º 29
0
def train(cfg, writer, logger):

    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg["training"].get("augmentations", None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        sbd_path=cfg["data"]["sbd_path"],
        is_transform=True,
        split=cfg["data"]["train_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
        augmentations=data_aug,
    )

    v_loader = data_loader(
        data_path,
        sbd_path=cfg["data"]["sbd_path"],
        is_transform=True,
        split=cfg["data"]["val_split"],
        img_size=(cfg["data"]["img_rows"], cfg["data"]["img_cols"]),
    )

    n_classes = t_loader.n_classes
    trainloader = data.DataLoader(
        t_loader,
        batch_size=cfg["training"]["batch_size"],
        num_workers=cfg["training"]["n_workers"],
        shuffle=True,
    )

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg["training"]["batch_size"],
                                num_workers=cfg["training"]["n_workers"])

    # Setup Metrics
    running_metrics_val = runningScore(n_classes)

    # Setup Model
    model = get_model(cfg["model"], n_classes).to(device)

    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg["training"]["optimizer"].items() if k != "name"
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg["training"]["lr_schedule"])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg["training"]["resume"] is not None:
        if os.path.isfile(cfg["training"]["resume"]):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg["training"]["resume"]))
            checkpoint = torch.load(cfg["training"]["resume"])
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg["training"]["resume"], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg["training"]["resume"]))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    i = start_iter
    flag = True

    while i <= cfg["training"]["train_iters"] and flag:
        for (images, labels) in trainloader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)

            loss = loss_fn(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg["training"]["print_interval"] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    cfg["training"]["train_iters"],
                    loss.item(),
                    time_meter.avg / cfg["training"]["batch_size"],
                )

                print(print_str)
                logger.info(print_str)
                writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg["training"]["val_interval"] == 0 or (
                    i + 1) == cfg["training"]["train_iters"]:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(images_val)
                        val_loss = loss_fn(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                logger.info("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                score, class_iou = running_metrics_val.get_scores()
                for k, v in score.items():
                    print(k, v)
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/{}".format(k), v, i + 1)

                for k, v in class_iou.items():
                    logger.info("{}: {}".format(k, v))
                    writer.add_scalar("val_metrics/cls_{}".format(k), v, i + 1)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if score["Mean IoU : \t"] >= best_iou:
                    best_iou = score["Mean IoU : \t"]
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best_iou": best_iou,
                    }
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg["model"]["arch"],
                                                      cfg["data"]["dataset"]),
                    )
                    torch.save(state, save_path)

            if (i + 1) == cfg["training"]["train_iters"]:
                flag = False
                break
def train(cfg, writer, logger):

    # Setup seeds for reproducing
    torch.manual_seed(cfg.get('seed', 1337))
    torch.cuda.manual_seed(cfg.get('seed', 1337))
    np.random.seed(cfg.get('seed', 1337))
    random.seed(cfg.get('seed', 1337))

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Setup Augmentations
    augmentations = cfg['training'].get('augmentations', None)
    data_aug = get_composed_augmentations(augmentations)

    # Setup Dataloader
    data_loader = get_loader(cfg['data']['dataset'], cfg['task'])
    data_path = cfg['data']['path']

    t_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['train_split'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
        img_norm=cfg['data']['img_norm'],
        # version = cfg['data']['version'],
        augmentations=data_aug)

    v_loader = data_loader(
        data_path,
        is_transform=True,
        split=cfg['data']['val_split'],
        img_norm=cfg['data']['img_norm'],
        # version=cfg['data']['version'],
        img_size=(cfg['data']['img_rows'], cfg['data']['img_cols']),
    )

    trainloader = data.DataLoader(t_loader,
                                  batch_size=cfg['training']['batch_size'],
                                  num_workers=cfg['training']['n_workers'],
                                  shuffle=True)

    valloader = data.DataLoader(v_loader,
                                batch_size=cfg['training']['batch_size'],
                                num_workers=cfg['training']['n_workers'])

    # Setup Metrics
    if cfg['task'] == "seg":
        n_classes = t_loader.n_classes
        running_metrics_val = runningScoreSeg(n_classes)
    elif cfg['task'] == "depth":
        n_classes = 0
        running_metrics_val = runningScoreDepth()
    else:
        raise NotImplementedError('Task {} not implemented'.format(
            cfg['task']))

    # Setup Model
    model = get_model(cfg['model'], cfg['task'], n_classes).to(device)

    # Setup optimizer, lr_scheduler and loss function
    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {
        k: v
        for k, v in cfg['training']['optimizer'].items() if k != 'name'
    }

    optimizer = optimizer_cls(model.parameters(), **optimizer_params)
    logger.info("Using optimizer {}".format(optimizer))

    scheduler = get_scheduler(optimizer, cfg['training']['lr_schedule'])

    loss_fn = get_loss_function(cfg)
    logger.info("Using loss {}".format(loss_fn))

    start_iter = 0
    if cfg['training']['resume'] is not None:
        if os.path.isfile(cfg['training']['resume']):
            logger.info(
                "Loading model and optimizer from checkpoint '{}'".format(
                    cfg['training']['resume']))
            checkpoint = torch.load(cfg['training']['resume'])
            # checkpoint = torch.load(cfg['training']['resume'], map_location=lambda storage, loc: storage)  # load model trained on gpu on cpu
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            # start_iter = checkpoint["epoch"]
            logger.info("Loaded checkpoint '{}' (iter {})".format(
                cfg['training']['resume'], checkpoint["epoch"]))
        else:
            logger.info("No checkpoint found at '{}'".format(
                cfg['training']['resume']))

    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    best_iou = -100.0
    best_rel = 100.0
    # i = start_iter
    i = 0
    flag = True

    while i <= cfg['training']['train_iters'] and flag:
        print(len(trainloader))
        for (images, labels, img_path) in trainloader:
            start_ts = time.time()  # return current time stamp
            scheduler.step()
            model.train()  # set model to training mode
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()  #clear earlier gradients
            outputs = model(images)
            if cfg['model']['arch'] == "dispnet" and cfg['task'] == "depth":
                outputs = 1 / outputs

            loss = loss_fn(input=outputs, target=labels)  # compute loss
            loss.backward()  # backpropagation loss
            optimizer.step()  # optimizer parameter update

            time_meter.update(time.time() - start_ts)

            if (i + 1) % cfg['training']['print_interval'] == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1, cfg['training']['train_iters'], loss.item(),
                    time_meter.val / cfg['training']['batch_size'])

                print(print_str)
                logger.info(print_str)
                writer.add_scalar('loss/train_loss', loss.item(), i + 1)
                time_meter.reset()

            if (i + 1) % cfg['training']['val_interval'] == 0 or (
                    i + 1) == cfg['training']['train_iters']:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val, labels_val,
                                img_path_val) in tqdm(enumerate(valloader)):
                        images_val = images_val.to(device)
                        labels_val = labels_val.to(device)

                        outputs = model(
                            images_val
                        )  # [batch_size, n_classes, height, width]
                        if cfg['model']['arch'] == "dispnet" and cfg[
                                'task'] == "depth":
                            outputs = 1 / outputs

                        val_loss = loss_fn(input=outputs, target=labels_val
                                           )  # mean pixelwise loss in a batch

                        if cfg['task'] == "seg":
                            pred = outputs.data.max(1)[1].cpu().numpy(
                            )  # [batch_size, height, width]
                            gt = labels_val.data.cpu().numpy(
                            )  # [batch_size, height, width]
                        elif cfg['task'] == "depth":
                            pred = outputs.squeeze(1).data.cpu().numpy()
                            gt = labels_val.data.squeeze(1).cpu().numpy()
                        else:
                            raise NotImplementedError(
                                'Task {} not implemented'.format(cfg['task']))

                        running_metrics_val.update(gt=gt, pred=pred)
                        val_loss_meter.update(val_loss.item())

                writer.add_scalar('loss/val_loss', val_loss_meter.avg, i + 1)
                logger.info("Iter %d val_loss: %.4f" %
                            (i + 1, val_loss_meter.avg))
                print("Iter %d val_loss: %.4f" % (i + 1, val_loss_meter.avg))

                # output scores
                if cfg['task'] == "seg":
                    score, class_iou = running_metrics_val.get_scores()
                    for k, v in score.items():
                        print(k, v)
                        sys.stdout.flush()
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                    for k, v in class_iou.items():
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/cls_{}'.format(k), v,
                                          i + 1)

                elif cfg['task'] == "depth":
                    val_result = running_metrics_val.get_scores()
                    for k, v in val_result.items():
                        print(k, v)
                        logger.info('{}: {}'.format(k, v))
                        writer.add_scalar('val_metrics/{}'.format(k), v, i + 1)
                else:
                    raise NotImplementedError('Task {} not implemented'.format(
                        cfg['task']))

                val_loss_meter.reset()
                running_metrics_val.reset()

                save_model = False
                if cfg['task'] == "seg":
                    if score["Mean IoU : \t"] >= best_iou:
                        best_iou = score["Mean IoU : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_iou": best_iou,
                        }

                if cfg['task'] == "depth":
                    if val_result["abs rel : \t"] <= best_rel:
                        best_rel = val_result["abs rel : \t"]
                        save_model = True
                        state = {
                            "epoch": i + 1,
                            "model_state": model.state_dict(),
                            "optimizer_state": optimizer.state_dict(),
                            "scheduler_state": scheduler.state_dict(),
                            "best_rel": best_rel,
                        }

                if save_model:
                    save_path = os.path.join(
                        writer.file_writer.get_logdir(),
                        "{}_{}_best_model.pkl".format(cfg['model']['arch'],
                                                      cfg['data']['dataset']))
                    torch.save(state, save_path)

            if (i + 1) == cfg['training']['train_iters']:
                flag = False
                break
            i += 1