def init_config(conf):
    # define the graph for the computation.
    conf.graph = topology.define_graph_topology(
        world=conf.world,
        world_conf=conf.world_conf,
        n_participated=conf.n_participated,
        on_cuda=conf.on_cuda,
    )
    conf.graph.rank = dist.get_rank()

    # init related to randomness on cpu.
    if not conf.same_seed_process:
        conf.manual_seed = 1000 * conf.manual_seed + conf.graph.rank
    conf.random_state = np.random.RandomState(conf.manual_seed)
    torch.manual_seed(conf.manual_seed)

    # configure cuda related.
    if conf.graph.on_cuda:
        assert torch.cuda.is_available()
        torch.cuda.manual_seed(conf.manual_seed)
        torch.cuda.set_device(conf.graph.primary_device)
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True if conf.train_fast else False

    # init the model arch info.
    conf.arch_info = (param_parser.dict_parser(conf.complex_arch)
                      if conf.complex_arch is not None else {
                          "master": conf.arch,
                          "worker": conf.arch
                      })
    conf.arch_info["worker"] = conf.arch_info["worker"].split(":")

    # parse the fl_aggregate scheme.
    conf._fl_aggregate = conf.fl_aggregate
    conf.fl_aggregate = (param_parser.dict_parser(conf.fl_aggregate) if
                         conf.fl_aggregate is not None else conf.fl_aggregate)
    [
        setattr(conf, f"fl_aggregate_{k}", v)
        for k, v in conf.fl_aggregate.items()
    ]

    # define checkpoint for logging (for federated learning server).
    checkpoint.init_checkpoint(conf, rank=str(conf.graph.rank))

    # configure logger.
    conf.logger = logging.Logger(conf.checkpoint_dir)

    # display the arguments' info.
    if conf.graph.rank == 0:
        logging.display_args(conf)

    # sync the processes.
    dist.barrier()
Exemplo n.º 2
0
def maybe_resume_from_checkpoint(conf, model, optimizer, scheduler):
    if conf.resume:
        if conf.checkpoint_index is not None:
            # reload model from a specific checkpoint index.
            checkpoint_index = "_epoch_" + conf.checkpoint_index
        else:
            # reload model from the latest checkpoint.
            checkpoint_index = ""
        checkpoint_path = join(
            conf.resume,
            str(conf.graph.rank),
            "checkpoint{}.pth.tar".format(checkpoint_index),
        )
        print("try to load previous model from the path:{}".format(
            checkpoint_path))

        if isfile(checkpoint_path):
            print("=> loading checkpoint {} for {}".format(
                conf.resume, conf.graph.rank))

            # get checkpoint.
            checkpoint = torch.load(checkpoint_path, map_location="cpu")

            # restore some run-time info.
            scheduler.update_from_checkpoint(checkpoint)

            # reset path for log.
            try:
                remove_folder(conf.checkpoint_root)
            except RuntimeError as e:
                print(f"ignore the error={e}")
            conf.checkpoint_root = conf.resume
            conf.checkpoint_dir = join(conf.resume, str(conf.graph.rank))
            # restore model.
            model.load_state_dict(checkpoint["state_dict"])
            # restore optimizer.
            optimizer.load_state_dict(checkpoint["optimizer"])
            # logging.
            print("=> loaded model from path '{}' checkpointed at (epoch {})".
                  format(conf.resume, checkpoint["current_epoch"]))
            # configure logger.
            conf.logger = logging.Logger(conf.checkpoint_dir)

            # try to solve memory issue.
            del checkpoint
            torch.cuda.empty_cache()
            gc.collect()
            return
        else:
            print("=> no checkpoint found at '{}'".format(conf.resume))
Exemplo n.º 3
0
def init_config(conf):
    # define the graph for the computation.
    cur_rank = dist.get_rank() if conf.distributed else 0
    conf.graph = topology.define_graph_topology(
        graph_topology=conf.graph_topology,
        world=conf.world,
        n_mpi_process=conf.n_mpi_process,  # the # of total main processes.
        n_sub_process=conf.
        n_sub_process,  # the # of subprocess for each main process.
        comm_device=conf.comm_device,
        on_cuda=conf.on_cuda,
        rank=cur_rank,
    )
    conf.is_centralized = conf.graph_topology == "complete"

    # re-configure batch_size if sub_process > 1.
    if conf.n_sub_process > 1:
        conf.batch_size = conf.batch_size * conf.n_sub_process

    # configure cuda related.
    if conf.graph.on_cuda:
        assert torch.cuda.is_available()
        torch.manual_seed(conf.manual_seed)
        torch.cuda.manual_seed(conf.manual_seed)
        torch.cuda.set_device(conf.graph.device[0])
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True if conf.train_fast else False

    # define checkpoint for logging.
    checkpoint.init_checkpoint(conf)

    # configure logger.
    conf.logger = logging.Logger(conf.checkpoint_dir)

    # display the arguments' info.
    logging.display_args(conf)