コード例 #1
0
    def forward(self, x, prev_state):
        """ Apply GRU step channelwise 
        Args: 
            x (Layer): Input (shape: (num_channels, *))
            prev_state (Layer): Sate from previous GRU step  (shape: (num_channels, size))
        Returns:
            (Layer, Layer): The output (out) and state (hn). The state can be passed directly into the next GRU step
        """

        self.step += 1

        name = f"{self.name}_step{self.step}"

        mat_size = self.num_channels * self.size

        prev_state = lbann.Reshape(prev_state,
                                   dims=str_list(
                                       [self.num_channels, self.size]),
                                   name=name + "_prev_state_reshape")

        fc1 = self.ih_fc(x)
        fc2 = self.hh_fc(prev_state)

        fc1_slice = lbann.Slice(
            fc1,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Wir_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wir_x')
        Wiz_z = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Wiz_z')
        Win_x = lbann.Reshape(lbann.Identity(fc1_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Win_x')
        fc2_slice = lbann.Slice(
            fc2,
            axis=1,
            slice_points=str_list([0, self.size, 2 * self.size,
                                   3 * self.size]))

        Whr_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whr_x')
        Whz_z = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whz_z')
        Whn_x = lbann.Reshape(lbann.Identity(fc2_slice),
                              dims=str_list([self.num_channels, self.size]),
                              name=name + '_Whn_x')

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_x, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_z, Whz_z, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_x, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        ht = lbann.Reshape(ht, dims=str_list([self.num_channels, self.size]))

        return ht, ht
コード例 #2
0
        _reader = reader.reader.add()
        _reader.name = 'synthetic'
        _reader.role = role
        _reader.num_samples = 1
        _reader.num_labels = 1
        _reader.synth_dimensions = '1'
        _reader.percent_of_data_to_use = 1.0

    add_data_reader('train')
    add_data_reader('test')
    input_ = lbann.Input()

    # Radial profile
    x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(image.flatten())), ),
        dims=str_list(image.shape),
    )
    max_r = image.shape[-1] // 2
    rprof = RadialProfile()(x, image.shape, max_r)
    rprof_slice = lbann.Slice(rprof, slice_points=str_list([0, 1, 2, 3]))
    red = lbann.Identity(rprof_slice, name='red')
    green = lbann.Identity(rprof_slice, name='green')
    blue = lbann.Identity(rprof_slice, name='blue')

    # Construct model
    callbacks = [
        lbann.CallbackDumpOutputs(layers=str_list(['red', 'green', 'blue'])),
    ]
    model = lbann.Model(
        epochs=0,
コード例 #3
0
ファイル: NNConvModel.py プロジェクト: benson31/lbann
def graph_data_splitter(_input, NUM_NODES, NUM_EDGES, NUM_NODE_FEATURES,
                        NUM_EDGE_FEATURES, EMBEDDING_DIM, EDGE_EMBEDDING_DIM):
    """Helper function to split the input data into

			Args:
				NUM_NODES (int): The number of nodes in the largest graph in the dataset (51 for LSC-PPQM4M)
		      	NUM_EDGES (int): The number of edges in the largest graph in the dataset (118 for LSC-PPQM4M)
		      	NUM_NODE_FEATURES (int): The dimensionality of the input node features vector (9 for LSC-PPQM4M)
		      	NUM_EDGE_FEATURES (int): The dimensionality of the input edge feature vectors (3 for LSC-PPQM4M)
		      	EMBEDDING_DIM (int): The embedding dimensionality of the node feature vector

		      	EDGE_EMBEDDING_DIM (int): The embedding dimensionality of the edge feature vector
			Returns:
				(Layer, Layer, Layer, Layer, Layer): Returns 5 Layers. The embedded node feature matrix, the
													 neighbord nodes feature tensor, the embedded edge feature matrix,
													 the source node index vector, and the label
		"""
    split_indices = []

    start_index = 0
    split_indices.append(start_index)

    node_feature = [NUM_NODES for i in range(1, NUM_NODE_FEATURES + 1)]

    split_indices.extend(node_feature)

    edge_features = [NUM_EDGES for i in range(1, NUM_EDGE_FEATURES + 1)]

    split_indices.extend(edge_features)

    edge_indices_sources = NUM_EDGES
    split_indices.append(edge_indices_sources)

    edge_indices_targets = NUM_EDGES
    split_indices.append(edge_indices_targets)

    target = 1
    split_indices.append(target)

    for i in range(1, len(split_indices)):
        split_indices[i] = split_indices[i] + split_indices[i - 1]

    graph_input = lbann.Slice(_input,
                              axis=0,
                              slice_points=str_list(split_indices))

    neighbor_feature_dims = str_list([NUM_EDGES, 1, EMBEDDING_DIM])

    node_feature_columns = [
        lbann.Reshape(lbann.Identity(graph_input),
                      dims=str_list([NUM_NODES]),
                      name="node_ft_{}_col".format(x))
        for x in range(NUM_NODE_FEATURES)
    ]

    edge_feature_columns = [
        lbann.Reshape(lbann.Identity(graph_input),
                      dims=str_list([NUM_EDGES]),
                      name="edge_ft_{}_col".format(x))
        for x in range(NUM_EDGE_FEATURES)
    ]

    source_nodes = lbann.Reshape(lbann.Identity(graph_input),
                                 dims=str_list([NUM_EDGES]),
                                 name="source_nodes")
    target_nodes = lbann.Reshape(lbann.Identity(graph_input),
                                 dims=str_list([NUM_EDGES]),
                                 name="target_nodes")
    label = lbann.Reshape(lbann.Identity(graph_input),
                          dims=str_list([1]),
                          name="Graph_Label")

    embedded_node_features = AtomEncoder(node_feature_columns, EMBEDDING_DIM)

    embedded_edge_features = BondEncoder(edge_feature_columns,
                                         EDGE_EMBEDDING_DIM)

    neighbor_features = lbann.Gather(embedded_node_features,
                                     target_nodes,
                                     axis=0)
    neighbor_feature_mat = lbann.Reshape(neighbor_features,
                                         dims=neighbor_feature_dims)
    return \
    embedded_node_features, neighbor_feature_mat, embedded_edge_features, source_nodes, label
コード例 #4
0
        loss = lbann.Negative(loss, data_layout=self.data_layout)
        return loss


# ----------------------------------------------
# Build and Run Model
# ----------------------------------------------
with open("./config.json") as f:
    config = json.load(f, object_hook=lambda d: SimpleNamespace(**d))
config.input_shape = (16, 32)
config.load_weights = os.path.exists('./pretrained_weights')

# Construct the model
input_ = lbann.Slice(
    lbann.Input(data_field="samples"),
    slice_points=str_list([0, 1, 1 + np.prod(config.input_shape)]),
)
labels = lbann.Identity(input_)
sample = lbann.Reshape(input_, dims=str_list(config.input_shape))
roberta = RobertaModel(config, load_weights=config.load_weights)
out = roberta(sample)
out = lbann.ChannelwiseFullyConnected(out, output_channel_dims=[1000])
loss = CrossEntropyLoss(10, data_layout="model_parallel")
obj = loss(out, labels)
metrics = [lbann.Metric(obj, name="loss")]

model = lbann.Model(
    lbann_params.epochs,
    layers=lbann.traverse_layer_graph(input_),
    objective_function=obj,
    metrics=metrics,
コード例 #5
0
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf,
                                   lambda_cyc, useCNN, dump_models,
                                   pretrained_dir, ltfb_batch_interval,
                                   num_epochs):
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_network_architectures.MACCWAE(
        zdim, ydim, cf=wae_mcf, use_CNN=useCNN)  #pretrained, freeze
    inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf)
    fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(input_fake, gt_x)
    L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x)

    L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd)
    L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd)

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y)
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {lambda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {lambda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    #Freeze appropriate (pretrained) weights
    pretrained_models = ["wae"]  #add macc?
    for l in layers:
        for idx in range(len(pretrained_models)):
            if (l.weights and pretrained_models[idx] in l.name):
                for w in range(len(l.weights)):
                    l.weights[w].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)

    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='fw_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackLoadModel(dirs=str(pretrained_dir)),
        lbann.CallbackTimer()
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='fw_loss',
                               low_score_wins=True,
                               exchange_hyperparameters=True))
    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
