Exemplo n.º 1
0
    def __call__(self, dataset='', load=False):
        import os
        from data_loader import get_loader
        last_name = self.resume_name()
        save_folder = os.path.join(self.config.sample_path,
                                   '{}_test'.format(last_name))
        create_dir(save_folder)
        if dataset == '':
            dataset = self.config.dataset_fake
            data_loader = self.data_loader
            self.dataset_real = dataset
        else:
            data_loader = get_loader(
                self.config.mode_data,
                self.config.image_size,
                self.config.batch_size,
                shuffling=True,
                dataset=dataset,
                mode='test')

        _debug = range(1, self.config.style_label_debug + 1)
        style_all = self.G.random_style(self.config.batch_size)

        string = '{}'.format(TimeNow_str())
        for i, (real_x, org_c, _) in enumerate(data_loader):
            save_path = os.path.join(
                save_folder, '{}_{}_{}.jpg'.format(dataset, '{}', i + 1))
            name = os.path.abspath(save_path.format(string))
            if self.config.dataset_fake == dataset:
                label = org_c
            else:
                label = None
            self.PRINT(
                'Translated test images and saved into "{}"..!'.format(name))

            if self.config.dataset_fake in ['Image2Edges', 'Yosemite']:
                self.save_multimodal_output(real_x, 1 - org_c, name)
                self.save_multimodal_output(
                    real_x, 1 - org_c, name, interpolation=True)

            else:
                self.generate_SMIT(
                    real_x,
                    name,
                    label=label,
                    fixed_style=style_all,
                    TIME=not i)
                for k in _debug:
                    self.generate_SMIT(
                        real_x,
                        name,
                        label=label,
                        Multimodal=k,
                        TIME=not i and k == 1)
                    self.generate_SMIT(
                        real_x,
                        name,
                        label=label,
                        Multimodal=k,
                        fixed_style=style_all)
Exemplo n.º 2
0
def setup_logAndCheckpoints(args):

    # create folder if not there
    create_dir(args.check_point_dir)

    fname = str.lower(args.env_name) + '_' + args.alg_name + '_' + args.log_id
    fname_log = os.path.join(args.log_dir, fname)
    fname_eval = os.path.join(fname_log, 'eval.csv')
    fname_adapt = os.path.join(fname_log, 'adapt.csv')

    return os.path.join(args.check_point_dir, fname), fname_log, fname_eval, fname_adapt
