def main():
    """
    Main Function
    """
    rank = args.rank
    cfg.GLOBAL_RANK = rank
    args.gpus = torch.cuda.device_count()
    device = torch.device("cpu")
    loc_dist = True if args.gpus > 1 else False
    loc_rank = rank % args.gpus
    args.gpu = loc_rank
    args.local_rank = loc_rank
    if loc_dist:
        device = "cuda:" + str(loc_rank)
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "19500"
        os.environ["NCCL_SOCKET_IFNAME"] = "ib"
        torch.cuda.set_device(device)
        torch.distributed.init_process_group(backend="nccl",
                                             rank=loc_rank,
                                             world_size=args.gpus)
        # torch.cuda.set_device(device)
    elif args.gpus == 1:
        args.gpus = torch.cuda.device_count()
        device = "cuda:0"
        args.local_rank = 0
        torch.cuda.set_device(device)

    assert args.result_dir is not None, 'need to define result_dir arg'
    logx.initialize(logdir=args.result_dir,
                    tensorboard=True,
                    hparams=vars(args),
                    global_rank=args.global_rank)

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args)
    #     args.ngpu = torch.cuda.device_count()
    #     args.best_record = {'mean_iu': -1, 'epoch': 0}

    train_loader, val_loader, train_obj = datasets.setup_loaders(args)
    criterion, criterion_val = get_loss(args)

    cwd = os.getcwd()
    sz = ht.MPI_WORLD.size
    filename = cwd + "/citys-heat-checkpoint-" + str(sz) + ".pth.tar"
    if args.resume and os.path.isfile(filename):
        checkpoint = torch.load(filename, map_location=torch.device('cpu'))
        args.arch = checkpoint['arch']
        args.start_epoch = int(checkpoint['epoch']) + 1
        args.restore_net = True
        args.restore_optimizer = True
        logx.msg(f"Resuming from: checkpoint={args.resume}, "
                 f"epoch {args.start_epoch}, arch {args.arch}")
    elif args.snapshot:
        if 'ASSETS_PATH' in args.snapshot:
            args.snapshot = args.snapshot.replace('ASSETS_PATH',
                                                  cfg.ASSETS_PATH)
        checkpoint = torch.load(args.snapshot,
                                map_location=torch.device('cpu'))
        args.restore_net = True
        logx.msg(f"Loading weights from: checkpoint={args.snapshot}")

    net = network.get_net(args, criterion)
    net = net.to(device)
    # args.lr = (1. / args.world_size * (5 * (args.world_size - 1) / 6.)) * 0.0125 * args.world_size
    optim, scheduler = get_optimizer(args, net)

    # the scheduler in this code is only run at the end of each epoch
    # todo: make heat an option not this whole file
    # if args.heat:
    dp_optim = ht.optim.DASO(
        local_optimizer=optim,
        total_epochs=args.max_epoch,
        max_global_skips=4,
    )
    #if args.no_cycling:
    dp_optim.disable_cycling(global_skips=args.batch_skip,
                             batches_to_wait=args.gs)
    # this is where the network is wrapped with DDDP (w/apex) or DP
    htnet = ht.nn.DataParallelMultiGPU(net,
                                       comm=ht.MPI_WORLD,
                                       optimizer=dp_optim)

    if args.summary:
        print(str(net))
        from thop import profile
        img = torch.randn(1, 3, 1024, 2048).cuda()
        mask = torch.randn(1, 1, 1024, 2048).cuda()
        macs, params = profile(net, inputs={'images': img, 'gts': mask})
        print0(f'macs {macs} params {params}')
        sys.exit()

    if args.restore_optimizer:
        restore_opt(optim, checkpoint)
        dp_optim.stability.load_dict(checkpoint["skip_stable"])
    if args.restore_net:
        #restore_net(net, checkpoint)
        htnet.load_state_dict(checkpoint["state_dict"])
        #dp_optim.module.load_state_dist(checkpoint["state_dict"])
    # htnet = ht.nn.DataParallelMultiGPU(net, ht.MPI_WORLD, dp_optim)

    if args.init_decoder:
        net.module.init_mods()

    torch.cuda.empty_cache()

    if args.start_epoch != 0:
        # TODO: need a loss value for the restart at a certain epoch...
        scheduler.step(args.start_epoch)

    # There are 4 options for evaluation:
    #  --eval val                           just run validation
    #  --eval val --dump_assets             dump all images and assets
    #  --eval folder                        just dump all basic images
    #  --eval folder --dump_assets          dump all images and assets
    # todo: HeAT fixes -- not urgent --
    if args.eval == 'val':
        if args.dump_topn:
            validate_topn(val_loader, net, criterion_val, optim, 0, args)
        else:
            validate(val_loader,
                     net,
                     criterion=criterion_val,
                     optim=optim,
                     epoch=0,
                     dump_assets=args.dump_assets,
                     dump_all_images=args.dump_all_images,
                     calc_metrics=not args.no_metrics)
        return 0
    elif args.eval == 'folder':
        # Using a folder for evaluation means to not calculate metrics
        validate(val_loader,
                 net,
                 criterion=None,
                 optim=None,
                 epoch=0,
                 calc_metrics=False,
                 dump_assets=args.dump_assets,
                 dump_all_images=True)
        return 0
    elif args.eval is not None:
        raise 'unknown eval option {}'.format(args.eval)

    scaler = amp.GradScaler()
    if dp_optim.comm.rank == 0:
        print("scheduler", args.lr_schedule)
    dp_optim.add_scaler(scaler)

    nodes = str(int(dp_optim.comm.size / torch.cuda.device_count()))
    cwd = os.getcwd()
    fname = cwd + "/" + nodes + "-heat-citys-benchmark"
    if args.resume and rank == 0 and os.path.isfile(fname + ".pkl"):
        with open(fname + ".pkl", "rb") as f:
            out_dict = pickle.load(f)
    else:
        out_dict = {
            "epochs": [],
            nodes + "-avg-batch-time": [],
            nodes + "-total-train-time": [],
            nodes + "-train-loss": [],
            nodes + "-val-loss": [],
            nodes + "-val-iou": [],
            nodes + "-val-time": [],
        }
        print0("Output dict:", fname)

    for epoch in range(args.start_epoch, args.max_epoch):
        # todo: HeAT fixes -- possible conflict between processes
        update_epoch(epoch)

        if args.only_coarse:  # default: false
            train_obj.only_coarse()
            train_obj.build_epoch()
        elif args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.disable_coarse()
                train_obj.build_epoch()
            else:
                train_obj.build_epoch()
        else:
            pass

        ls, bt, btt = train(train_loader, htnet, dp_optim, epoch, scaler)
        dp_optim.epoch_loss_logic(ls, loss_globally_averaged=True)

        # if epoch % args.val_freq == 0:
        vls, iu, vtt = validate(val_loader, htnet, criterion_val, dp_optim,
                                epoch)
        if args.lr_schedule == "plateau":
            if dp_optim.comm.rank == 0:
                print("loss", ls, 'best:',
                      scheduler.best * (1. - scheduler.threshold),
                      scheduler.num_bad_epochs)
            scheduler.step(ls)  # val_loss)
        else:
            scheduler.step()

        if args.rank == 0:
            save_checkpoint({
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": htnet.state_dict(),
                "optimizer": optim.state_dict(),
                "skip_stable": dp_optim.stability.get_dict()
            })

        out_dict["epochs"].append(epoch)
        out_dict[nodes + "-train-loss"].append(ls)
        out_dict[nodes + "-avg-batch-time"].append(bt)
        out_dict[nodes + "-total-train-time"].append(btt)
        out_dict[nodes + "-val-loss"].append(vls)
        out_dict[nodes + "-val-iou"].append(iu)
        out_dict[nodes + "-val-time"].append(vtt)

        if args.rank == 0:
            save_obj(out_dict, fname)

    if args.rank == 0:
        print("\nRESULTS\n")
        import pandas as pd
        df = pd.DataFrame.from_dict(out_dict).set_index("epochs")
        with pd.option_context("display.max_rows", None, "display.max_columns",
                               None):
            # more options can be specified also
            print(df)
        if args.benchmarking:
            try:
                fulldf = pd.read_csv(cwd + "/heat-bench-results.csv")
                fulldf = pd.concat([df, fulldf], axis=1)
            except FileNotFoundError:
                fulldf = df
            fulldf.to_csv(cwd + "/heat-bench-results.csv")