コード例 #6
0
ファイル: fftshift.py プロジェクト: benson31/lbann
    input_ = lbann.Input()

    # NumPy implementation
    dims = [2, 3, 4, 7]
    np_x = np.random.uniform(size=dims).astype(np.float32)
    np_y = np.zeros_like(np_x)
    for i in range(dims[0]):
        np_y[i] = np.fft.fftshift(np_x[i])
    np_scales = np.random.uniform(size=np.prod(dims)).astype(np.float32)
    np_z = np.inner(np_y.flatten(), np_scales).item()
    tol = 8 * np_z * np.finfo(np.float32).eps

    # LBANN implementation
    lbann_x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_x.flatten())), ),
        dims=str_list(np_x.shape),
    )
    lbann_y = FFTShift()(lbann_x, dims)
    lbann_scales = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_scales)),
            optimizer=lbann.NoOptimizer(),
        ),
        dims=str_list(np_scales.shape),
    )
    lbann_z = lbann.MatMul(lbann.Reshape(lbann_y, dims=str_list([1, -1])),
                           lbann.Reshape(lbann_scales, dims=str_list([-1, 1])))

    # Construct LBANN model with metric checking and gradient checking
    metric = lbann.Metric(lbann_z, name='metric')
コード例 #7
0
    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
    ):

        if position_ids is None:
            if input_ids is not None:
                position_ids = create_position_ids_from_input_ids(
                    input_ids,
                    self.input_shape,
                    self.padding_idx,
                )
            else:
                position_ids = self.create_position_ids_from_inputs_embeds(
                    inputs_embeds)

        if token_type_ids is None:
            token_type_ids = lbann.Constant(value=0,
                                            num_neurons=str_list(
                                                self.input_shape))

        if inputs_embeds is None:
            inputs_embeds = lbann.Embedding(
                input_ids,
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "word_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "word_embeddings")),
            )
        token_type_embeddings = lbann.Embedding(
            token_type_ids,
            num_embeddings=self.type_vocab_size,
            embedding_dim=self.hidden_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "token_type_embeddings.weight")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "token_type_embeddings")),
        )

        embeddings = lbann.Add(inputs_embeds, token_type_embeddings)
        if self.position_embedding_type == "absolute":
            position_embeddings = lbann.Embedding(
                position_ids,
                num_embeddings=self.max_position_embeddings,
                embedding_dim=self.hidden_size,
                padding_idx=self.pad_token_id,
                weights=_load_pretrained_weights(
                    ".".join((self.name, "position_embeddings.weight")),
                    load_weights=self.load_weights,
                ),
                name=".".join((self.name, "position_embeddings")),
            )
            embeddings = lbann.Add(embeddings, position_embeddings)

        embeddings = lbann.modules.PytorchLayerNorm(
            embeddings,
            self.layer_norm_eps,
            self.input_shape + (self.hidden_size, ),
            weights=_load_pretrained_weights(
                ".".join((self.name, "layernorm.weightbias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "LayerNorm")),
        )
        embeddings = lbann.Dropout(embeddings,
                                   keep_prob=self.hidden_dropout_prob)
        return embeddings
コード例 #8
0
def make_model(
    num_epochs,
    embed_dim,
    num_heads,
    label_smoothing,
):

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.Transformer(
        hidden_size=embed_dim,
        num_heads=num_heads,
        name='transformer',
    )
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
コード例 #9
0
ファイル: transformer.py プロジェクト: benson31/lbann
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.inner_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(queries_fc,
                                    axis=1,
                                    slice_points=slice_points,
                                    name=f'{name}_queries_slice',
                                    parallel_strategy={
                                        'sub_branch_tag': 0,
                                        'enable_subgraph': ENABLE_SUBGRAPH
                                    })
        keys_slice = lbann.Slice(keys_fc,
                                 axis=1,
                                 slice_points=slice_points,
                                 name=f'{name}_keys_slice',
                                 parallel_strategy={
                                     'sub_branch_tag': 0,
                                     'enable_subgraph': ENABLE_SUBGRAPH
                                 })
        values_slice = lbann.Slice(values_fc,
                                   axis=1,
                                   slice_points=slice_points,
                                   name=f'{name}_values_slice',
                                   parallel_strategy={
                                       'sub_branch_tag': 0,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })

        # Compute scaled dot-product attention for each head
        attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs

            if (ENABLE_SUBGRAPH):
                if (head % int(self.num_heads / BRANCHES) == 0):
                    tag += 1

                q = lbann.Identity(queries_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                k = lbann.Identity(keys_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
                v = lbann.Identity(values_slice,
                                   parallel_strategy={
                                       'sub_branch_tag': tag,
                                       'enable_subgraph': ENABLE_SUBGRAPH
                                   })
            else:
                q = lbann.Identity(queries_slice)
                k = lbann.Identity(keys_slice)
                v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim

            attentions.append(lbann.MatMul(y, v, name=head_name))

            #Strong scaling

        # Concatenate heads and apply fully-connected layer
        if (ENABLE_SUBGRAPH):
            attentions = lbann.Concatenation(attentions,
                                             axis=1,
                                             name=f'{name}_heads_concat',
                                             parallel_strategy={
                                                 'sub_branch_tag': 0,
                                                 'enable_subgraph':
                                                 ENABLE_SUBGRAPH
                                             })
        else:
            attentions = lbann.Concatenation(
                attentions,
                axis=1,
                name=f'{name}_heads_concat',
            )

        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
コード例 #10
0
    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
    ):
        mixed_query_layer, query_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "query.weight")),
                ".".join((self.name, "query.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "query")),
            return_dims=True,
        )
        query_layer, query_shape = self.transpose_for_scores(
            mixed_query_layer, query_shape)

        key_layer, key_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "key.weight")),
                ".".join((self.name, "key.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "key")),
            return_dims=True,
        )
        key_layer, key_shape = self.transpose_for_scores(key_layer, key_shape)

        value_layer, value_shape = lbann.modules.PytorchLinear(
            hidden_states,
            self.input_shape,
            self.all_head_size,
            weights=_load_pretrained_weights(
                ".".join((self.name, "value.weight")),
                ".".join((self.name, "value.bias")),
                load_weights=self.load_weights,
            ),
            name=".".join((self.name, "value")),
            return_dims=True,
        )
        value_layer, value_shape = self.transpose_for_scores(
            value_layer, value_shape)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        key_layer, key_shape = lbann.modules.Permute(key_layer,
                                                     key_shape,
                                                     axes=(0, 1, -1, -2),
                                                     return_dims=True)
        attention_scores, attention_shape = lbann.modules.PytorchMatmul(
            query_layer,
            query_shape,
            key_layer,
            key_shape,
            return_dims=True,
        )

        attention_scores = lbann.Scale(attention_scores,
                                       constant=1 /
                                       math.sqrt(self.attention_head_size))

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
            attention_scores = lbann.Add(attention_scores, attention_mask)

        # Normalize the attention scores to probabilities.
        attention_scores = lbann.Reshape(
            attention_scores,
            dims=str_list([np.prod(attention_shape[:-1]),
                           attention_shape[-1]]),
        )
        attention_probs = lbann.ChannelwiseSoftmax(attention_scores)
        attention_probs = lbann.Reshape(attention_probs,
                                        dims=str_list(attention_shape))

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = lbann.Dropout(
            attention_probs,
            keep_prob=self.attention_probs_dropout_prob,
        )

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = lbann.Multiply(attention_probs, head_mask)

        context_layer, context_shape = lbann.modules.PytorchMatmul(
            attention_probs,
            attention_shape,
            value_layer,
            value_shape,
            return_dims=True,
        )
        context_layer, context_shape = lbann.modules.Permute(
            context_layer,
            context_shape,
            axes=(0, 2, 1, 3),
            return_dims=True,
        )
        new_context_layer_shape = context_shape[:-2] + (self.all_head_size, )
        context_layer = lbann.Reshape(context_layer,
                                      dims=str_list(self.input_shape))

        return context_layer
コード例 #11
0
ファイル: transformer.py プロジェクト: benson31/lbann
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        ENABLE_SUBGRAPH = self.ENABLE_SUBGRAPH
        BRANCHES = self.BRANCHES
        if (ENABLE_SUBGRAPH):
            if (self.num_heads % BRANCHES != 0):
                raise ValueError('Num heads should be divisible by BRANCHES')
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = []
        keys_fc = []
        values_fc = []

        # Slice embedding vectors for each head
        slice_points = str_list(
            self.head_dim * i
            for i in range(int(self.num_heads / self.BRANCHES) + 1))

        #Queries strong scaling in CFC
        attentions = []
        for count, query in enumerate(queries):
            temp = lbann.ChannelwiseFullyConnected(
                query,
                weights=self.query_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_queries_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_queries_slice',
            )

            queries_fc.append(temp)

        #keys strong scaling in CFC

        attentions = []
        for count, key in enumerate(keys):
            temp = lbann.ChannelwiseFullyConnected(
                key,
                weights=self.key_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_keys_fc',
            )

            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):

            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_keys_slice',
            )

            keys_fc.append(temp)

        #Values strong scaling in CFC
        attentions = []

        for count, value in enumerate(values):
            temp = lbann.ChannelwiseFullyConnected(
                value,
                weights=self.value_weights[count],
                output_channel_dims=[self.inner_dim],
                name=f'{name}_subgrid{count}_values_fc',
            )
            attentions.append(temp)

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        for head in range(self.BRANCHES):
            temp = lbann.Slice(
                attentions[head],
                axis=1,
                slice_points=slice_points,
                name=f'{name}_subgrid{head}_values_slice',
            )
            values_fc.append(temp)

        queries_slice = []
        keys_slice = []
        values_slice = []

        for branch in range(self.BRANCHES):
            querie_slice = queries_fc[branch]
            key_slice = keys_fc[branch]
            value_slice = values_fc[branch]

            for head in range(int(self.num_heads / self.BRANCHES)):
                queries_slice.append(lbann.Identity(querie_slice))
                keys_slice.append(lbann.Identity(key_slice))
                values_slice.append(lbann.Identity(value_slice))

        # Compute scaled dot-product attention for each head
        attentions = []

        #variable to combine heads locally in sub-grids
        temp_attentions = []
        tag = 0
        for head in range(self.num_heads):
            head_name = f'{name}_myattention_head{head}'

            # Attention inputs
            if (head % int(self.num_heads / BRANCHES) == 0):
                temp_attentions.append([])
                tag += 1

            q = lbann.Identity(queries_slice[head])
            k = lbann.Identity(keys_slice[head])
            v = lbann.Identity(values_slice[head])

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )

            if (ENABLE_SUBGRAPH):
                if mask != None:
                    y = lbann.Sum([y, mask[tag]], name=f'{head_name}_mask')
            else:
                if mask:
                    y = lbann.Sum([y, mask], name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            y = lbann.MatMul(y, v, name=head_name)
            # attentions.append(lbann.MatMul(y, v, name=head_name))

            temp_attentions[-1].append(y)

        for count, temp_attention in enumerate(temp_attentions):

            if (self.BRANCHES == self.num_heads):
                # No need to concat the heads at subgrid level
                # if number of subgrids is equal to number of heads
                attention_single_subgrid = temp_attentions[count][0]
            else:
                attention_single_subgrid = lbann.Concatenation(
                    temp_attention,
                    axis=1,
                    name=f'{name}_subgrid_heads_concat{count}',
                    parallel_strategy={
                        'sub_branch_tag': 0,
                        'enable_subgraph': False
                    })

            attention_single_subgrid = lbann.ChannelwiseFullyConnected(
                attention_single_subgrid,
                weights=self.output_weights[count],
                output_channel_dims=[self.embed_dim],
                name=f'{name}_cfc_{count}',
            )

            attentions.append(attention_single_subgrid)

        #Strong scaling

        grid_sum_slice = lbann.Cross_Grid_Sum_Slice(attentions)

        attentions = []

        for head in range(self.BRANCHES):
            attentions.append(lbann.Identity(grid_sum_slice))

        return attentions
コード例 #12
0
    def __init__(self,
                 input_channels,
                 output_channels,
                 bias=True,
                 activation=lbann.Relu,
                 name=None,
                 data_layout='data_parallel'):
        """Initialize Graph layer

        Args: 
            input_channels (int): The size of the input node features 
            output_channels (int): The output size  of the node features 
            bias (bool): Whether to apply biases after MatMul 
            name (str): Default name of the layer is GCN_{number}
            data_layout (str): Data layout
            activation (type): Activation layer for the node features. If None, then no activation is 
                                applied. (default: lbann.Relu)
        """
        super().__init__()

        ## Add variables

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.data_layout = data_layout

        ## Add Name for the components for the layer
        GraphConv.global_count += 1
        self.name = (name
                     if name else 'Graph_{}'.format(GraphConv.global_count))

        ## Initialize weights for the matrix
        value = math.sqrt(6 / (input_channels + output_channels))

        self.mat_weights = lbann.Weights(initializer=lbann.UniformInitializer(
            min=-value, max=value),
                                         name=self.name + '_Weights')

        self.weights1 = lbann.WeightsLayer(dims=str_list(
            [input_channels, output_channels]),
                                           name=self.name + '_layer',
                                           weights=self.mat_weights)

        self.id_weights = lbann.Weights(initializer=lbann.UniformInitializer(
            min=-value, max=value),
                                        name=self.name + '_ID_Weights')

        self.weights2 = lbann.WeightsLayer(dims=str_list(
            [input_channels, output_channels]),
                                           name=self.name + '_ID_layer',
                                           weights=self.id_weights)

        ## Initialize bias variables
        self.has_bias = bias
        self.bias_weights = None
        self.bias = None

        if (self.has_bias):
            self.bias_weights = lbann.Weights(
                initializer=lbann.ConstantInitializer(value=0.0),
                name=self.name + '_bias_weights')
            self.bias = lbann.WeightsLayer(dims=str_list([1, output_channels]),
                                           weights=self.bias_weights,
                                           name=self.name + '_bias_layer')

        self.activation = None

        if activation:
            if isinstance(activation, type):
                self.activation = activation
            else:
                self.activation = type(actvation)
            if not issubclass(self.activation, lbann.Layer):
                raise ValueError('activation must be a layer')
コード例 #13
0
ファイル: train_atom_char_rnn.py プロジェクト: benson31/lbann
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular SMILES generation
    Network architecture and training hyperparameters from
    https://github.com/samadejacobs/moses/tree/master/moses/char_rnn

    """

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"

    # Layer graph
    _input = lbann.Input(name="inp_tensor", data_field='samples')
    print(sequence_length)
    x_slice = lbann.Slice(
        _input,
        axis=0,
        slice_points=str_list(range(sequence_length + 1)),
        name="inp_slice",
    )

    # embedding layer
    emb = []
    embedding_dim = run_args.embedding_dim
    num_embeddings = run_args.num_embeddings
    assert embedding_dim is not None
    assert num_embeddings is not None

    emb_weights = lbann.Weights(
        initializer=lbann.NormalInitializer(mean=0, standard_deviation=1),
        name="emb_matrix",
    )

    lstm1 = lbann.modules.GRU(size=run_args.hidden, data_layout=data_layout)
    fc = lbann.modules.FullyConnectedModule(size=num_embeddings,
                                            data_layout=data_layout)

    last_output = lbann.Constant(
        value=0.0,
        num_neurons="{}".format(run_args.hidden),
        data_layout=data_layout,
        name="lstm_init_output",
    )

    lstm1_prev_state = [last_output]

    loss = []
    idl = []
    for i in range(sequence_length):
        idl.append(
            lbann.Identity(x_slice, name="slice_idl_" + str(i), device="CPU"))

    for i in range(sequence_length - 1):

        emb_l = lbann.Embedding(
            idl[i],
            name="emb_" + str(i),
            weights=emb_weights,
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
        )

        x, lstm1_prev_state = lstm1(emb_l, lstm1_prev_state)
        fc_l = fc(x)
        y_soft = lbann.Softmax(fc_l, name="soft_" + str(i))
        gt = lbann.OneHot(idl[i + 1], size=num_embeddings)
        ce = lbann.CrossEntropy([y_soft, gt], name="loss_" + str(i))
        # mask padding in input
        pad_mask = lbann.NotEqual(
            [idl[i], lbann.Constant(value=pad_index, num_neurons="1")], )
        ce_mask = lbann.Multiply([pad_mask, ce], name="loss_mask_" + str(i))
        loss.append(lbann.LayerTerm(ce_mask, scale=1 / (sequence_length - 1)))

    layers = list(lbann.traverse_layer_graph(_input))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    obj = lbann.ObjectiveFunction(loss)

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackStepLearningRate(step=run_args.step_size,
                                       amt=run_args.gamma),
        lbann.CallbackDumpWeights(directory=run_args.dump_weights_dir,
                                  epoch_interval=1),
    ]

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       layers=layers,
                       weights=weights,
                       objective_function=obj,
                       callbacks=callbacks)
コード例 #14
0
    def forward(self, x, prev_state):
        """Apply GRU step.

        Args:
            x (Layer): Input.
            prev_state: State from previous GRU step.

        Returns:
            (Layer, Layer): The output (out)  and state (hn).
                          The state can be passed directly into
                           the next GRU step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)


        fc1 = self.ih_fc(x)   #input_fc
        fc2 = self.hh_fc(prev_state)  #hidden_fc


        # Get gates and cell update
        fc1_slice = lbann.Slice(fc1,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc1_slice',
                            data_layout=self.data_layout)
        Wir_x = lbann.Identity(fc1_slice, name=name + '_Wrx',
                           data_layout=self.data_layout)
        Wiz_x = lbann.Identity(fc1_slice, name=name + '_Wzx',
                           data_layout=self.data_layout)
        Win_x = lbann.Identity(fc1_slice, name=name + '_Wnx',
                           data_layout=self.data_layout)

        fc2_slice = lbann.Slice(fc2,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_fc2_slice',
                            data_layout=self.data_layout)
        Whr_prev = lbann.Identity(fc2_slice, name=name + '_Wrh',
                           data_layout=self.data_layout)
        Whz_prev = lbann.Identity(fc2_slice, name=name + '_Wzh',
                           data_layout=self.data_layout)
        Whn_prev = lbann.Identity(fc2_slice, name=name + '_Wnh',
                           data_layout=self.data_layout)

        rt = \
            lbann.Sigmoid(
                lbann.Add(Wir_x, Whr_prev, data_layout=self.data_layout),
                name=name + '_reset_gate',
                data_layout=self.data_layout
            )

        zt = \
            lbann.Sigmoid(
                lbann.Add(Wiz_x, Whz_prev, data_layout=self.data_layout),
                name=name + '_update_gate',
                data_layout=self.data_layout,
            )

        nt = \
            lbann.Tanh(
                lbann.Add(
                    Win_x,
                    lbann.Multiply(rt, Whn_prev, data_layout=self.data_layout),
                    data_layout=self.data_layout,
                ),
                name=name + '_new_gate', data_layout=self.data_layout,
            )

        ht = \
            lbann.Add(
                lbann.Multiply(
                    lbann.WeightedSum(
                        self.ones,
                        zt,
                        scaling_factors='1 -1', data_layout=self.data_layout
                    ),
                    nt,
                    data_layout=self.data_layout
                ),
                lbann.Multiply(zt, prev_state, data_layout=self.data_layout),
                name=name+ '_output', data_layout=self.data_layout,
            )

        # Return output
        return ht, ht
