Ejemplo n.º 1
0
def print_debug(feed, layers, file=None):
    if file is not None:
        PRINT(file, feed.size())
    else:
        print(feed.size())
    for layer in layers:
        try:
            feed = layer(feed)
        except BaseException:
            raise BaseException(
                "Type of layer {} not compatible with input {}.".format(
                    layer, feed))
        if isinstance(layer, nn.Conv2d) or isinstance(
                layer, nn.ConvTranspose2d) or isinstance(
                    layer, nn.Linear) or isinstance(
                        layer, ResidualBlock) or isinstance(
                            layer, SpectralNormalization) or isinstance(
                                layer, nn.Upsample):
            _str = '{}, {}'.format(str(layer).split('(')[0], feed.size())
            if file is not None:
                PRINT(file, _str)
            else:
                print(_str)
    if file is not None:
        PRINT(file, ' ')
    else:
        print(' ')
    return feed
Ejemplo n.º 2
0
def _PRINT(config):
    string = '------------ Options -------------'
    PRINT(config.log, string)
    for k, v in sorted(vars(config).items()):
        string = '%s: %s' % (str(k), str(v))
        PRINT(config.log, string)
    string = '-------------- End ---------------'
    PRINT(config.log, string)
Ejemplo n.º 3
0
 def histogram(self):
     values = {key: 0 for key in self.attr2idx.keys()}
     for line in self.lines:
         key = line.split('/')[-2].split('_')[1]
         values[key] += 1
     total = 0
     with open('datasets/{}_histogram_attributes.txt'.format(self.name),
               'w') as f:
         for key, value in sorted(values.items(),
                                  key=lambda kv: (kv[1], kv[0]),
                                  reverse=True):
             total += value
             PRINT(f, '{} {}'.format(key, value))
         PRINT(f, 'TOTAL {}'.format(total))
Ejemplo n.º 4
0
 def histogram(self):
     values = np.array([int(i) for i in self.lines[1][1:]]) * 0
     for line in self.lines[1:]:
         value = np.array([int(i) for i in line[1:]]).clip(min=0)
         values += value
     dict_ = {}
     for key, value in zip(self.lines[0][1:], values):
         dict_[key] = value
     total = 0
     with open('datasets/{}_histogram_attributes.txt'.format(self.name),
               'w') as f:
         for key, value in sorted(dict_.items(),
                                  key=lambda kv: (kv[1], kv[0]),
                                  reverse=True):
             total += value
             PRINT(f, '{} {}'.format(key, value))
         PRINT(f, 'TOTAL {}'.format(total))
