Beispiel #1
0
def init_checkpoint(conf):
    # init checkpoint dir.
    conf.checkpoint_root = join(
        conf.checkpoint,
        conf.data,
        conf.arch,
        conf.experiment if conf.experiment is not None else "",
        conf.timestamp,
    )
    conf.checkpoint_dir = join(conf.checkpoint_root, str(conf.graph.rank))
    if conf.save_some_models is not None:
        conf.save_some_models = conf.save_some_models.split(",")

    # if the directory does not exists, create them.
    build_dirs(conf.checkpoint_dir)
def init_checkpoint(conf, rank=None):
    # init checkpoint_root for the main process.
    conf.checkpoint_root = join(
        conf.checkpoint,
        conf.data,
        conf.arch,
        conf.experiment,
        conf.timestamp + get_checkpoint_folder_name(conf),
    )
    if conf.save_some_models is not None:
        conf.save_some_models = conf.save_some_models.split(",")

    if rank is None:
        # if the directory does not exists, create them.
        build_dirs(conf.checkpoint_root)
    else:
        conf.checkpoint_dir = join(conf.checkpoint_root, rank)
        build_dirs(conf.checkpoint_dir)
Beispiel #3
0
def init_distributed_world(conf, backend):
    if backend == "mpi":
        dist.init_process_group("mpi")
    elif backend == "nccl" or backend == "gloo":
        # init the process group.
        _tmp_path = os.path.join(conf.checkpoint, "tmp", conf.timestamp)
        op_paths.build_dirs(_tmp_path)

        dist_init_file = os.path.join(_tmp_path, "dist_init")

        torch.distributed.init_process_group(
            backend=backend,
            init_method="file://" + os.path.abspath(dist_init_file),
            timeout=datetime.timedelta(seconds=120),
            world_size=conf.n_mpi_process,
            rank=conf.local_rank,
        )
    else:
        raise NotImplementedError
Beispiel #4
0
def init_distributed_world(conf, backend):
    if backend == "mpi":
        dist.init_process_group("mpi")
    elif backend == "nccl" or backend == "gloo":
        # init the process group.
        _tmp_path = os.path.join(conf.checkpoint, "tmp", conf.timestamp)
        op_paths.build_dirs(_tmp_path)

        dist_init_file = os.path.join(_tmp_path, "dist_init")
        
        torch.distributed.init_process_group(
            backend=backend,
            init_method='tcp://{}:60000'.format(os.environ['MASTER_PORT_29500_TCP_ADDR']),
            timeout=datetime.timedelta(seconds=120),
            world_size=conf.n_mpi_process,
            rank=int(conf.local_rank),
        )
    else:
        raise NotImplementedError
def generate_data(random_state, batch_size, num_images_per_classes, device,
                  output_path):
    generator_model = BigGAN.from_pretrained("biggan-deep-128",
                                             cache_dir=os.path.join(
                                                 "./data/checkpoint",
                                                 "cached_model"))
    generator_model = generator_model.to(device)

    # prepare a input
    truncation = 0.4
    op_paths.build_dirs(f"{output_path}")

    for class_idx in range(1000):
        _id = 0
        num_batches = int(num_images_per_classes / batch_size)

        op_paths.build_dirs(f"{output_path}/{class_idx}")
        for _ in range(num_batches):
            class_vector = one_hot_from_int(class_idx, batch_size=batch_size)
            noise_vector = truncated_noise_sample(truncation=truncation,
                                                  batch_size=batch_size)
            noise_vector = torch.from_numpy(noise_vector).to(device)
            class_vector = torch.from_numpy(class_vector).to(device)

            # generate images
            with torch.no_grad():
                generated_images = generator_model(noise_vector, class_vector,
                                                   truncation).clamp(min=-1,
                                                                     max=1)

            for image in generated_images:
                torchvision.utils.save_image(
                    image,
                    fp=f"{output_path}/{class_idx}/{_id}",
                    format="JPEG",
                    scale_each=True,
                    normalize=True,
                )
                _id += 1
        print(f"finished {class_idx + 1}/1000.")