def initialize(self, args=None, save=True):

        if hasattr(args, 'name') and args.name is not None:
            print("args.name= ", args.name)
            self.name = args.name

        if hasattr(args, 'batch_size') and args.name is not None:
            self.batch_size = args.batch_size

        self.main_dir = os.path.join(self.expr_dir, self.name)
        self.model_dir = os.path.join(self.main_dir, "pth")
        self.tb_dir = os.path.join(self.main_dir, "log")
        self.info_dir = os.path.join(self.main_dir, "info")
        self.output_dir = os.path.join(self.main_dir, "output")

        ensure_dirs([
            self.main_dir, self.model_dir, self.tb_dir, self.info_dir,
            self.output_dir
        ])

        self.device = torch.device(
            "cuda:%d" % self.cuda_id if torch.cuda.is_available() else "cpu")

        if save:
            self.config_name = args.config
            cfg_file = "%s.py" % self.config_name
            shutil.copy(pjoin(BASEPATH, cfg_file),
                        os.path.join(self.info_dir, cfg_file))
Beispiel #2
0
def get_all_codes(cfg, output_path):

    print(output_path)
    if os.path.exists(output_path):
        return np.load(output_path, allow_pickle=True)['data'].item()
    ensure_dirs(os.path.dirname(output_path))

    print("start over")
    # Dataloader
    train_loader = get_dataloader(cfg, 'train', shuffle=False)
    test_loader = get_dataloader(cfg, 'test', shuffle=False)

    # Trainer
    trainer = Trainer(cfg)
    trainer.to(cfg.device)
    trainer.resume()

    with torch.no_grad():
        vis_dicts = {}
        for phase, loader in [['train', train_loader],
                              ['test', test_loader]]:

            vis_dict = None
            for t, data in enumerate(loader):
                vis_codes = trainer.get_latent_codes(data)
                if vis_dict is None:
                    vis_dict = {}
                    for key, value in vis_codes.items():
                        vis_dict[key] = [value]
                else:
                    for key, value in vis_codes.items():
                        vis_dict[key].append(value)
            for key, value in vis_dict.items():
                if phase == "test" and key == "content_code":
                    continue
                if key == "meta":
                    secondary_keys = value[0].keys()
                    num = len(value)
                    vis_dict[key] = {
                        secondary_key: [to_float(item) for i in range(num) for item in value[i][secondary_key]]
                        for secondary_key in secondary_keys}
                else:
                    vis_dict[key] = torch.cat(vis_dict[key], 0)
                    vis_dict[key] = vis_dict[key].cpu().numpy()
                    vis_dict[key] = to_float(vis_dict[key].reshape(vis_dict[key].shape[0], -1))
            vis_dicts[phase] = vis_dict

        np.savez_compressed(output_path, data=vis_dicts)
        return vis_dicts
Beispiel #3
0
def get_demo_plots(data, output_path):
    """
    data: {"train": dict_train, "test": dict_test}
    dict_train: {"style2d_code": blabla, etc.}
    """
    ensure_dirs(output_path)

    def fig_title(title):
        return pjoin(output_path, title)

    style_labels = data["train"]["meta"]["style"]

    adain_raw = []
    for key in ["style2d_adain", "style3d_adain"]:
        for phase in ["train", "test"]:
            adain_raw.append(data[phase][key])
    adain_tsne = calc_many_blas(adain_raw, calc_tsne)
    plot2D_overlay([adain_tsne[0], adain_tsne[2]],
                   [style_labels, style_labels], [1.0, 0.5],
                   fig_title(f'joint_embedding_adain_tsne'))

    for key in ["style3d_code", "style3d_adain"]:
        tsne_code = calc_tsne(data["train"][key])
        plot2D(tsne_code, style_labels, fig_title(f'{key}_tsne'))

    content_code_pca = calc_pca(data["train"]["content_code"])

    indices = [
        i for i in range(len(data["train"]["meta"]["content"]))
        if data["train"]["meta"]["content"][i] == "walk"
    ]
    walk_code = content_code_pca[np.array(indices)]
    phase_labels = [data["train"]["meta"]["phase"][i] for i in indices]
    plot2D_phase(walk_code, phase_labels, fig_title(f'content_by_phase'))

    plot2D(content_code_pca, style_labels, fig_title(f'content_by_style'))
Beispiel #4
0
def get_all_plots(data,
                  output_path,
                  writers,
                  iter,
                  summary=True,
                  style_cluster_protocols=('pca'),
                  separate_compute=False):
    """
    data: {"train": dict_train, "test": dict_test}
    dict_train: {"style2d_code": blabla, etc.}
    separate_compute: compute t-SNE for 2D & 3D separately
    """
    ensure_dirs(output_path)

    def fig_title(title):
        return pjoin(output_path, title)

    def add_fig(fig, title, phase):
        if summary:
            writers[phase].add_figure(title, fig, global_step=iter)

    keys = data["train"].keys()
    has2d = "style2d_code" in keys
    has3d = "style3d_code" in keys

    # style codes & adain params
    for suffix in ["_code", "_adain"]:

        codes_raw = []
        titles = []
        phases = []

        data_keys = []
        if has2d: data_keys.append("style2d" + suffix)
        if has3d: data_keys.append("style3d" + suffix)
        for key in data_keys:
            for phase in ["train", "test"]:
                codes_raw.append(data[phase][key])
                titles.append(f'{phase}_{key}')
                phases.append(phase)

        # calc tsne with style2/3d, train/test altogether
        for name, protocol in zip(['pca', 'tsne'], [calc_pca, calc_tsne]):
            if name not in style_cluster_protocols:
                continue
            style_codes = calc_many_blas(codes_raw, protocol)
            fig = plot2D_overlay([style_codes[0], style_codes[2]], [
                data["train"]["meta"]["style"], data["train"]["meta"]["style"]
            ], [1.0, 0.5], fig_title(f'joint_embedding_{name}{suffix}'))
            add_fig(fig, f'joint_embedding_{name}{suffix}', "train")

            for i, (code, phase,
                    title) in enumerate(zip(style_codes, phases, titles)):
                if separate_compute:
                    code = protocol(codes_raw[i])
                for label_type in ["style", "content"]:
                    fig = plot2D(code, data[phase]["meta"][label_type],
                                 fig_title(f'{title}_{name}_{label_type}'))
                    add_fig(fig, f'{title}_{name}_{label_type}', phase)

    # content codes (train only)
    content_code_pca = calc_pca(data["train"]["content_code"])

    for label in ["style", "content", "phase"]:
        if label == "phase":
            indices = [
                i for i in range(len(data["train"]["meta"]["content"]))
                if data["train"]["meta"]["content"][i] == "walk"
            ]
            walk_code = content_code_pca[np.array(indices)]
            phase_labels = [data["train"]["meta"]["phase"][i] for i in indices]
            fig = plot2D_phase(walk_code, phase_labels,
                               fig_title(f'content_by_{label}'))
        else:
            fig = plot2D(content_code_pca, data["train"]["meta"][label],
                         fig_title(f'content_by_{label}'))
        add_fig(fig, f'content_by_{label}', "train")
    """