Ejemplo n.º 1
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
            name: 'DenseDiscriminator_v1'
            ch: 512
            init_type: 'orth'
            cfg_downsample:
              name: "AvgPool2d"
            num_cells: 3
            cfg_cell:
              name: "DenseBlock"
              n_nodes: 4
              cfg_mix_layer:
                name: "MixedLayer"
            cfg_ops:
              None:
                name: "D2None"
              Identity:
                name: "Identity"
              Conv2d_3x3:
                name: "Conv2dAct"
                cfg_conv:
                  name: "SNConv2d"
                  kernel_size: 3
                  padding: 1
                cfg_act:
                  name: "ReLU"
    """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg
Ejemplo n.º 2
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
      name: "DenseCell"
      n_nodes: 3
      cfg_mix_layer:
        name: "MixedLayer"
      cfg_ops:
        Identity:
          name: "Identity"
        Conv2d_3x3:
          name: "ActConv2d"
          cfg_act:
            name: "ReLU"
          cfg_conv:
            name: "Conv2d"
            kernel_size: 3
            padding: 1
        None:
          name: "D2None"
        
      """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg
Ejemplo n.º 3
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
      name: 'BigGANDisc'
      img_size: "kwargs['img_size']"
      n_classes: "kwargs['n_classes']"
      ch: 8
      use_cdisc: true
    """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg
Ejemplo n.º 4
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
        name: "StyleLayer"
        z_dim: 128
        n_mlp: 1
        num_features: 256

      """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg
Ejemplo n.º 5
0
  def update_cfg(cfg):
    if not getattr(cfg, 'update_cfg', False):
      return cfg

    cfg_str = """
              GAN_metric:
                name: TFFIDISScore
                tf_fid_stat: "datasets/tf_fid_stat_{dataset_name}_{img_size}.npz"
                tf_inception_model_dir: "datasets/tf_inception_model"
                num_inception_images: 50000

              """
    default_cfg = EasyDict(yaml.safe_load(cfg_str))
    cfg = update_config(default_cfg, cfg)
    return cfg
Ejemplo n.º 6
0
def update_nni_config_file(nni_config_file, update_nni_cfg_str):
  update_nni_cfg = yaml.safe_load(update_nni_cfg_str)
  # os.makedirs(update_nni_cfg['logDir'], exist_ok=True)

  with open(nni_config_file, 'r') as f:
    nni_cfg = yaml.safe_load(f)

  nni_cfg = update_config(nni_cfg, update_nni_cfg)
  nni_cfg = convert_easydict_to_dict(nni_cfg)
  logging.getLogger('tl').info('\nnni config:\n ' + get_dict_str(nni_cfg))

  updated_config_file = nni_config_file.split('.')[-2] + '_updated.' + nni_config_file.split('.')[-1]
  with open(updated_config_file, 'w') as f:
    yaml.dump(nni_cfg, f, indent=2, sort_keys=False)
  return updated_config_file
Ejemplo n.º 7
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
        name: "StyleV2Conv"
        cfg_modconv:
          name: "ModulatedConv2d"
          kernel_size: 3
          style_dim: 192

      """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg
Ejemplo n.º 8
0
    def update_cfg(self, cfg):
        if not getattr(cfg, 'update_cfg', False):
            return cfg

        cfg_str = """
      name: ""
      n_nodes: 4
      cfg_mix_layer:
        name: "MixedLayerWithArc"
      cfg_ops:        
        Conv2d_1x1:
          name: "Conv2d"
          kernel_size: 1
          padding: 0
        Conv2d_3x3:
          name: "Conv2d"
          kernel_size: 3
          padding: 1

      """
        default_cfg = EasyDict(yaml.safe_load(cfg_str))
        cfg = update_config(default_cfg, cfg)
        return cfg