Beispiel #1
0
    def __call__(self, class_logits, box_regression):
        if not hasattr(self, "_proposals"):
            raise RuntimeError("subsample needs to be called before")

        proposals = self._proposals

        labels = torch.cat(
            [proposal.get_field("labels") for proposal in proposals], dim=0)
        regression_targets = torch.cat([
            proposal.get_field("regression_targets") for proposal in proposals
        ],
                                       dim=0)

        sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
        labels_pos = labels[sampled_pos_inds_subset]
        map_inds = 4 * labels_pos[:, None] + self.offset

        valid = labels >= 0
        box_cls_loss = F.cross_entropy(class_logits[valid], labels[valid])
        box_loc_loss = smooth_l1_loss(
            box_regression[sampled_pos_inds_subset[:, None], map_inds],
            regression_targets[sampled_pos_inds_subset],
            size_average=False,
            beta=1,
        )

        if self.sampling_free:
            box_loc_loss = box_loc_loss / (labels_pos.numel() * 4)
            with torch.no_grad():
                ratio = box_loc_loss / box_cls_loss
            box_cls_loss = 2 * ratio * box_cls_loss
        else:
            box_loc_loss = box_loc_loss / labels.numel()
        return dict(box_cls_loss=box_cls_loss, box_loc_loss=box_loc_loss)
Beispiel #2
0
    def forward(self,
                graph: dgl.DGLGraph,
                graph_state_embedding: torch.Tensor,
                extra_node_state: torch.Tensor = None,
                choose_node=None):
        num_possible_nodes = graph.number_of_nodes()
        possible_nodes = range(graph.number_of_nodes())
        possible_nodes_embed = graph.nodes[possible_nodes].data[
            MessageKey.repr_vertex]
        per_node_extra_node_state = extra_node_state.expand(
            num_possible_nodes, -1)
        per_node_graph_state_embedding = graph_state_embedding.expand(
            num_possible_nodes, -1)
        per_node_choice = self._choose_node(
            torch.cat([
                per_node_graph_state_embedding, per_node_extra_node_state,
                possible_nodes_embed
            ],
                      dim=1))
        choices_logit = per_node_choice.view(-1, num_possible_nodes)
        choices_probs = F.softmax(choices_logit, dim=1)

        if not self.training:
            return Categorical(choices_probs).sample()
        else:
            assert choose_node < num_possible_nodes
            log_loss = F.cross_entropy(
                choices_logit,
                choose_node.view(1)) / np.log(num_possible_nodes)
            self.add_log_loss(log_loss)
            return choose_node
Beispiel #3
0
    def forward(self, outputs, targets):
        """損失関数の計算

        Args:
            outputs PSPNetの出力(tuple): (output=torch.Size([num_batch, 21, 475, 475]),
                                         outupt_aux=torch.Size([num_batch, 21, 475, 475]))

            targets [num_batch, 475, 475]: 正解のアノテーション情報

        Returns:
            loss : テンソル 損失
        """

        loss = F.cross_entropy(outputs[0], targets, reduction='mean')
        loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')

        return loss + self.aux_weight * loss_aux
    def train(self, obs: dict) -> None:
        obs, z = self._split_obs(obs)
        z = torch.argmax(z, dim=1, keepdim=True).squeeze(1)
        loss = F.cross_entropy(self.model.forward(obs), z)
        self.log("loss", loss)

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
 def validation_step(self, batch, batch_idx):
     x, y = batch
     y_pred = self(x)
     val_loss = F.cross_entropy(y_pred, y)
     self.val_acc(y_pred, y)
     self.log('val_acc',
              self.val_acc,
              prog_bar=True,
              on_step=False,
              on_epoch=True)
     return val_loss
Beispiel #6
0
 def test_step(self, batch, batch_idx):
     x, y = batch
     y_pred = self(x)
     loss = F.cross_entropy(y_pred, y)
     self.test_acc(y_pred, y)
     self.log('test_acc',
              self.test_acc,
              prog_bar=True,
              on_step=False,
              on_epoch=True)
     return loss
Beispiel #7
0
    def forward(self, embedding, action=None):
        next_state_logits = self._next_state_action(embedding)

        if self.training:
            next_state_log_loss = F.cross_entropy(
                next_state_logits, self.map_action_to_index(action))
            self.add_log_loss(next_state_log_loss)
            return action
        else:
            next_state_probs = F.softmax(next_state_logits, dim=1)
            next_action_ix = Categorical(next_state_probs).sample().item()
            return self.map_index_to_action(next_action_ix)
