コード例 #1
0
    def Create_Visual_List(self, batch):
        fake_image_list = single_source(to_data(batch))
        fake_attn_list = single_source(denorm(to_data(batch)))

        fake_image_list = color_frame(
            fake_image_list, thick=5, color='green', first=True)
        fake_attn_list = color_frame(
            fake_attn_list, thick=5, color='green', first=True)

        fake_image_list = [fake_image_list.cpu()]
        fake_attn_list = [fake_attn_list.cpu()]

        return fake_image_list, fake_attn_list
コード例 #2
0
ファイル: test.py プロジェクト: yuzuda283/SMIT
 def save_multidomain_output(self, real_x, label, save_path, **kwargs):
     self.G.eval()
     self.D.eval()
     no_grad = open('/var/tmp/null.txt',
                    'w') if get_torch_version() < 1.0 else torch.no_grad()
     with no_grad:
         real_x = to_var(real_x, volatile=True)
         n_style = self.config.style_debug
         n_interp = self.config.n_interpolation + 10
         _name = 'domain_interpolation'
         no_label = True
         for idx in range(n_style):
             dirname = save_path.replace('.jpg', '')
             filename = '{}_style{}.jpg'.format(_name,
                                                str(idx + 1).zfill(2))
             _save_path = os.path.join(dirname, filename)
             create_dir(_save_path)
             fake_image_list, fake_attn_list = self.Create_Visual_List(
                 real_x)
             style = self.G.random_style(1).repeat(real_x.size(0), 1)
             style = to_var(style, volatile=True)
             label0 = to_var(label, volatile=True)
             opposite_label = self.target_multiAttr(1 - label,
                                                    2)  # 2: black hair
             opposite_label[:, 7] = 0  # Pale skin
             label1 = to_var(opposite_label, volatile=True)
             labels = [label0, label1]
             styles = [style, style]
             domain_interp = self.MMInterpolation(labels,
                                                  styles,
                                                  n_interp=n_interp)
             for target_de in domain_interp[5:]:
                 # target_de = target_de.repeat(real_x.size(0), 1)
                 target_de = to_var(target_de, volatile=True)
                 fake_x = self.G(real_x, target_de, style, DE=target_de)
                 fake_image_list.append(to_data(fake_x[0], cpu=True))
                 fake_attn_list.append(
                     to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
             self._SAVE_IMAGE(_save_path,
                              fake_image_list,
                              no_label=no_label,
                              arrow=False,
                              circle=False)
             self._SAVE_IMAGE(_save_path,
                              fake_attn_list,
                              Attention=True,
                              arrow=False,
                              no_label=no_label,
                              circle=False)
     self.G.train()
     self.D.train()
コード例 #3
0
    def Create_Visual_List(self, batch, Multimodal=False):
        batch = to_data(batch)
        if Multimodal:
            fake_image_list = single_source(batch)
            fake_attn_list = single_source(denorm(batch))
            fake_image_list = color_frame(
                fake_image_list, thick=5, color='green', first=True)
            fake_attn_list = color_frame(
                fake_attn_list, thick=5, color='green', first=True)
            fake_image_list = [fake_image_list.cpu()]
            fake_attn_list = [fake_attn_list.cpu()]
        else:
            fake_image_list = [batch.cpu()]
            fake_attn_list = [denorm(batch).cpu()]

        return fake_image_list, fake_attn_list
コード例 #4
0
ファイル: scores.py プロジェクト: zhoushiwei/SMIT
    def INCEPTION_REAL(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'Real'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            label = torch.max(org_c, 1)[1][0]
            real_x = to_var((real_x + 1) / 2., volatile=True)
            pred = to_data(F.softmax(net(inception_up(real_x)), dim=1),
                           cpu=True).numpy()
            PRED_IS[int(label)].append(pred)

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            # prior is computed from all outputs
            py = np.sum(PRED_IS[label], axis=0)
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        file_.close()
コード例 #5
0
ファイル: RafD.py プロジェクト: zhoushiwei/SMIT
def train_inception(batch_size, shuffling=False, num_workers=4, **kwargs):

    from torchvision.models import inception_v3
    from misc.utils import to_var, to_cuda, to_data
    from torchvision import transforms
    from torch.utils.data import DataLoader
    import torch.nn.functional as F
    import torch
    import torch.nn as nn
    import tqdm

    metadata_path = os.path.join('data', 'RafD', 'normal')
    # inception Norm

    image_size = 299

    transform = []
    window = int(image_size / 10)
    transform += [
        transforms.Resize((image_size + window, image_size + window),
                          interpolation=Image.ANTIALIAS)
    ]
    transform += [
        transforms.RandomResizedCrop(image_size,
                                     scale=(0.7, 1.0),
                                     ratio=(0.8, 1.2))
    ]
    transform += [transforms.RandomHorizontalFlip()]
    transform += [transforms.ToTensor()]
    transform = transforms.Compose(transform)

    dataset_train = RafD(image_size,
                         metadata_path,
                         transform,
                         'train',
                         shuffling=True,
                         **kwargs)
    dataset_test = RafD(image_size,
                        metadata_path,
                        transform,
                        'test',
                        shuffling=False,
                        **kwargs)

    train_loader = DataLoader(dataset=dataset_train,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)
    test_loader = DataLoader(dataset=dataset_test,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=num_workers)

    num_labels = len(train_loader.dataset.labels[0])
    n_epochs = 100
    net = inception_v3(pretrained=True, transform_input=True)
    net.aux_logits = False
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, num_labels)

    net_save = metadata_path + '/inception_v3/{}.pth'
    if not os.path.isdir(os.path.dirname(net_save)):
        os.makedirs(os.path.dirname(net_save))
    print("Model will be saved at: " + net_save)
    optimizer = torch.optim.RMSprop(net.parameters(), lr=1e-5)
    # loss = F.cross_entropy(output, target)
    to_cuda(net)

    for epoch in range(n_epochs):
        LOSS = {'train': [], 'test': []}
        OUTPUT = {'train': [], 'test': []}
        LABEL = {'train': [], 'test': []}

        net.eval()
        for i, (data, label,
                files) in tqdm.tqdm(enumerate(test_loader),
                                    total=len(test_loader),
                                    desc='Validating Inception V3 | RafD'):
            data = to_var(data, volatile=True)
            label = to_var(torch.max(label, dim=1)[1], volatile=True)
            out = net(data)
            loss = F.cross_entropy(out, label)
            # ipdb.set_trace()
            LOSS['test'].append(to_data(loss, cpu=True)[0])
            OUTPUT['test'].extend(
                to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist())
            LABEL['test'].extend(to_data(label, cpu=True).tolist())
        acc_test = (np.array(OUTPUT['test']) == np.array(LABEL['test'])).mean()
        print('[Test] Loss: {:.4f} Acc: {:.4f}'.format(
            np.array(LOSS['test']).mean(), acc_test))

        net.train()
        for i, (data, label, files) in tqdm.tqdm(
                enumerate(train_loader),
                total=len(train_loader),
                desc='[{}/{}] Train Inception V3 | RafD'.format(
                    str(epoch).zfill(5),
                    str(n_epochs).zfill(5))):
            # ipdb.set_trace()
            data = to_var(data)
            label = to_var(torch.max(label, dim=1)[1])
            out = net(data)
            # ipdb.set_trace()
            loss = F.cross_entropy(out, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            LOSS['train'].append(to_data(loss, cpu=True)[0])
            OUTPUT['train'].extend(
                to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist())
            LABEL['train'].extend(to_data(label, cpu=True).tolist())

        acc_train = (np.array(OUTPUT['train']) == np.array(
            LABEL['train'])).mean()
        print('[Train] Loss: {:.4f} Acc: {:.4f}'.format(
            np.array(LOSS['train']).mean(), acc_train))
        torch.save(net.state_dict(), net_save.format(str(epoch).zfill(5)))
        train_loader.dataset.shuffle(epoch)
コード例 #6
0
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation:
                    _name = _name + '_interp'
                _save_path = os.path.join(
                    save_path.replace('.jpg', ''), '{}_{}.jpg'.format(
                        _name,
                        str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1, 1, 1)

                fake_image_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]
                fake_attn_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    # target_c = _target_c#[0].repeat(n_rep, 1)
                    if not interpolation:
                        style_ = self.G.random_style(n_rep)
                    else:
                        z0 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        z1 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        style_ = self.G.random_style(n_rep)
                        style_[:] = torch.FloatTensor(
                            np.array([
                                slerp(sz, z0, z1)
                                for sz in np.linspace(0, 1, n_rep)
                            ]))
                    style = to_var(style_, volatile=True)
                    fake_x = self.G(real_x0, target_c, stochastic=style)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(
                    _save_path,
                    fake_image_list,
                    mode='style_' + chr(65 + idx),
                    no_label=no_label,
                    arrow=interpolation,
                    circle=False)
                self._SAVE_IMAGE(
                    _save_path,
                    fake_attn_list,
                    Attention=True,
                    mode='style_' + chr(65 + idx),
                    arrow=interpolation,
                    no_label=no_label,
                    circle=False)
        self.G.train()
        self.D.train()
コード例 #7
0
ファイル: test.py プロジェクト: yuzuda283/SMIT
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation == 1:
                    _name += '_interp'
                elif interpolation == 2:
                    _name = 'multidomain_interp'

                _save_path = os.path.join(
                    save_path.replace('.jpg', ''),
                    '{}_{}.jpg'.format(_name,
                                       str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1)

                fake_image_list, fake_attn_list = self.Create_Visual_List(
                    real_x0, Multimodal=True)

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    if interpolation == 0:
                        style_ = to_var(self.G.random_style(n_rep),
                                        volatile=True)
                        embeddings = self.label2embedding(target_c,
                                                          style_,
                                                          _torch=True)
                    elif interpolation == 1:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        style1 = to_var(self.G.random_style(1), volatile=True)
                        _target_c = target_c[0].unsqueeze(0)
                        styles = [style_, style1]
                        targets = [_target_c, _target_c]
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    elif interpolation == 2:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        target0 = 1 - target_c[0].unsqueeze(0)
                        target1 = target_c[0].unsqueeze(0)
                        styles = [style_, style_]
                        targets = [target0, target1]
                        # import ipdb; ipdb.set_trace()
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    else:
                        raise ValueError(
                            "There are only 2 types of interpolation:\
                            Multimodal and Multi-domain")
                    fake_x = self.G(real_x0, target_c, style_, DE=embeddings)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(_save_path,
                                 fake_image_list,
                                 mode='style_' + chr(65 + idx),
                                 no_label=no_label,
                                 arrow=interpolation,
                                 circle=False)
                self._SAVE_IMAGE(_save_path,
                                 fake_attn_list,
                                 Attention=True,
                                 mode='style_' + chr(65 + idx),
                                 arrow=interpolation,
                                 no_label=no_label,
                                 circle=False)
        self.G.train()
        self.D.train()
コード例 #8
0
ファイル: scores.py プロジェクト: zhoushiwei/SMIT
    def INCEPTION(self):
        from misc.utils import load_inception
        from scipy.stats import entropy
        n_styles = 20
        net = load_inception()
        net = to_cuda(net)
        net.eval()
        self.G.eval()
        inception_up = nn.Upsample(size=(299, 299), mode='bilinear')
        mode = 'SMIT'
        data_loader = self.data_loader
        file_name = 'scores/Inception_{}.txt'.format(mode)

        PRED_IS = {i: []
                   for i in range(len(data_loader.dataset.labels[0]))
                   }  # 0:[], 1:[], 2:[]}
        CIS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}
        IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))}

        for i, (real_x, org_c, files) in tqdm(
                enumerate(data_loader),
                desc='Calculating CIS/IS - {}'.format(file_name),
                total=len(data_loader)):
            PRED_CIS = {
                i: []
                for i in range(len(data_loader.dataset.labels[0]))
            }  # 0:[], 1:[], 2:[]}
            org_label = torch.max(org_c, 1)[1][0]
            real_x = real_x.repeat(n_styles, 1, 1, 1)  # .unsqueeze(0)
            real_x = to_var(real_x, volatile=True)

            target_c = (org_c * 0).repeat(n_styles, 1)
            target_c = to_var(target_c, volatile=True)
            for label in range(len(data_loader.dataset.labels[0])):
                if org_label == label:
                    continue
                target_c *= 0
                target_c[:, label] = 1
                style = to_var(self.G.random_style(n_styles),
                               volatile=True) if mode == 'SMIT' else None

                fake = (self.G(real_x, target_c, style)[0] + 1) / 2

                pred = to_data(F.softmax(net(inception_up(fake)), dim=1),
                               cpu=True).numpy()
                PRED_CIS[label].append(pred)
                PRED_IS[label].append(pred)

                # CIS for each image
                PRED_CIS[label] = np.concatenate(PRED_CIS[label], 0)
                py = np.sum(
                    PRED_CIS[label], axis=0
                )  # prior is computed from outputs given a specific input
                for j in range(PRED_CIS[label].shape[0]):
                    pyx = PRED_CIS[label][j, :]
                    CIS[label].append(entropy(pyx, py))

        for label in range(len(data_loader.dataset.labels[0])):
            PRED_IS[label] = np.concatenate(PRED_IS[label], 0)
            py = np.sum(PRED_IS[label],
                        axis=0)  # prior is computed from all outputs
            for j in range(PRED_IS[label].shape[0]):
                pyx = PRED_IS[label][j, :]
                IS[label].append(entropy(pyx, py))

        total_cis = []
        total_is = []
        file_ = open(file_name, 'w')
        for label in range(len(data_loader.dataset.labels[0])):
            cis = np.exp(np.mean(CIS[label]))
            total_cis.append(cis)
            _is = np.exp(np.mean(IS[label]))
            total_is.append(_is)
            PRINT(file_, "Label {}".format(label))
            PRINT(file_, "Inception Score: {:.4f}".format(_is))
            PRINT(file_, "conditional Inception Score: {:.4f}".format(cis))
        PRINT(file_, "")
        PRINT(
            file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_is).mean(),
                np.array(total_is).std()))
        PRINT(
            file_,
            "[TOTAL] conditional Inception Score: {:.4f} +/- {:.4f}".format(
                np.array(total_cis).mean(),
                np.array(total_cis).std()))
        file_.close()
