示例#1
0
    def load_data(self, data):
        """
        Load the full patch and mask
        Parameters:
            data -- the dictionary include the attribute of 'image', 'mask', 'name', as created by "data.py" file.
        """
        self.image_name = data[
            'name']  # here the name is set as the name of input patch.
        self.img = data['image']
        self.mask = data['mask']

        if self.mask.shape != self.img.shape:
            raise ValueError('The loaded mask shape has to be', self.img.shape)

        sha = tuple(range(self.img.ndim))
        re_sha = sha[-1:] + sha[:-1]

        self.img_ = u.np_to_torch(np.transpose(
            self.img, re_sha)[np.newaxis]).type(self.dtype)
        self.mask_ = u.np_to_torch(
            np.transpose(self.mask, re_sha)[np.newaxis]).type(self.dtype)
        self.coarse_img_ = self.img_ * self.mask_

        # compute std on coarse data for skipping all-zeros patches
        input_std = torch.std(self.img_ * self.mask_).item()
        return input_std
示例#2
0
def load_data():
    train_data = np.load(TRAIN_DATA_PATH)
    test_data = np.load(TEST_DATA_PATH)
    train_xy_torch = np_to_torch(train_data)
    test_xy_torch = np_to_torch(test_data)
    xtrain = train_xy_torch[:, :, :-1]
    ytrain = torch.reshape(train_xy_torch[:, -1, -1],
                           (train_xy_torch.shape[0], 1))
    xtest = test_xy_torch[:, :, :-1]
    ytest = torch.reshape(test_xy_torch[:, -1, -1],
                          (test_xy_torch.shape[0], 1))
    return xtrain, ytrain, xtest, ytest
示例#3
0
    def build_input(self):
        # build a noise tensor
        data_shape = self.img.shape[:-1]
        self.input_ = u.get_noise(shape=(1, self.args.inputdepth) + data_shape,
                                  noise_type=self.args.noise_dist).type(
                                      self.dtype)
        self.input_ *= self.args.noise_std

        if self.args.filter_noise_with_wavelet:
            self.input_ = u.np_to_torch(
                u.filter_noise_traces(
                    self.input_.detach().clone().cpu().numpy(),
                    np.load(os.path.join(self.args.imgdir,
                                         'wavelet.npy')))).type(self.dtype)

        if self.args.data_forgetting_factor != 0:
            # build decimated data tensor
            data_ = self.img_ * self.mask_
            # how many times we can repeat the data in order to fill the input depth?
            num_rep = int(np.ceil(self.args.inputdepth / self.args.imgchannel))
            # repeat data along the channel dim and crop to the input depth size
            data_ = data_.repeat([1, num_rep] + [1] *
                                 len(data_shape))[:, :self.args.inputdepth]
            # normalize data to noise std
            data_ *= torch.std(self.input_) / torch.std(data_)
            self.add_data_ = data_
            self.add_data_weight = np.logspace(
                0, -4, self.args.data_forgetting_factor)

        self.input_old = self.input_.detach().clone()
        self.add_noise_ = self.input_.detach().clone()
        print(
            colored('The input shape is %s' % str(tuple(self.input_.shape)),
                    'cyan'))
示例#4
0
def load_data():
    train_data = np.load(TRAIN_DATA_PATH)
    test_data = np.load(TEST_DATA_PATH)
    train_xy, valid_xy = train_test_split(train_data,
                                          test_size=0.2,
                                          shuffle=True,
                                          random_state=0)
    train_xy_torch = np_to_torch(train_xy)
    valid_xy_torch = np_to_torch(valid_xy)
    test_xy_torch = np_to_torch(test_data)
    xtrain = train_xy_torch[:, :, :-1]
    ytrain = torch.reshape(train_xy_torch[:, -1, -1],
                           (train_xy_torch.shape[0], 1))
    xvalid = valid_xy_torch[:, :, :-1]
    yvalid = torch.reshape(valid_xy_torch[:, -1, -1],
                           (valid_xy_torch.shape[0], 1))
    xtest = test_xy_torch[:, :, :-1]
    ytest = torch.reshape(test_xy_torch[:, -1, -1],
                          (test_xy_torch.shape[0], 1))
    return xtrain, ytrain, xvalid, yvalid, xtest, ytest