Ejemplo n.º 5
0
    def INCEPTION_REAL(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'Real'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            label = torch.max(org_c, 1)[1][0]
            real_x = to_var((real_x + 1) / 2., volatile=True)
            pred = to_data(F.softmax(net(inception_up(real_x)), dim=1),
                           cpu=True).numpy()
            PRED_IS[int(label)].append(pred)

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            # prior is computed from all outputs
            py = np.sum(PRED_IS[label], axis=0)
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        file_.close()
Ejemplo n.º 6
0
 def debug(self):
     PRINT(self.config.log, '-- Generator:')
     feed = to_var(torch.ones(1, self.color_dim, self.image_size,
                              self.image_size),
                   volatile=True,
                   no_cuda=True)
     features = self.print_debug(feed, self.main)
     self.print_debug(features, self.fake)
     self.print_debug(features, self.attn)
Ejemplo n.º 7
0
 def debug(self):
     feed = to_var(
         torch.ones(1, self.color_dim, self.image_size, self.image_size),
         volatile=True,
         no_cuda=True)
     PRINT(self.config.log, '-- StyleEncoder:')
     features = self.print_debug(feed, self.main)
     fc_in = features.view(features.size(0), -1)
     self.print_debug(fc_in, self.fc)
Ejemplo n.º 8
0
 def debug(self):
     feed = to_var(torch.ones(1, self.color_dim, self.image_size,
                              self.image_size),
                   volatile=True,
                   no_cuda=True)
     modelList = zip(self.cnns_main, self.cnns_src, self.cnns_aux)
     for idx, outs in enumerate(modelList):
         PRINT(self.config.log, '-- MultiDiscriminator ({}):'.format(idx))
         features = self.print_debug(feed, outs[-3])
         self.print_debug(features, outs[-2])
         self.print_debug(features, outs[-1]).view(feed.size(0), -1)
         feed = self.downsample(feed)
Ejemplo n.º 9
0
    def LPIPS_REAL(self):
        from misc.utils import compute_lpips
        data_loader = self.data_loader
        model = None
        file_name = 'scores/{}_Attr_{}_LPIPS.txt'.format(
            self.config.dataset_fake, self.config.ALL_ATTR)
        if os.path.isfile(file_name):
            print(file_name)
            for line in open(file_name).readlines():
                print(line.strip())
            return

        DISTANCE = {
            i: []
            for i in range(len(data_loader.dataset.labels[0]) + 1)
        }  # 0:[], 1:[], 2:[]}
        for i, (real_x, org_c,
                files) in tqdm(enumerate(data_loader),
                               desc='Calculating LPISP - {}'.format(file_name),
                               total=len(data_loader)):
            for label in range(len(data_loader.dataset.labels[0])):
                for j, (_real_x, _org_c, _files) in enumerate(data_loader):
                    if j <= i:
                        continue
                    _org_label = torch.max(_org_c, 1)[1][0]
                    for _label in range(len(data_loader.dataset.labels[0])):
                        if _org_label == _label:
                            continue
                        distance, model = compute_lpips(real_x,
                                                        _real_x,
                                                        model=model)
                        DISTANCE[len(data_loader.dataset.labels[0])].append(
                            distance[0])
                        if label == _label:
                            DISTANCE[_label].append(distance[0])

        file_ = open(file_name, 'w')
        DISTANCE = {k: np.array(v) for k, v in DISTANCE.items()}
        for key, values in DISTANCE.items():
            if key == len(data_loader.dataset.labels[0]):
                mode = 'All'
            else:
                mode = chr(65 + key)
            PRINT(
                file_, "LPISP {}: {} +/- {}".format(mode, values.mean(),
                                                    values.std()))
        file_.close()
Ejemplo n.º 10
0
    def INCEPTION(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        n_styles = 20
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        self.G.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'SMIT'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: []
                   for i in range(len(data_loader.dataset.labels[0]))
                   }  # 0:[], 1:[], 2:[]}
        CIS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            PRED_CIS = {
                i: []
                for i in range(len(data_loader.dataset.labels[0]))
            }  # 0:[], 1:[], 2:[]}
            org_label = torch.max(org_c, 1)[1][0]
            real_x = real_x.repeat(n_styles, 1, 1, 1)  # .unsqueeze(0)
            real_x = to_var(real_x, volatile=True)

            target_c = (org_c * 0).repeat(n_styles, 1)
            target_c = to_var(target_c, volatile=True)
            for label in range(len(data_loader.dataset.labels[0])):
                if org_label == label:
                    continue
                target_c *= 0
                target_c[:, label] = 1
                style = to_var(self.G.random_style(n_styles),
                               volatile=True) if mode == 'SMIT' else None

                fake = (self.G(real_x, target_c, style)[0] + 1) / 2

                pred = to_data(F.softmax(net(inception_up(fake)), dim=1),
                               cpu=True).numpy()
                PRED_CIS[label].append(pred)
                PRED_IS[label].append(pred)

                # CIS for each image
                PRED_CIS[label] = np.concatenate(PRED_CIS[label], 0)
                py = np.sum(
                    PRED_CIS[label], axis=0
                )  # prior is computed from outputs given a specific input
                for j in range(PRED_CIS[label].shape[0]):
                    pyx = PRED_CIS[label][j, :]
                    CIS[label].append(entropy(pyx, py))

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            py = np.sum(PRED_IS[label],
                        axis=0)  # prior is computed from all outputs
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_cis = []
        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            cis = np.exp(np.mean(CIS[label]))
            total_cis.append(cis)
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
            PRINT(file_, "conditional Inception Score: {:.4f}".format(cis))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        PRINT(
            file_,
            "[TOTAL] conditional Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_cis).mean(),
                np.array(total_cis).std()))
        file_.close()