Beispiel #8
0
    def calc_loss(self, coord, label, loc, conf):
        # step2. hard negative mining
        num_class = conf.size(2)

        coord = coord.view(-1, 4)
        loc = loc.view(-1, 4)

        label = label.view(-1)
        conf = conf.view(-1, num_class)

        # "positive" means label is not background
        pos_mask = label != 0
        pos_conf = conf[pos_mask]
        pos_label = label[pos_mask]

        # sort background confidence by loss in descending order
        tmp = F.cross_entropy(conf, label, reduction='none')
        tmp[pos_mask] = 0.

        _, neg_indices = tmp.sort(descending=True)

        # pick num(positive_samples)*3 of negative samples per batch
        num_pos = pos_conf.size(0)
        num_neg = min(num_pos * 3, conf.size(0) - num_pos)

        neg_conf = conf[neg_indices[0:num_neg]]
        neg_label = label[neg_indices[0:num_neg]]

        conf = torch.cat([pos_conf, neg_conf], 0)
        label = torch.cat([pos_label, neg_label], 0)

        l_conf = F.cross_entropy(conf, label, reduction='sum')

        # - calc l_loc
        coord = coord[pos_mask]
        loc = loc[pos_mask]

        l_loc = F.smooth_l1_loss(loc, coord, reduction='sum')

        return (l_conf + self.alpha * l_loc) / num_pos
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self(x)
        loss = F.cross_entropy(y_pred, y)
        self.test_acc(y_pred, y)
        self.log('test_acc',
                 self.test_acc,
                 prog_bar=True,
                 on_step=False,
                 on_epoch=True)

        pred = torch.argmax(y_pred, -1)
        return dict(loss=loss, pred=pred, y=y)
 def calc_rewards(self, obs: dict, eps: float = 1e-7) -> torch.Tensor:
     with torch.no_grad():
         obs, z = self._split_obs(obs)
         z = torch.argmax(z, dim=1, keepdim=True).squeeze(1)
         logits = self.model.forward(obs)
         rewards = (
             (self.reward_weight * -F.cross_entropy(logits, z, reduction="none"))
             .unsqueeze(1)
             .detach()
         )
         pred_z = pyd.Categorical(logits=logits).sample()
         self.log("accuracy", (pred_z == z).float().mean())
         self.log("rewards", rewards)
         return rewards
Beispiel #11
0
def cal_loss(pred, gold, PAD, smoothing='1'):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing == '0':
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(0)
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()  # average later
    elif smoothing == '1':
        loss = F.cross_entropy(pred, gold, ignore_index=PAD)
    else:
        # loss = F.cross_entropy(pred, gold, ignore_index=PAD)
        loss = F.cross_entropy(pred, gold)

    return loss
 def training_step(self, batch, batch_idx):
     x, y = batch
     y_pred = self(x)
     loss = F.cross_entropy(y_pred, y)
     train_acc_batch = self.train_acc(y_pred, y)
     self.log('train_acc_batch',
              train_acc_batch,
              prog_bar=True,
              on_step=True)
     self.log('train_acc',
              self.train_acc,
              prog_bar=True,
              on_step=False,
              on_epoch=True)
     return loss
Beispiel #13
0
 def loss(outputs, labels):
     outputs = outputs.reshape(outputs.shape[0], -1)
     labels = labels.reshape(labels.shape[0], -1)
     loss = F.cross_entropy(outputs.double(), labels.double().max(1)[1])
     return loss
Beispiel #14
0
 def distillation(self, y, labels, teacher_scores, temp, alpha):
     return self.KLDivLoss(F.log_softmax(y / temp, dim=1),
                           F.softmax(teacher_scores / temp, dim=1)) * (
                               temp * temp * 2.0 * alpha) + F.cross_entropy(
                                   y, labels) * (1. - alpha)