示例#5
0
 def get_minibatches(self, batch_size=64, shuffle=True, drop_reminder=True, use_torch=False, device='cpu'):
     idx = np.arange(len(self.X))
     if shuffle:
         np.random.shuffle(idx)
     if drop_reminder:
         n_batches = len(idx) // batch_size
     else:
         n_batches = np.ceil(len(idx) / batch_size).astype(np.int32)
     for b in range(n_batches):
         li = b*batch_size
         ri = min((b+1)*batch_size, len(idx))
         current_idx = idx[li:ri]
         xbatch = np.concatenate(self.X[current_idx])
         ybatch = self.y[current_idx]
         sbatch = np.concatenate([[s]*len(self.X[c]) for s,c in enumerate(current_idx)])
         
         if use_torch:
              yield np_to_torch(xbatch, dtype=torch.long, device=device), \
                    np_to_torch(sbatch, dtype=torch.long, device=device), \
                    np_to_torch(ybatch, dtype=torch.float32, device=device)
         else:
              yield xbatch, sbatch, ybatch
示例#6
0
 def pre_process(self, image, target=True):
     if self.mode in ['SR', 'hybrid']:
         # apply downsampling, this part is the same as deep image prior
         if target:
             image_pil = utils.np_to_pil(
                 utils.torch_to_np((image.cpu() + 1) / 2))
             LR_size = [
                 image_pil.size[0] // self.factor,
                 image_pil.size[1] // self.factor
             ]
             img_LR_pil = image_pil.resize(LR_size, Image.ANTIALIAS)
             image = utils.np_to_torch(utils.pil_to_np(img_LR_pil)).cuda()
             image = image * 2 - 1
         else:
             image = self.downsampler((image + 1) / 2)
             image = image * 2 - 1
         # interpolate to the orginal resolution via bilinear interpolation
         image = F.interpolate(image,
                               scale_factor=self.factor,
                               mode='bilinear')
     n, _, h, w = image.size()
     if self.mode in ['colorization', 'hybrid']:
         # transform the image to gray-scale
         r = image[:, 0, :, :]
         g = image[:, 1, :, :]
         b = image[:, 2, :, :]
         gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
         image = gray.view(n, 1, h, w).expand(n, 3, h, w)
     if self.mode in ['inpainting', 'hybrid']:
         # remove the center part of the image
         hole = min(h, w) // 3
         begin = (h - hole) // 2
         end = h - begin
         self.begin, self.end = begin, end
         mask = torch.ones(1, 1, h, w).cuda()
         mask[0, 0, begin:end, begin:end].zero_()
         image = image * mask
     return image