Exemplo n.º 3
0
 def save_multidomain_output(self, real_x, label, save_path, **kwargs):
     self.G.eval()
     self.D.eval()
     no_grad = open('/var/tmp/null.txt',
                    'w') if get_torch_version() < 1.0 else torch.no_grad()
     with no_grad:
         real_x = to_var(real_x, volatile=True)
         n_style = self.config.style_debug
         n_interp = self.config.n_interpolation + 10
         _name = 'domain_interpolation'
         no_label = True
         for idx in range(n_style):
             dirname = save_path.replace('.jpg', '')
             filename = '{}_style{}.jpg'.format(_name,
                                                str(idx + 1).zfill(2))
             _save_path = os.path.join(dirname, filename)
             create_dir(_save_path)
             fake_image_list, fake_attn_list = self.Create_Visual_List(
                 real_x)
             style = self.G.random_style(1).repeat(real_x.size(0), 1)
             style = to_var(style, volatile=True)
             label0 = to_var(label, volatile=True)
             opposite_label = self.target_multiAttr(1 - label,
                                                    2)  # 2: black hair
             opposite_label[:, 7] = 0  # Pale skin
             label1 = to_var(opposite_label, volatile=True)
             labels = [label0, label1]
             styles = [style, style]
             domain_interp = self.MMInterpolation(labels,
                                                  styles,
                                                  n_interp=n_interp)
             for target_de in domain_interp[5:]:
                 # target_de = target_de.repeat(real_x.size(0), 1)
                 target_de = to_var(target_de, volatile=True)
                 fake_x = self.G(real_x, target_de, style, DE=target_de)
                 fake_image_list.append(to_data(fake_x[0], cpu=True))
                 fake_attn_list.append(
                     to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
             self._SAVE_IMAGE(_save_path,
                              fake_image_list,
                              no_label=no_label,
                              arrow=False,
                              circle=False)
             self._SAVE_IMAGE(_save_path,
                              fake_attn_list,
                              Attention=True,
                              arrow=False,
                              no_label=no_label,
                              circle=False)
     self.G.train()
     self.D.train()
Exemplo n.º 4
0
Arquivo: main.py Projeto: rasoolfa/P3O
def setup_logAndCheckpoints(args):

    # create folder if not there
    create_dir(args.check_point_dir)

    # log dir for off-policy logs
    create_dir(args.log_dir + '_off_policy')

    fname = str.lower(args.env_name) + '_' + args.alg_name + '_' + args.log_id
    fname_log = os.path.join(args.log_dir, fname)
    fname_offpolicy = os.path.join(args.log_dir + '_off_policy', fname + '_off.csv')
    fname_alt_reward = os.path.join(fname_log, 'stats.csv')

    return os.path.join(args.check_point_dir, fname), fname_log, fname_offpolicy, fname_alt_reward
Exemplo n.º 5
0
    def DEMO(self, path):
        from data_loader import get_loader
        last_name = self.resume_name()
        save_folder = os.path.join(self.config.sample_path,
                                   '{}_test'.format(last_name))
        create_dir(save_folder)
        batch_size = 1
        no_label = self.config.dataset_fake in self.Binary_Datasets
        data_loader = get_loader(
            path,
            self.config.image_size,
            batch_size,
            shuffling=False,
            dataset='DEMO',
            Detect_Face=True,
            mode='test')
        label = self.config.DEMO_LABEL
        if self.config.DEMO_LABEL != '':
            label = torch.FloatTensor([int(i) for i in label.split(',')]).view(
                1, -1)
        else:
            label = None
        _debug = range(self.config.style_label_debug + 1)
        style_all = self.G.random_style(max(self.config.batch_size, 50))

        name = TimeNow_str()
        for i, real_x in enumerate(data_loader):
            save_path = os.path.join(save_folder, 'DEMO_{}_{}.jpg'.format(
                name, i + 1))
            self.PRINT('Translated test images and saved into "{}"..!'.format(
                save_path))
            for k in _debug:
                self.generate_SMIT(
                    real_x,
                    save_path,
                    label=label,
                    Multimodal=k,
                    fixed_style=style_all,
                    TIME=not i,
                    no_label=no_label,
                    circle=True)
                self.generate_SMIT(
                    real_x,
                    save_path,
                    label=label,
                    Multimodal=k,
                    no_label=no_label,
                    circle=True)
Exemplo n.º 6
0
    #### Generic setups
    ##############################
    CUDA_AVAL = torch.cuda.is_available()

    if not args.disable_cuda and CUDA_AVAL: 
        gpu_id = "cuda:" + str(args.gpu_id)
        device = torch.device(gpu_id)
        print("**** Yayy we use GPU %s ****" % gpu_id)

    else:                                                   
        device = torch.device('cpu')
        print("**** No GPU detected or GPU usage is disabled, sorry! ****")

    ####
    # train and evalution checkpoints, log folders, ck file names
    create_dir(args.log_dir, cleanup = True)
    # create folder for save checkpoints
    ck_fname_part, log_file_dir, fname_csv_eval, fname_adapt = setup_logAndCheckpoints(args)
    logger.configure(dir = log_file_dir)
    wrt_csv_eval = None

    ##############################
    #### Init env, model, alg, batch generator etc
    #### Step 1: build env
    #### Step 2: Build model
    #### Step 3: Initiate Alg e.g. a2c
    #### Step 4: Initiate batch/rollout generator  
    ##############################

    ##### env setup #####
    env = make_env(args)
Exemplo n.º 7
0
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation:
                    _name = _name + '_interp'
                _save_path = os.path.join(
                    save_path.replace('.jpg', ''), '{}_{}.jpg'.format(
                        _name,
                        str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1, 1, 1)

                fake_image_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]
                fake_attn_list = [
                    to_data(
                        color_frame(
                            single_source(real_x0),
                            thick=5,
                            color='green',
                            first=True),
                        cpu=True)
                ]

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    # target_c = _target_c#[0].repeat(n_rep, 1)
                    if not interpolation:
                        style_ = self.G.random_style(n_rep)
                    else:
                        z0 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        z1 = to_data(
                            self.G.random_style(1), cpu=True).numpy()[0]
                        style_ = self.G.random_style(n_rep)
                        style_[:] = torch.FloatTensor(
                            np.array([
                                slerp(sz, z0, z1)
                                for sz in np.linspace(0, 1, n_rep)
                            ]))
                    style = to_var(style_, volatile=True)
                    fake_x = self.G(real_x0, target_c, stochastic=style)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(
                    _save_path,
                    fake_image_list,
                    mode='style_' + chr(65 + idx),
                    no_label=no_label,
                    arrow=interpolation,
                    circle=False)
                self._SAVE_IMAGE(
                    _save_path,
                    fake_attn_list,
                    Attention=True,
                    mode='style_' + chr(65 + idx),
                    arrow=interpolation,
                    no_label=no_label,
                    circle=False)
        self.G.train()
        self.D.train()
Exemplo n.º 8
0
    def save_multimodal_output(self,
                               real_x,
                               label,
                               save_path,
                               interpolation=False,
                               **kwargs):
        self.G.eval()
        self.D.eval()
        n_rep = 4
        no_label = self.config.dataset_fake in self.Binary_Datasets
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            real_x = to_var(real_x, volatile=True)
            out_label = to_var(label, volatile=True)
            # target_c_list = [out_label] * 7

            for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)):
                _name = 'multimodal'
                if interpolation == 1:
                    _name += '_interp'
                elif interpolation == 2:
                    _name = 'multidomain_interp'

                _save_path = os.path.join(
                    save_path.replace('.jpg', ''),
                    '{}_{}.jpg'.format(_name,
                                       str(idx).zfill(4)))
                create_dir(_save_path)
                real_x0 = real_x0.repeat(n_rep, 1, 1, 1)
                real_c0 = real_c0.repeat(n_rep, 1)

                fake_image_list, fake_attn_list = self.Create_Visual_List(
                    real_x0, Multimodal=True)

                target_c_list = [real_c0] * 7
                for _, target_c in enumerate(target_c_list):
                    if interpolation == 0:
                        style_ = to_var(self.G.random_style(n_rep),
                                        volatile=True)
                        embeddings = self.label2embedding(target_c,
                                                          style_,
                                                          _torch=True)
                    elif interpolation == 1:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        style1 = to_var(self.G.random_style(1), volatile=True)
                        _target_c = target_c[0].unsqueeze(0)
                        styles = [style_, style1]
                        targets = [_target_c, _target_c]
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    elif interpolation == 2:
                        style_ = to_var(self.G.random_style(1), volatile=True)
                        target0 = 1 - target_c[0].unsqueeze(0)
                        target1 = target_c[0].unsqueeze(0)
                        styles = [style_, style_]
                        targets = [target0, target1]
                        # import ipdb; ipdb.set_trace()
                        embeddings = self.MMInterpolation(targets,
                                                          styles,
                                                          n_interp=n_rep)[:, 0]
                    else:
                        raise ValueError(
                            "There are only 2 types of interpolation:\
                            Multimodal and Multi-domain")
                    fake_x = self.G(real_x0, target_c, style_, DE=embeddings)
                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))
                self._SAVE_IMAGE(_save_path,
                                 fake_image_list,
                                 mode='style_' + chr(65 + idx),
                                 no_label=no_label,
                                 arrow=interpolation,
                                 circle=False)
                self._SAVE_IMAGE(_save_path,
                                 fake_attn_list,
                                 Attention=True,
                                 mode='style_' + chr(65 + idx),
                                 arrow=interpolation,
                                 no_label=no_label,
                                 circle=False)
        self.G.train()
        self.D.train()