コード例 #15
0
def construct_model():
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    import lbann

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list(
                                [0, args.ydim, args.ydim + args.xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_models.MACCWAE(args.zdim,
                              args.ydim,
                              cf=args.wae_mcf,
                              use_CNN=args.useCNN)  #pretrained, freeze
    inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf)
    fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    y_out = wae.decoder(y_pred_fwd)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(
        input_fake, gt_x)  #(x,inv(enc(y)), (encoder+)inverse loss
    L_cyc_x = lbann.MeanSquaredError(
        input_cyc, gt_x)  #param, x cycle loss, from latent space

    L_l2_y = lbann.MeanSquaredError(
        output_fake, y_pred_fwd)  #pred error into latent space (enc(y),fw(x))
    L_cyc_y = lbann.MeanSquaredError(
        output_cyc,
        y_pred_fwd)  # pred error into latent space (enc(y), fw(inv(enc(y))))

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(
        y_image_re,
        gt_y)  # (y,dec(fw(x))) #forward model to decoder, no latent space
    dec_fw_inv_enc_y = lbann.MeanSquaredError(
        y_image_re2, gt_y)  #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y'
    wae_loss = lbann.MeanSquaredError(y_out, gt_y)  #(y, dec(enc(y)) '
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    conc_out = lbann.Concatenation(
        [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x],
        name='x_errors')
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    for l in layers:
        weights.update(l.weights)

    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='img_re1'),
        lbann.Metric(dec_fw_inv_enc_y, name='img_re2'),
        lbann.Metric(wae_loss, name='wae_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackDumpOutputs(layers=f'{conc_out.name}',
                                  execution_modes='test',
                                  directory=args.dump_outputs,
                                  batch_interval=1,
                                  format='npy'),
        lbann.CallbackTimer()
    ]

    # Construct model
    num_epochs = 1
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       serialize_io=True,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
コード例 #16
0
    def forward(self, x, prev_state):
        """Apply LSTM step.

        Args:
            x (Layer): Input.
            prev_state (tuple with two `Layer`s): State from previous
                LSTM step. Comprised of LSTM output and cell state.

        Returns:
            (Layer, (Layer, Layer)): The output and state (the output
                and cell state). The state can be passed directly into
                the next LSTM step.

        """
        self.step += 1
        name = '{0}_step{1}'.format(self.name, self.step)

        # Get output and cell state from previous step
        prev_output, prev_cell = prev_state

        # Apply linearity
        input_concat = lbann.Concatenation(x, prev_output,
                                           name=name + '_input',
                                           data_layout=self.data_layout)
        fc = self.fc(input_concat)

        # Get gates and cell update
        slice = lbann.Slice(fc,
                            slice_points=str_list([0, self.size, 4*self.size]),
                            name=name + '_fc_slice',
                            data_layout=self.data_layout)
        cell_update = lbann.Tanh(slice,
                                 name=name + '_cell_update',
                                 data_layout=self.data_layout)
        sigmoid = lbann.Sigmoid(slice,
                                name=name + '_sigmoid',
                                data_layout=self.data_layout)
        slice = lbann.Slice(sigmoid,
                            slice_points=str_list([0, self.size, 2*self.size, 3*self.size]),
                            name=name + '_sigmoid_slice',
                            data_layout=self.data_layout)
        f = lbann.Identity(slice, name=name + '_forget_gate',
                           data_layout=self.data_layout)
        i = lbann.Identity(slice, name=name + '_input_gate',
                           data_layout=self.data_layout)
        o = lbann.Identity(slice, name=name + '_output_gate',
                           data_layout=self.data_layout)

        # Cell state
        cell_forget = lbann.Multiply(f, prev_cell,
                                     name=name + '_cell_forget',
                                     data_layout=self.data_layout)
        cell_input = lbann.Multiply(i, cell_update,
                                    name=name + '_cell_input',
                                    data_layout=self.data_layout)
        cell = lbann.Add(cell_forget, cell_input, name=name + '_cell',
                         data_layout=self.data_layout)

        # Output
        cell_act = lbann.Tanh(cell, name=name + '_cell_activation',
                              data_layout=self.data_layout)
        output = lbann.Multiply(o, cell_act, name=name,
                                data_layout=self.data_layout)

        # Return output and state
        return output, (output, cell)
