예제 #1
0
    def generate_radial_samples(distance_a, radial_sample_comb,
                                radial_neighbor_combinations):
        result = []
        dominant_size = radial_sample_comb.size()[0]
        print("generating " + str(dominant_size) + " radial elements")
        for i in range(len(distance_a)):
            neighbor_size = distance_a[i].size()[0]
            neighbor_pairs = radial_neighbor_combinations[neighbor_size]
            rs_list_init = torch.cat(
                tuple([
                    torch.index_select(distance_a[i], 0, _).unsqueeze(0)
                    for __, _ in zip(distance_a[i], neighbor_pairs)
                ]))

            rs_list = rs_list_init[:, 1:]
            manipulate_tensor = torch.reshape(
                torch.cat(tuple([rs_list for _ in range(dominant_size)])),
                (dominant_size, -1))
            # print(manipulate_tensor)
            # print(f_c(manipulate_tensor))
            result.append(
                exponential_map(manipulate_tensor, radial_sample_comb).mul(
                    f_c(manipulate_tensor)).sum(dim=1))

        return result
예제 #2
0
def generate_angular_samples(distance_a, angular_sample_comb,
                             angular_neighbor_combinations, neighbor_x,
                             neighbor_y, neighbor_z):
    result = []
    dominant_size = angular_sample_comb.size()[0]
    for i in range(len(distance_a)):
        neighbor_size = distance_a[i].size()[0]
        neighbor_triples = angular_neighbor_combinations[neighbor_size]
        # print(neighbor_triples)

        # part_1: last_fc
        rs_list_init = torch.cat(
            tuple([
                torch.index_select(distance_a[i], 0, _).unsqueeze(0)
                for __, _ in zip(neighbor_triples, neighbor_triples)
            ]))
        rs_list = rs_list_init[:, 1:]
        mul_list_1 = f_c(rs_list[:, :1]).mul(f_c(rs_list[:, 1:]))
        manipulate_tensor_1 = torch.reshape(
            torch.cat(tuple([mul_list_1 for _ in range(dominant_size)])),
            (dominant_size, -1))

        # part_2: exponential
        mul_list_2 = rs_list[:, :1].add(rs_list[:, 1:]).div(2.0)
        manipulate_tensor_2 = torch.reshape(
            torch.cat(tuple([mul_list_2 for _ in range(dominant_size)])),
            (dominant_size, -1))
        manipulate_tensor_2 = torch.exp(
            ((manipulate_tensor_2 - angular_sample_comb[:, 1:2])**2.0
             ).mul(-angular_sample_comb[:, :1])).mul(manipulate_tensor_1)

        # part_3: angular
        x_temp_list = torch.cat(
            tuple([
                torch.index_select(neighbor_x[i], 0, _).unsqueeze(0)
                for __, _ in zip(neighbor_triples, neighbor_triples)
            ]))
        y_temp_list = torch.cat(
            tuple([
                torch.index_select(neighbor_y[i], 0, _).unsqueeze(0)
                for __, _ in zip(neighbor_triples, neighbor_triples)
            ]))
        z_temp_list = torch.cat(
            tuple([
                torch.index_select(neighbor_z[i], 0, _).unsqueeze(0)
                for __, _ in zip(neighbor_triples, neighbor_triples)
            ]))

        x_component_1 = x_temp_list[:, 1:2] - x_temp_list[:, :1]
        y_component_1 = y_temp_list[:, 1:2] - y_temp_list[:, :1]
        z_component_1 = z_temp_list[:, 1:2] - z_temp_list[:, :1]

        x_component_2 = x_temp_list[:, 2:3] - x_temp_list[:, :1]
        y_component_2 = y_temp_list[:, 2:3] - y_temp_list[:, :1]
        z_component_2 = z_temp_list[:, 2:3] - z_temp_list[:, :1]

        inner_product = \
            x_component_1 * x_component_2 + y_component_1 * y_component_2 + z_component_1 * z_component_2

        cosine_triple_angle = inner_product.div(rs_list[:,
                                                        0:1]).div(rs_list[:,
                                                                          1:2])
        sine_triple_angle = torch.sqrt((-cosine_triple_angle**2.0).add(1.0))
        manipulate_tensor_3 = torch.reshape(
            torch.cat(
                tuple([cosine_triple_angle for _ in range(dominant_size)])),
            (dominant_size, -1))
        manipulate_tensor_4 = torch.reshape(
            torch.cat(tuple([sine_triple_angle
                             for _ in range(dominant_size)])),
            (dominant_size, -1))
        manipulate_tensor_5 = \
            torch.cos(angular_sample_comb[:, 3:]).mul(
                manipulate_tensor_3).add(torch.sin(angular_sample_comb[:, 3:]).mul(manipulate_tensor_4))
        manipulate_tensor_5 = (manipulate_tensor_5.add(1.0)).pow(
            angular_sample_comb[:, 2:3]).mul(
                2.0**(1 -
                      angular_sample_comb[:, 2:3])).mul(manipulate_tensor_2)

        result.append(manipulate_tensor_5.sum(dim=1))

    return result