Exemplo n.º 9
0
    def generate_SMIT(self,
                      batch,
                      save_path,
                      Multimodal=0,
                      label=None,
                      output=False,
                      training=False,
                      fixed_style=None,
                      TIME=False,
                      **kwargs):
        self.G.eval()
        self.D.eval()
        modal = 'Multimodal' if Multimodal else 'Unimodal'
        Output = []
        flag_time = True
        no_grad = open('/var/tmp/null.txt',
                       'w') if get_torch_version() < 1.0 else torch.no_grad()
        with no_grad:
            batch = self.get_batch_inference(batch, Multimodal)
            _label = self.get_batch_inference(label, Multimodal)
            for idx, real_x in enumerate(batch):
                if training and Multimodal and \
                        idx == self.config.style_train_debug:
                    break
                real_x = to_var(real_x, volatile=True)
                label = _label[idx]
                target_list = target_debug_list(
                    real_x.size(0), self.config.c_dim, config=self.config)

                # Start translations
                fake_image_list, fake_attn_list = self.Create_Visual_List(
                    real_x, Multimodal=Multimodal)
                if self.config.dataset_fake in self.MultiLabel_Datasets \
                        and label is None:
                    self.org_label = self._CLS(real_x)
                elif label is not None:
                    self.org_label = to_var(label.squeeze(), volatile=True)
                else:
                    self.org_label = torch.zeros(
                        real_x.size(0), self.config.c_dim)
                    self.org_label = to_var(self.org_label, volatile=True)

                if fixed_style is None:
                    style = self.random_style(real_x.size(0))
                    style = to_var(style, volatile=True)
                else:
                    style = to_var(fixed_style[:real_x.size(0)], volatile=True)

                for k, target in enumerate(target_list):
                    start_time = time.time()
                    embeddings = self.Modality(
                        target, style, Multimodal, idx=k)
                    fake_x = self.G(real_x, target, style, DE=embeddings)
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))
                    if TIME and flag_time:
                        print("[{}] Time/batch x forward (bs:{}): {}".format(
                            modal, real_x.size(0), elapsed))
                        flag_time = False

                    fake_image_list.append(to_data(fake_x[0], cpu=True))
                    fake_attn_list.append(
                        to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True))

                # Create Folder
                if training:
                    _name = '' if fixed_style is not None \
                            and Multimodal else '_Random'
                    _save_path = save_path.replace('.jpg', _name + '.jpg')
                else:
                    _name = '' if fixed_style is not None else '_Random'
                    _save_path = os.path.join(
                        save_path.replace('.jpg', ''), '{}_{}{}.jpg'.format(
                            Multimodal,
                            str(idx).zfill(4), _name))
                    create_dir(_save_path)

                mode = 'fake' if not Multimodal else 'style_' + chr(65 + idx)
                Output.extend(
                    self._SAVE_IMAGE(
                        _save_path, fake_image_list, mode=mode, **kwargs))
                Output.extend(
                    self._SAVE_IMAGE(
                        _save_path,
                        fake_attn_list,
                        Attention=True,
                        mode=mode,
                        **kwargs))

        self.G.train()
        self.D.train()
        if output:
            return Output