Example #1
0
def iou_match(yx_min, yx_max, data):
    batch_size, cells, num_anchors, _ = yx_min.size()
    iou_matrix = utils.iou.torch.batch_iou_matrix(yx_min.view(batch_size, -1, 2), yx_max.view(batch_size, -1, 2), data['yx_min'], data['yx_max'])
    iou_matrix = iou_matrix.view(batch_size, cells, num_anchors, -1)
    iou, index = iou_matrix.max(-1)
    _index = torch.unbind(index.view(batch_size, -1))
    _data = {}
    for key in 'yx_min, yx_max, cls'.split(', '):
        t = data[key]
        if len(t.size()) == 2:
            t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors)
        elif len(t.size()) == 3:
            t = torch.stack([d[i] for d, i in zip(torch.unbind(t, 0), _index)]).view(batch_size, cells, num_anchors, -1)
        _data[key] = t
    return iou_matrix, iou, index, _data
Example #2
0
	def func(module, x):
		""" The actual wrapped operation.
		"""
		return torch.stack(
			tuple(Layer.resolve(layer)(module, X) for X in torch.unbind(x, 0)),
			0
		)
Example #3
0
    def output_batch(self, ner_model, documents, fout):
        """
        decode the whole corpus in the specific format by calling apply_model to fit specific models

        args:
            ner_model: sequence labeling model
            feature (list): list of words list
            fout: output file
        """
        ner_model.eval()

        d_len = len(documents)
        for d_ind in tqdm( range(0, d_len), mininterval=1,
                desc=' - Process', leave=False, file=sys.stdout):
            fout.write('-DOCSTART- -DOCSTART- -DOCSTART-\n\n')
            features = documents[d_ind]
            f_len = len(features)
            for ind in range(0, f_len, self.batch_size):
                eind = min(f_len, ind + self.batch_size)
                labels = self.apply_model(ner_model, features[ind: eind])
                labels = torch.unbind(labels, 1)

                for ind2 in range(ind, eind):
                    f = features[ind2]
                    l = labels[ind2 - ind][0: len(f) ]
                    fout.write(self.decode_str(features[ind2], l) + '\n\n')
Example #4
0
def fit_positive(rows, cols, yx_min, yx_max, anchors):
    device_id = anchors.get_device() if torch.cuda.is_available() else None
    batch_size, num, _ = yx_min.size()
    num_anchors, _ = anchors.size()
    valid = torch.prod(yx_min < yx_max, -1)
    center = (yx_min + yx_max) / 2
    ij = torch.floor(center)
    i, j = torch.unbind(ij.long(), -1)
    index = i * cols + j
    anchors2 = anchors / 2
    iou_matrix = utils.iou.torch.iou_matrix((yx_min - center).view(-1, 2), (yx_max - center).view(-1, 2), -anchors2, anchors2).view(batch_size, -1, num_anchors)
    iou, index_anchor = iou_matrix.max(-1)
    _positive = []
    cells = rows * cols
    for valid, index, index_anchor in zip(torch.unbind(valid), torch.unbind(index), torch.unbind(index_anchor)):
        index, index_anchor = (t[valid] for t in (index, index_anchor))
        t = utils.ensure_device(torch.ByteTensor(cells, num_anchors).zero_(), device_id)
        t[index, index_anchor] = 1
        _positive.append(t)
    return torch.stack(_positive)