示例#7
0
    def init_params_with_data(self, dataset, config, device=None, subset=None):

        balance_by_domain = config.balance_batches_by_domain
        assert 'init_params' in config
        init_params = config.init_params
        
        # The code here repeats the steps in forward above, but adds the steps necessary for initialization.
        # I chose to keep these two methods separate to leave the forward small and easy to read.
        with torch.no_grad():

            x, meta, _ = dataset.get_data_and_meta(subset)
            speaker_ids = meta['speaker_id']
            domain_ids  = meta['domain_id']
            x_torch = utils.np_to_torch(x, device)
            
            if init_params.get("random"):
                self.lda_stage.init_random(init_params.get('stdev',0.1))
            else:
                self.lda_stage.init_with_lda(x, speaker_ids, init_params, sec_ids=domain_ids)

            x2_torch = self.lda_stage(x_torch)
            x2 = x2_torch.cpu().numpy()

            if hasattr(self,'si_stage1'):
                if self.si_input == 'main_input':
                    si_input_torch = x_torch
                    si_input = x
                else:
                    si_input_torch = x2_torch
                    si_input = x2

                if init_params.get("random"):
                    self.si_stage1.init_random(init_params.get('stdev',0.1))
                else:
                    self.si_stage1.init_with_lda(si_input, speaker_ids, init_params, sec_ids=domain_ids, complement=True)

                s2_torch = self.si_stage1(si_input_torch)
                    
                if init_params.get('init_si_stage2_with_domain_gb', False):
                    # Initialize the second stage of the si-extractor to be a gaussian backend that predicts
                    # the posterior of each domain. In this case, the number of domains has to coincide with the
                    # dimension of the side info vector
                    assert self.si_dim == len(np.unique(domain_ids))
                    self.si_stage2.init_with_lda(s2_torch.cpu().numpy(), domain_ids, init_params, sec_ids=speaker_ids, gaussian_backend=True)
    
                else:
                    # This is the only component that is initialized randomly unless otherwise indicated by the variable "init_si_stage2_with_domain_gb"
                    self.si_stage2.init_random(init_params.get('w_init', 0.5), init_params.get('b_init', 0.0), init_params.get('type', 'normal'))

                if hasattr(self,'shift_selector'):
                    # Initialize the shifts as the mean of the lda outputs weighted by the si
                    si_torch = self.si_stage2(s2_torch)
                    si = si_torch.cpu().numpy()
                    if init_params.get("random"):
                        self.shift_selector.init_random(init_params.get('stdev',0.1))
                    else:
                        self.shift_selector.init_with_weighted_means(x2, si)
                    x2_torch -= self.shift_selector(si_torch)
                    x2 = x2_torch.cpu().numpy()

            if init_params.get("random"):
                self.plda_stage.init_random(init_params.get('stdev',0.1))
            else:    
                self.plda_stage.init_with_plda(x2, speaker_ids, init_params, domain_ids=domain_ids)

            # Since the training data is usually large, we cannot create all possible trials for x3.
            # So, to create a bunch of trials, we just create a trial loader with a large batch size.
            # This means we need to rerun lda again, but it is a small price to pay for the 
            # convenience of reusing the machinery of trial creation in the TrialLoader.
            loader = ddata.TrialLoader(dataset, device, seed=0, batch_size=2000, num_batches=1, balance_by_domain=balance_by_domain, subset=subset)
            x_torch, meta_batch = next(loader.__iter__())
            x2_torch = self.lda_stage(x_torch)
            scrs_torch = self.plda_stage(x2_torch)
            same_spk_torch, valid_torch = utils.create_scoring_masks(meta_batch)
            scrs, same_spk, valid = [v.detach().cpu().numpy() for v in [scrs_torch, same_spk_torch, valid_torch]]

            if init_params.get("random"):
                self.cal_stage.init_random(init_params.get('stdev',0.1))
            else:
                self.cal_stage.init_with_logreg(scrs, same_spk, valid, config.ptar, std_for_mats=init_params.get('std_for_cal_matrices',0))

            dummy_durs = torch.ones(scrs.shape[0]).to(device) 
            dummy_si = torch.zeros(scrs.shape[0], self.si_dim).to(device)
            llrs_torch = self.cal_stage(scrs_torch, dummy_durs, dummy_si)
            mask = np.ones_like(same_spk, dtype=int)
            mask[~same_spk] = -1
            mask[~valid] = 0
            
            return compute_loss(llrs_torch, mask=utils.np_to_torch(mask, device), ptar=config.ptar, loss_type=config.loss)
示例#8
0
 def _np_to_torch(self, x):
     return utils.np_to_torch(x, self.device)
