Пример #1
0
    def set_teacher_signal(self, y):
        if isinstance(y, dict): y = y[P.KEY_LABEL_TARGETS]
        if y is not None: y = utils.dense2onehot(y, self.NUM_CLASSES)

        self.fc6.set_teacher_signal(y)
        if y is None:
            self.conv4.set_teacher_signal(y)
            self.fc5.set_teacher_signal(y)
        elif self.DEEP_TEACHER_SIGNAL:
            # Extend teacher signal for deep layers
            l4_knl_per_class = 240 // self.NUM_CLASSES
            l5_knl_per_class = 4000 // self.NUM_CLASSES
            if self.NUM_CLASSES <= 20:
                self.conv4.set_teacher_signal(
                    torch.cat((
                        torch.ones(y.size(0),
                                   self.conv4.weight.size(0) -
                                   l4_knl_per_class * self.NUM_CLASSES,
                                   device=y.device),
                        y.view(y.size(0), y.size(1), 1).repeat(
                            1, 1, l4_knl_per_class).view(y.size(0), -1),
                    ),
                              dim=1))
            self.fc5.set_teacher_signal(
                torch.cat((
                    torch.ones(y.size(0),
                               self.fc5.weight.size(0) -
                               l5_knl_per_class * self.NUM_CLASSES,
                               device=y.device),
                    y.view(y.size(0), y.size(1), 1).repeat(
                        1, 1, l5_knl_per_class).view(y.size(0), -1),
                ),
                          dim=1))
Пример #2
0
    def set_teacher_signal(self, y):
        if y is not None: y = utils.dense2onehot(y, self.NUM_CLASSES)

        self.fc10.set_teacher_signal(y)
        if y is None:
            self.conv8.set_teacher_signal(y)
            self.fc9.set_teacher_signal(y)
        elif self.DEEP_TEACHER_SIGNAL:
            # Extend teacher signal for deep layers
            l8_knl_per_class = 500 // self.NUM_CLASSES
            l9_knl_per_class = 4000 // self.NUM_CLASSES
            if self.NUM_CLASSES <= 20:
                self.conv8.set_teacher_signal(
                    torch.cat((
                        torch.ones(y.size(0),
                                   self.conv8.weight.size(0) -
                                   l8_knl_per_class * self.NUM_CLASSES,
                                   device=y.device),
                        y.view(y.size(0), y.size(1), 1).repeat(
                            1, 1, l8_knl_per_class).view(y.size(0), -1),
                    ),
                              dim=1))
            self.fc9.set_teacher_signal(
                torch.cat((
                    torch.ones(y.size(0),
                               self.fc9.weight.size(0) -
                               l9_knl_per_class * self.NUM_CLASSES,
                               device=y.device),
                    y.view(y.size(0), y.size(1), 1).repeat(
                        1, 1, l9_knl_per_class).view(y.size(0), -1),
                ),
                          dim=1))
Пример #3
0
 def set_teacher_signal(self, y):
     if isinstance(y, dict): y = y[P.KEY_LABEL_TARGETS]
     if y is not None: y = utils.dense2onehot(y, self.NUM_CLASSES)
     self.fc2.set_teacher_signal(y)
Пример #4
0
 def set_teacher_signal(self, y):
     if y is not None: y = utils.dense2onehot(y, self.NUM_CLASSES)
     self.fc2.set_teacher_signal(y)
Пример #5
0
 def __call__(self, outputs, targets):
     if isinstance(outputs, dict): outputs = outputs[P.KEY_CLASS_SCORES]
     if isinstance(targets, dict): targets = targets[P.KEY_LABEL_TARGETS]
     return self.mse_loss(outputs,
                          utils.dense2onehot(targets, outputs.size(1)))