Ejemplo n.º 1
0
    def reset_parameters(self):
        self.apply(weights_init)

        relu_gain = nn.init.calculate_gain('relu')

        self.conv1.weight.data.mul_(relu_gain)
        self.conv2.weight.data.mul_(relu_gain)
        self.conv3.weight.data.mul_(relu_gain)
        self.conv4.weight.data.mul_(relu_gain)
        self.conv5.weight.data.mul_(relu_gain)
        self.conv6.weight.data.mul_(relu_gain)
        self.conv7.weight.data.mul_(relu_gain)
        self.conv8.weight.data.mul_(relu_gain)
        self.conv9.weight.data.mul_(relu_gain)
        self.conv10.weight.data.mul_(relu_gain)
        self.conv11.weight.data.mul_(relu_gain)
        self.conv12.weight.data.mul_(relu_gain)

        self.fc1.weight.data.mul_(relu_gain)
        self.fc2.weight.data.mul_(relu_gain)

        if hasattr(self, 'gru'):
            orthogonal(self.gru.weight_ih.data)
            orthogonal(self.gru.weight_hh.data)
            self.gru.bias_ih.data.fill_(0)
            self.gru.bias_hh.data.fill_(0)

        if self.dist.__class__.__name__ == "DiagGaussian":
            self.dist.fc_mean.weight.data.mul_(0.01)
Ejemplo n.º 2
0
def weight_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        orthogonal(m.weight.data)

        if m.bias is not None:
            m.bias.data.fill_(0)
Ejemplo n.º 3
0
    def reset_parameters(self):
        self.apply(weights_init_mlp)

        orthogonal(self.gru.weight_ih.data)
        orthogonal(self.gru.weight_hh.data)
        self.gru.bias_ih.data.fill_(0)
        self.gru.bias_hh.data.fill_(0)

        if self.dist.__class__.__name__ == "DiagGaussian":
            self.dist.fc_mean.weight.data.mul_(0.01)
Ejemplo n.º 4
0
    def reset_parameters(self):
        self.apply(weights_init)

        relu_gain = nn.init.calculate_gain('leaky_relu')
        self.linear1.weight.data.mul_(relu_gain)

        if hasattr(self, 'gru'):
            orthogonal(self.gru.weight_ih.data)
            orthogonal(self.gru.weight_hh.data)
            self.gru.bias_ih.data.fill_(0)
            self.gru.bias_hh.data.fill_(0)

        if self.dist.__class__.__name__ == "DiagGaussian":
            self.dist.fc_mean.weight.data.mul_(0.01)
Ejemplo n.º 5
0
    def reset_parameters(self):
        self.apply(weights_init)
        self.head.apply(weights_init_head)

        if hasattr(self, 'icm'):
            self.icm.head.apply(weights_init_head)

        if hasattr(self, 'gru'):
            orthogonal(self.gru.weight_ih.data)
            orthogonal(self.gru.weight_hh.data)
            self.gru.bias_ih.data.fill_(0)
            self.gru.bias_hh.data.fill_(0)

        if self.dist.__class__.__name__ == "DiagGaussian":
            self.dist.fc_mean.weight.data.mul_(0.01)
Ejemplo n.º 6
0
    def reset_parameters(self):
        self.apply(weights_init)

        relu_gain = nn.init.calculate_gain('relu')
        self.conv1.weight.data.mul_(relu_gain)
        self.conv2.weight.data.mul_(relu_gain)
        self.conv3.weight.data.mul_(relu_gain)
        self.linear1.weight.data.mul_(relu_gain)
        self.linear2.weight.data.mul_(relu_gain)

        if hasattr(self, 'gru'):
            orthogonal(self.gru.weight_ih.data)
            orthogonal(self.gru.weight_hh.data)
            self.gru.bias_ih.data.fill_(0)
            self.gru.bias_hh.data.fill_(0)
Ejemplo n.º 7
0
    def add_item(self, item):
        u = utils.vector(self.point, item.point)

        direction = utils.orthogonal(u)
        center = (.5 * (self.point[0] + item.point[0]),
                  .5 * (self.point[1] + item.point[1]))
        line = shapes.InfiniteLine(center, direction)
        new_item = OtherItem(item, line)

        keep_new_item = True
        others_to_remove = set()

        for other in self.others:
            other.line.constrain(new_item, self.point)
            new_item.line.constrain(other, self.point)

            if not other.is_valid():
                others_to_remove |= {other}

            if not new_item.is_valid():
                keep_new_item = False
                break

        self.others -= others_to_remove

        if keep_new_item:
            self.others |= {new_item}
            return new_item
        return None
Ejemplo n.º 8
0
    def item_that_contains(self, point):
        for point_item in self.point_items:
            if not point_item.is_bounded():
                continue

            skip = False
            for other in point_item.others:
                edge_normal = utils.orthogonal(other.line.direction)
                A, B = other.vertices()
                AO = utils.vector(A, point_item.point)
                AP = utils.vector(A, point)
                if utils.dot(AO, edge_normal) * utils.dot(AP, edge_normal) < 0:
                    skip = True
                    break
            if not skip:
                return point_item
        raise Exception("PointSet.item_that_contains: not found")
Ejemplo n.º 9
0
 def reset_parameters(self):
     orthogonal(self.gru.weight_ih.data)
     orthogonal(self.gru.weight_hh.data)
     self.gru.bias_ih.data.fill_(0)
     self.gru.bias_hh.data.fill_(0)
     self.critic_linear.bias.data.fill_(0)
Ejemplo n.º 10
0
 def __init__(self, point, direction):
     self.point = point
     self.direction = direction
     self.normal = utils.orthogonal(self.direction)