示例#9
0
    def __init__(self,
                 emb_file,
                 meta_file=None,
                 meta_is_dur_only=False,
                 device=None):
        """
        Args:
            emb_file (string):  File with embeddings and sample ids in npz format
            meta_file (string): File with metadata for each id we want to store from the file above. 
                                Should contain: sample_id speaker_id session_id domain_id
                                *  The speaker id is a unique string identifying the speaker
                                *  The session id is a unique string identifying the recording session from which
                                   the audio sample was extracted (ie, if a waveform is split into chunks as a pre-
                                   processing step, the chunks all belong to the same session, or if several mics 
                                   were used to record a person speaking all these recordings belong to the same
                                   session). This information is used by the loader to avoid creating same-session
                                   trials which would mess up calibration.
                                *  The domain id is a unique string identifying the domain. Domains should correspond 
                                   to disjoint speaker sets. This information is also used by the loader. Only same-
                                   domain trials are created since cross-domain trials would never include target 
                                   trials and would likely result in very easy impostor trials.
        """
        if meta_file is not None:
            print("Loading data from %s\n  with metadata file %s" %
                  (emb_file, meta_file))
        else:
            print("Loading data from %s without metadata" % emb_file)

        if emb_file.endswith(".npz"):
            data_dict = np.load(emb_file)
        elif emb_file.endswith(".h5"):
            with h5py.File(emb_file, 'r') as f:
                data_dict = {'ids': f['ids'][()], 'data': f['data'][()]}
        else:
            raise Exception("Unrecognized format for embeddings file %s" %
                            emb_file)

        embeddings_all = data_dict['data']
        if type(data_dict['ids'][0]) == np.bytes_:
            ids_all = [i.decode('UTF-8') for i in data_dict['ids']]
        elif type(data_dict['ids'][0]) == np.str_:
            ids_all = data_dict['ids']
        else:
            raise Exception(
                "Bad format for ids in embeddings file %s (should be strings)"
                % emb_file)

        self.idx_to_str = dict()
        self.meta = dict()

        if meta_file is None:
            fields = ('sample_id', )
            formats = ('O', )
            self.meta_raw = np.array(
                ids_all, np.dtype({
                    'names': fields,
                    'formats': ('O', )
                }))
        else:
            if meta_is_dur_only:
                fields, formats = zip(*[('sample_id',
                                         'O'), ('duration', 'float32')])
            else:
                fields, formats = zip(
                    *[('sample_id',
                       'O'), ('speaker_id',
                              'O'), ('session_id',
                                     'O'), ('domain_id',
                                            'O'), ('duration', 'float32')])
            self.meta_raw = np.loadtxt(
                meta_file, np.dtype({
                    'names': fields,
                    'formats': formats
                }))

        # Convert the metadata strings into indices
        print("  Converting metadata strings into indices")
        for field, fmt in zip(fields, formats):
            if fmt == 'O':
                # Convert all the string fields into indices
                names, nmap = np.unique(self.meta_raw[field],
                                        return_inverse=True)
                if field == 'sample_id' and len(names) != len(self.meta_raw):
                    raise Exception(
                        "Metadata file %s has repeated sample ids" % meta_file)

                # Index to string and string to index maps
                self.idx_to_str[field] = dict(zip(np.arange(len(names)),
                                                  names))
                self.idx_to_str[field + "_inv"] = dict(
                    zip(names, np.arange(len(names))))
                self.meta[field] = np.array(nmap, dtype=np.int32)
            else:
                self.meta[field] = self.meta_raw[field]

        self.meta = utils.AttrDict(self.meta)

        # Subset the embeddings to only those in the metadata file
        name_to_idx = dict(zip(ids_all, np.arange(len(ids_all))))
        keep_idxs = np.array(
            [name_to_idx.get(n, -1) for n in self.meta_raw['sample_id']])
        if np.any(keep_idxs == -1):
            raise Exception(
                "There are %d sample ids (out of %d in the metadata file %s) that are missing from the embeddings file %s.\nPlease, remove those files from the metadata file and try again"
                % (np.sum(keep_idxs == -1), len(
                    self.meta_raw), meta_file, emb_file))
        self.embeddings = embeddings_all[keep_idxs]
        self.ids = np.array(ids_all)[keep_idxs]

        if device is not None:
            # Move the embeddings and the durations to the device
            self.embeddings = utils.np_to_torch(self.embeddings, device)
            if 'duration' in self.meta:
                self.meta['duration'] = utils.np_to_torch(
                    self.meta['duration'], device)

        print("Done. Loaded %d embeddings from %s" %
              (len(self.embeddings), emb_file),
              flush=True)