Example #2
0
def main():
    """
    Main Function
    """
    if AutoResume:
        AutoResume.init()

    assert args.result_dir is not None, 'need to define result_dir arg'
    logx.initialize(logdir=args.result_dir,
                    tensorboard=True,
                    hparams=vars(args),
                    global_rank=args.global_rank)

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args)
    train_loader, val_loader, train_obj = \
        datasets.setup_loaders(args)
    criterion, criterion_val = get_loss(args)

    auto_resume_details = None
    if AutoResume:
        auto_resume_details = AutoResume.get_resume_details()

    if auto_resume_details:
        checkpoint_fn = auto_resume_details.get("RESUME_FILE", None)
        checkpoint = torch.load(checkpoint_fn,
                                map_location=torch.device('cpu'))
        args.result_dir = auto_resume_details.get("TENSORBOARD_DIR", None)
        args.start_epoch = int(auto_resume_details.get("EPOCH", None)) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = ("Found details of a requested auto-resume: checkpoint={}"
               " tensorboard={} at epoch {}")
        logx.msg(msg.format(checkpoint_fn, args.result_dir, args.start_epoch))
    elif args.resume:
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        args.arch = checkpoint['arch']
        args.start_epoch = int(checkpoint['epoch']) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = "Resuming from: checkpoint={}, epoch {}, arch {}"
        logx.msg(msg.format(args.resume, args.start_epoch, args.arch))
    elif args.snapshot:
        if 'ASSETS_PATH' in args.snapshot:
            args.snapshot = args.snapshot.replace('ASSETS_PATH',
                                                  cfg.ASSETS_PATH)
        checkpoint = torch.load(args.snapshot,
                                map_location=torch.device('cpu'))
        args.restore_net = True
        msg = "Loading weights from: checkpoint={}".format(args.snapshot)
        logx.msg(msg)

    #define the NASA optimizer parameter
    iter_tot = len(train_loader) * args.max_epoch
    #    tau = args.tau_factor/sqrt(iter_tot)
    tau = 1
    net = network.get_net(args, criterion)
    k = 1
    #    optim, scheduler = get_optimizer(args, net)
    optim, scheduler = get_optimizer(args, net, tau, k)
    # Visualize feature maps
    #activation = {}
    #def get_activation(name):
    #def hook(model, input, output):
    #activation[name] = output.detach()
    #return hook

    #net.layer[0].register_forward_hook(get_activation('conv1'))
    #data, _ = dataset[0]
    #data.unsqueeze_(0)
    #output = model(data)

    #act = activation['conv1'].squeeze()
    #fig, axarr = plt.subplots(act.size(0))
    #for idx in range(act.size(0)):
    #axarr[idx].imshow(act[idx])

    if args.fp16:
        net, optim = amp.initialize(net, optim, opt_level=args.amp_opt_level)

    net = network.wrap_network_in_dataparallel(net, args.apex)

    if args.summary:

        from thop import profile
        img = torch.randn(1, 3, 640, 640).cuda()
        mask = torch.randn(1, 1, 640, 640).cuda()
        macs, params = profile(net, inputs={'images': img, 'gts': mask})
        print(f'macs {macs} params {params}')
        sys.exit()

    if args.restore_optimizer:
        restore_opt(optim, checkpoint)
    if args.restore_net:
        restore_net(net, checkpoint)

    if args.init_decoder:
        net.module.init_mods()

    torch.cuda.empty_cache()

    if args.start_epoch != 0:
        scheduler.step(args.start_epoch)

    # There are 4 options for evaluation:
    #  --eval val                           just run validation
    #  --eval val --dump_assets             dump all images and assets
    #  --eval folder                        just dump all basic images
    #  --eval folder --dump_assets          dump all images and assets

    if args.eval == 'test':
        validate(val_loader,
                 net,
                 criterion=None,
                 optim=None,
                 epoch=0,
                 calc_metrics=False,
                 dump_assets=args.dump_assets,
                 dump_all_images=True,
                 testing=True,
                 grid=city)

        return 0

    if args.eval == 'val':

        if args.dump_topn:
            validate_topn(val_loader, net, criterion_val, optim, 0, args)
        else:
            validate(val_loader,
                     net,
                     criterion=criterion_val,
                     optim=optim,
                     epoch=0,
                     dump_assets=args.dump_assets,
                     dump_all_images=args.dump_all_images,
                     calc_metrics=not args.no_metrics)
        return 0
    elif args.eval == 'folder':
        # Using a folder for evaluation means to not calculate metrics
        validate(val_loader,
                 net,
                 criterion=criterion_val,
                 optim=optim,
                 epoch=0,
                 calc_metrics=False,
                 dump_assets=args.dump_assets,
                 dump_all_images=True)
        return 0
    elif args.eval is not None:
        raise 'unknown eval option {}'.format(args.eval)

    for epoch in range(args.start_epoch, args.max_epoch):
        update_epoch(epoch)

        if args.only_coarse:
            train_obj.only_coarse()
            train_obj.build_epoch()
            if args.apex:
                train_loader.sampler.set_num_samples()

        elif args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.disable_coarse()
                train_obj.build_epoch()
                if args.apex:
                    train_loader.sampler.set_num_samples()
            else:
                train_obj.build_epoch()
        else:
            pass

        train(train_loader, net, optim, epoch)

        if args.apex:
            train_loader.sampler.set_epoch(epoch + 1)

        if epoch % args.val_freq == 0:
            validate(val_loader, net, criterion_val, optim, epoch)

        scheduler.step()

        if check_termination(epoch):
            return 0