def run(rank, size):
    torch.manual_seed(1234)
    train_set, bsz, test_set, fake_set = partition_dataset()
    print(len(train_set))
    model = Simple.Simple()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
    clients = torch.arange(1, dist.get_world_size())

    print('Rank', dist.get_rank())

    if dist.get_rank() != 0:
        num_batches = ceil(len(train_set.dataset) / float(bsz))
        train_set_size = len(train_set)
        number_of_epochs = 0

        losses = []
        mini_batch_loss = []
        train_set_iterator = iter(train_set)
        fake_set_iterator = iter(fake_set)
        epoch_loss = 0.0
        index = 0
        current_batch_loss = []
        turn = torch.tensor(0)
        participating_clients = torch.zeros(participating_counts,
                                            dtype=torch.long)
        mini_batches_to_go = 0
        group = None
        writer = SummaryWriter(
            'mnist-multiplied-100/logs-mnist-fedsgd-3-attacker/')
    number_of_elections = 0
    while True:
        if dist.get_rank() == 0:
            print('============election phase %s started===========' %
                  number_of_elections)
            number_of_elections += 1
            permutation = torch.randperm(dist.get_world_size() - 1)
            clients = clients[permutation]
            participating_clients = clients[0:participating_counts]
            participating_clients_final = participating_clients.tolist()
            print('Participating %s' % participating_clients_final)
            mini_batches_to_go = LOCAL_ITERATIONS
            for index_2, client in enumerate(clients.tolist()):
                if index_2 < participating_counts:
                    dist.send(torch.tensor(1), dst=client)
                    print('Turn sent for client %s' % client)
                else:
                    dist.send(torch.tensor(0), dst=client)
                    print('Turn sent for client %s' % client)
                dist.send(participating_clients, dst=client)
            print('Client %s passed the dist.new_group' % dist.get_rank())
            group = dist.new_group(participating_clients_final)
        else:
            if mini_batches_to_go == 0:
                print('============election phase %s started===========' %
                      number_of_elections)
                number_of_elections += 1
                if number_of_elections == 450 * 5 / LOCAL_ITERATIONS:
                    number_of_elections = 0
                    print("////////// Epoch Changed! %s //////////////" %
                          number_of_epochs)
                    number_of_epochs += 1
                    if number_of_epochs == EPOCHS:
                        break
                    train_set_iterator = iter(train_set)
                    fake_set_iterator = iter(fake_set)
                    epoch_loss = 0.0

                print('Waiting for the current turn...')
                dist.recv(turn, src=0)
                print('Turn is %s' % turn)
                print('Waiting for the winners of the election...')
                dist.recv(participating_clients, src=0)
                participating_clients_final = participating_clients.tolist()
                print('Participating %s' % participating_clients_final)
                print('Client %s passed the dist.new_group' % dist.get_rank())
                # This needs to be fixed!! Race Condition!! Distributed Systems Bug!
                time.sleep(0.03)
                group = dist.new_group(participating_clients_final)
                if turn == 1:
                    mini_batches_to_go = LOCAL_ITERATIONS
        # Only those with the pass can enter the game!
        if dist.get_rank() != 0 and turn == 1:
            if dist.get_rank() not in fake_targets:
                data, target = next(train_set_iterator)
            else:
                data, target = next(fake_set_iterator)
            mini_batches_to_go -= 1
            index += 1
            if index % 250 == 0 and dist.get_rank() == 1:
                model_accuracy = accuracy(model, test_set)
                writer.add_scalar('%s/Validation/Accuracy' % dist.get_rank(),
                                  model_accuracy, index)
                writer.flush()
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            writer.add_scalar('%s/Train/Loss' % dist.get_rank(), loss.item(),
                              index)
            writer.flush()
            epoch_loss += loss.item()
            current_batch_loss.append(epoch_loss / index)
            loss.backward()
            print('Processed', str(index) + '/' + str(train_set_size))
            if mini_batches_to_go == 0:
                print('Aggregating...')
                participating_clients_final = participating_clients.tolist()
                print('Participating %s' % participating_clients_final)
                average_gradients(model, group=group)
            optimizer.step()
    #mini_batch_loss.append(epoch_loss)
    #losses.append(epoch_loss)
    #print(losses)
    if dist.get_rank() == 1:
        accuracy(model, test_set)
        torch.save(model.state_dict(), './model-2.pt')
    def forward(self, predictions, targets):
        """
        損失関数の計算
        Args:
            predictions: SSD netの訓練時の出力(tuple)
             loc=torch.Size([num_batch, 8732, 4]),
             conf=torch.Size([num_batch, 8732, 21]),
             dbox_list=torch.Size([8732, 4])

            targets: [num_batch, num_jobs, 5]
            5は正解アノテーション情報[xmin, ymin, xmax, ymax, label_index]を示す

        Returns:
            loss_l: locの損失値 SmoothL1Loss
            loss_c: confの損失値 CrossEntropyLoss
        """

        loc_data, conf_data, dbox_list = predictions
        # print("loc_data size: ", loc_data.size())
        num_batch = loc_data.size(0)  # ミニバッチ数(*)
        num_dbox = loc_data.size(1)  # DBox数(8732)
        num_classes = conf_data.size(2)  # クラス数(21)

        # 損失計算に使用する変数
        # conf_t_label: 各DBoxに、一番近い正解のBBoxのラベルを格納 8732
        # loc_t: 各DBoxに、一番近いBBoxのいち情報を格納 8732
        conf_t_label = torch.LongTensor(num_batch,
                                        num_dbox).to(self.device)  # torch.long
        loc_t = torch.Tensor(num_batch, num_dbox,
                             4).to(self.device)  # Tensorはtorch.float32
        # print("loc_t size: ", loc_t.size())
        # print("conf_t_label size: ", conf_t_label.size())

        # loc_tとconf_t_labelに, DBoxと正解アノテーションtargets(BBox)をmatchさせた結果を上書きする
        for idx in range(num_batch):
            truths_loc = targets[idx][:, :-1].to(self.device)  # BBox
            labels_conf = targets[idx][:, -1].to(self.device)  # Labels
            # print("truths_loc size: ", truths_loc.size())
            # print("labels_conf size: ", labels_conf)

            dbox = dbox_list.to(self.device)

            # 関数matchを実行し、loc_tとconf_t_labelの内容を更新する
            # (詳細)
            # loc_t: 各DBoxに、一番近い正解のBBoxの位置情報が上書きされる
            # conf_t_label: 各DBoxに、一番近い正解のBBoxのラベルが上書きされる
            # ただし、一番近いBBoxとのjaccard係数が0.5より小さい場合は、正解BBoxのconf_t_labelは背景クラス0とする
            variance = [0.1, 0.2]
            # loc_t[idx], conf_t_label[idx] = match(self.jaccard_thresh, truths_loc, dbox, variance, labels_conf)
            match(self.jaccard_thresh, truths_loc, dbox, variance, labels_conf,
                  loc_t, conf_t_label, idx)

        # ここで、
        # loc_tは8732個の要素のうち、Positive DBoxに該当する数だけ有効な数値が入る
        # conf_t_labelは8732個の要素数は変わらず、Positive DBoxはtarget BBoxのクラスラベルが入り、Negative DBoxは背景(0)になる

        # -----
        # 位置の損失:loss_l
        # Smooth L1関数
        # ただし物体を発見したDBoxのオフセットのみを計算する
        # -----

        # 物体を検出したDBox(Positive DBox)を取り出すマスク
        pos_mask = conf_t_label > 0  # torch.Size([num_batch, 8732])

        # torch.Size([num_batch, 8732]) -> torch.Size([num_batch, 8732, 4])
        pos_idx = pos_mask.unsqueeze(pos_mask.dim()).expand_as(loc_data)

        # Positive DBoxのloc_data(位置補正情報の推論値)と教師データloc_tを取得
        loc_p = loc_data[pos_idx].view(
            -1, 4)  # Boolean Indexによる抽出後は必ず、1次元配列になるので、形状を変更する
        loc_t = loc_t[pos_idx].view(-1, 4)

        # 物体を発見したPositive DBoxのオフセット情報loc_tの損失(誤差)を計算
        loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
        # print("loc_p", loc_p)
        # print("loc_t", loc_t)
        # print("loss_l", loss_l)

        # -----
        # クラス予測の損失: loss_c
        # 交差エントロピー誤差関数
        # 背景クラスが正解のDBoxが圧倒的に多いので、Hard Negative Miningを実施し、
        # 物体発見DBoxと背景クラスDBoxの比が1:3になるようにする。
        # 背景クラスDBoxと予想したもののうち、損失が小さいものはクラス予測の損失から除く
        # -----

        batch_conf = conf_data.view(
            -1, num_classes)  # (batch_num,8732,21) -> (batch_num*8732,21)
        # print("batch_conf", batch_conf)
        # print("batch_conf size: ", batch_conf.size())

        # クラス予測の損失関数を計算(reduction='none'にして、和を取らずに次元を潰さない)
        # batch_conf size: (batch_num*8732,21), conf_t_label size: (batch_num*8732,)
        loss_c = F.cross_entropy(batch_conf,
                                 conf_t_label.view(-1),
                                 reduction='none')  # 一旦、すべてのDBoxに対して損失を計算
        # loss_c: (batch_num * 8732,)

        # -----
        # Negative DBoxのうち, Hard Negative Miningで抽出するものを求めるマスクを作成
        # -----

        # 物体を発見したPositive DBoxの損失を0にする
        # (注意) 物体はlabelが1以上.0は背景
        num_pos = pos_mask.long().sum(
            dim=1, keepdim=True
        )  # 各入力データ(画像)毎のPositive Boxの数を取得 (batch_num, 8732) -> (batch_num, 1)
        loss_c = loss_c.view(num_batch, -1)  # torch.Size([num_batch, 8732])
        loss_c[pos_mask] = 0  # 物体を発見したDBoxに対応する損失は0にする

        # Hard Negative Miningの実行
        """各DBoxの損失の大きさloss_cの順位であるidx_rankを求める"""
        _, loss_idx = loss_c.sort(dim=1,
                                  descending=True)  # 損失に基づいて各DBox(8732)を降順にソート
        _, idx_rank = loss_idx.sort(dim=1)
        # loss_rankは、DBoxの損失を降順にソートした時の元配列のインデックスの並び
        """
        (注釈)
        上2行の実装コードは特殊で直感的でない。
        やりたいことは、各DBoxに対して、損失の大きさが何番目なのかの情報をidx_rankとして高速に取得する。

        DBoxの損失値の大きい方から降順に並べ、DBoxの降順のindexをloss_idxに格納。
        損失の大きさloss_cの順位であるidx_rankを求める。
        ここで、
        降順になった配列indexであるloss_idxを0~8732までの昇順で並べ直すためには、
        何番目のloss_idxのインデックスを取ってきたら良いかを示すのが、idx_rankである。
        例えば、
        idx_rankの要素0番目 = idx_rank[0]を求めるには、loss_idxの値が0の要素、つまり
        loss_idx[?] = 0の?は何番目かを求めることになる。ここで、? = idx_rank[0]である。
        いま、loss_idx[?] = 0の0は、元のloss_cの要素の0番目という意味である。
        つまり、?は、元のloss_cの要素0番目は、降順に並び替えられたloss_idxの何番目ですか
        を求めていることになり、結果、? = idx_rank[0]はloss_cの要素0番目が降順の何番目かを示す。

        e.g
        loss_c                      3.2  5.8  1.3  2.5  4.0
        sorted_loss_c               5.8  4.0  3.2  2.5  1.3
        descending_of_loss_c_index    1    4    0    3    2 (loss_idx)
        sorted_loss_idx               0    1    2    3    4
        ascending_of_loss_idx         2    0    4    3    1 (idx_rank)

        """

        # 背景のDBoxの数num_negを決める。Hard Negative Miningにより、物体を発見したDBoxの数num_posの3倍(self.negpos_ratio)とする。
        # 万が一、DBoxの数を超える場合は、DBoxを上限とする
        num_neg = torch.clamp(num_pos * self.negpos_ratio, max=num_dbox)

        # 背景のDBoxの数num_negよりも順位が低い(損失が大きい)DBoxを抽出するマスク
        neg_mask = idx_rank < num_neg.expand_as(idx_rank)

        # -----
        # (終了)
        # -----

        # Negative DBoxのうち、Hard Negative Miningで抽出するものを求めるマスクを作成

        # pos_mask: torch.Size([num_batch, 8732]) -> pos_idx_mask: torch.Size([num_batch, 8732, 21])
        pos_idx_mask = pos_mask.unsqueeze(2).expand_as(conf_data)
        neg_idx_mask = neg_mask.unsqueeze(2).expand_as(conf_data)

        # posとnegだけを取り出してconf_hnmにする。torch.Size([num_pos + num_neg, 21])
        # gtは greater than (>)の略称。これでmaskが1のindexを取り出す。
        conf_hnm = conf_data[(pos_idx_mask + neg_idx_mask).gt(0)].view(
            -1, num_classes)

        # posとnegだけのconf_t_label torch.Size([pos + neg])
        conf_t_label_hnm = conf_t_label[(pos_mask + neg_mask).gt(0)]

        # confidenceの損失関数を計算
        loss_c = F.cross_entropy(conf_hnm, conf_t_label_hnm, reduction='sum')
        # print("conf_hnm", conf_hnm)
        # print("conf_t_label_num", conf_t_label_hnm)
        # print("loss_c", loss_c)

        # 物体を発見したBBoxの数N(全ミニバッチの合計)で損失を割り算
        N = num_pos.sum()
        loss_l /= N
        loss_c /= N

        return loss_l, loss_c
Beispiel #17
0
def xloss(logits, labels, ignore=None):
    """
    Cross entropy loss
    """
    return F.cross_entropy(logits, Variable(labels), ignore_index=255)