Ejemplo n.º 11
0
    def LPIPS_MULTIMODAL(self):
        from misc.utils import compute_lpips

        torch.manual_seed(1)
        torch.cuda.manual_seed(1)

        data_loader = self.data_loader
        model = None
        n_images = 20

        file_name = os.path.join(
            self.name.replace('{}.pth', 'LPIPS_MULTIMODAL.txt'))
        if os.path.isfile(file_name):
            print(file_name)
            for line in open(file_name).readlines():
                print(line.strip())

        # DISTANCE = {0:[], 1:[], 2:[]}
        DISTANCE = {
            i: []
            for i in range(len(data_loader.dataset.labels[0]) + 1)
        }  # 0:[], 1:[], 2:[]}
        print(file_name)
        for i, (real_x, org_c, files) in tqdm(enumerate(data_loader),
                                              desc='Calculating LPISP ',
                                              total=len(data_loader)):
            org_label = torch.max(org_c, 1)[1][0]
            for label in range(len(data_loader.dataset.labels[0])):
                if org_label == label:
                    continue
                target_c = org_c * 0
                target_c[:, label] = 1
                target_c = target_c.repeat(n_images, 1)
                real_x_var = to_var(real_x.repeat(n_images, 1, 1, 1),
                                    volatile=True)
                target_c = to_var(target_c, volatile=True)
                style = to_var(self.G.random_style(n_images), volatile=True)
                fake_x = self.G(real_x_var, target_c, stochastic=style)[0].data
                fake_x = [f.unsqueeze(0) for f in fake_x]
                _DISTANCE = []
                for ii, fake0 in enumerate(fake_x):
                    for jj, fake1 in enumerate(fake_x):
                        if jj <= ii:
                            continue
                        distance, model = compute_lpips(fake0,
                                                        fake1,
                                                        model=model)
                        _DISTANCE.append(distance[0])
                DISTANCE[len(data_loader.dataset.labels[0])].append(
                    np.array(_DISTANCE).mean())
                DISTANCE[label].append(DISTANCE[len(
                    data_loader.dataset.labels[0])][-1])

        file_ = open(file_name, 'w')
        DISTANCE = {k: np.array(v) for k, v in DISTANCE.items()}
        for key, values in DISTANCE.items():
            if key == len(data_loader.dataset.labels[0]):
                mode = 'All'
            else:
                mode = chr(65 + key)
            PRINT(
                file_, "LPISP {}: {} +/- {}".format(mode, values.mean(),
                                                    values.std()))
        file_.close()