Example #3
0
def main():
    """
    Main Function
    """
    if AutoResume:
        AutoResume.init()

    assert args.result_dir is not None, 'need to define result_dir arg'
    logx.initialize(logdir=args.result_dir,
                    tensorboard=False,
                    hparams=vars(args),
                    global_rank=args.global_rank)

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args)
    train_loader, val_loader, train_obj = datasets.setup_loaders(args)
    criterion, criterion_val = get_loss(args)

    auto_resume_details = None
    if AutoResume:
        auto_resume_details = AutoResume.get_resume_details()

    if auto_resume_details:
        checkpoint_fn = auto_resume_details.get("RESUME_FILE", None)
        checkpoint = torch.load(checkpoint_fn,
                                map_location=torch.device('cpu'))
        args.result_dir = auto_resume_details.get("TENSORBOARD_DIR", None)
        args.start_epoch = int(auto_resume_details.get("EPOCH", None)) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = ("Found details of a requested auto-resume: checkpoint={}"
               " tensorboard={} at epoch {}")
        logx.msg(msg.format(checkpoint_fn, args.result_dir, args.start_epoch))
    elif args.resume:
        checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
        args.arch = checkpoint['arch']
        args.start_epoch = int(checkpoint['epoch']) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = "Resuming from: checkpoint={}, epoch {}, arch {}"
        logx.msg(msg.format(args.resume, args.start_epoch, args.arch))
    elif args.snapshot:
        if 'ASSETS_PATH' in args.snapshot:
            args.snapshot = args.snapshot.replace('ASSETS_PATH',
                                                  cfg.ASSETS_PATH)
        checkpoint = torch.load(args.snapshot,
                                map_location=torch.device('cpu'))
        args.restore_net = True
        msg = "Loading weights from: checkpoint={}".format(args.snapshot)
        logx.msg(msg)

    net = network.get_net(args, criterion)
    optim, scheduler = get_optimizer(args, net)

    net = network.wrap_network_in_dataparallel(net, args.apex)

    if args.restore_optimizer:
        restore_opt(optim, checkpoint)
    if args.restore_net:
        restore_net(net, checkpoint)

    if args.init_decoder:
        net.module.init_mods()

    torch.cuda.empty_cache()

    if args.start_epoch != 0:
        scheduler.step(args.start_epoch)

    if args.eval == 'folder':
        # Using a folder for evaluation means to not calculate metrics
        # validate(val_loader, net, criterion=None, optim=None, epoch=0,
        #          calc_metrics=False, dump_assets=args.dump_assets,
        #          dump_all_images=True)
        if not os.path.exists(args.result_dir + 'image_2/'):
            os.mkdir(args.result_dir + 'image_2/')
        if not os.path.exists(args.result_dir + 'image_3/'):
            os.mkdir(args.result_dir + 'image_3/')

        num_image = 7481
        for idx in tqdm(range(num_image)):
            sample_idx = "%06d" % idx
            eval_minibatch(sample_idx, "image_2/", net, args)
            eval_minibatch(sample_idx, "image_3/", net, args)

        return 0
    elif args.eval is not None:
        raise 'unknown eval option {}'.format(args.eval)
