def main(args): log = Log(os.path.join(args.out, "log"), out=sys.stderr) tiles = set(tiles_from_csv(args.cover)) extension = "" for tile in tqdm(tiles, desc="Subset", unit="tiles", ascii=True): paths = glob(os.path.join(args.dir, str(tile.z), str(tile.x), "{}.*".format(tile.y))) if len(paths) != 1: log.log("Warning: {} skipped.".format(tile)) continue src = paths[0] try: extension = os.path.splitext(src)[1][1:] dst = os.path.join(args.out, str(tile.z), str(tile.x), "{}.{}".format(tile.y, extension)) if not os.path.isdir(os.path.join(args.out, str(tile.z), str(tile.x))): os.makedirs(os.path.join(args.out, str(tile.z), str(tile.x)), exist_ok=True) if args.move: assert os.path.isfile(src) shutil.move(src, dst) else: shutil.copyfile(src, dst) except: sys.exit("Error: Unable to process {}".format(tile)) if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template web_ui(args.out, args.web_ui, tiles, tiles, extension, template)
def main(args): model = load_config(args.model) dataset = load_config(args.dataset) device = torch.device("cuda" if model["common"]["cuda"] else "cpu") if model["common"]["cuda"] and not torch.cuda.is_available(): sys.exit("Error: CUDA requested but not available") os.makedirs(model["common"]["checkpoint"], exist_ok=True) num_classes = len(dataset["common"]["classes"]) net = UNet(num_classes) net = DataParallel(net) net = net.to(device) if model["common"]["cuda"]: torch.backends.cudnn.benchmark = True try: weight = torch.Tensor(dataset["weights"]["values"]) except KeyError: if model["opt"]["loss"] in ("CrossEntropy", "mIoU", "Focal"): sys.exit( "Error: The loss function used, need dataset weights values") optimizer = Adam(net.parameters(), lr=model["opt"]["lr"], weight_decay=model["opt"]["decay"]) resume = 0 if args.checkpoint: def map_location(storage, _): return storage.cuda() if model["common"]["cuda"] else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) net.load_state_dict(chkpt["state_dict"]) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] if model["opt"]["loss"] == "CrossEntropy": criterion = CrossEntropyLoss2d(weight=weight).to(device) elif model["opt"]["loss"] == "mIoU": criterion = mIoULoss2d(weight=weight).to(device) elif model["opt"]["loss"] == "Focal": criterion = FocalLoss2d(weight=weight).to(device) elif model["opt"]["loss"] == "Lovasz": criterion = LovaszLoss2d().to(device) else: sys.exit("Error: Unknown [opt][loss] value !") train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers) num_epochs = model["opt"]["epochs"] if resume >= num_epochs: sys.exit( "Error: Epoch {} set in {} already reached by the checkpoint provided" .format(num_epochs, args.model)) history = collections.defaultdict(list) log = Log(os.path.join(model["common"]["checkpoint"], "log")) log.log("--- Hyper Parameters on Dataset: {} ---".format( dataset["common"]["dataset"])) log.log("Batch Size:\t {}".format(model["common"]["batch_size"])) log.log("Image Size:\t {}".format(model["common"]["image_size"])) log.log("Learning Rate:\t {}".format(model["opt"]["lr"])) log.log("Weight Decay:\t {}".format(model["opt"]["decay"])) log.log("Loss function:\t {}".format(model["opt"]["loss"])) if "weight" in locals(): log.log("Weights :\t {}".format(dataset["weights"]["values"])) log.log("---") for epoch in range(resume, num_epochs): log.log("Epoch: {}/{}".format(epoch + 1, num_epochs)) train_hist = train(train_loader, num_classes, device, net, optimizer, criterion) log.log( "Train loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}". format( train_hist["loss"], train_hist["miou"], dataset["common"]["classes"][1], train_hist["fg_iou"], train_hist["mcc"], )) for k, v in train_hist.items(): history["train " + k].append(v) val_hist = validate(val_loader, num_classes, device, net, criterion) log.log( "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}". format(val_hist["loss"], val_hist["miou"], dataset["common"]["classes"][1], val_hist["fg_iou"], val_hist["mcc"])) for k, v in val_hist.items(): history["val " + k].append(v) visual = "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs) plot(os.path.join(model["common"]["checkpoint"], visual), history) checkpoint = "checkpoint-{:05d}-of-{:05d}.pth".format( epoch + 1, num_epochs) states = { "epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict() } torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint))
def main(args): config = load_config(args.config) print(config) lr = args.lr if args.lr else config["model"]["lr"] dataset_path = args.dataset if args.dataset else config["dataset"]["path"] num_epochs = args.epochs if args.epochs else config["model"]["epochs"] log = Log(os.path.join(args.out, "log")) if torch.cuda.is_available(): device = torch.device("cuda") torch.backends.cudnn.benchmark = True log.log("RoboSat - training on {} GPUs, with {} workers".format(torch.cuda.device_count(), args.workers)) else: device = torch.device("cpu") print(args.workers) log.log("RoboSat - training on CPU, with {} workers". format(args.workers)) num_classes = len(config["classes"]["titles"]) num_channels = 0 for channel in config["channels"]: num_channels += len(channel["bands"]) pretrained = config["model"]["pretrained"] net = DataParallel(UNet(num_classes, num_channels=num_channels, pretrained=pretrained)).to(device) if config["model"]["loss"] in ("CrossEntropy", "mIoU", "Focal"): try: weight = torch.Tensor(config["classes"]["weights"]) except KeyError: sys.exit("Error: The loss function used, need dataset weights values") optimizer = Adam(net.parameters(), lr=lr, weight_decay=config["model"]["decay"]) resume = 0 if args.checkpoint: def map_location(storage, _): return storage.cuda() if torch.cuda.is_available() else storage.cpu() # https://github.com/pytorch/pytorch/issues/7178 chkpt = torch.load(args.checkpoint, map_location=map_location) net.load_state_dict(chkpt["state_dict"]) log.log("Using checkpoint: {}".format(args.checkpoint)) if args.resume: optimizer.load_state_dict(chkpt["optimizer"]) resume = chkpt["epoch"] if config["model"]["loss"] == "CrossEntropy": criterion = CrossEntropyLoss2d(weight=weight).to(device) elif config["model"]["loss"] == "mIoU": criterion = mIoULoss2d(weight=weight).to(device) elif config["model"]["loss"] == "Focal": criterion = FocalLoss2d(weight=weight).to(device) elif config["model"]["loss"] == "Lovasz": criterion = LovaszLoss2d().to(device) else: sys.exit("Error: Unknown [model][loss] value !") train_loader, val_loader = get_dataset_loaders(dataset_path, config, args.workers) if resume >= num_epochs: sys.exit("Error: Epoch {} set in {} already reached by the checkpoint provided".format(num_epochs, args.config)) history = collections.defaultdict(list) log.log("") log.log("--- Input tensor from Dataset: {} ---".format(dataset_path)) num_channel = 1 for channel in config["channels"]: for band in channel["bands"]: log.log("Channel {}:\t\t {}[band: {}]".format(num_channel, channel["sub"], band)) num_channel += 1 log.log("") log.log("--- Hyper Parameters ---") log.log("Batch Size:\t\t {}".format(config["model"]["batch_size"])) log.log("Image Size:\t\t {}".format(config["model"]["image_size"])) log.log("Data Augmentation:\t {}".format(config["model"]["data_augmentation"])) log.log("Learning Rate:\t\t {}".format(lr)) log.log("Weight Decay:\t\t {}".format(config["model"]["decay"])) log.log("Loss function:\t\t {}".format(config["model"]["loss"])) log.log("ResNet pre-trained:\t {}".format(config["model"]["pretrained"])) if "weight" in locals(): log.log("Weights :\t\t {}".format(config["dataset"]["weights"])) log.log("") for epoch in range(resume, num_epochs): log.log("---") log.log("Epoch: {}/{}".format(epoch + 1, num_epochs)) train_hist = train(train_loader, num_classes, device, net, optimizer, criterion) log.log( "Train loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format( train_hist["loss"], train_hist["miou"], config["classes"]["titles"][1], train_hist["fg_iou"], train_hist["mcc"], ) ) for k, v in train_hist.items(): history["train " + k].append(v) val_hist = validate(val_loader, num_classes, device, net, criterion) log.log( "Validate loss: {:.4f}, mIoU: {:.3f}, {} IoU: {:.3f}, MCC: {:.3f}".format( val_hist["loss"], val_hist["miou"], config["classes"]["titles"][1], val_hist["fg_iou"], val_hist["mcc"] ) ) for k, v in val_hist.items(): history["val " + k].append(v) visual_path = os.path.join(args.out, "history-{:05d}-of-{:05d}.png".format(epoch + 1, num_epochs)) plot(visual_path, history) if (args.save_intermed): states = {"epoch": epoch + 1, "state_dict": net.state_dict(), "optimizer": optimizer.state_dict()} checkpoint_path = os.path.join(args.out, "checkpoint-{:05d}-of-{:05d}.pth".format(epoch + 1, num_epochs)) torch.save(states, checkpoint_path)
def main(args): tiles = list(tiles_from_csv(args.tiles)) already_dl = 0 dl = 0 with requests.Session() as session: num_workers = args.rate os.makedirs(os.path.join(args.out), exist_ok=True) log = Log(os.path.join(args.out, "log"), out=sys.stderr) log.log("Begin download from {}".format(args.url)) # tqdm has problems with concurrent.futures.ThreadPoolExecutor; explicitly call `.update` # https://github.com/tqdm/tqdm/issues/97 progress = tqdm(total=len(tiles), ascii=True, unit="image") with futures.ThreadPoolExecutor(num_workers) as executor: def worker(tile): tick = time.monotonic() x, y, z = map(str, [tile.x, tile.y, tile.z]) os.makedirs(os.path.join(args.out, z, x), exist_ok=True) path = os.path.join(args.out, z, x, "{}.{}".format(y, args.ext)) if os.path.isfile(path): return tile, None, True if args.type == "XYZ": url = args.url.format(x=tile.x, y=tile.y, z=tile.z) elif args.type == "TMS": tile.y = (2 ** tile.z) - tile.y - 1 url = args.url.format(x=tile.x, y=tile.y, z=tile.z) elif args.type == "WMS": xmin, ymin, xmax, ymax = xy_bounds(tile) url = args.url.format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) res = fetch_image(session, url, args.timeout) if not res: return tile, url, False try: image = Image.open(res) image.save(path, optimize=True) except OSError: return tile, url, False tock = time.monotonic() time_for_req = tock - tick time_per_worker = num_workers / args.rate if time_for_req < time_per_worker: time.sleep(time_per_worker - time_for_req) progress.update() return tile, url, True for tile, url, ok in executor.map(worker, tiles): if url and ok: dl += 1 elif not url and ok: already_dl += 1 else: log.log("Warning:\n {} failed, skipping.\n {}\n".format(tile, url)) if already_dl: log.log("Notice:\n {} tiles were already downloaded previously, and so skipped now.".format(already_dl)) if already_dl + dl == len(tiles): log.log(" Coverage is fully downloaded.") if args.web_ui: template = "leaflet.html" if not args.web_ui_template else args.web_ui_template web_ui(args.out, args.web_ui, tiles, tiles, args.ext, template)