Ejemplo n.º 12
0
    def LPIPS_UNIMODAL(self):
        from misc.utils import compute_lpips
        from shutil import copyfile
        torch.manual_seed(1)
        torch.cuda.manual_seed(1)

        data_loader = self.data_loader
        model = None
        style_fixed = True
        style_str = 'fixed' if style_fixed else 'random'
        file_name = os.path.join(
            self.name.replace('{}.pth',
                              'LPIPS_UNIMODAL_{}.txt'.format(style_str)))
        copy_name = 'scores/{}_Attr_{}_LPIPS_UNIMODAL_{}.txt'.format(
            self.config.dataset_fake, self.config.ALL_ATTR, style_str)
        if os.path.isfile(file_name):
            print(file_name)
            for line in open(file_name).readlines():
                print(line.strip())
            return

        DISTANCE = {
            i: []
            for i in range(len(data_loader.dataset.labels[0]) + 1)
        }  # 0:[], 1:[], 2:[]}
        n_images = {i: 0 for i in range(len(data_loader.dataset.labels[0]))}

        style0 = to_var(self.G.random_style(1), volatile=True)
        print(file_name)
        for i, (real_x, org_c, files) in tqdm(enumerate(data_loader),
                                              desc='Calculating LPISP ',
                                              total=len(data_loader)):
            org_label = torch.max(org_c, 1)[1][0]
            real_x = to_var(real_x, volatile=True)
            for label in range(len(data_loader.dataset.labels[0])):
                if org_label == label:
                    continue
                target_c = to_var(org_c * 0, volatile=True)
                target_c[:, label] = 1
                if not style_fixed:
                    style0 = to_var(self.G.random_style(real_x.size(0)),
                                    volatile=True)
                real_x = self.G(real_x,
                                to_var(target_c, volatile=True),
                                stochastic=style0)[0]
                n_images[label] += 1
                for j, (_real_x, _org_c, _files) in enumerate(data_loader):
                    if j <= i:
                        continue
                    _org_label = torch.max(_org_c, 1)[1][0]
                    _real_x = to_var(_real_x, volatile=True)
                    for _label in range(len(data_loader.dataset.labels[0])):
                        if _org_label == _label:
                            continue
                        _target_c = to_var(_org_c * 0, volatile=True)
                        _target_c[:, _label] = 1
                        if not style_fixed:
                            style0 = to_var(self.G.random_style(
                                _real_x.size(0)),
                                            volatile=True)
                        _real_x = self.G(_real_x,
                                         to_var(_target_c, volatile=True),
                                         stochastic=style0)[0]
                        distance, model = compute_lpips(real_x.data,
                                                        _real_x.data,
                                                        model=model)
                        DISTANCE[len(data_loader.dataset.labels[0])].append(
                            distance[0])
                        if label == _label:
                            DISTANCE[_label].append(distance[0])

        file_ = open(file_name, 'w')
        DISTANCE = {k: np.array(v) for k, v in DISTANCE.items()}
        for key, values in DISTANCE.items():
            if key == len(data_loader.dataset.labels[0]):
                mode = 'All'
            else:
                mode = chr(65 + key)
            PRINT(
                file_, "LPISP {}: {} +/- {}".format(mode, values.mean(),
                                                    values.std()))
        file_.close()
        copyfile(file_name, copy_name)
Ejemplo n.º 13
0
 def PRINT(self, str):
     if self.verbose:
         if self.config.mode == 'train':
             PRINT(self.config.log, str)
         else:
             print(str)
Ejemplo n.º 14
0
        # Horovod
        torch.cuda.set_device(hvd.local_rank())
        config.GPU = [int(i) for i in range(hvd.size())]
        config.g_lr *= hvd.size()
        config.d_lr *= hvd.size()

    else:
        if config.GPU == 'NO_CUDA':
            config.GPU = '-1'
        os.environ["CUDA_VISIBLE_DEVICES"] = config.GPU
        config.GPU = [int(i) for i in config.GPU.split(',')]
        config.batch_size *= len(config.GPU)
        config.g_lr *= len(config.GPU)
        config.d_lr *= len(config.GPU)

    torch.manual_seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(config.seed)

    config_yaml(config, 'datasets/{}.yaml'.format(config.dataset_fake))
    config = cfg.update_config(config)
    if config.mode == 'train':
        if hvd.rank() == 0:
            PRINT(config.log, ' '.join(sys.argv))
            _PRINT(config)
        main(config)
        config.log.close()

    else:
        main(config)
Ejemplo n.º 15
0
 def debug(self):
     feed = to_var(torch.ones(1, self.input_dim),
                   volatile=True,
                   no_cuda=True)
     PRINT(self.config.log, '-- DE [*{}]:'.format(self.comment))
     self.print_debug(feed, self.model)