Example #4
0
def main():
    """
    Main Function
    """
    if AutoResume:
        AutoResume.init()

    assert args.result_dir is not None, 'need to define result_dir arg'
    logx.initialize(logdir=args.result_dir,
                    tensorboard=True, hparams=vars(args),
                    global_rank=args.global_rank)

    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args)
    train_loader, val_loader, train_obj = \
        datasets.setup_loaders(args)
    criterion, criterion_val = get_loss(args)

    auto_resume_details = None
    if AutoResume:
        auto_resume_details = AutoResume.get_resume_details()

    if auto_resume_details:
        checkpoint_fn = auto_resume_details.get("RESUME_FILE", None)
        checkpoint = torch.load(checkpoint_fn,
                                map_location=torch.device('cpu'))
        args.result_dir = auto_resume_details.get("TENSORBOARD_DIR", None)
        args.start_epoch = int(auto_resume_details.get("EPOCH", None)) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = ("Found details of a requested auto-resume: checkpoint={}"
               " tensorboard={} at epoch {}")
        logx.msg(msg.format(checkpoint_fn, args.result_dir,
                            args.start_epoch))
    elif args.resume:
        checkpoint = torch.load(args.resume,
                                map_location=torch.device('cpu'))
        args.arch = checkpoint['arch']
        args.start_epoch = int(checkpoint['epoch']) + 1
        args.restore_net = True
        args.restore_optimizer = True
        msg = "Resuming from: checkpoint={}, epoch {}, arch {}"
        logx.msg(msg.format(args.resume, args.start_epoch, args.arch))
    elif args.snapshot:
        if 'ASSETS_PATH' in args.snapshot:
            args.snapshot = args.snapshot.replace('ASSETS_PATH', cfg.ASSETS_PATH)
        checkpoint = torch.load(args.snapshot,
                                map_location=torch.device('cpu'))
        args.restore_net = True
        msg = "Loading weights from: checkpoint={}".format(args.snapshot)
        logx.msg(msg)

    net = network.get_net(args, criterion)
    optim, scheduler = get_optimizer(args, net)

    if args.fp16:
        net, optim = amp.initialize(net, optim, opt_level=args.amp_opt_level)

    net = network.wrap_network_in_dataparallel(net, args.apex)

    if args.summary:
        print(str(net))
        from pytorchOpCounter.thop import profile
        img = torch.randn(1, 3, 1024, 2048).cuda()
        mask = torch.randn(1, 1, 1024, 2048).cuda()
        macs, params = profile(net, inputs={'images': img, 'gts': mask})
        print(f'macs {macs} params {params}')
        sys.exit()

    if args.restore_optimizer:
        restore_opt(optim, checkpoint)
    if args.restore_net:
        restore_net(net, checkpoint)

    if args.init_decoder:
        net.module.init_mods()

    torch.cuda.empty_cache()

    if args.start_epoch != 0:
        scheduler.step(args.start_epoch)

    # There are 4 options for evaluation:
    #  --eval val                           just run validation
    #  --eval val --dump_assets             dump all images and assets
    #  --eval folder                        just dump all basic images
    #  --eval folder --dump_assets          dump all images and assets
    if args.eval == 'val':

        if args.dump_topn:
            validate_topn(val_loader, net, criterion_val, optim, 0, args)
        else:
            validate(val_loader, net, criterion=criterion_val, optim=optim, epoch=0,
                     dump_assets=args.dump_assets,
                     dump_all_images=args.dump_all_images,
                     calc_metrics=not args.no_metrics)
        return 0
    elif args.eval == 'folder':
        # Using a folder for evaluation means to not calculate metrics
        validate(val_loader, net, criterion=None, optim=None, epoch=0,
                 calc_metrics=False, dump_assets=args.dump_assets,
                 dump_all_images=True)
        return 0
    elif args.eval is not None:
        raise 'unknown eval option {}'.format(args.eval)

    for epoch in range(args.start_epoch, args.max_epoch):
        update_epoch(epoch)

        if args.only_coarse:
            train_obj.only_coarse()
            train_obj.build_epoch()
            if args.apex:
                train_loader.sampler.set_num_samples()

        elif args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.disable_coarse()
                train_obj.build_epoch()
                if args.apex:
                    train_loader.sampler.set_num_samples()
            else:
                train_obj.build_epoch()
        else:
            pass

        train(train_loader, net, optim, epoch)

        if args.apex:
            train_loader.sampler.set_epoch(epoch + 1)

        if epoch % args.val_freq == 0:
            validate(val_loader, net, criterion_val, optim, epoch)

        scheduler.step()

        if check_termination(epoch):
            return 0