Example #5
0
    def calc_acc_batch(self, decoded_data, target_data):
        """
        update statics for accuracy

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)
            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length].numpy()
            best_path = decoded[:length].numpy()

            self.total_labels += length
            self.correct_labels += np.sum(np.equal(best_path, gold))
Example #6
0
    def calc_f1_batch(self, decoded_data, target_data):
        """
        update statics for f1 score

        args:
            decoded_data (batch_size, seq_len): prediction sequence
            target_data (batch_size, seq_len): ground-truth
        """
        batch_decoded = torch.unbind(decoded_data, 1)
        batch_targets = torch.unbind(target_data, 0)

        for decoded, target in zip(batch_decoded, batch_targets):
            gold = self.packer.convert_for_eval(target)
            # remove padding
            length = utils.find_length_from_labels(gold, self.l_map)
            gold = gold[:length]
            best_path = decoded[:length]

            correct_labels_i, total_labels_i, gold_count_i, guess_count_i, overlap_count_i = self.eval_instance(best_path.numpy(), gold.numpy())
            self.correct_labels += correct_labels_i
            self.total_labels += total_labels_i
            self.gold_count += gold_count_i
            self.guess_count += guess_count_i
            self.overlap_count += overlap_count_i
def plot_samples():
    try:
        z_samples = model.z.data.numpy().T
    except AttributeError:
        z_samples = model.sample()
    plt.clf()
    plot_generative()
    plt.scatter(z_samples[0], z_samples[1], alpha=0.6)
    plt.xlabel('mu')
    plt.ylabel('logvar')
    plt.title('qz')
    # plt.xlim(-8,8)
    # plt.ylim(-4,4)
    mu, logvar =  torch.unbind(model.qz.mu, 1)
    mu = mu.data.numpy()
    logvar = logvar.data.numpy()
    plt.scatter(mu, logvar, color='black', alpha=0.3)
    plt.pause(0.01)
Example #8
0
    def forward(self, x):
        B, T = x.shape
        # 获取掩码
        mask = x.gt(0)
        # 获取按长度有序的字序列索引
        lens, indices = torch.sort(mask.sum(dim=1), descending=True)
        # 获取逆序索引
        _, inverse_indices = indices.sort()
        # 获取单词最大长度
        max_len = lens[0]
        # 序列按长度由大到小排列
        x = x[indices, :max_len]
        # 获取字嵌入向量
        x = self.embed(x)
        # 打包数据
        x = pack_padded_sequence(x, lens, True)

        x, (hidden, _) = self.lstm(x)
        # 获取词的字符表示
        reprs = torch.cat(torch.unbind(hidden), dim=1)
        # 恢复原有的顺序
        reprs = reprs[inverse_indices]

        return reprs
Example #9
0
    def beam_search(self, init_state, init_logprobs, *args, **kwargs):

        # function computes the similarity score to be augmented
        def add_diversity(beam_seq_table, logprobsf, t, divm, diversity_lambda,
                          bdash):
            local_time = t - divm
            unaug_logprobsf = logprobsf.clone()
            for prev_choice in range(divm):
                prev_decisions = beam_seq_table[prev_choice][local_time]
                for sub_beam in range(bdash):
                    for prev_labels in range(bdash):
                        logprobsf[sub_beam][
                            prev_decisions[prev_labels]] = logprobsf[sub_beam][
                                prev_decisions[prev_labels]] - diversity_lambda
            return unaug_logprobsf

        # does one step of classical beam search

        def beam_step(logprobsf, unaug_logprobsf, beam_size, t, beam_seq,
                      beam_seq_logprobs, beam_logprobs_sum, state):
            #INPUTS:
            #logprobsf: probabilities augmented after diversity
            #beam_size: obvious
            #t        : time instant
            #beam_seq : tensor contanining the beams
            #beam_seq_logprobs: tensor contanining the beam logprobs
            #beam_logprobs_sum: tensor contanining joint logprobs
            #OUPUTS:
            #beam_seq : tensor containing the word indices of the decoded captions
            #beam_seq_logprobs : log-probability of each decision made, same size as beam_seq
            #beam_logprobs_sum : joint log-probability of each beam

            ys, ix = torch.sort(logprobsf, 1, True)
            candidates = []
            cols = min(beam_size, ys.size(1))
            rows = beam_size
            if t == 0:
                rows = 1
            for c in range(cols):  # for each column (word, essentially)
                for q in range(rows):  # for each beam expansion
                    #compute logprob of expanding beam q with word in (sorted) position c
                    local_logprob = ys[q, c].item()
                    candidate_logprob = beam_logprobs_sum[q] + local_logprob
                    local_unaug_logprob = unaug_logprobsf[q, ix[q, c]]
                    candidates.append({
                        'c': ix[q, c],
                        'q': q,
                        'p': candidate_logprob,
                        'r': local_unaug_logprob
                    })
            candidates = sorted(candidates, key=lambda x: -x['p'])

            new_state = [_.clone() for _ in state]
            #beam_seq_prev, beam_seq_logprobs_prev
            if t >= 1:
                #we''ll need these as reference when we fork beams around
                beam_seq_prev = beam_seq[:t].clone()
                beam_seq_logprobs_prev = beam_seq_logprobs[:t].clone()
            for vix in range(beam_size):
                v = candidates[vix]
                #fork beam index q into index vix
                if t >= 1:
                    beam_seq[:t, vix] = beam_seq_prev[:, v['q']]
                    beam_seq_logprobs[:t, vix] = beam_seq_logprobs_prev[:,
                                                                        v['q']]
                #rearrange recurrent states
                for state_ix in range(len(new_state)):
                    #  copy over state in previous beam q to new beam at vix
                    new_state[state_ix][:, vix] = state[state_ix][:, v[
                        'q']]  # dimension one is time step
                #append new end terminal at the end of this beam
                beam_seq[t, vix] = v['c']  # c'th word is the continuation
                beam_seq_logprobs[t, vix] = v['r']  # the raw logprob here
                beam_logprobs_sum[vix] = v[
                    'p']  # the new (sum) logprob along this beam
            state = new_state
            return beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, candidates

        # Start diverse_beam_search
        opt = kwargs['opt']
        beam_size = opt.get('beam_size', 10)
        group_size = opt.get('group_size', 1)
        diversity_lambda = opt.get('diversity_lambda', 0.5)
        decoding_constraint = opt.get('decoding_constraint', 0)
        max_ppl = opt.get('max_ppl', 0)
        length_penalty = utils.penalty_builder(opt.get('length_penalty', ''))
        bdash = beam_size // group_size  # beam per group

        # INITIALIZATIONS
        beam_seq_table = [
            torch.LongTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_seq_logprobs_table = [
            torch.FloatTensor(self.seq_length, bdash).zero_()
            for _ in range(group_size)
        ]
        beam_logprobs_sum_table = [
            torch.zeros(bdash) for _ in range(group_size)
        ]

        # logprobs # logprobs predicted in last time step, shape (beam_size, vocab_size+1)
        done_beams_table = [[] for _ in range(group_size)]
        state_table = [
            list(torch.unbind(_))
            for _ in torch.stack(init_state).chunk(group_size, 2)
        ]
        logprobs_table = list(init_logprobs.chunk(group_size, 0))
        # END INIT

        # Chunk elements in the args
        args = list(args)
        args = [
            _.chunk(group_size) if _ is not None else [None] * group_size
            for _ in args
        ]
        args = [[args[i][j] for i in range(len(args))]
                for j in range(group_size)]

        for t in range(self.seq_length + group_size - 1):
            for divm in range(group_size):
                if t >= divm and t <= self.seq_length + divm - 1:
                    # add diversity
                    logprobsf = logprobs_table[divm].data.float()
                    # suppress previous word
                    if decoding_constraint and t - divm > 0:
                        logprobsf.scatter_(
                            1, beam_seq_table[divm][t - divm -
                                                    1].unsqueeze(1).cuda(),
                            float('-inf'))
                    # suppress UNK tokens in the decoding
                    logprobsf[:, logprobsf.size(1) -
                              1] = logprobsf[:, logprobsf.size(1) - 1] - 1000
                    # diversity is added here
                    # the function directly modifies the logprobsf values and hence, we need to return
                    # the unaugmented ones for sorting the candidates in the end. # for historical
                    # reasons :-)
                    unaug_logprobsf = add_diversity(beam_seq_table, logprobsf,
                                                    t, divm, diversity_lambda,
                                                    bdash)

                    # infer new beams
                    beam_seq_table[divm],\
                    beam_seq_logprobs_table[divm],\
                    beam_logprobs_sum_table[divm],\
                    state_table[divm],\
                    candidates_divm = beam_step(logprobsf,
                                                unaug_logprobsf,
                                                bdash,
                                                t-divm,
                                                beam_seq_table[divm],
                                                beam_seq_logprobs_table[divm],
                                                beam_logprobs_sum_table[divm],
                                                state_table[divm])

                    # if time's up... or if end token is reached then copy beams
                    for vix in range(bdash):
                        if beam_seq_table[divm][
                                t - divm,
                                vix] == 0 or t == self.seq_length + divm - 1:
                            final_beam = {
                                'seq':
                                beam_seq_table[divm][:, vix].clone(),
                                'logps':
                                beam_seq_logprobs_table[divm][:, vix].clone(),
                                'unaug_p':
                                beam_seq_logprobs_table[divm]
                                [:, vix].sum().item(),
                                'p':
                                beam_logprobs_sum_table[divm][vix].item()
                            }
                            final_beam['p'] = length_penalty(
                                t - divm + 1, final_beam['p'])
                            # if max_ppl:
                            #     final_beam['p'] = final_beam['p'] / (t-divm+1)
                            done_beams_table[divm].append(final_beam)
                            # don't continue beams from finished sequences
                            beam_logprobs_sum_table[divm][vix] = -1000

                    # move the current group one step forward in time

                    it = beam_seq_table[divm][t - divm]
                    logprobs_table[divm], state_table[
                        divm] = self.get_logprobs_state(
                            it.cuda(), *(args[divm] + [state_table[divm]]))

        # all beams are sorted by their log-probabilities
        done_beams_table = [
            sorted(done_beams_table[i], key=lambda x: -x['p'])[:bdash]
            for i in range(group_size)
        ]
        done_beams = reduce(lambda a, b: a + b, done_beams_table)
        return done_beams
Example #10
0
    def run_SAM(self, df_data, skeleton=None, **kwargs):
        """Execute the SAM model.
        :param df_data:
        """
        gpu = kwargs.get('gpu', False)
        gpu_no = kwargs.get('gpu_no', 0)
        categorical_variables = kwargs.get('categorical_variables', None)

        verbose = kwargs.get('verbose', True)
        plot = kwargs.get("plot", False)
        plot_generated_pair = kwargs.get("plot_generated_pair", False)

        d_str = "Epoch: {} -- Disc: {} -- Gen: {} -- L1: {}"

        if categorical_variables is None:
            warnings.warn("Dataset considered as numerical")
            categorical_variables = [
                False for i in range(len(df_data.columns))
            ]
        # list_nodes = list(df_data.columns)
        onehotdata = []
        for i, var_is_categorical in enumerate(categorical_variables):
            if var_is_categorical:
                onehotdata.append(
                    pd.get_dummies(df_data.iloc[:, i]).as_matrix())
            else:
                onehotdata.append(df_data.iloc[:, [i]].as_matrix())

        cat_size = [i.shape[1] for i in onehotdata]
        # cat_size.append(1)  # Noise

        df_data = np.concatenate(onehotdata, 1)

        data = df_data.astype('float32')
        self.data = df_data
        data = th.from_numpy(data)
        if self.batchsize == -1:
            self.batchsize = data.shape[0]
        rows, cols = data.size()
        # CAT data: cols override
        cols = len(cat_size)
        # Get the list of indexes to ignore
        if skeleton is not None:
            zero_components = [[] for i in range(cols)]
            for i, j in zip(*((1 - skeleton).nonzero())):
                zero_components[j].append(i)
        else:
            zero_components = [[i] for i in range(cols)]
        self.sam = SAM_generators((rows, cols),
                                  cat_size,
                                  zero_components,
                                  batch_norm=True,
                                  nh=self.nh,
                                  batch_size=self.batchsize,
                                  **kwargs)

        # Begin UGLY
        activation_function = kwargs.get('activation_function', th.nn.Tanh)
        try:
            del kwargs["activation_function"]
        except KeyError:
            pass
        self.discriminator_sam = SAM_discriminator(
            [sum(cat_size), self.dnh, self.dnh, 1],
            batch_norm=True,
            activation_function=th.nn.LeakyReLU,
            activation_argument=0.2,
            **kwargs)
        kwargs["activation_function"] = activation_function
        # End of UGLY

        if gpu:
            self.sam = self.sam.cuda(gpu_no)
            self.discriminator_sam = self.discriminator_sam.cuda(gpu_no)
            data = data.cuda(gpu_no)

        # Select parameters to optimize : ignore the non connected nodes
        criterion = th.nn.BCEWithLogitsLoss()
        g_optimizer = th.optim.Adam(self.sam.parameters(), lr=self.lr)
        d_optimizer = th.optim.Adam(self.discriminator_sam.parameters(),
                                    lr=self.dlr)

        true_variable = Variable(th.ones(self.batchsize, 1),
                                 requires_grad=False)
        false_variable = Variable(th.zeros(self.batchsize, 1),
                                  requires_grad=False)
        causal_filters = th.zeros(cols, cols)

        if gpu:
            true_variable = true_variable.cuda(gpu_no)
            false_variable = false_variable.cuda(gpu_no)
            causal_filters = causal_filters.cuda(gpu_no)

        data_iterator = DataLoader(data,
                                   batch_size=self.batchsize,
                                   shuffle=True,
                                   drop_last=True)

        # TRAIN
        for epoch in range(self.train + self.test):
            for i_batch, batch in enumerate(data_iterator):
                batch = Variable(batch)
                # print(batch.size())
                unbind_vectors = th.unbind(batch, 1)
                # print(cat_size)
                batch_vectors = [
                    th.stack(
                        unbind_vectors[sum(cat_size[:idx]
                                           ):sum(cat_size[:idx]) + i], 1) if
                    i > 1 else unbind_vectors[sum(cat_size[:idx])].unsqueeze(1)
                    for idx, i in enumerate(cat_size)
                ]
                g_optimizer.zero_grad()
                d_optimizer.zero_grad()

                # Train the discriminator
                generated_variables = self.sam(batch)
                # for i in generated_variables:
                #     print(i.size())
                # print(batch.size())
                disc_losses = []
                gen_losses = []
                # print([j.size() for j in batch_vectors])
                # print([j.size() for j in generated_variables])

                for i in range(cols):
                    generator_output = th.cat([
                        v for c in [
                            batch_vectors[:i], [generated_variables[i]],
                            batch_vectors[i + 1:]
                        ] for v in c
                    ], 1)

                    # 1. Train discriminator on fake
                    # print(i, generator_output.size())
                    disc_output_detached = self.discriminator_sam(
                        generator_output.detach())
                    disc_output = self.discriminator_sam(generator_output)
                    disc_losses.append(
                        criterion(disc_output_detached, false_variable))

                    # 2. Train the generator :
                    gen_losses.append(criterion(disc_output, true_variable))

                true_output = self.discriminator_sam(batch)
                adv_loss = sum(disc_losses)/cols + \
                    criterion(true_output, true_variable)
                gen_loss = sum(gen_losses)

                adv_loss.backward()
                d_optimizer.step()

                # 3. Compute filter regularization
                filters = th.stack(
                    [i.filter._filter[0, :-1].abs() for i in self.sam.blocks],
                    1)
                l1_reg = self.l1 * filters.sum()
                loss = gen_loss + l1_reg

                if verbose and not epoch % 20:

                    print(
                        str(i) + " " +
                        d_str.format(epoch,
                                     adv_loss.cpu().item(),
                                     gen_loss.cpu().item() / cols,
                                     l1_reg.cpu().item()))
                loss.backward()
                # STORE ASSYMETRY values for output
                if epoch >= self.train:
                    causal_filters.add_(filters.data)
                g_optimizer.step()

                if plot and i_batch == 0:
                    try:
                        ax.clear()
                        ax.plot(range(len(adv_plt)),
                                adv_plt,
                                "r-",
                                linewidth=1.5,
                                markersize=4,
                                label="Discriminator")
                        ax.plot(range(len(adv_plt)),
                                gen_plt,
                                "g-",
                                linewidth=1.5,
                                markersize=4,
                                label="Generators")
                        ax.plot(range(len(adv_plt)),
                                l1_plt,
                                "b-",
                                linewidth=1.5,
                                markersize=4,
                                label="L1-Regularization")
                        ax.plot(range(len(adv_plt)),
                                asym_plt,
                                "c-",
                                linewidth=1.5,
                                markersize=4,
                                label="Assym penalization")

                        plt.legend()

                        adv_plt.append(adv_loss.cpu().data[0])
                        gen_plt.append(gen_loss.cpu().data[0] / cols)
                        l1_plt.append(l1_reg.cpu().data[0])
                        asym_plt.append(asymmetry_reg.cpu().data[0])
                        plt.pause(0.0001)

                    except NameError:
                        plt.ion()
                        plt.figure()
                        plt.xlabel("Epoch")
                        plt.ylabel("Losses")

                        plt.pause(0.0001)

                        adv_plt = [adv_loss.cpu().data[0]]
                        gen_plt = [gen_loss.cpu().data[0] / cols]
                        l1_plt = [l1_reg.cpu().data[0]]

                elif plot:
                    adv_plt.append(adv_loss.cpu().data[0])
                    gen_plt.append(gen_loss.cpu().data[0] / cols)
                    l1_plt.append(l1_reg.cpu().data[0])

                if plot_generated_pair and i_batch == 0:
                    if epoch == 0:
                        plt.ion()
                        to_print = [[0,
                                     1]]  # , [1, 0]]  # [2, 3]]  # , [11, 17]]
                        plt.clf()
                    for (i, j) in to_print:

                        plt.scatter(generated_variables[i].data.cpu().numpy(),
                                    batch.data.cpu().numpy()[:, j],
                                    label="Y -> X")
                        plt.scatter(batch.data.cpu().numpy()[:, i],
                                    generated_variables[j].data.cpu().numpy(),
                                    label="X -> Y")

                        plt.scatter(batch.data.cpu().numpy()[:, i],
                                    batch.data.cpu().numpy()[:, j],
                                    label="original data")
                        plt.legend()

                    plt.pause(0.01)

        return causal_filters.div_(self.test).cpu().numpy()
Example #11
0
def apply(func, apply_dimension):
    ''' Applies a function along a given dimension '''
    output_list = [func(m) for m in torch.unbind(apply_dimension, dim=0)]
    return torch.stack(output_list, dim=0)
Example #12
0
    def _take_step(self, indices, context1, context1_):

        num_tasks = len(indices)

        # data is (task, batch, feat)
        obs, actions, rewards, next_obs, terms = self.sample_data(indices)
        t, b, _ = obs.size()
        # run inference in networks
        policy_outputs, task_z = self.agent(obs, context1)
        new_actions, policy_mean, policy_log_std, log_pi = policy_outputs[:4]
        policy_mean = policy_mean.view(t, b, -1)
        policy_log_std = policy_log_std.view(t, b, -1)

        #positive samples
        policy_outputs_1, task_z_1 = self.agent(obs, context1_)
        new_actions_1, policy_mean_1, policy_log_std_1, log_pi_1 = policy_outputs_1[:
                                                                                    4]
        policy_mean_1 = policy_mean_1.view(t, b, -1)  #(task, batch, feat)
        policy_log_std_1 = policy_log_std_1.view(t, b,
                                                 -1)  #(task, batch, feat)

        #computer policy kl divergence
        policy1 = [[
            torch.distributions.Normal(mu, torch.exp(log_std))
            for mu, log_std in zip(torch.unbind(policy_mean[i]),
                                   torch.unbind(policy_log_std[i]))
        ] for i in range(t)]
        policy2 = [[
            torch.distributions.Normal(mu, torch.exp(log_std))
            for mu, log_std in zip(torch.unbind(policy_mean_1[i]),
                                   torch.unbind(policy_log_std_1[i]))
        ] for i in range(t)]
        kl_divs = [[
            torch.distributions.kl.kl_divergence(policy1[i][j], policy2[i][j])
            for j in range(b)
        ] for i in range(t)]
        kl_div_sum = torch.mean(torch.stack(kl_divs))
        kl_divs_final = []
        for i in range(len(kl_divs)):
            kl_divs_final.append(torch.mean(torch.stack(kl_divs[i])))
        kl_policy_final = torch.mean(torch.stack(kl_divs_final))
        kl_policy_loss = 1 * kl_policy_final

        self.encoder_optimizer.zero_grad()
        self.policy_optimizer.zero_grad()
        kl_policy_loss.backward(retain_graph=True)
        # flattens out the task dimension
        t, b, _ = obs.size()
        obs = obs.view(t * b, -1)
        actions = actions.view(t * b, -1)
        next_obs = next_obs.view(t * b, -1)

        # Q and V networks
        # encoder will only get gradients from Q nets
        q1_pred = self.qf1(obs, actions, task_z)
        q2_pred = self.qf2(obs, actions, task_z)
        v_pred = self.vf(obs, task_z.detach())
        # get targets for use in V and Q updates
        with torch.no_grad():
            target_v_values = self.target_vf(next_obs, task_z)

        # KL constraint on z if probabilistic
        if self.use_information_bottleneck:
            kl_div = self.agent.compute_kl_div()
            kl_loss = self.kl_lambda * kl_div
            kl_loss.backward(retain_graph=True)

        # qf and encoder update (note encoder does not get grads from policy or vf)
        self.qf1_optimizer.zero_grad()
        self.qf2_optimizer.zero_grad()
        rewards_flat = rewards.view(self.batch_size * num_tasks, -1)
        # scale rewards for Bellman update
        rewards_flat = rewards_flat * self.reward_scale
        terms_flat = terms.view(self.batch_size * num_tasks, -1)
        q_target = rewards_flat + (
            1. - terms_flat) * self.discount * target_v_values
        qf_loss = torch.mean((q1_pred - q_target)**2) + torch.mean(
            (q2_pred - q_target)**2)
        qf_loss.backward()
        self.qf1_optimizer.step()
        self.qf2_optimizer.step()
        self.encoder_optimizer.step()

        # compute min Q on the new actions
        min_q_new_actions = self._min_q(obs, new_actions, task_z)

        # vf update
        v_target = min_q_new_actions - log_pi
        vf_loss = self.vf_criterion(v_pred, v_target.detach())
        self.vf_optimizer.zero_grad()
        vf_loss.backward()
        self.vf_optimizer.step()
        self._update_target_network()

        # policy update
        # n.b. policy update includes dQ/da

        log_policy_target = min_q_new_actions

        policy_loss = (log_pi - log_policy_target).mean()

        mean_reg_loss = self.policy_mean_reg_weight * (policy_mean**2).mean()
        std_reg_loss = self.policy_std_reg_weight * (policy_log_std**2).mean()
        pre_tanh_value = policy_outputs[-1]
        pre_activation_reg_loss = self.policy_pre_activation_weight * (
            (pre_tanh_value**2).sum(dim=1).mean())
        policy_reg_loss = mean_reg_loss + std_reg_loss + pre_activation_reg_loss
        policy_loss = policy_loss + policy_reg_loss
        #self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # save some statistics for eval
        if self.eval_statistics is None:
            # eval should set this to None.
            # this way, these statistics are only computed for one batch.
            self.eval_statistics = OrderedDict()
            if self.use_information_bottleneck:
                z_mean = np.mean(np.abs(ptu.get_numpy(self.agent.z_means[0])))
                z_sig = np.mean(ptu.get_numpy(self.agent.z_vars[0]))
                self.eval_statistics['Z mean train'] = z_mean
                self.eval_statistics['Z variance train'] = z_sig
                #self.eval_statistics['KL Divergence'] = ptu.get_numpy(kl_div)
                #self.eval_statistics['KL Loss'] = ptu.get_numpy(kl_loss)
            self.eval_statistics['Contrastive Loss'] = np.mean(
                ptu.get_numpy(loss))
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'V Predictions',
                    ptu.get_numpy(v_pred),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy mu',
                    ptu.get_numpy(policy_mean),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Policy log std',
                    ptu.get_numpy(policy_log_std),
                ))
    def gradient3DForBboxFace(self, emb3D_scenes, bbox, scores):
        # emb3D_scenes should be B x C x D x H x W
        dz_batch, dy_batch, dx_batch = utils_basic.gradient3D(emb3D_scenes,
                                                              absolute=False,
                                                              square=False)

        bbox = torch.clamp(bbox, min=0)
        sizes_val = [hyp.Z2 - 1, hyp.Y2 - 1, hyp.X2 - 1]

        gs_loss_list = []  # gradient smoothness loss

        for index_batch, emb_scene in enumerate(emb3D_scenes):
            gsloss = 0
            dz, dy, dx = dz_batch[index_batch:index_batch + 1], dy_batch[
                index_batch:index_batch +
                1], dx_batch[index_batch:index_batch + 1]
            for index_box, box in enumerate(bbox[index_batch]):
                if scores[index_batch][index_box] > 0:

                    lower, upper = torch.unbind(box)
                    lower = [torch.floor(i).to(torch.int32) for i in lower]
                    upper = [torch.ceil(i).to(torch.int32) for i in upper]
                    xmin, ymin, zmin = [max(i, 0) for i in lower]

                    xmax, ymax, zmax = [
                        min(i, sizes_val[index])
                        for index, i in enumerate(upper)
                    ]

                    #zmin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dz, zmin, zmin + 1, ymin, ymax, xmin, xmax)
                    if zmin < sizes_val[0]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dz, zmin + 1, zmin + 2, ymin, ymax, xmin, xmax)

                    #zmax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dz, zmax, zmax + 1, ymin, ymax, xmin, xmax)
                    if zmax < sizes_val[0]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dz, zmax + 1, zmax + 2, ymin, ymax, xmin, xmax)

                    #ymin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dy, zmin, zmax, ymin, ymin + 1, xmin, xmax)
                    if ymin < sizes_val[1]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dy, zmin, zmax, ymin + 1, ymin + 2, xmin, xmax)

                    #ymax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dy, zmin, zmax, ymax, ymax + 1, xmin, xmax)
                    if ymax < sizes_val[1]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dy, zmin, zmax, ymax + 1, ymax + 2, xmin, xmax)

                    #xmin face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dx, zmin, zmax, ymin, ymax, xmin, xmin + 1)
                    if xmin < sizes_val[2]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dx, zmin, zmax, ymin, ymax, xmin + 1, xmin + 2)

                    #xmax face
                    gsloss += self.get_gradient_loss_on_bbox_surface(
                        dx, zmin, zmax, ymin, ymax, xmax, xmax + 1)
                    if xmax < sizes_val[2]:
                        gsloss += self.get_gradient_loss_on_bbox_surface(
                            dx, zmin, zmax, ymin, ymax, xmax + 1, xmax + 2)

            gs_loss_list.append(gsloss)

        gsloss = torch.mean(torch.tensor(gs_loss_list))
        return gsloss
Example #14
0
def train_point(model,
                data_loader,
                val_data_loader,
                config,
                transform_data_fn=None):

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=-1)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    while is_training:

        num_class = 20
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()
            for sub_iter in range(config.iter_size):
                # Get training data
                data = data_iter.next()
                points, target, sample_weight = data
                if config.pure_point:

                    sinput = points.transpose(1, 2).cuda().float()

                    # DEBUG: use the discrete coord for point-based
                    '''

                        feats = torch.unbind(points[:,:,:], dim=0)
                        voxel_size = config.voxel_size
                        coords = torch.unbind(points[:,:,:3]/voxel_size, dim=0)  # 0.05 is the voxel-size
                        coords, feats= ME.utils.sparse_collate(coords, feats)
                        # assert feats.reshape([16, 4096, -1]) == points[:,:,3:]
                        points_ = ME.TensorField(features=feats.float(), coordinates=coords, device=device)
                        tmp_voxel = points_.sparse()
                        sinput_ = tmp_voxel.slice(points_)
                        sinput = torch.cat([sinput_.C[:,1:]*config.voxel_size, sinput_.F[:,3:]],dim=1).reshape([config.batch_size, config.num_points, 6])
                        # sinput = sinput_.F.reshape([config.batch_size, config.num_points, 6])
                        sinput = sinput.transpose(1,2).cuda().float()

                        # sinput = torch.cat([coords[:,1:], feats],dim=1).reshape([config.batch_size, config.num_points, 6])
                        # sinput = sinput.transpose(1,2).cuda().float()
                        '''

                    # For some networks, making the network invariant to even, odd coords is important
                    # coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                    # Preprocess input
                    # if config.normalize_color:
                    # feats = feats / 255. - 0.5

                    # torch.save(points[:,:,:3], './sandbox/tensorfield-c.pth')
                    # torch.save(points_.C, './sandbox/points-c.pth')

                else:
                    # feats = torch.unbind(points[:,:,3:], dim=0) # WRONG: should also feed in xyz as inupt feature
                    voxel_size = config.voxel_size
                    coords = torch.unbind(points[:, :, :3] / voxel_size,
                                          dim=0)  # 0.05 is the voxel-size
                    # Normalize the xyz in feature
                    # points[:,:,:3] = points[:,:,:3] / points[:,:,:3].mean()
                    feats = torch.unbind(points[:, :, :], dim=0)
                    coords, feats = ME.utils.sparse_collate(coords, feats)

                    # For some networks, making the network invariant to even, odd coords is important
                    coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                    # Preprocess input
                    # if config.normalize_color:
                    # feats = feats / 255. - 0.5

                    # they are the same
                    points_ = ME.TensorField(features=feats.float(),
                                             coordinates=coords,
                                             device=device)
                    # points_1 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
                    # points_2 = ME.TensorField(features=feats.float(), coordinates=coords, device=device, quantization_mode=ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE)
                    sinput = points_.sparse()

                data_time += data_timer.toc(False)
                B, npoint = target.shape

                # model.initialize_coords(*init_args)
                soutput = model(sinput)
                if config.pure_point:
                    soutput = soutput.reshape([B * npoint, -1])
                else:
                    soutput = soutput.slice(points_).F
                    # s1 = soutput.slice(points_)
                    # print(soutput.quantization_mode)
                    # soutput.quantization_mode = ME.SparseTensorQuantizationMode.RANDOM_SUBSAMPLE
                    # s2 = soutput.slice(points_)

                # The output of the network is not sorted
                target = (target - 1).view(-1).long().to(device)

                # catch NAN
                if torch.isnan(soutput).sum() > 0:
                    import ipdb
                    ipdb.set_trace()

                loss = criterion(soutput, target)

                if torch.isnan(loss).sum() > 0:
                    import ipdb
                    ipdb.set_trace()

                loss = (loss * sample_weight.to(device)).mean()

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()
                # print(model.input_mlp[0].weight.max())
                # print(model.input_mlp[0].weight.grad.max())

            # Update number of steps
            optimizer.step()
            scheduler.step()

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput, target)
            score = precision_at_one(pred, target, ignore_label=-1)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # Calc the iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target >= 0)) |
                                            (target == l)).sum()

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation:
            # for point-based should use alternate dataloader for eval
            # if curr_iter % config.val_freq == 0:
            # val_miou = test_points(model, val_data_loader, None, curr_iter, config, transform_data_fn)
            # if val_miou > best_val_miou:
            # best_val_miou = val_miou
            # best_val_iter = curr_iter
            # checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou, best_val_iter,
            # "best_val")
            # logging.info("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))

            # # Recover back
            # model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)

    test_points(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Example #15
0
    def __call__(self, boxes, scores, indices, stage):
        """Perform NMS-based selection of detections

        Parameters
        ----------
        boxes : sequence of torch.Tensor
            Sequence of N tensors of class-specific bounding boxes with shapes M_i x C x 4, entries can be None
        scores : sequence of torch.Tensor
            Sequence of N tensors of class probabilities with shapes M_i x (C + 1), entries can be None

        Returns
        -------
        bbx_pred : PackedSequence
            A sequence of N tensors of bounding boxes with shapes S_i x 4, entries are None for images in which no
            detection can be kept according to the selection parameters
        cls_pred : PackedSequence
            A sequence of N tensors of thing class predictions with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        obj_pred : PackedSequence
            A sequence of N tensors of detection confidences with shapes S_i, entries are None for images in which no
            detection can be kept according to the selection parameters
        """

        bbx_pred, cls_pred, obj_pred, indices_pred = [], [], [], []

        for bbx_i, obj_i, index_i in zip(boxes, scores, indices):
            try:
                if bbx_i is None or obj_i is None:
                    raise Empty

                # Do NMS separately for each class
                bbx_pred_i, cls_pred_i, obj_pred_i, indices_pred_i = [], [], [], []
                for cls_id, (bbx_cls_i, obj_cls_i, index_cls_i) in enumerate(
                        zip(torch.unbind(bbx_i, dim=1),
                            torch.unbind(obj_i, dim=1)[1:],
                            torch.unbind(index_i, dim=1))):
                    # Filter out low-scoring predictions
                    idx = obj_cls_i > self.score_threshold
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]
                    index_cls_i = index_cls_i[idx]

                    # Filter out empty predictions
                    idx = (bbx_cls_i[:, 2] > bbx_cls_i[:, 0]) & (
                        bbx_cls_i[:, 3] > bbx_cls_i[:, 1])
                    if not idx.any().item():
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]
                    index_cls_i = index_cls_i[idx]

                    # Do NMS
                    idx = nms(bbx_cls_i.contiguous(),
                              obj_cls_i.contiguous(),
                              threshold=self.nms_threshold[stage],
                              n_max=-1)
                    if idx.numel() == 0:
                        continue
                    bbx_cls_i = bbx_cls_i[idx]
                    obj_cls_i = obj_cls_i[idx]
                    index_cls_i = index_cls_i[idx]

                    # Save remaining outputs
                    bbx_pred_i.append(bbx_cls_i)
                    cls_pred_i.append(
                        bbx_cls_i.new_full((bbx_cls_i.size(0), ),
                                           cls_id,
                                           dtype=torch.long))
                    obj_pred_i.append(obj_cls_i)
                    indices_pred_i.append(index_cls_i)

                # Compact predictions from the classes
                if len(bbx_pred_i) == 0:
                    raise Empty
                bbx_pred_i = torch.cat(bbx_pred_i, dim=0)
                cls_pred_i = torch.cat(cls_pred_i, dim=0)
                obj_pred_i = torch.cat(obj_pred_i, dim=0)
                indices_pred_i = torch.cat(indices_pred_i, dim=0)

                # Do post-NMS selection (if needed)
                if bbx_pred_i.size(0) > self.max_predictions:
                    _, idx = obj_pred_i.topk(self.max_predictions)
                    bbx_pred_i = bbx_pred_i[idx]
                    cls_pred_i = cls_pred_i[idx]
                    obj_pred_i = obj_pred_i[idx]
                    indices_pred_i = indices_pred_i[idx]

                # Save results
                bbx_pred.append(bbx_pred_i)
                cls_pred.append(cls_pred_i)
                obj_pred.append(obj_pred_i)
                indices_pred.append(indices_pred_i)

            except Empty:
                bbx_pred.append(None)
                cls_pred.append(None)
                obj_pred.append(None)
                indices_pred.append(None)

        return PackedSequence(bbx_pred), PackedSequence(
            cls_pred), PackedSequence(obj_pred), PackedSequence(indices_pred)
Example #16
0
 def unbind(self, dim=0):
     results = torch.unbind(self._tensor, dim)
     results = tuple(CUDALongTensor(t) for t in results)
     return results
Example #17
0
def apply_along_axis(func, M, dim):
    tList = [func(m) for m in torch.unbind(M, dim)]
    res = torch.stack(tList, dim).to(device=M.device)
    return res
Example #18
0
    import numpy as np
    import cv2

    MVSDataset = find_dataset_def("dtu_yao")
    dataset = MVSDataset("/data1/Dataset/mvs_training/dtu/",
                         '../lists/dtu/train.txt', 'train', 3, 256)
    dataloader = DataLoader(dataset, batch_size=2)
    item = next(iter(dataloader))

    imgs = item["imgs"][:, :, :, ::4, ::4].cuda()
    proj_matrices = item["proj_matrices"].cuda()
    mask = item["mask"].cuda()
    depth = item["depth"].cuda()
    depth_values = item["depth_values"].cuda()

    imgs = torch.unbind(imgs, 1)
    proj_matrices = torch.unbind(proj_matrices, 1)
    ref_img, src_imgs = imgs[0], imgs[1:]
    ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]

    warped_imgs = homo_warping2(src_imgs[0], src_projs[0], ref_proj,
                                depth_values)

    cv2.imwrite(
        '../tmp/ref.png',
        ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] *
        255)
    cv2.imwrite(
        '../tmp/src.png', src_imgs[0].permute(
            [0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255)
Example #19
0
    basename = 'test_split_size',
    fn = lambda x: bf.split(x, 3, 0),
    torch_fn = lambda x: torch.split(x, 3, 0),
    inputs = utils.random_inputs([
        [(9)],
        [(9,2)]
    ])
)

utils.add_tests(ShapeTestCase,
    basename = 'test_split_indices',
    fn = lambda x: bf.split(x, [2, 7], 0),
    torch_fn = lambda x: torch.split(x, [2, 5, 2], 0),
    inputs = utils.random_inputs([
        [(9)],
        [(9,2)]
    ])
)

utils.add_tests(ShapeTestCase,
    basename = 'test_unstack',
    fn = lambda x: bf.unstack(x, 0),
    torch_fn = lambda x: torch.unbind(x, 0),
    inputs = utils.random_inputs([
        [(5,3)],
        [(5,2,3)]
    ])
)


Example #20
0
    def forward(self, x, h0, x_mask):
        start_scores_list = []
        end_scores_list = []
        for turn in range(self.num_turn):
            st_scores = self.attn_b(x, h0, x_mask)
            start_scores_list.append(st_scores)
            if self.answer_opt == 3:
                ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1),
                                      x).squeeze(1)
                ptr_net_b = self.dropout(ptr_net_b)
                xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b)
                end_scores = self.attn_e(x, h0 + xb, x_mask)
                ptr_net_e = torch.bmm(
                    F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1)
                ptr_net_in = (ptr_net_b + ptr_net_e) / 2.0
                h0 = self.dropout(h0)
                h0 = self.rnn(ptr_net_in, h0)
            elif self.answer_opt == 2:
                ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1),
                                      x).squeeze(1)
                ptr_net_b = self.dropout(ptr_net_b)
                xb = ptr_net_b if self.proj is None else self.proj(ptr_net_b)
                end_scores = self.attn_e(x, xb, x_mask)
                ptr_net_e = torch.bmm(
                    F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1)
                ptr_net_in = ptr_net_e
                h0 = self.dropout(h0)
                h0 = self.rnn(ptr_net_in, h0)
            elif self.answer_opt == 1:
                ptr_net_b = torch.bmm(F.softmax(st_scores, 1).unsqueeze(1),
                                      x).squeeze(1)
                ptr_net_b = self.dropout(ptr_net_b)
                h0 = self.dropout(h0)
                ptr_net_in = ptr_net_b
                h1 = self.rnn(ptr_net_in, h0)
                end_scores = self.attn_e(x, h1, x_mask)
                h0 = h1
            else:
                end_scores = self.attn_e(x, h0, x_mask)
                ptr_net_e = torch.bmm(
                    F.softmax(end_scores, 1).unsqueeze(1), x).squeeze(1)
                ptr_net_in = ptr_net_e
                h0 = self.dropout(h0)
                h0 = self.rnn(ptr_net_in, h0)
            end_scores_list.append(end_scores)

        if self.mem_type == 1:
            mask = generate_mask(self.alpha.data.new(x.size(0), self.num_turn),
                                 self.mem_random_drop, self.training)
            mask = [m.contiguous() for m in torch.unbind(mask, 1)]
            start_scores_list = [
                mask[idx].view(x.size(0), 1).expand_as(inp) *
                F.softmax(inp, 1) for idx, inp in enumerate(start_scores_list)
            ]
            end_scores_list = [
                mask[idx].view(x.size(0), 1).expand_as(inp) *
                F.softmax(inp, 1) for idx, inp in enumerate(end_scores_list)
            ]
            start_scores = torch.stack(start_scores_list, 2)
            end_scores = torch.stack(end_scores_list, 2)
            start_scores = torch.mean(start_scores, 2)
            end_scores = torch.mean(end_scores, 2)
            start_scores.data.masked_fill_(x_mask.data, SMALL_POS_NUM)
            end_scores.data.masked_fill_(x_mask.data, SMALL_POS_NUM)
            start_scores = torch.log(start_scores)
            end_scores = torch.log(end_scores)
        else:
            start_scores = start_scores_list[-1]
            end_scores = end_scores_list[-1]

        return start_scores, end_scores
    def predict(self, loader):
        """
        Predict for an input.

        Args
        ----
            loader : PyTorch DataLoader.

        """
        self.model.eval()
        all_preds = []
        all_ys = []
        all_cs = []
        all_ts = []
        all_ms = []
        all_idx = []

        if isinstance(loader.dataset, torch.utils.data.Subset):
            n_frames = loader.dataset.dataset.n_frames
        elif isinstance(loader.dataset, torch.utils.data.ConcatDataset):
            n_frames = loader.dataset.datasets[0].n_frames
        else:
            n_frames = loader.dataset.n_frames

        with torch.no_grad():
            for batch_samples in tqdm(loader):

                # prepare training sample
                X = batch_samples['X']
                if X.dim() == 4:
                    full_track = False
                    # batch_size x in_channels x 1025 x 129
                else:
                    bs = X.size(0)
                    ns = X.size(1)
                    full_track = True
                    # batch_size * splits x in_channels x 1025 x 129
                    X = X.view(bs * ns, self.in_channels, self.n_fft, n_frames)

                # batch_size x in_channels x 1025 x 129 x 2
                X_complex = batch_samples['X_complex']
                if X_complex.dim() != 5:
                    # batch_size * splits x in_channels x 1025 x 129 x 2
                    X_complex = X_complex.view(
                        bs * ns, self.out_channels, self.n_fft, n_frames, 2)

                # batch_size x nclasses x in_channels x 1025 x time samples x 2
                y = batch_samples['y_complex']
                # batch_size x nclasses
                cs = batch_samples['c']
                # batch_size x 1
                ts = batch_samples['t']
                track_idx = batch_samples['track_idx']

                if self.USE_CUDA:
                    X = X.cuda()
                    X_complex = X_complex.cuda()
                    y = y.cuda()

                if X.size(0) > 4:
                    X_list = torch.split(X, 4, dim=0)
                else:
                    X_list = [X]

                masks_list = []
                pred_list = []
                for X in X_list:
                    # detach hidden state
                    self.model.detach_hidden(X.size(0))
                    # forward pass
                    preds, mask = self.model(X)
                    masks_list += [mask]
                    pred_list += [preds]
                mask = torch.cat(masks_list, dim=0)
                preds = torch.cat(pred_list, dim=0)

                if full_track:
                    # batch size x nclasses x in_channels x 1025 x time samples
                    if self.regression:
                        preds = preds.view(
                            bs, ns, self.n_classes, self.out_channels,
                            self.n_fft, n_frames)
                        preds = torch.unbind(preds, dim=1)
                        preds = torch.cat(preds, dim=4)
                    else:
                        mask = mask.view(
                            bs, ns, self.n_classes, self.out_channels,
                            self.n_fft, n_frames)
                        mask = torch.unbind(mask, dim=1)
                        mask = torch.cat(mask, dim=4)
                    # batch_size x in_channels x 1025 x time samples x 2
                    X_complex = X_complex.view(
                        bs, ns, self.out_channels, self.n_fft, n_frames, 2)
                    X_complex = torch.unbind(X_complex, dim=1)
                    X_complex = torch.cat(X_complex, dim=3)

                # convert to complex
                # batch size x nclasses x in_channels x 1025 x time samples x 2
                X_complex = X_complex.unsqueeze(1).repeat(
                    1, self.n_classes, 1, 1, 1, 1)
                X_complex = self._to_complex(X_complex)
                if self.regression:
                    _, X_phase = magphase(X_complex)
                    preds = preds.cpu().numpy() * X_phase
                else:
                    preds = mask.cpu().numpy() * X_complex
                # batch size x nclasses x in_channels x 1025 x time samples
                ys = self._to_complex(y)

                all_preds += [preds]
                all_ys += [ys]
                all_cs += [cs]
                all_ts += [ts]
                all_ms += [mask.cpu().numpy()]
                all_idx += [track_idx]

        return all_preds, all_ys, all_cs, all_ts, all_ms, all_idx
Example #22
0
    def _conductance_grads(self, forward_fn, input, target_ind=None):
        with torch.autograd.set_grad_enabled(True):
            # Set a forward hook on specified module and run forward pass to
            # get output tensor size.
            saved_tensor = None

            def forward_hook(module, inp, out):
                nonlocal saved_tensor
                saved_tensor = out

            hook = self.layer.register_forward_hook(forward_hook)
            output = forward_fn(input)

            # Compute layer output tensor dimensions and total number of units.
            # The hidden layer tensor is assumed to have dimension (num_hidden, ...)
            # where the product of the dimensions >= 1 correspond to the total
            # number of hidden neurons in the layer.
            layer_size = tuple(saved_tensor.size())[1:]
            layer_units = int(np.prod(layer_size))

            # Remove unnecessary forward hook.
            hook.remove()

            # Backward hook function to override gradients in order to obtain
            # just the gradient of each hidden unit with respect to input.
            saved_grads = None

            def backward_hook(grads):
                nonlocal saved_grads
                saved_grads = grads
                zero_mat = torch.zeros((1, ) + layer_size)
                scatter_indices = torch.arange(0,
                                               layer_units).view_as(zero_mat)
                # Creates matrix with each layer containing a single unit with
                # value 1 and remaining zeros, which will provide gradients
                # with respect to each unit independently.
                to_return = torch.zeros((layer_units, ) + layer_size).scatter(
                    0, scatter_indices, 1)
                to_repeat = [1] * len(to_return.shape)
                to_repeat[0] = grads.shape[0] // to_return.shape[0]
                expanded = to_return.repeat(to_repeat)
                return expanded

            # Create a forward hook in order to attach backward hook to appropriate
            # tensor. Save backward hook in order to remove hook appropriately.
            back_hook = None

            def forward_hook_register_back(module, inp, out):
                nonlocal back_hook
                back_hook = out.register_hook(backward_hook)

            hook = self.layer.register_forward_hook(forward_hook_register_back)

            # Expand input to include layer_units copies of each input.
            # This allows obtaining gradient with respect to each hidden unit
            # in one pass.
            expanded_input = torch.repeat_interleave(input, layer_units, dim=0)
            output = forward_fn(expanded_input)
            hook.remove()
            output = output[:,
                            target_ind] if target_ind is not None else output
            input_grads = torch.autograd.grad(torch.unbind(output),
                                              expanded_input)

            # Remove backwards hook
            back_hook.remove()

            # Remove duplicates in gradient with respect to hidden layer,
            # choose one for each layer_units indices.
            output_mid_grads = torch.index_select(
                saved_grads,
                0,
                torch.tensor(range(0, input_grads[0].shape[0], layer_units)),
            )
        return input_grads[0], output_mid_grads, layer_units
Example #23
0
    def build(self, im_batch):
        ''' Decomposes a batch of images into a complex steerable pyramid. 
        The pyramid typically has ~4 levels and 4-8 orientations. 
        
        Args:
            im_batch (torch.Tensor): Batch of images of shape [N,C,H,W]
        
        Returns:
            pyramid: list containing torch.Tensor objects storing the pyramid
        '''

        assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(
            self.device, im_batch.device)
        assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
        assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
        assert im_batch.shape[
            1] == 1, 'Second dimension must be 1 encoding grayscale image'

        im_batch = im_batch.squeeze(1)  # flatten channels dim
        height, width = im_batch.shape[2], im_batch.shape[1]

        # Check whether image size is sufficient for number of levels
        if self.height > int(np.floor(np.log2(min(width, height))) - 2):
            raise RuntimeError(
                'Cannot build {} levels, image too small.'.format(self.height))

        # Prepare a grid
        log_rad, angle = math_utils.prepare_grid(height, width)

        # Radial transition function (a raised cosine in log-frequency):
        Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
        Yrcos = np.sqrt(Yrcos)

        YIrcos = np.sqrt(1 - Yrcos**2)

        lo0mask = pointOp(log_rad, YIrcos, Xrcos)
        hi0mask = pointOp(log_rad, Yrcos, Xrcos)

        # Note that we expand dims to support broadcasting later
        lo0mask = torch.from_numpy(lo0mask).float()[None, :, :,
                                                    None].to(self.device)
        hi0mask = torch.from_numpy(hi0mask).float()[None, :, :,
                                                    None].to(self.device)

        # Fourier transform (2D) and shifting
        batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
        batch_dft = math_utils.batch_fftshift2d(batch_dft)

        # Low-pass
        lo0dft = batch_dft * lo0mask

        # Start recursively building the pyramids
        coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos,
                                   self.height - 1)

        # High-pass
        hi0dft = batch_dft * hi0mask
        hi0 = math_utils.batch_ifftshift2d(hi0dft)
        hi0 = torch.ifft(hi0, signal_ndim=2)
        hi0_real = torch.unbind(hi0, -1)[0]
        coeff.insert(0, hi0_real)
        return coeff
Example #24
0
def compute_pr_curves(class_hist, total_hist):
    """
    Computes precision recall curves from the true sample / total
    sample histogram tensors. The histogram tensors are num_bins x num_classes
    and each column represents a histogram over
    prediction_probabilities.

    The two tensors should have the same dimensions.
    The two tensors should have nonnegative integer values.

    Returns map of precision / recall values from highest precision to lowest
    and the calculated AUPRC (i.e. the average precision).
    """
    assert torch.is_tensor(class_hist) and torch.is_tensor(
        total_hist), "Both arguments must be tensors"
    assert (class_hist.dtype == torch.int64 and total_hist.dtype
            == torch.int64), "Both arguments must contain int64 values"
    assert (len(class_hist.size()) == 2 and len(total_hist.size())
            == 2), "Both arguments must have 2 dimensions, (score_bin, class)"
    assert (class_hist.size() == total_hist.size()), """
        For compute_pr_curve, arguments must be  of same size.
        class_hist.size(): %s
        total_hist.size(): %s
        """ % (
        str(class_hist.size()),
        str(total_hist.size()),
    )
    assert (class_hist > total_hist).sum() == 0, (
        "Invalid. Class histogram must be less than or equal to total histogram"
    )

    num_bins = class_hist.size()[0]
    # Cumsum from highest bucket to lowest
    cum_class_hist = torch.cumsum(torch.flip(class_hist, dims=(0, )),
                                  dim=0).double()
    cum_total_hist = torch.cumsum(torch.flip(total_hist, dims=(0, )),
                                  dim=0).double()
    class_totals = cum_class_hist[-1, :]

    prec_t = cum_class_hist / cum_total_hist
    recall_t = cum_class_hist / class_totals

    prec = torch.unbind(prec_t, dim=1)
    recall = torch.unbind(recall_t, dim=1)
    assert len(prec) == len(
        recall
    ), "The number of precision curves does not match the number of recall curves"

    final_prec = []
    final_recall = []
    final_ap = []
    for c, prec_curve in enumerate(prec):
        recall_curve = recall[c]
        assert (
            recall_curve.size()[0] == num_bins
            and prec_curve.size()[0] == num_bins
        ), "Precision and recall curves do not have the correct number of entries"

        # Check if any samples from class were seen
        if class_totals[c] == 0:
            continue

        # Remove duplicate entries
        prev_r = torch.tensor(-1.0).double()
        prev_p = torch.tensor(1.1).double()
        new_recall_curve = torch.tensor([], dtype=torch.double)
        new_prec_curve = torch.tensor([], dtype=torch.double)
        for idx, r in enumerate(recall_curve):
            p = prec_curve[idx]
            # Remove points on PR curve that are invalid
            if r.item() <= 0:
                continue

            # Remove duplicates (due to empty buckets):
            if r.item() == prev_r.item() and p.item() == prev_p.item():
                continue

            # Add points to curve
            new_recall_curve = torch.cat((new_recall_curve, r.unsqueeze(0)),
                                         dim=0)
            new_prec_curve = torch.cat((new_prec_curve, p.unsqueeze(0)), dim=0)
            prev_r = r
            prev_p = p

        ap = calc_ap(new_prec_curve, new_recall_curve)
        final_prec.append(new_prec_curve)
        final_recall.append(new_recall_curve)
        final_ap.append(ap)

    return {"prec": final_prec, "recall": final_recall, "ap": final_ap}
Example #25
0
    def forward(
        self,
        input_ids,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        start_positions=None,
        end_positions=None,
        answer_masks=None,
    ):
        r"""
        start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
            Span-start scores (before SoftMax).
        end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
            Span-end scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    Examples::
        # The checkpoint roberta-large is not fine-tuned for question answering. Please see the
        # examples/run_squad.py example to see how to fine-tune a model to a question answering task.
        from transformers import RobertaTokenizer, RobertaForQuestionAnswering
        import torch
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        model = RobertaForQuestionAnswering.from_pretrained('roberta-base')
        question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
        input_ids = tokenizer.encode(question, text)
        start_scores, end_scores = model(torch.tensor([input_ids]))
        all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
        answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
        """

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (
            start_logits,
            end_logits,
        ) + outputs[2:]
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index,
                                        reduce=False)

            start_losses = [(loss_fct(start_logits, _start_positions) * _span_mask) \
                            for (_start_positions, _span_mask) \
                            in zip(torch.unbind(start_positions, dim=1), torch.unbind(answer_masks, dim=1))]
            end_losses = [(loss_fct(end_logits, _end_positions) * _span_mask) \
                            for (_end_positions, _span_mask) \
                          in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_masks, dim=1))]

            total_loss = sum(start_losses + end_losses)
            total_loss = torch.mean(total_loss) / 2

            outputs = (total_loss, ) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
    def __init__(
        self,
        actions,
        actions_logp,
        actions_entropy,
        dones,
        behaviour_logits,
        target_logits,
        discount,
        rewards,
        values,
        bootstrap_value,
        dist_class,
        valid_mask,
        vf_loss_coeff=0.5,
        entropy_coeff=0.01,
        clip_rho_threshold=1.0,
        clip_pg_rho_threshold=1.0,
    ):
        """Policy gradient loss with vtrace importance weighting.

                VTraceLoss takes tensors of shape [T, B, ...], where `B` is the
                batch_size. The reason we need to know `B` is for V-trace to properly
                handle episode cut boundaries.

                Args:
                    actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
                    actions_logp: A float32 tensor of shape [T, B].
                    actions_entropy: A float32 tensor of shape [T, B].
                    dones: A bool tensor of shape [T, B].
                    behaviour_logits: A list with length of ACTION_SPACE of float32
                        tensors of shapes
                        [T, B, ACTION_SPACE[0]],
                        ...,
                        [T, B, ACTION_SPACE[-1]]
                    target_logits: A list with length of ACTION_SPACE of float32
                        tensors of shapes
                        [T, B, ACTION_SPACE[0]],
                        ...,
                        [T, B, ACTION_SPACE[-1]]
                    discount: A float32 scalar.
                    rewards: A float32 tensor of shape [T, B].
                    values: A float32 tensor of shape [T, B].
                    bootstrap_value: A float32 tensor of shape [B].
                    dist_class: action distribution class for logits.
                    valid_mask: A bool tensor of valid RNN input elements (#2992).
                """
        # Compute vtrace on the CPU for the better perf
        device = behaviour_logits[0].get_device()
        if device >= 0:
            device = torch.device("cuda:" + str(device))
        else:
            device = torch.device("cpu")
        for i in range(len(behaviour_logits)):
            behaviour_logits[i] = behaviour_logits[i].data.cpu()
            target_logits[i] = target_logits[i].data.cpu()

            # Make sure tensor ranks are as expected
            # The rest will be checked by from_aciton_log_probs.
            assert len(behaviour_logits[i].size()) == 3
            assert len(target_logits[i].size()) == 3
        reverse_dones = (dones.float() == torch.zeros_like(dones.float()))
        self.vtrace_returns = vtrace.multi_from_logits(
            behaviour_policy_logits=behaviour_logits,
            target_policy_logits=target_logits,
            actions=torch.unbind(actions.data.cpu(), dim=2),
            discounts=reverse_dones.float() * discount,
            rewards=rewards.data.cpu(),
            values=values.data.cpu(),
            bootstrap_value=bootstrap_value.data.cpu(),
            dist_class=dist_class,
            clip_rho_threshold=clip_rho_threshold,
            clip_pg_rho_threshold=clip_pg_rho_threshold)

        self.value_targets = self.vtrace_returns.vs
        valid_mask = (valid_mask == torch.ones_like(valid_mask).to(device))
        self.pi_loss = -torch.sum(
            (actions_logp *
             self.vtrace_returns.pg_advantages.to(device))[valid_mask])
        self.valid_mask = valid_mask
        self.values = values
        self.logp_min = actions_logp.min()
        self.logp_mean = actions_logp.mean()
        self.logp_max = actions_logp.max()
        self.ad_min = self.vtrace_returns.pg_advantages.min()
        self.ad_mean = self.vtrace_returns.pg_advantages.mean()
        self.ad_max = self.vtrace_returns.pg_advantages.max()
        delta = (values - self.vtrace_returns.vs)[valid_mask]
        self.vf_loss = 0.5 * (delta**2).sum()
        self.entropy = torch.sum(actions_entropy[valid_mask])
        self.total_loss = (self.pi_loss + self.vf_loss * vf_loss_coeff -
                           self.entropy * entropy_coeff)
def main():
    args = parse_arguments()
    config = json.load(open(args.config))
    # dataset_type = config['dataset']['type']

    # data loader params
    loader = get_instance(datasets, 'dataset', 'val_args', config)
    # to_tensor = transforms.ToTensor()
    to_tensor = transforms.Compose([
        transforms.Resize((config['dataset']['train_args']['crop_size'],
                           config['dataset']['train_args']['crop_size'])),
        transforms.ToTensor()
    ])
    restore_transform = transforms.Compose([
        DeNormalize(config['dataset']['mean'], config['dataset']['std']),
    ])
    palette = loader.dataset.palette

    model = get_instance(models, 'model', 'args', config)

    checkpoint = torch.load(args.weight)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
        checkpoint = checkpoint['state_dict']
    if 'module' in list(checkpoint.keys())[0] and not isinstance(
            model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()
    print(model)

    if not os.path.exists(args.output):
        os.makedirs(args.output)

    with torch.no_grad():
        image_path = args.image
        if image_path is not None:
            image_name = str(image_path).split('/')[-1].split('.')[0]
            image = Image.open(image_path).convert('RGB')
            input = normalize(to_tensor(image)).unsqueeze(0)
            mask = inference(model, input, image, palette)
            mask_path = os.path.join(args.output,
                                     image_name + config['name'] + '.png')
            save_mask(mask, mask_path)
        else:
            dataiter = iter(loader)
            while True:
                batch = dataiter.next()
                images = batch['img']
                label = batch['label']
                if config['dataset']['name'] == "SeqLane":
                    images = torch.unbind(images, dim=0)
                    image = images[0][-1]
                    image_name = batch['images_name'][-1][0]
                else:
                    image = images[0]
                    image_name = batch['img_name'][0]
                # images = torch.unbind(images, dim=0)
                # print(image.size())
                mask = inference(model, batch['img'], restore_transform(image),
                                 label[0], palette)

                # image_name = batch['images_name'][-1]
                print(image_name)
                mask_path = os.path.join(args.output, image_name[-1])
Example #28
0
    save_ply(pc, "./input.ply", colors=None, normals=None)
    pc = torch.from_numpy(pc).requires_grad_().to(cuda0).unsqueeze(0)
    pc = pc.transpose(2, 1)

    # test furthest point
    idx, sampled_pc = furthest_point_sample(pc, 1250)
    output = sampled_pc.transpose(2, 1).cpu().squeeze()
    save_ply(output.detach(), "./output.ply", colors=None, normals=None)

    # test KNN
    knn_points, _, _ = group_knn(10, sampled_pc, pc, NCHW=True)  # B, C, M, K
    labels = torch.arange(
        0, knn_points.size(2)).unsqueeze_(0).unsqueeze_(0).unsqueeze_(
            -1)  # 1, 1, M, 1
    labels = labels.expand(knn_points.size(0), -1, -1,
                           knn_points.size(3))  # B, 1, M, K
    # B, C, P
    labels = torch.cat(torch.unbind(labels, dim=-1),
                       dim=-1).squeeze().detach().cpu().numpy()
    knn_points = torch.cat(torch.unbind(knn_points, dim=-1), dim=-1).transpose(
        2, 1).squeeze(0).detach().cpu().numpy()
    save_ply_property(knn_points, labels, "./knn_output.ply", cmap_name='jet')

    from torch.autograd import gradcheck
    # test = gradcheck(furthest_point_sample, [pc, 1250], eps=1e-6, atol=1e-4)
    # print(test)
    test = gradcheck(gather_points, [pc.to(dtype=torch.float64), idx],
                     eps=1e-6,
                     atol=1e-4)
    print(test)
def apply(func, M):
    ## applies a function along a given dimension
    tList = [func(m) for m in torch.unbind(M, dim=0)]
    res = torch.stack(tList, dim=0)
    return res
Example #30
0
    def future_v(self, x, boundry_params, segment_params):
        TINY = 1e-6

        v0, q0, rhoNp1 = torch.unbind(boundry_params, dim=2)
        vf, a_var, rhocr, g, omegar, omegas, epsq, epsv = torch.unbind(
            segment_params, dim=2)
        #rhocr = torch.clamp(rhocr, min=30, max=500)
        try:
            if self.print_count % 100 == 0:  #random.random() > 0.9999:
                wandb.log({"vf": wandb.Histogram(vf.cpu().detach().numpy())})
                wandb.log(
                    {"a_var": wandb.Histogram(a_var.cpu().detach().numpy())})
                wandb.log(
                    {"rhocr": wandb.Histogram(rhocr.cpu().detach().numpy())})
                wandb.log({"g": wandb.Histogram(g.cpu().detach().numpy())})
                wandb.log(
                    {"omegar": wandb.Histogram(q0.cpu().detach().numpy())})
                wandb.log(
                    {"omegas": wandb.Histogram(rhoNp1.cpu().detach().numpy())})
            # tb.add_histogram('vf', vf, epoch)
            # tb.add_histogram('a', a, epoch)
            # tb.add_histogram('rhocr', rhocr, epoch)
            # tb.add_histogram('g', g, epoch)
            # tb.add_histogram('omegar', omegar, epoch)
            # tb.add_histogram('omegas', omegas, epoch)
        except Exception as e:
            print(e)

        current_densities = x[:, :, self.rho_index]
        current_flows = x[:, :, self.q_index]
        if self.calculate_velocity:
            current_velocities = current_flows / (
                current_densities * self.segment_fixed[:, self.lambda_index] +
                TINY)
            if self.print_count % 1000 == 0:  #random.random() > 0.9999:
                wandb.log({
                    "current_velocities":
                    wandb.Histogram(current_velocities.cpu().detach().numpy())
                })
                wandb.log({
                    "current_densities":
                    wandb.Histogram(current_densities.cpu().detach().numpy())
                })
                wandb.log({
                    "current_flows":
                    wandb.Histogram(current_flows.cpu().detach().numpy())
                })
            current_velocities = torch.clamp(current_velocities,
                                             min=5,
                                             max=200)
        else:
            current_velocities = x[:, :, self.v_index]
            current_velocities = torch.clamp(current_velocities,
                                             min=5,
                                             max=200)

        prev_velocities = torch.cat([v0, current_velocities[:, :-1]], dim=1)
        next_densities = torch.cat([current_densities[:, 1:], rhoNp1], dim=1)

        stat_speed = vf * torch.exp(
            torch.div(-1, a_var + TINY) * torch.pow(
                torch.div(current_densities, rhocr + TINY) + TINY, a_var))
        if self.print_count % 1000 == 0:  #random.random() > 0.9999:
            wandb.log({
                "stat_speed":
                wandb.Histogram(stat_speed.cpu().detach().numpy())
            })
            print("stat speed", stat_speed.size(),
                  stat_speed.min().item(),
                  stat_speed.mean().item(),
                  stat_speed.max().item())
            print("v0,q0,rhoNN", v0[0].item(), q0[0].item(), rhoNp1[0].item(),
                  v0.mean().item(),
                  q0.mean().item(),
                  rhoNp1.mean().item())
            print("q1", x[0, 0, self.q_index].item(),
                  x[:, 0, self.q_index].mean().item())
            print("vf, a, rhocr,g, omegar, omegas, epsq, epsv",
                  vf[0].mean().item(), a_var[0].mean().item(),
                  rhocr[0].mean().item(), g[0].mean().item(),
                  omegar[0].mean().item(), omegas[0].mean().item(),
                  epsq[0].mean().item(), epsv[0].mean().item())

        #import pdb; pdb.set_trace()

        return current_velocities + (torch.div(self.T,self.tau+TINY)) * (stat_speed - current_velocities )  \
                + (torch.div(self.T,self.Delta) * current_velocities * (prev_velocities - current_velocities)) \
                - (torch.div(self.nu*self.T, (self.tau*self.Delta)) * torch.div( (next_densities - current_densities), (current_densities+self.kappa)) ) \
                - (torch.div( (self.delta*self.T) , (self.Delta * self.lambda_var) ) * torch.div( (x[:,:,self.r_index]*current_velocities),(current_densities+self.kappa) ) ) \
                + epsv
Example #31
0
 def forward(self, x):
     out = torch.index_select(self.I, 0, x)
     # add some noise - not enough to change anything (I don't think)  + Variable(torch.rand(o_t.size())*1e-5)
     out = torch.stack([o_t for o_t in torch.unbind(out, 0)])
     return out
Example #32
0
    def forward(self, h_o_prev, y_e_prev, C_o, pos, phase, memor):
        """
        h_o_prev: N * O * dim_h_o
        y_e_prev: N * O * dim_y_e
        C_o:      N * C2_1 * C2_2
        """
        o = self.o
        perm_mat = torch.zeros(o.N, o.O, o.O)
        m = memor.cpu().detach()
        p = pos.cpu()
        for i in range(o.N):
            tree = cKDTree(m[i])
            r = tree.query(p[i], k=10)[1].reshape(-1)
            #r = np.array(r.tolist() + [sum(range(o.O)) - r.sum()])
            #print(r, p[i].tolist(), m[i].tolist())
            if(r.sum() != sum(range(o.O))):
                print(m[i])
                print(p[i])
                print(r)
            perm_mat[i, torch.arange(o.O), r] = 1
        perm_mat = perm_mat.cuda()
        perm_mat_inv = perm_mat.transpose(1, 2)
        
        #pos = pos / 120
        #inp = pos.unsqueeze(1).expand(o.N, o.O, 2).contiguous().view(-1, 2)
        #y_e_prev = self.linear_pos_to_conf(inp).tanh().abs().view(o.N, o.O, o.dim_y_e)
        #if "no_rep" not in o.exp_config:
        #    if o.train == 0:
        #        print(y_e_prev[0].view(-1))
        #    delta = torch.arange(0, o.O).type(torch.FloatTensor).cuda().unsqueeze(0) * 0.0001 # 1 * O
        #    y_e_prev_mdf = y_e_prev.squeeze(2).round() - Variable(delta)
        #    perm_mat = self.permutation_matrix_calculator(y_e_prev_mdf) # N * O * O
        #    perm_mat_inv = perm_mat.transpose(1, 2) # N * O * O
    
        if phase == 'obs':
            #if "no_tem" in o.exp_config:
            #    h_o_prev = Variable(torch.zeros_like(h_o_prev).cuda())W3
            #    y_e_prev = Variable(torch.zeros_like(y_e_prev).cuda())

            # Sort h_o_prev and y_e_prev
            if "no_rep" not in o.exp_config:
                h_o_prev = perm_mat.bmm(h_o_prev) # N * O * dim_h_o
            #    y_e_prev = perm_mat.bmm(y_e_prev) # N * O * dim_y_e
            #    if o.train == 0:
            #        print('Shuffled:')
            #        print(y_e_prev[0].view(-1))

            # Update h_o
            h_o_prev_split = torch.unbind(h_o_prev, 1) # N * dim_h_o
            h_o_split = {}
            k_split = {}
            r_split = {}
            for i in range(0, o.O):
                self.ntm_cell.i = i
                h_o_split[i], C_o, k_split[i], r_split[i] = self.ntm_cell(h_o_prev_split[i], C_o, pos, phase)
            h_o = torch.stack(tuple(h_o_split.values()), dim=1) # N * O * dim_h_o
            k = torch.stack(tuple(k_split.values()), dim=1) # N * O * C2_2
            r = torch.stack(tuple(r_split.values()), dim=1) # N * O * C2_2
            att = self.ntm_cell.att
            mem = self.ntm_cell.mem

            # Recover the original order of h_o
            if "no_rep" not in o.exp_config:
                #perm_mat_inv = perm_mat.transpose(1, 2) # N * O * O
                h_o = perm_mat_inv.bmm(h_o) # N * O * dim_h_o
                k = perm_mat_inv.bmm(k) # N * O * dim_c_2
                r = perm_mat_inv.bmm(r) # N * O * dim_c_2
                att = perm_mat_inv.data[self.ntm_cell.n].mm(att.view(o.O, -1)).view(o.O, -1, self.ntm_cell.wa) # O * ha * wa
                mem = perm_mat_inv.data[self.ntm_cell.n].mm(mem.view(o.O, -1)).view(o.O, -1, self.ntm_cell.wa) # O * ha * wa
                h_o_prev = perm_mat_inv.bmm(h_o_prev)
                #y_e_prev = perm_mat_inv.bmm(y_e_prev)
            
            if o.v > 0:
                self.att[self.t].copy_(att)
                self.mem[self.t].copy_(mem)
        else:
            h_o = h_o_prev

        y_e, y_l, y_p, Y_s, Y_a = self.generate_outputs(h_o, pos)

        # adaptive computation time
        if "act" in o.exp_config:
            #y_e_perm = perm_mat.bmm(y_e).round() # N * O * dim_y_e
            #y_e_mask = y_e_prev.round() + y_e_perm  # N * O * dim_y_e
            # y_e_mask = y_e_perm
            #y_e_mask = perm_mat.bmm(y_e).round() # N * O * dim_y_e
            #y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask)
            #y_e_mask = y_e_mask.cumsum(1)
            #y_e_mask = y_e_mask.lt(0.5).type_as(y_e_mask)
            #ones = Variable(torch.ones(y_e_mask.size(0), 1, o.dim_y_e).cuda())  # N * 1 * dim_y_e
            #y_e_mask = torch.cat((ones, y_e_mask[:, 0:o.O-1]), dim=1)
            #y_e_mask = perm_mat_inv.bmm(y_e_mask)  # N * O * dim_y_e
            y_e_mask = y_e.round() # N * O * dim_y_e
            y_e_inv_mask = y_e.round().lt(0.5).type_as(y_e) # N * O * dim_y_e
            h_o = y_e_mask * (h_o - h_o_prev) + h_o_prev  # N * O * dim_h_o
            # h_o = y_e_mask * h_o  # N * O * dim_h_o
            y_e = y_e_mask * y_e  # N * O * dim_y_e
            y_p = y_e_mask * y_p  # N * O * dim_y_p
            Y_a = y_e_mask.view(-1, o.O, o.dim_y_e, 1, 1) * Y_a  # N * O * D * h * w
            global_agent_pos = y_e_mask * pos.unsqueeze(1).expand(o.N, o.O, 2)
            global_tracker_offset = y_p[:, :, -2:] * torch.Tensor([o.H // 2, o.W // 2]).cuda()
            if phase == 'obs':
                memor = (y_e_inv_mask * memor) + global_agent_pos + global_tracker_offset

        if self.t == o.T - 1:
            print(y_e.data.view(-1, o.O)[0:1, 0:min(o.O, 10)])

        return memor, h_o, y_e, y_l, y_p, Y_s, Y_a
Example #33
0
    def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

        if height <= 1:

            # Low-pass
            lo0 = math_utils.batch_ifftshift2d(lodft)
            lo0 = torch.ifft(lo0, signal_ndim=2)
            lo0_real = torch.unbind(lo0, -1)[0]
            coeff = [lo0_real]

        else:

            Xrcos = Xrcos - np.log2(self.scale_factor)

            ####################################################################
            ####################### Orientation bandpass #######################
            ####################################################################

            himask = pointOp(log_rad, Yrcos, Xrcos)
            himask = torch.from_numpy(himask[None, :, :,
                                             None]).float().to(self.device)

            order = self.nbands - 1
            const = np.power(2, 2 * order) * np.square(
                factorial(order)) / (self.nbands * factorial(2 * order))
            Ycosn = 2 * np.sqrt(const) * np.power(np.cos(
                self.Xcosn), order) * (np.abs(self.alpha) < np.pi / 2)  # [n,]

            # Loop through all orientation bands
            orientations = []
            for b in range(self.nbands):

                anglemask = pointOp(angle, Ycosn,
                                    self.Xcosn + np.pi * b / self.nbands)
                anglemask = anglemask[None, :, :, None]  # for broadcasting
                anglemask = torch.from_numpy(anglemask).float().to(self.device)

                # Bandpass filtering
                banddft = lodft * anglemask * himask

                # Now multiply with complex number
                # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
                banddft = torch.unbind(banddft, -1)
                banddft_real = self.complex_fact_construct.real * banddft[
                    0] - self.complex_fact_construct.imag * banddft[1]
                banddft_imag = self.complex_fact_construct.real * banddft[
                    1] + self.complex_fact_construct.imag * banddft[0]
                banddft = torch.stack((banddft_real, banddft_imag), -1)

                band = math_utils.batch_ifftshift2d(banddft)
                band = torch.ifft(band, signal_ndim=2)
                orientations.append(band)

            ####################################################################
            ######################## Subsample lowpass #########################
            ####################################################################

            # Don't consider batch_size and imag/real dim
            dims = np.array(lodft.shape[1:3])

            # Both are tuples of size 2
            low_ind_start = (np.ceil((dims + 0.5) / 2) - np.ceil((np.ceil(
                (dims - 0.5) / 2) + 0.5) / 2)).astype(int)
            low_ind_end = (low_ind_start + np.ceil(
                (dims - 0.5) / 2)).astype(int)

            # Subsampling indices
            log_rad = log_rad[low_ind_start[0]:low_ind_end[0],
                              low_ind_start[1]:low_ind_end[1]]
            angle = angle[low_ind_start[0]:low_ind_end[0],
                          low_ind_start[1]:low_ind_end[1]]

            # Actual subsampling
            lodft = lodft[:, low_ind_start[0]:low_ind_end[0],
                          low_ind_start[1]:low_ind_end[1], :]

            # Filtering
            YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
            lomask = pointOp(log_rad, YIrcos, Xrcos)
            lomask = torch.from_numpy(lomask[None, :, :, None]).float()
            lomask = lomask.to(self.device)

            # Convolution in spatial domain
            lodft = lomask * lodft

            ####################################################################
            ####################### Recursion next level #######################
            ####################################################################

            coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos,
                                       height - 1)
            coeff.insert(0, orientations)

        return coeff
Example #34
0
 def forward(self, x):
     cfilter = th.cat([
         f.repeat(i) for f, i in zip(th.unbind(self._filter, 1), self.cat)
     ], 0).unsqueeze(0)
     # print(x.shape, cfilter.shape)
     return x * (self.hard_filter * cfilter).expand_as(x)
Example #35
0
    def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

        if len(coeff) == 1:
            dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
            dft = math_utils.batch_fftshift2d(dft)
            return dft

        Xrcos = Xrcos - np.log2(self.scale_factor)

        ####################################################################
        ####################### Orientation Residue ########################
        ####################################################################

        himask = pointOp(log_rad, Yrcos, Xrcos)
        himask = torch.from_numpy(himask[None, :, :,
                                         None]).float().to(self.device)

        lutsize = 1024
        Xcosn = np.pi * np.array(range(-(2 * lutsize + 1),
                                       (lutsize + 2))) / lutsize
        order = self.nbands - 1
        const = np.power(2, 2 * order) * np.square(
            factorial(order)) / (self.nbands * factorial(2 * order))
        Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)

        orientdft = torch.zeros_like(coeff[0][0])
        for b in range(self.nbands):

            anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b / self.nbands)
            anglemask = anglemask[None, :, :, None]  # for broadcasting
            anglemask = torch.from_numpy(anglemask).float().to(self.device)

            banddft = torch.fft(coeff[0][b], signal_ndim=2)
            banddft = math_utils.batch_fftshift2d(banddft)

            banddft = banddft * anglemask * himask
            banddft = torch.unbind(banddft, -1)
            banddft_real = self.complex_fact_reconstruct.real * banddft[
                0] - self.complex_fact_reconstruct.imag * banddft[1]
            banddft_imag = self.complex_fact_reconstruct.real * banddft[
                1] + self.complex_fact_reconstruct.imag * banddft[0]
            banddft = torch.stack((banddft_real, banddft_imag), -1)

            orientdft = orientdft + banddft

        ####################################################################
        ########## Lowpass component are upsampled and convolved ##########
        ####################################################################

        dims = np.array(coeff[0][0].shape[1:3])

        lostart = (np.ceil((dims + 0.5) / 2) - np.ceil((np.ceil(
            (dims - 0.5) / 2) + 0.5) / 2)).astype(np.int32)
        loend = lostart + np.ceil((dims - 0.5) / 2).astype(np.int32)

        nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
        nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
        YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
        lomask = pointOp(nlog_rad, YIrcos, Xrcos)

        # Filtering
        lomask = pointOp(nlog_rad, YIrcos, Xrcos)
        lomask = torch.from_numpy(lomask[None, :, :, None])
        lomask = lomask.float().to(self.device)

        ################################################################################

        # Recursive call for image reconstruction
        nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos,
                                           nangle)

        resdft = torch.zeros_like(coeff[0][0]).to(self.device)
        resdft[:, lostart[0]:loend[0],
               lostart[1]:loend[1], :] = nresdft * lomask

        return resdft + orientdft
def _compute_dplstm_grad_sample(layer: DPLSTM,
                                A: torch.Tensor,
                                B: torch.Tensor,
                                batch_dim: int = 0) -> None:
    """
    Computes per sample gradients for ``DPLSTM`` layer

    Parameters
    ----------
    layer : opacus.layers.dp_lstm.DPLSTM
        Layer
    A : torch.Tensor
        Activations
    B : torch.Tensor
        Backpropagations
    batch_dim : int, optional
        Batch dimension position
    """
    lstm_params = [
        layer.weight_ih_l0,
        layer.weight_hh_l0,
        layer.bias_ih_l0,
        layer.bias_hh_l0,
    ]
    lstm_out_dim = layer.hidden_size

    x = torch.unbind(A, dim=1)
    hooks_delta = torch.unbind(B, dim=1)

    SEQ_LENGTH = len(x)
    BATCH_SIZE = B.shape[0]

    h_init = torch.zeros(1, BATCH_SIZE, lstm_out_dim, device=A.device)
    c_init = torch.zeros(1, BATCH_SIZE, lstm_out_dim, device=A.device)

    delta_h = {}
    delta_h[SEQ_LENGTH - 1] = 0
    f_last = 0
    dc_last = 0

    for t in range(SEQ_LENGTH - 1, -1, -1):
        f_next = f_last if t == SEQ_LENGTH - 1 else layer.cells[t + 1].f_t
        dc_next = dc_last if t == SEQ_LENGTH - 1 else layer.cells[t + 1].dc_t
        c_prev = c_init if t == 0 else layer.cells[t - 1].c_t
        delta_h[t - 1] = layer.cells[t].backward(x[t], delta_h[t],
                                                 hooks_delta[t], f_next,
                                                 dc_next, c_prev)

    grad_sample = {param: 0 for param in lstm_params}

    for t in range(0, SEQ_LENGTH):
        h_prev = h_init[0, :] if t == 0 else layer.cells[t - 1].h_t[0, :]
        grad_sample[layer.weight_ih_l0] += torch.einsum(
            "ij,ik->ijk", layer.cells[t].dgates_t, x[t])
        grad_sample[layer.weight_hh_l0] += torch.einsum(
            "ij,ik->ijk", layer.cells[t].dgates_t, h_prev)
        grad_sample[layer.bias_ih_l0] += layer.cells[t].dgates_t
        grad_sample[layer.bias_hh_l0] += layer.cells[t].dgates_t

    for param, grad_value in grad_sample.items():
        # pyre-ignore[6]
        _create_or_extend_grad_sample(param, grad_value, batch_dim)
Example #37
0
def RadiallyAverageFourierTransform(x, dim=2):
    real, imag = torch.unbind(SpaceToFourier(x, signal_dim=dim), dim=-1)
    xF = (real.pow(2) + imag.pow(2)).sqrt()
    fsc = RadiallyAverage(xF, dim)
    return fsc
def batch_ifftshift2d(x):
    real, imag = torch.unbind(x, -1)
    for dim in range(len(real.size()) - 1, 0, -1):
        real = roll_n(real, axis=dim, n=real.size(dim) // 2)
        imag = roll_n(imag, axis=dim, n=imag.size(dim) // 2)
    return torch.stack((real, imag), -1)  # last dim=2 (real&imag)
Example #39
0
 def forward(self, v0):
     X0 = torch.stack((self.h0, v0), 1)
     X1 = self.model(X0)
     h1, v1 = torch.unbind(X1, 1)
     return h1, v1