コード例 #17
0
ファイル: eval_atom_wae_dec.py プロジェクト: oyamay/lbann
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    #sequence_length = run_args.sequence_length
    sequence_length = 102
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(target_mode='N/A',name='inp_data')
    inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z = lbann.Identity(inp_slice, name='z') #param not used
    #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1')
    vae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    #print("Inp smile len ", len(inp_smile), "z len ",  len(z))
    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128")
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )

    
    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, 
                       execution_modes='test', 
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
コード例 #18
0
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models,
                            ltfb_batch_interval, num_epochs):
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + 5]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    model = macc_network_architectures.MACCWAE(zdim,
                                               ydim,
                                               cf=mcf,
                                               use_CNN=useCNN)
    d1_real, d1_fake, d_adv, pred_y = model(z, gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')
    img_loss = lbann.MeanSquaredError([pred_y, gt_y])
    rec_error = lbann.L2Norm2(
        lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01)
    obj = lbann.ObjectiveFunction(
        [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]
    #pred_y = macc_models.MACCWAE.pred_y_name
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackPrintModelDescription(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2)
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='recon_error',
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
コード例 #19
0
def make_model(num_vertices=None,
               node_features=None,
               num_classes=None,
               kernel_type='GCN',
               callbacks=None,
               num_epochs=1):
    '''Construct a model DAG using one of the Graph Kernels

    Args:
        num_vertices (int): Number of vertices of each graph (default: None)
        node_features (int): Number of features per noded (default: None)
        num_classes (int): Number of classes as targets (default: None)

        kernel_type (str): Graph Kernel to use in model. Expected one of
                            GCN, GIN, Graph, or GatedGraph (deafult: GCN)
        callbacks (list): Callbacks for the model. If set to None the model description,
                          GPU usage, training_output, and timer is reported.
                          (default: None)
        num_epochs (int): Number of epochs to run (default: 1)
    Returns:
        (lbann.Model) : A model object with the supplied callbacks, dataset
                               presets, and graph kernels.
    '''

    num_vertices = 100
    num_classes = 2
    node_feature_size = 3
    max_edges = 415

    #----------------------------------
    # Reshape and Slice Input Tensor
    #----------------------------------

    input_ = lbann.Input(data_field='samples')

    # Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )

    data = Graph_Data_Parser(input_, num_vertices, node_feature_size,
                             max_edges, num_classes)

    feature_matrix = data['node_features']
    source_indices = data['source_indices']
    target_indices = data['target_indices']
    target = data['target']

    #----------------------------------
    # Select Graph Convolution
    #----------------------------------

    output_channels = 16
    graph_kernel_op = None
    if kernel_type == 'GIN':
        graph_kernel_op = GINConvLayer
    elif kernel_type == 'GCN':
        graph_kernel_op = GCNConvLayer
    elif kernel_type == 'Graph':
        graph_kernel_op = GraphConvLayer
    elif kernel_type == 'GatedGraph':
        graph_kernel_op = GATConvLayer
    else:
        raise ValueError(
            'Invalid Graph kernel specifier "{}" recieved. Expected one of:\
                    GIN,GCN,Graph or GatedGraph'.format(kernel_type))
    #----------------------------------
    # Perform Graph Convolution
    #----------------------------------

    x = graph_kernel_op(feature_matrix, source_indices, target_indices,
                        num_vertices, max_edges, node_feature_size,
                        output_channels)
    #----------------------------------
    # Apply Reduction on Node Features
    #----------------------------------

    average_vector = lbann.Constant(value=1 / num_vertices,
                                    num_neurons=str_list([1, num_vertices]),
                                    name="Average_Vector")

    x = lbann.MatMul(average_vector, x, name="Node_Feature_Reduction")

    # X is now a vector with output_channel dimensions

    x = lbann.Reshape(x, dims=str_list([output_channels]), name="Squeeze")
    x = lbann.FullyConnected(x, num_neurons=64, name="hidden_layer_1")
    x = lbann.Relu(x, name="hidden_layer_1_activation")
    x = lbann.FullyConnected(x,
                             num_neurons=num_classes,
                             name="Output_Fully_Connected")

    #----------------------------------
    # Loss Function and Accuracy s
    #----------------------------------

    probs = lbann.Softmax(x, name="Softmax")
    loss = lbann.CrossEntropy(probs, target, name="Cross_Entropy_Loss")
    accuracy = lbann.CategoricalAccuracy(probs, target, name="Accuracy")

    layers = lbann.traverse_layer_graph(input_)

    if callbacks is None:
        print_model = lbann.CallbackPrintModelDescription(
        )  #Prints initial Model after Setup
        training_output = lbann.CallbackPrint(
            interval=1,
            print_global_stat_only=False)  #Prints training progress
        gpu_usage = lbann.CallbackGPUMemoryUsage()
        timer = lbann.CallbackTimer()
        callbacks = [print_model, training_output, gpu_usage, timer]
    else:
        if isinstance(callbacks, list):
            callbacks = callbacks

    metrics = [lbann.Metric(accuracy, name='accuracy', unit="%")]

    model = lbann.Model(num_epochs,
                        layers=layers,
                        objective_function=loss,
                        metrics=metrics,
                        callbacks=callbacks)
    return model
コード例 #20
0
    def forward(self, queries, keys, values, mask=None):
        """Apply multi-head attention.

        The input and output tensors are interpreted as sequences of
        vectors, where the first tensor dimension is the sequence
        dimension.

        Args:
            queries (lbann.Layer): Sequence of query vectors.
            keys (lbann.Layer): Sequence of key vectors.
            values (lbann.Layer): Sequence of value vectors.
            mask (lbann.Layer, optional): Additive attention mask. If
                the (i,j) entry is very negative (e.g. -1e9), then the
                ith query does not attend to the jth key/value pair.

        Returns:
            lbann.Layer: Sequence of output vectors. The sequence
                length is the same as `queries`.

        """
        self.instance += 1
        name = f'{self.name}_instance{self.instance}'

        # Apply fully-connected layers to input sequences
        queries_fc = lbann.ChannelwiseFullyConnected(
            queries,
            weights=self.query_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_queries_fc',
        )
        keys_fc = lbann.ChannelwiseFullyConnected(
            keys,
            weights=self.key_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_keys_fc',
        )
        values_fc = lbann.ChannelwiseFullyConnected(
            values,
            weights=self.value_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}_values_fc',
        )

        # Slice embedding vectors for each head
        slice_points = str_list(self.head_dim * i
                                for i in range(self.num_heads + 1))
        queries_slice = lbann.Slice(
            queries_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_queries_slice',
        )
        keys_slice = lbann.Slice(
            keys_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_keys_slice',
        )
        values_slice = lbann.Slice(
            values_fc,
            axis=1,
            slice_points=slice_points,
            name=f'{name}_values_slice',
        )

        # Compute scaled dot-product attention for each head
        attentions = []
        for head in range(self.num_heads):
            head_name = f'{name}_head{head}'

            # Attention inputs
            q = lbann.Identity(queries_slice)
            k = lbann.Identity(keys_slice)
            v = lbann.Identity(values_slice)

            # Multiply queries and keys
            # Note: num_queries x num_keys
            y = lbann.MatMul(
                q,
                k,
                transpose_b=True,
                name=f'{head_name}_matmul',
            )
            y = lbann.WeightedSum(
                y,
                scaling_factors=str(1 / math.sqrt(self.head_dim)),
                name=f'{head_name}_scale',
            )
            if mask:
                y = lbann.Add(y, mask, name=f'{head_name}_mask')
            y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax')

            # Attention output
            # Note: num_queries x head_dim
            attentions.append(lbann.MatMul(y, v, name=head_name))

        # Concatenate heads and apply fully-connected layer
        attentions = lbann.Concatenation(attentions,
                                         axis=1,
                                         name=f'{name}_heads_concat')
        outputs_fc = lbann.ChannelwiseFullyConnected(
            attentions,
            weights=self.output_weights,
            output_channel_dims=[self.embed_dim],
            name=f'{name}',
        )
        return outputs_fc
コード例 #21
0
ファイル: Dense_Graph_Trainer.py プロジェクト: oyamay/lbann
def make_model(num_vertices = None, 
               node_features = None, 
               num_classes = None,
               dataset = None,
               kernel_type = 'GCN',
               callbacks = None,
               num_epochs = 1):
    '''Construct a model DAG using one of the Graph Kernels

    Args:
        num_vertices (int): Number of vertices of each graph (default: None) 
        node_features (int): Number of features per noded (default: None)
        num_classes (int): Number of classes as targets (default: None)
        dataset (str): Preset data set to use. Either a datset parameter has to be 
                       supplied or all of num_vertices, node_features, and 
                       num_classes have to be supplied. (default: None) 
        kernel_type (str): Graph Kernel to use in model. Expected one of 
                            GCN, or Graph (deafult: GCN)
        callbacks (list): Callbacks for the model. If set to None the model description, 
                          GPU usage, training_output, and timer is reported. 
                          (default: None)                    
        num_epochs (int): Number of epochs to run (default: 1)
    Returns:
        (lbann Model Object: A model object with the supplied callbacks, dataset
                               presets, and graph kernels. 
    '''   
    
    assert num_vertices != dataset #Ensure atleast one of the values is set 

    if dataset is not None:
        assert num_vertices is None

        if dataset == 'MNIST':
            num_vertices = 75
            num_classes = 10
            node_features = 1

        elif dataset == 'PROTEINS':
            num_vertices = 100
            num_classes = 2
            node_features = 3
        else:
            raise Exception("Unkown Dataset")

    assert num_vertices is not None
    assert num_classes is not None 
    assert node_features is not None 
    

    #----------------------------------
    # Reshape and Slice Input Tensor 
    #----------------------------------

    input_ = lbann.Input(target_mode = 'classification')

    # Input dimensions should be (num_vertices * node_features + num_vertices^2 + num_classes )
    # input should have atleast two children since the target is classification 
    
    sample_dims = num_vertices*node_features + (num_vertices ** 2) + num_classes
    graph_dims = num_vertices*node_features + (num_vertices ** 2)
    feature_matrix_size = num_vertices * node_features 
   
    graph_input = lbann.Slice(input_, axis = 0 , 
                              slice_points = str_list([0,feature_matrix_size,graph_dims, sample_dims]),
                              name = "Graph_Input") 

    
    feature_matrix = lbann.Reshape(graph_input, 
                                   dims = str_list([num_vertices, node_features]), 
                                   name="Node_features")
    
    adj_matrix = lbann.Reshape(graph_input,
                               dims = str_list([num_vertices,num_vertices]), 
                               name="Adj_Mat") 

    target = lbann.Identity(graph_input, name="Target")
    target = lbann.Reshape(target, dims=str(num_classes))    
   
    #----------------------------------
    # Perform Graph Convolution
    #----------------------------------
    
    if kernel_type == 'GCN':
        x = DGCN_layer(feature_matrix, adj_matrix, node_features)
    elif kernel_type == 'Graph':
        x = DGraph_Layer(feature_matrix, adj_matrix, node_features)
    else:
        ValueError('Invalid Graph kernel specifier "{}" recieved. Expected one of:\
                    GCN or Graph'.format(kernel_type)) 
    out_channel = 256    
    #----------------------------------
    # Apply Reduction on Node Features
    #----------------------------------

    average_vector = lbann.Constant(value = 1/num_vertices, num_neurons = str_list([1,num_vertices]), name="Average_Vector")
    x = lbann.MatMul(average_vector,x, name="Node_Feature_Reduction") # X is now a vector with output_channel dimensions 
    
    x = lbann.Reshape(x, dims= str_list([out_channel]), name="Squeeze")
    x = lbann.FullyConnected(x, num_neurons=256, name="hidden_layer_1")
    x = lbann.Relu(x, name="hidden_layer_1_activation")
    x = lbann.FullyConnected(x, num_neurons=num_classes, name="Output_Fully_Connected")
    
    #----------------------------------
    # Loss Function and Accuracy s
    #----------------------------------
    
    
    probs = lbann.Softmax(x, name="Softmax")
    loss = lbann.CrossEntropy(probs, target, name="Cross_Entropy_Loss")
    accuracy = lbann.CategoricalAccuracy(probs, target, name="Accuracy")

    layers = lbann.traverse_layer_graph(input_)
    if callbacks is None:
        print_model = lbann.CallbackPrintModelDescription() #Prints initial Model after Setup
        training_output = lbann.CallbackPrint( interval = 1,
                           print_global_stat_only = False) #Prints training progress
        gpu_usage = lbann.CallbackGPUMemoryUsage()
        timer = lbann.CallbackTimer()
        callbacks = [print_model, training_output, gpu_usage, timer]
    else:
        if isinstance (callbacks, list):
            callbacks = callbacks   
    metrics = [lbann.Metric(accuracy, name='accuracy', unit="%")]

    model = lbann.Model(num_epochs, 
                       layers = layers,
                       objective_function = loss,
                       metrics = metrics, 
                       callbacks = callbacks
                       )
    return model
コード例 #22
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")

    x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, save_output)
    x_emb = lbann.Embedding(x,
                            num_embeddings=waemodel.dictionary_size,
                            embedding_dim=waemodel.embedding_size,
                            name='emb',
                            weights=waemodel.emb_weights)

    latentz = waemodel.forward_encoder(x_emb)

    fake_loss = lbann.MeanAbsoluteError(latentz, z)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)

    obj = lbann.ObjectiveFunction(fake_loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    #Dump output (activation) for post processing
    conc_out = lbann.Concatenation([input_, latentz], name='conc_out')
    callbacks.append(
        lbann.CallbackDumpOutputs(
            batch_interval=run_args.dump_outputs_interval,
            execution_modes='test',
            directory=run_args.dump_outputs_dir,
            layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       callbacks=callbacks)
コード例 #23
0
def construct_model(lbann):
    """Construct LBANN model.

    Args:
        lbann (module): Module for LBANN Python frontend

    """

    # ------------------------------------------
    # NumPy implementation
    # ------------------------------------------

    # Compute Kronecker factors
    x = _samples[:, :_in_size]
    dy = _samples[:, -_out_size:] / _num_samples
    A = np.matmul(x.T, x)
    G = np.matmul(dy.T, dy)

    # Apply one K-FAC step
    w = np.ones((_out_size, _in_size))
    dw = np.matmul(dy.T, x)
    dw = np.linalg.solve(A.T, dw.T).T
    dw = np.linalg.solve(G, dw)
    w_next = w - dw

    # ------------------------------------------
    # Basic model with conv layer
    # ------------------------------------------

    # Inputs and error signals are from data reader
    input_ = lbann.Input(data_field='samples')
    input_slice = lbann.Slice(input_,
                              slice_points=str_list(
                                  [0, _in_size, _in_size + _out_size]))
    x = lbann.Reshape(input_slice, dims=str_list([-1, 1, 1]))
    dy = lbann.Reshape(input_slice, dims=str_list([-1, 1, 1]))

    # Convolution layer
    w = lbann.Weights(
        initializer=lbann.ConstantInitializer(value=1),
        optimizer=lbann.SGD(learn_rate=1),
    )
    y = lbann.Convolution(
        x,
        weights=w,
        num_dims=2,
        num_output_channels=_out_size,
        num_groups=1,
        has_vectors=False,
        conv_dims_i=1,
        conv_pads_i=0,
        conv_strides_i=1,
        conv_dilations_i=1,
        has_bias=False,
    )
    obj = lbann.Reduction(lbann.Multiply(y, dy))

    # ------------------------------------------
    # Metric checking
    # ------------------------------------------

    # Extract weights entries
    w = lbann.WeightsLayer(
        weights=w,
        dims=str_list([_out_size, _in_size, 1, 1]),
    )
    w = lbann.Slice(
        lbann.Reshape(w, dims=str_list([-1])),
        slice_points=str_list(range(_out_size * _in_size + 1)),
    )
    w = [lbann.Identity(w) for _ in range(_out_size * _in_size)]

    # Metric checking with values from NumPy implementation
    metrics = []
    callbacks = []
    tol = 8 * np.finfo(np.float32).eps
    for i in range(_out_size * _in_size):
        val = w_next.flatten()[i]
        metrics.append(lbann.Metric(w[i], name=f'w[{i}]'))
        callbacks.append(
            lbann.CallbackCheckMetric(metric=metrics[-1].name,
                                      lower_bound=val - tol,
                                      upper_bound=val + tol,
                                      error_on_failure=True,
                                      execution_modes='test'))

    # ------------------------------------------
    # Construct model
    # ------------------------------------------

    num_epochs = 1
    return lbann.Model(
        num_epochs,
        layers=lbann.traverse_layer_graph([input_] + w),
        objective_function=obj,
        metrics=metrics,
        callbacks=callbacks,
    )