def main():
    """
    Main Function
    """
    rank = args.rank
    cfg.GLOBAL_RANK = rank
    args.gpus = torch.cuda.device_count()
    device = torch.device("cpu")
    hvd.init()

    torch.manual_seed(999999)
    #if args.cuda:
    args.cuda = True
    # Horovod: pin GPU to local rank.
    torch.cuda.set_device(hvd.local_rank())
    #torch.cuda.manual_seed(args.seed)

    assert args.result_dir is not None, 'need to define result_dir arg'
    logx.initialize(logdir=args.result_dir,
                    tensorboard=True,
                    hparams=vars(args),
                    global_rank=args.global_rank)
    #print("vefore assert and infer")
    # Set up the Arguments, Tensorboard Writer, Dataloader, Loss Fn, Optimizer
    assert_and_infer_cfg(args)
    prep_experiment(args)
    #     args.ngpu = torch.cuda.device_count()
    #     args.best_record = {'mean_iu': -1, 'epoch': 0}
    #print("before datasets / loss")
    train_loader, val_loader, train_obj = datasets.setup_loaders(args)
    criterion, criterion_val = get_loss(args)

    cwd = os.getcwd()
    sz = ht.MPI_WORLD.size
    filename = cwd + "/citys-hvd-checkpoint-" + str(sz) + ".pth.tar"
    if args.resume and os.path.isfile(filename):
        checkpoint = torch.load(filename, map_location=torch.device('cpu'))
        args.arch = checkpoint['arch']
        args.start_epoch = int(checkpoint['epoch']) + 1
        args.restore_net = True
        args.restore_optimizer = True
        logx.msg(f"Resuming from: checkpoint={args.resume}, " \
                 f"epoch {args.start_epoch}, arch {args.arch}")
    elif args.snapshot:
        if 'ASSETS_PATH' in args.snapshot:
            args.snapshot = args.snapshot.replace('ASSETS_PATH',
                                                  cfg.ASSETS_PATH)
        checkpoint = torch.load(args.snapshot,
                                map_location=torch.device('cpu'))
        args.restore_net = True
        logx.msg(f"Loading weights from: checkpoint={args.snapshot}")

    # todo: HeAT fixes -- urgent -- DDDP / optim / scheduler
    net = network.get_net(args, criterion)
    # net = net.to(device)

    # todo: optim -> direct wrap after this, scheduler stays the same?
    optim, scheduler = get_optimizer(args, net)

    # if args.fp16:
    #     net, optim = amp.initialize(net, optim, opt_level=args.amp_opt_level)
    compression = hvd.Compression.fp16  # if args.fp16_allreduce else hvd.Compression.none

    optim = hvd.DistributedOptimizer(
        optim,
        named_parameters=net.named_parameters(),
        compression=compression,
        backward_passes_per_step=1,  # args.batches_per_allreduce,
        op=hvd.Average,
        gradient_predivide_factor=1.0,  # args.gradient_predivide_factor)
    )
    #print("after hvd optimizer setup")

    if args.summary:
        print(str(net))
        from thop import profile
        img = torch.randn(1, 3, 1024, 2048).cuda()
        mask = torch.randn(1, 1, 1024, 2048).cuda()
        macs, params = profile(net, inputs={'images': img, 'gts': mask})
        print0(f'macs {macs} params {params}')
        sys.exit()

    if args.restore_optimizer:
        restore_opt(optim, checkpoint)
    if args.restore_net:
        #net.loat_state_dict(checkpoint["state_dict"])
        restore_net(net, checkpoint)

    if args.init_decoder:
        net.module.init_mods()

    torch.cuda.empty_cache()
    #print("before parameter broadcasts")
    #hvd.broadcast_parameters(net.state_dict(), root_rank=0)
    #hvd.broadcast_optimizer_state(optim, root_rank=0)

    if args.start_epoch != 0:
        # TODO: need a loss value for the restart at a certain epoch...
        scheduler.step(args.start_epoch)

    #net = net.cuda()
    # There are 4 options for evaluation:
    #  --eval val                           just run validation
    #  --eval val --dump_assets             dump all images and assets
    #  --eval folder                        just dump all basic images
    #  --eval folder --dump_assets          dump all images and assets
    # todo: HeAT fixes -- not urgent --
    # if args.eval == 'val':
    #     if args.dump_topn:
    #         validate_topn(val_loader, net, criterion_val, optim, 0, args)
    #     else:
    #         validate(val_loader, net, criterion=criterion_val, optim=optim, epoch=0,
    #                  dump_assets=args.dump_assets,
    #                  dump_all_images=args.dump_all_images,
    #                  calc_metrics=not args.no_metrics)
    #     return 0
    # elif args.eval == 'folder':
    #     # Using a folder for evaluation means to not calculate metrics
    #     validate(val_loader, net, criterion=None, optim=None, epoch=0,
    #              calc_metrics=False, dump_assets=args.dump_assets,
    #              dump_all_images=True)
    #     return 0
    # elif args.eval is not None:
    #     raise 'unknown eval option {}'.format(args.eval)

    scaler = None  #amp.GradScaler()
    args.amp = False  #True

    nodes = str(int(hvd.size() / torch.cuda.device_count()))
    cwd = os.getcwd()
    fname = cwd + "/" + nodes + "-hvd-citys-benchmark"
    if args.resume and rank == 0 and os.path.isfile(fname + ".pkl"):
        with open(fname + ".pkl", "rb") as f:
            out_dict = pickle.load(f)
    else:
        out_dict = {
            "epochs": [],
            nodes + "-avg-batch-time": [],
            nodes + "-total-train-time": [],
            nodes + "-train-loss": [],
            nodes + "-val-loss": [],
            nodes + "-val-iou": [],
            nodes + "-val-time": [],
        }
        print0("Output dict:", fname)
    # train_losses, train_btimes, train_ttime = [], [], []
    # val_losses, val_iu, val_ttime = [], [], []

    for epoch in range(args.start_epoch, args.max_epoch):
        # todo: HeAT fixes -- possible conflict between processes
        update_epoch(epoch)

        if args.only_coarse:  # default: false
            train_obj.only_coarse()
            train_obj.build_epoch()
        elif args.class_uniform_pct:
            if epoch >= args.max_cu_epoch:
                train_obj.disable_coarse()
                train_obj.build_epoch()
            else:
                train_obj.build_epoch()
        else:
            pass

        ls, bt, btt = train(train_loader, net, optim, epoch, scaler)
        # dp_optim.epoch_loss_logic(ls, loss_globally_averaged=True)

        # if epoch % args.val_freq == 0:
        vls, iu, vtt = validate(val_loader, net, criterion_val, optim, epoch)
        if args.lr_schedule == "plateau":
            scheduler.step(ls)  # val_loss)
        else:
            scheduler.step()

        if args.rank == 0:
            save_checkpoint({
                "epoch": epoch + 1,
                "arch": args.arch,
                "state_dict": net.state_dict(),
                "optimizer": optim.state_dict(),
                # "skip_stable": optim.stability.get_dict()
            })

        out_dict["epochs"].append(epoch)
        out_dict[nodes + "-train-loss"].append(ls)
        out_dict[nodes + "-avg-batch-time"].append(bt)
        out_dict[nodes + "-total-train-time"].append(btt)
        out_dict[nodes + "-val-loss"].append(vls)
        out_dict[nodes + "-val-iou"].append(iu)
        out_dict[nodes + "-val-time"].append(vtt)
        if args.rank == 0:
            save_obj(out_dict, fname)

    if args.rank == 0:
        print("\nRESULTS\n")
        import pandas as pd
        df = pd.DataFrame.from_dict(out_dict).set_index("epochs")
        with pd.option_context("display.max_rows", None, "display.max_columns",
                               None):
            # more options can be specified also
            print(df)
        if args.benchmarking:
            try:
                fulldf = pd.read_csv(cwd + "/hvd-bench-results.csv")
                fulldf = pd.concat([df, fulldf], axis=1)
            except FileNotFoundError:
                fulldf = df
            fulldf.to_csv(cwd + "/hvd-bench-results.csv")