コード例 #9
0
 def imshow(self, img):
     import matplotlib.pyplot as plt
     img = to_data(denorm(img), cpu=True).numpy()
     img = img.transpose(1, 2, 0)
     plt.imshow(img)
     plt.show()
コード例 #10
0
    def generate_SMIT(self,
                      batch,
                      save_path,
                      Multimodal=0,
                      label=None,
                      output=False,
                      training=False,
                      fixed_style=None,
                      TIME=False,
                      **kwargs):
        self.G.eval()
        self.D.eval()
        modal = 'Multimodal' if Multimodal else 'Unimodal'
        Output = []
        flag_time = True
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            batch = self.get_batch_inference(batch, Multimodal)
            _label = self.get_batch_inference(label, Multimodal)
            for idx, real_x in enumerate(batch):
                if training and Multimodal and \
                        idx == self.config.style_train_debug:
                    break
                real_x = to_var(real_x, volatile=True)
                label = _label[idx]
                target_list = target_debug_list(
                    real_x.size(0), self.config.c_dim, config=self.config)

                # Start translations
                fake_image_list, fake_attn_list = self.Create_Visual_List(
                    real_x, Multimodal=Multimodal)
                if self.config.dataset_fake in self.MultiLabel_Datasets \
                        and label is None:
                    self.org_label = self._CLS(real_x)
                elif label is not None:
                    self.org_label = to_var(label.squeeze(), volatile=True)
                else:
                    self.org_label = torch.zeros(
                        real_x.size(0), self.config.c_dim)
                    self.org_label = to_var(self.org_label, volatile=True)

                if fixed_style is None:
                    style = self.random_style(real_x.size(0))
                    style = to_var(style, volatile=True)
                else:
                    style = to_var(fixed_style[:real_x.size(0)], volatile=True)

                for k, target in enumerate(target_list):
                    start_time = time.time()
                    embeddings = self.Modality(
                        target, style, Multimodal, idx=k)
                    fake_x = self.G(real_x, target, style, DE=embeddings)
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    if TIME and flag_time:
                        print("[{}] Time/batch x forward (bs:{}): {}".format(
                            modal, real_x.size(0), elapsed))
                        flag_time = False

                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))

                # Create Folder
                if training:
                    _name = '' if fixed_style is not None \
                            and Multimodal else '_Random'
                    _save_path = save_path.replace('.jpg', _name + '.jpg')
                else:
                    _name = '' if fixed_style is not None else '_Random'
                    _save_path = os.path.join(
                        save_path.replace('.jpg', ''), '{}_{}{}.jpg'.format(
                            Multimodal,
                            str(idx).zfill(4), _name))
                    create_dir(_save_path)

                mode = 'fake' if not Multimodal else 'style_' + chr(65 + idx)
                Output.extend(
                    self._SAVE_IMAGE(
                        _save_path, fake_image_list, mode=mode, **kwargs))
                Output.extend(
                    self._SAVE_IMAGE(
                        _save_path,
                        fake_attn_list,
                        Attention=True,
                        mode=mode,
                        **kwargs))

        self.G.train()
        self.D.train()
        if output:
            return Output