Exemple #1
0
    def forward(self, image, dims, max_r):
        """Compute radial profile.

        Args:
            image (lbann.Layer): Image
            dims (tuple of int): Image dimensions (dim 0 corresponds
                to channel)
            max_r (int): Maximum radial distance. Pixels outside this
                distance are ignored.

        Returns:
            Layer: num_channels x max_r radial profile

        """

        # Bin spatial positions
        r, r_counts = self._find_radial_bins(dims[1:], max_r)

        # Reciprocal of bin counts
        # Note: If a count is 0, its reciprocal is 0.
        r_counts_recip = [0 if c == 0 else 1 / c for c in r_counts]

        # Get scatter indices and scaling factors
        # Note: Independent binning for each channel (dim 0)
        tile_dims = [dims[0]] + [1] * r.ndim
        inds_vals = np.tile(r, tile_dims)
        inds_vals += np.arange(0, dims[0] * max_r, max_r).reshape(tile_dims)
        inds_vals[:, r >= max_r] = -1
        inds_vals = inds_vals.flatten()
        scales_vals = r_counts_recip * dims[0]

        # Construct LBANN layer graph
        image = lbann.Reshape(image, dims=str_list([np.prod(dims)]))
        inds = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(inds_vals)),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([len(inds_vals)]),
        )
        r_sums = lbann.Scatter(image, inds, dims=str_list([dims[0] * max_r]))
        scales = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(scales_vals)),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([len(scales_vals)]),
        )
        r_means = lbann.Multiply(scales, r_sums)
        return lbann.Reshape(r_means, dims=str_list([dims[0], max_r]))
Exemple #2
0
    def forward(self, x, dims):
        """Apply fftshift.

        Args:
            x (lbann.Layer): Input tensor
            dims (tuple of int): Dimensions of x (dim 0 corresponds to
                channel)

        Returns:
            Layer: Output tensor

        """

        # Get gather indices by applying fftshift to tensor filled with indices
        # Note: Independent fftshift for each channel (dim 0)
        spatial_size = np.prod(dims[1:])
        spatial_inds = np.arange(spatial_size).reshape(dims[1:])
        spatial_inds = np.fft.fftshift(spatial_inds)
        channel_offsets = np.arange(0, dims[0] * spatial_size, spatial_size)
        channel_offsets = channel_offsets.reshape([-1] +
                                                  [1] * spatial_inds.ndim)
        inds = np.expand_dims(spatial_inds, 0) + channel_offsets

        # Construct LBANN layer graph
        size = np.prod(dims)
        x = lbann.Reshape(x, dims=str_list([size]))
        inds = lbann.WeightsLayer(
            weights=lbann.Weights(
                lbann.ValueInitializer(values=str_list(inds.flatten())),
                optimizer=lbann.NoOptimizer(),
            ),
            dims=str_list([size]),
        )
        y = lbann.Gather(x, inds)
        return lbann.Reshape(y, dims=str_list(dims))
Exemple #3
0
def Permute(x, dims, axes=None, name="", return_dims=False):
    global _permute_cache
    key = (dims, axes)
    size = np.prod(dims)
    if key not in _permute_cache:
        # Construct gather indices
        inds = np.arange(size).reshape(dims, order="C").transpose(axes)
        inds = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(
                np.nditer(inds, order="C")), ),
            optimizer=lbann.NoOptimizer(),
        )
        inds = lbann.WeightsLayer(dims=str_list([size]), weights=inds)
        _permute_cache[key] = inds

    # Apply transpose with gather
    inds = _permute_cache[key]
    if axes == None:
        new_dims = dims[::-1]
    else:
        new_dims = np.array(dims)[list(axes)]
    x = lbann.Reshape(x, dims=str_list([size]))
    y = lbann.Gather(x, inds)
    y = lbann.Reshape(y, dims=str_list(list(new_dims)), name=name)

    if return_dims:
        return y, tuple(new_dims)
    return y
Exemple #4
0
def Cumsum(x, dims, axis=0):
    global _cumsum_cache

    if len(dims) != 2:
        raise RuntimeError("dims > 2 not tested/supported for cumsum")
    if (axis < 0) or (axis > 1):
        raise RuntimeError("Unsupported cumsum axis: {}".format(axis))
    shape = (dims[axis], dims[axis])
    if shape not in _cumsum_cache:
        tril_ones = np.tril(np.full(shape, 1, dtype=int), k=0)
        tril_ones = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(
                np.nditer(tril_ones, order="C")), ),
            optimizer=lbann.NoOptimizer(),
        )
        tril_ones = lbann.WeightsLayer(dims=str_list(shape), weights=tril_ones)
        _cumsum_cache[shape] = tril_ones

    # Apply cumsum
    tril_ones = _cumsum_cache[shape]
    if axis == 0:
        x = lbann.MatMul(tril_ones, x)
        return x
    if axis == 1:
        x = lbann.MatMul(x, tril_ones, transpose_b=True)
        return x
Exemple #5
0
def make_batch_script(trainer_params, model_params, script_params):

    #inference exe
    lbann_exe = abspath(lbann.lbann_exe())
    lbann_exe = join(dirname(lbann_exe), 'lbann_inf')

    # Create LBANN objects
    trainer = lbann.Trainer(mini_batch_size=trainer_params['mini_batch_size'])
    model = make_model(**model_params)
    # model.eval()
    reader = make_data_reader()

    # Optimizer with learning rate schedule
    # Note: Rough approximation of
    #   embed_dim^-0.5 * min(step^-0.5, step*warmup^-1.5)
    # with embed_dim=512 and warmup=4000.
    # opt = lbann.Adam(learn_rate=0.0001, beta1=0.9, beta2=0.98, eps=1e-9)
    opt = lbann.NoOptimizer()
    model.callbacks.append(
        lbann.CallbackDropFixedLearningRate(
            drop_epoch=[1],
            amt=2,
        ))
    model.callbacks.append(
        lbann.CallbackDropFixedLearningRate(
            drop_epoch=[2, 4, 8, 12],
            amt=0.75,
        ))

    # Checkpoint after every epoch
    # trainer.callbacks.append(
    #     lbann.CallbackCheckpoint(
    #         checkpoint_dir=os.path.join(script_params['work_dir'], 'checkpoint'),
    #         checkpoint_epochs=1,
    #     )
    # )

    # Dump weights after every epoch
    # model.callbacks.append(
    #     lbann.CallbackDumpWeights(
    #         basename=os.path.join(script_params['work_dir'], 'weights'),
    #         epoch_interval=1,
    #     )
    # )

    status = lbann.contrib.launcher.run(
        trainer,
        model,
        reader,
        opt,
        lbann_exe,
        nodes=script_params['nodes'],
        procs_per_node=script_params['procs_per_node'],
        time_limit=30,
        setup_only=False,
        batch_job=False,
    )
    # **kwargs)

    print(status)
Exemple #6
0
def random_projection(indices, num_projections, projection_dim):

    # Expand input indices to get an index for each vector entry
    # Note: proj_indices(i) = index*projection_dim + i
    proj_indices = lbann.WeightedSum(
        indices,
        scaling_factors=utils.str_list(projection_dim),
    )
    iota = lbann.WeightsLayer(
        dims=utils.str_list(projection_dim),
        weights=lbann.Weights(
            initializer=lbann.ValueInitializer(
                values=utils.str_list(range(projection_dim))),
            optimizer=lbann.NoOptimizer(),
        ),
    )
    proj_indices = lbann.Sum(
        lbann.Tessellate(
            lbann.Reshape(proj_indices,
                          dims=utils.str_list([num_projections, 1])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
        lbann.Tessellate(
            lbann.Reshape(iota, dims=utils.str_list([1, projection_dim])),
            dims=utils.str_list([num_projections, projection_dim]),
        ),
    )

    # Apply hash function and convert to Gaussian distribution
    proj = lbann.UniformHash(proj_indices)
    ones = lbann.Constant(
        value=1,
        num_neurons=utils.str_list([num_projections, projection_dim]),
    )
    eps = 0.001
    proj = lbann.ErfInv(
        lbann.WeightedSum(
            proj,
            ones,
            scaling_factors=utils.str_list([2 * (1 - eps), -(1 - eps)]),
        ))
    proj = lbann.InstanceNorm(proj)
    proj = lbann.WeightedSum(
        proj,
        scaling_factors=utils.str_list(1 / projection_dim),
    )
    return proj
Exemple #7
0
 def create_position_ids_from_inputs_embeds(self, input_embeds):
     sequence_length = self.input_shape[1]
     position_ids = range(self.padding_idx + 1,
                          sequence_length + self.padding_idx + 1)
     position_ids = lbann.WeightsLayer(
         weights=lbann.Weights(
             initializer=lbann.ValueInitializer(
                 values=str_list(position_ids)),
             optimizer=lbann.NoOptimizer(),
         ),
         dims=str_list([sequence_length]),
     )
     position_ids = lbann.Reshape(position_ids,
                                  dims=str_list([1, sequence_length]))
     position_ids = lbann.Tessellate(position_ids,
                                     dims=str_list(self.input_shape[:-1]))
     return position_ids
Exemple #8
0
def mean_squared_error(
    data_dim,
    sequence_length,
    source_sequence,
    target_sequence,
    scale_decay=0.8,
):

    # Compute inner product between source and target vectors
    # Note: Inner products are computed for each (x,y) pair and a
    # weighted sum is computed. The scaling factors sum to 1 and decay
    # exponentially as x and y get further apart in the sequence.
    prods = lbann.MatMul(
        source_sequence,
        target_sequence,
        transpose_b=True,
    )
    scale_dims = (sequence_length, sequence_length)
    scales = np.zeros(scale_dims)
    for i in range(sequence_length):
        for j in range(sequence_length):
            if i != j:
                scales[i, j] = ((1 - scale_decay) / (2 * scale_decay) *
                                scale_decay**np.abs(j - i))
    scales = lbann.Weights(
        initializer=lbann.ValueInitializer(
            values=utils.str_list(np.nditer(scales))),
        optimizer=lbann.NoOptimizer(),
    )
    scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims),
                                weights=scales)
    prods = lbann.MatMul(
        lbann.Reshape(prods, dims='1 -1'),
        lbann.Reshape(scales, dims='1 -1'),
        transpose_b=True,
    )
    prods = lbann.Reshape(prods, dims='1')

    # MSE(x,y) = ( norm(x)^2 + norm(y)^T - 2*prod(x,y) ) / dim(x)
    scale = 1 / (data_dim * sequence_length)
    return lbann.WeightedSum(lbann.L2Norm2(source_sequence),
                             lbann.L2Norm2(target_sequence),
                             prods,
                             scaling_factors=utils.str_list(
                                 [scale, scale, -2 * scale]))
Exemple #9
0
def positive_samples_loss(
        sequence_length,
        encoder_embeddings,
        decoder_embeddings,
        scale_decay=0.8,
):

    # Compute similarity scores between encoder and decoder embeddings
    scores = lbann.MatMul(
        encoder_embeddings,
        decoder_embeddings,
        transpose_b=True,
    )
    scores = lbann.LogSigmoid(scores)

    # Scale similarity scores and add together
    # Note: The scaling factor decays exponentially as embeddings get
    # futher apart in the sequence.
    # Note: The sum of all the scaling factors is approximately -1.
    scale_dims = (sequence_length,sequence_length)
    scales = np.zeros(scale_dims)
    for i in range(sequence_length):
        for j in range(sequence_length):
            if i != j:
                scales[i,j] = (
                    -(1-scale_decay)/(2*scale_decay*sequence_length)
                    * scale_decay**np.abs(j-i)
                )
    scales = lbann.Weights(
        initializer=lbann.ValueInitializer(values=utils.str_list(np.nditer(scales))),
        optimizer=lbann.NoOptimizer(),
    )
    scales = lbann.WeightsLayer(dims=utils.str_list(scale_dims), weights=scales)
    loss = lbann.MatMul(
        lbann.Reshape(scores, dims='1 -1'),
        lbann.Reshape(scales, dims='1 -1'),
        transpose_b=True,
    )
    loss = lbann.Reshape(loss, dims='1')
    return loss
Exemple #10
0
        np_y[i] = np.fft.fftshift(np_x[i])
    np_scales = np.random.uniform(size=np.prod(dims)).astype(np.float32)
    np_z = np.inner(np_y.flatten(), np_scales).item()
    tol = 8 * np_z * np.finfo(np.float32).eps

    # LBANN implementation
    lbann_x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_x.flatten())), ),
        dims=str_list(np_x.shape),
    )
    lbann_y = FFTShift()(lbann_x, dims)
    lbann_scales = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(np_scales)),
            optimizer=lbann.NoOptimizer(),
        ),
        dims=str_list(np_scales.shape),
    )
    lbann_z = lbann.MatMul(lbann.Reshape(lbann_y, dims=str_list([1, -1])),
                           lbann.Reshape(lbann_scales, dims=str_list([-1, 1])))

    # Construct LBANN model with metric checking and gradient checking
    metric = lbann.Metric(lbann_z, name='metric')
    callbacks = [
        lbann.CallbackCheckMetric(
            metric=metric.name,
            lower_bound=np_z - tol,
            upper_bound=np_z + tol,
            error_on_failure=True,
            execution_modes='test',
Exemple #11
0
def construct_macc_surrogate_model(xdim, ydim, zdim, wae_mcf, surrogate_mcf,
                                   lambda_cyc, useCNN, dump_models,
                                   pretrained_dir, ltfb_batch_interval,
                                   num_epochs):
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_network_architectures.MACCWAE(
        zdim, ydim, cf=wae_mcf, use_CNN=useCNN)  #pretrained, freeze
    inv = macc_network_architectures.MACCInverse(xdim, cf=surrogate_mcf)
    fwd = macc_network_architectures.MACCForward(zdim, cf=surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(input_fake, gt_x)
    L_cyc_x = lbann.MeanSquaredError(input_cyc, gt_x)

    L_l2_y = lbann.MeanSquaredError(output_fake, y_pred_fwd)
    L_cyc_y = lbann.MeanSquaredError(output_cyc, y_pred_fwd)

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(y_image_re, gt_y)
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {lambda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {lambda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    #Freeze appropriate (pretrained) weights
    pretrained_models = ["wae"]  #add macc?
    for l in layers:
        for idx in range(len(pretrained_models)):
            if (l.weights and pretrained_models[idx] in l.name):
                for w in range(len(l.weights)):
                    l.weights[w].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)

    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    #d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1, l2_reg])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='fw_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackLoadModel(dirs=str(pretrained_dir)),
        lbann.CallbackTimer()
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='fw_loss',
                               low_score_wins=True,
                               exchange_hyperparameters=True))
    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Exemple #12
0
def construct_jag_wae_model(ydim, zdim, mcf, useCNN, dump_models,
                            ltfb_batch_interval, num_epochs):
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    #inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list([0, ydim, ydim + 5]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    model = macc_network_architectures.MACCWAE(zdim,
                                               ydim,
                                               cf=mcf,
                                               use_CNN=useCNN)
    d1_real, d1_fake, d_adv, pred_y = model(z, gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')
    img_loss = lbann.MeanSquaredError([pred_y, gt_y])
    rec_error = lbann.L2Norm2(
        lbann.WeightedSum([pred_y, gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce, scale=0.01)
    obj = lbann.ObjectiveFunction(
        [d1_real_bce, d1_fake_bce, d_adv_bce, img_loss, rec_error, l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]
    #pred_y = macc_models.MACCWAE.pred_y_name
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackPrintModelDescription(),
        lbann.CallbackSaveModel(dir=dump_models),
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2)
    ]

    if (ltfb_batch_interval > 0):
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=ltfb_batch_interval,
                               metric='recon_error',
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    # Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Exemple #13
0
def make_model(num_epochs, embed_dim, num_heads, label_smoothing, branches,
               subgraph_topology, num_encoder_layers, num_decoder_layers,
               filter_size, d_kv, subgraph_num_common_resources,
               ENABLE_ALLSUBGRAPH, ENABLE_Concat):
    #branches = 4

    # Embedding weights
    var = 2 / (embed_dim + vocab_size)  # Glorot initialization
    embedding_weights = lbann.Weights(
        name='embeddings',
        initializer=lbann.NormalInitializer(standard_deviation=math.sqrt(var)),
    )

    # Input is two sequences of token IDs
    input_ = lbann.Input(data_field='samples')

    # Get sequences of embedding vectors
    # Note: Scale embeddings by sqrt(embed_dim).
    # Note: Decoder input is shifted right, so embedding for last
    # token isn't needed.
    embeddings_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            axis=0,
            slice_points=str_list([0, 2 * sequence_length - 1]),
        ))
    embeddings = lbann.Embedding(
        embeddings_tokens,
        weights=embedding_weights,
        num_embeddings=vocab_size,
        embedding_dim=embed_dim,
        padding_idx=pad_index,
    )
    embeddings = lbann.WeightedSum(
        embeddings,
        scaling_factors=str(math.sqrt(embed_dim)),
    )
    embeddings_slice = lbann.Slice(
        embeddings,
        axis=0,
        slice_points=str_list([0, sequence_length, 2 * sequence_length - 1]),
    )
    encoder_input = lbann.Identity(embeddings_slice)
    decoder_input = lbann.Identity(embeddings_slice)

    # Apply transformer model
    transformer = lbann.models.subgraph.TransformerSubGraph(
        branches=branches,
        hidden_size=embed_dim,
        num_heads=num_heads,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        filter_size=filter_size,
        d_kv=d_kv,
        name='transformer',
        ENABLE_ALLSUBGRAPH=ENABLE_ALLSUBGRAPH,
        ENABLE_Concat=ENABLE_Concat)
    result = transformer(
        encoder_input,
        sequence_length,
        decoder_input,
        sequence_length - 1,
    )

    # Reconstruct decoder input
    preds = lbann.ChannelwiseFullyConnected(
        result,
        weights=embedding_weights,
        output_channel_dims=[vocab_size],
        bias=False,
        transpose=True,
    )
    preds = lbann.ChannelwiseSoftmax(preds)
    preds = lbann.Slice(preds,
                        axis=0,
                        slice_points=str_list(range(sequence_length)))
    preds = [lbann.Identity(preds) for _ in range(sequence_length - 1)]

    # Count number of non-pad tokens
    label_tokens = lbann.Identity(
        lbann.Slice(
            input_,
            slice_points=str_list([sequence_length + 1, 2 * sequence_length]),
        ))
    pads = lbann.Constant(value=pad_index,
                          num_neurons=str(sequence_length - 1))
    is_not_pad = lbann.NotEqual(label_tokens, pads)
    num_not_pad = lbann.Reduction(is_not_pad, mode='sum')

    # Cross entropy loss with label smoothing
    label_tokens = lbann.Slice(
        label_tokens,
        slice_points=str_list(range(sequence_length)),
    )
    label_tokens = [
        lbann.Identity(label_tokens) for _ in range(sequence_length - 1)
    ]
    if label_smoothing > 0:
        uniform_label = lbann.Constant(value=1 / vocab_size,
                                       num_neurons=str_list([1, vocab_size]))
    loss = []
    for i in range(sequence_length - 1):
        label = lbann.OneHot(label_tokens[i], size=vocab_size)
        label = lbann.Reshape(label, dims=str_list([1, vocab_size]))
        if label_smoothing > 0:
            label = lbann.WeightedSum(
                label,
                uniform_label,
                scaling_factors=str_list(
                    [1 - label_smoothing, label_smoothing]),
            )
        loss.append(lbann.CrossEntropy(preds[i], label))
    loss = lbann.Concatenation(loss)

    # Average cross entropy over non-pad tokens
    loss_scales = lbann.Divide(
        is_not_pad,
        lbann.Tessellate(num_not_pad, hint_layer=is_not_pad),
    )
    loss = lbann.Multiply(loss, loss_scales)
    loss = lbann.Reduction(loss, mode='sum')

    # Construct model
    metrics = []
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackPrintModelDescription()
    ]

    layers = list(lbann.traverse_layer_graph(input_))
    print("Subgrpah subgraph_topology", subgraph_topology)
    for l in layers:
        for idx in range(len(l.weights)):
            l.weights[idx].optimizer = lbann.NoOptimizer()

    # for l in layers:
    #     l.device = "GPU"
    return lbann.Model(
        num_epochs,
        subgraph_communication=lbann.SubgraphCommunication.COLL_OPT,
        subgraph_topology=subgraph_topology,
        subgraph_num_common_resources=subgraph_num_common_resources,
        layers=lbann.traverse_layer_graph(input_),
        objective_function=loss,
        metrics=metrics,
        callbacks=callbacks,
    )
Exemple #14
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', data_field='samples'),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims=run_args.z_dim)

    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, run_args.z_dim,
                             save_output)
    recon, d1_real, d1_fake, d_adv, arg_max = waemodel(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_weights = [
        w for w in weights if not isinstance(w.optimizer, lbann.NoOptimizer)
    ]
    l2_reg = lbann.L2WeightRegularization(weights=l2_weights, scale=1e-4)

    wae_loss.append(d1_real_bce)
    wae_loss.append(d_adv_bce)
    wae_loss.append(d1_fake_bce)
    wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers=f'inp pred_tensor {waemodel.q_mu.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemple #15
0
    def compute_loss(self, x, y):

        # y[:, :-1]
        y = lbann.Slice(
            y,
            axis=0,
            slice_points=str_list([0, self.input_feature_dims-1]),
        )
        y = lbann.Identity(y)

        # x[:, 1:]
        x = lbann.Slice(
            x,
            slice_points=str_list([1, self.input_feature_dims]),
        )
        x = lbann.Identity(x)

        # Figure out entries in x to ignore
        ignore_mask = lbann.Equal(
            x,
            self.constant(self.label_to_ignore, hint_layer=x),
        )
        keep_mask = lbann.LogicalNot(ignore_mask)
        length = lbann.Reduction(keep_mask, mode='sum')
        length = lbann.Max(length, self.constant(1, [1]))

        # Convert entries in x to indices in y
        # Note: Ignored entries correspond to an index of -1.
        offsets = [
            row*self.dictionary_size
            for row in range(self.input_feature_dims-1)
        ]
        offsets = lbann.Weights(
            initializer=lbann.ValueInitializer(values=str_list(offsets)),
            optimizer=lbann.NoOptimizer(),
        )
        offsets = lbann.WeightsLayer(
            dims=str_list([self.input_feature_dims-1]),
            weights=offsets,
        )
        y_inds = lbann.Add(x, offsets)
        y_inds = lbann.Add(
            lbann.Multiply(keep_mask, y_inds),
            lbann.Multiply(
                ignore_mask,
                self.constant(-1, hint_layer=y_inds),
            ),
        )

        # recon_loss = F.cross_entropy(
        #     y[:, :-1].contiguous().view(-1, y.size(-1)),
        #     x[:, 1:].contiguous().view(-1),
        #     ignore_index=self.pad
        # )

        # Shift y for numerical stability
        # Note: We'd prefer to shift by y.max(-1)
        shifts = lbann.MatMul(
            lbann.Max(y, self.constant(0, hint_layer=y)),
            self.constant(
                1 / math.sqrt(self.dictionary_size),
                [self.dictionary_size, self.dictionary_size],
            ),
        )
        y = lbann.Subtract(y, shifts)

        # Compute log of softmax denominator and sum
        z = lbann.MatMul(
            lbann.Exp(y),
            self.constant(1, [self.dictionary_size, 1]),
        )
        z = lbann.Log(z)
        z = lbann.MatMul(
            lbann.Reshape(keep_mask, dims=str_list([1, -1])),
            z,
        )
        z = lbann.Reshape(z, dims=str_list([1]))

        # Compute cross entropy
        recon_loss = lbann.Gather(
            lbann.Reshape(y, dims=str_list([-1])),
            y_inds,
        )
        recon_loss = lbann.Reduction(recon_loss, mode='sum')
        recon_loss = lbann.Subtract(z, recon_loss)
        recon_loss = lbann.Divide(recon_loss, length)

        return recon_loss
Exemple #16
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    vae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")
    recon, d1_real, d1_fake, d_adv, arg_max = molwae.MolWAE(
        input_feature_dims, dictionary_size, embedding_size, pad_index,
        save_output)(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    vae_loss.append(d1_real_bce)
    vae_loss.append(d_adv_bce)
    vae_loss.append(d1_fake_bce)
    vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    if (run_args.dump_weights_interval > 0):
        callbacks.append(
            lbann.CallbackDumpWeights(
                directory=run_args.dump_weights_dir,
                epoch_interval=run_args.dump_weights_interval))
    if (run_args.ltfb):
        send_name = ('' if run_args.weights_to_send == 'All' else
                     run_args.weights_to_send)  #hack for Merlin empty string
        weights_to_ex = [w.name for w in weights if send_name in w.name]
        print("LTFB Weights to exchange ", weights_to_ex)
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=run_args.ltfb_batch_interval,
                               metric='recon',
                               weights=list2str(weights_to_ex),
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump final weight for inference
    if (run_args.dump_model_dir):
        callbacks.append(lbann.CallbackSaveModel(dir=run_args.dump_model_dir))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers='inp pred_tensor'))

    if (run_args.warmup):
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=run_args.lr / 512 *
                                                   run_args.batch_size,
                                                   num_epochs=5))

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemple #17
0
def construct_model():
    """Construct LBANN model.

    ExaGAN  model

    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_img')
    #label flipping
    label_flip_rand = lbann.Uniform(min=0,max=1, neuron_dims='1')
    label_flip_prob = lbann.Constant(value=0.01, num_neurons='1')
    one = lbann.GreaterEqual(label_flip_rand,label_flip_prob, name='is_real')
    zero = lbann.LogicalNot(one,name='is_fake')

    z = lbann.Reshape(lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="64", name='noise_vec'),dims='1 64')
    d1_real, d1_fake, d_adv, gen_img  = ExaGAN.CosmoGAN()(input,z)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce')

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
      if(l.weights and "disc1" in l.name and "instance1" in l.name):
        src_layers.append(l.name)
      #freeze weights in disc2, analogous to discrim.trainable=False in Keras
      if(l.weights and "disc2" in l.name):
        dst_layers.append(l.name)
        for idx in range(len(l.weights)):
          l.weights[idx].optimizer = lbann.NoOptimizer()
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce])
    # Initialize check metric callback
    metrics = [lbann.Metric(d1_real_bce,name='d_real'),
               lbann.Metric(d1_fake_bce, name='d_fake'),
               lbann.Metric(d_adv_bce,name='gen')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer(),
                 #Uncomment to dump output for plotting and further statistical analysis
                 #lbann.CallbackDumpOutputs(layers='inp_img gen_img_instance1_activation',
                 #                          execution_modes='train validation',
                 #                          directory='dump_outs',
                 #                          batch_interval=100,
                 #                          format='npy'),
                 lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                      destination_layers=list2str(dst_layers),
                                      batch_interval=2)]

    # Construct model
    num_epochs = 20
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Exemple #18
0
def construct_model(args):
    """Construct LBANN for CosmoGAN 3D model.

    """
    obj = []
    metrics = []
    callbacks = []

    w  = [args.input_width]*3
    w.insert(0,args.input_channel)
    _sample_dims = w

    ps = None
    #have model and input ps
    if(args.use_distconv):
      ps = get_parallel_strategy_args(
          sample_groups=args.mini_batch_size,
          depth_groups=args.depth_groups,
          height_groups=args.height_groups,
      )

    g_device = 'GPU'
    input_ = lbann.Input(name='input', data_field='samples')
    input_ = lbann.Reshape(input_, dims=list2str(_sample_dims),name='in_reshape', device=g_device),
    x1 = lbann.Identity(input_, parallel_strategy=None, name='x1')
    x2 = lbann.Identity(input_, name='x2') if args.compute_mse else None

    zero  = lbann.Constant(value=0.0,num_neurons='1',name='zero',device=g_device)
    one  = lbann.Constant(value=1.0,num_neurons='1',name='one', device=g_device)

    z = lbann.Reshape(lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="64", name='noise_vec', device=g_device),
                      dims='1 64', name='noise_vec_reshape',device=g_device)
    print("RUN ARGS ", args)

    d1_real,d1_fake,d_adv, gen_img = model.Exa3DGAN(args.input_width,args.input_channel,
                             g_device,ps,use_bn=args.use_bn)(x1,z)

    layers=list(lbann.traverse_layer_graph([d1_real, d1_fake]))
   # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
      if(l.weights and "disc1" in l.name and "instance1" in l.name):
        src_layers.append(l.name)
      #freeze weights in disc2, analogous to discrim.trainable=False in Keras
      if(l.weights and "disc2" in l.name):
        dst_layers.append(l.name)
        for idx in range(len(l.weights)):
          l.weights[idx].optimizer = lbann.NoOptimizer()
      weights.update(l.weights)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce')
    mse = lbann.MeanSquaredError([gen_img, x2], name='MSE') if args.compute_mse else None

    obj.append(d1_real_bce)
    obj.append(d1_fake_bce)
    obj.append(d_adv_bce)

    metrics.append(lbann.Metric(d_adv_bce, name='d_adv_bce'))
    metrics.append(lbann.Metric(d1_real_bce, name='d1_real_bce'))
    metrics.append(lbann.Metric(d1_fake_bce, name='d1_fake_bce'))
    if (mse is not None):
      obj.append(mse)
      metrics.append(lbann.Metric(mse, name='MSE'))


    callbacks.append(lbann.CallbackPrint())
    callbacks.append(lbann.CallbackTimer())
    callbacks.append(lbann.CallbackGPUMemoryUsage())

    # ------------------------------------------
    # Construct model
    # ------------------------------------------

    return lbann.Model(args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Exemple #19
0
def construct_model(num_epochs,mcr,spectral_loss,save_batch_interval):
    """Construct LBANN model.
    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_img')
    
    ### Create expected labels for real and fake data (with label flipping = 0.01)
    prob_flip=0.01
    label_flip_rand = lbann.Uniform(min=0,max=1, neuron_dims='1')
    label_flip_prob = lbann.Constant(value=prob_flip, num_neurons='1')
    ones = lbann.GreaterEqual(label_flip_rand,label_flip_prob, name='is_real')
    zeros = lbann.LogicalNot(ones,name='is_fake')
    gen_ones=lbann.Constant(value=1.0,num_neurons='1')## All ones: no flip. Input for training Generator.
    
    #==============================================
    ### Implement GAN
    ##Create the noise vector
    z = lbann.Reshape(lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="64", name='noise_vec'),dims='1 64')
    ## Creating the GAN object and implementing forward pass for both networks ###
    d1_real, d1_fake, d_adv, gen_img, img  = ExaGAN.CosmoGAN(mcr)(input,z,mcr) 
    
    #==============================================
    ### Compute quantities for adding to Loss and Metrics
    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,ones],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zeros],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,gen_ones],name='d_adv_bce')
    
    #img_loss = lbann.MeanSquaredError([gen_img,img])
    #l1_loss = lbann.L1Norm(lbann.WeightedSum([gen_img,img], scaling_factors="1 -1")) 
    
    #==============================================
    ### Set up source and destination layers
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    src_layers,dst_layers = [],[]
    for l in layers:
        if(l.weights and "disc1" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2, analogous to discrim.trainable=False in Keras
        if(l.weights and "disc2" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    
    #==============================================
    ### Define Loss and Metrics
    #Define loss (Objective function)
    loss_list=[d1_real_bce,d1_fake_bce,d_adv_bce] ## Usual GAN loss function
#     loss_list=[d1_real_bce,d1_fake_bce] ## skipping adversarial loss for G for testing spectral loss
    
    if spectral_loss:
        dft_gen_img = lbann.DFTAbs(gen_img)
        dft_img = lbann.StopGradient(lbann.DFTAbs(img))
        spec_loss = lbann.Log(lbann.MeanSquaredError(dft_gen_img, dft_img))
        
        loss_list.append(lbann.LayerTerm(spec_loss, scale=8.0))
        
    loss = lbann.ObjectiveFunction(loss_list)
    
    #Define metrics
    metrics = [lbann.Metric(d1_real_bce,name='d_real'),lbann.Metric(d1_fake_bce, name='d_fake'), lbann.Metric(d_adv_bce,name='gen_adv')]
    if spectral_loss: metrics.append(lbann.Metric(spec_loss,name='spec_loss'))
    
    #==============================================
    ### Define callbacks list
    callbacks_list=[]
    dump_outputs=True
    save_model=False
    print_model=False
    
    callbacks_list.append(lbann.CallbackPrint())
    callbacks_list.append(lbann.CallbackTimer())
    callbacks_list.append(lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers),batch_interval=1))
    if dump_outputs:
        #callbacks_list.append(lbann.CallbackDumpOutputs(layers='inp_img gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) 
        callbacks_list.append(lbann.CallbackDumpOutputs(layers='gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) 
    
    if save_model : callbacks_list.append(lbann.CallbackSaveModel(dir='models'))
    if print_model: callbacks_list.append(lbann.CallbackPrintModelDescription())
    
    ### Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=loss,
                       callbacks=callbacks_list)
Exemple #20
0
def construct_model():
    """Construct LBANN model.

    JAG Wasserstein autoencoder  model

    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input, axis=0, slice_points="0 16399 16404",name='inp_slice')
    gt_y = lbann.Identity(inp_slice,name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x') #param not used

    zero  = lbann.Constant(value=0.0,num_neurons='1',name='zero')
    one  = lbann.Constant(value=1.0,num_neurons='1',name='one')

    y_dim = 16399 #image+scalar shape
    z_dim = 20  #Latent space dim

    z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="20")
    d1_real, d1_fake, d_adv, pred_y  = jag_models.WAE(z_dim,y_dim)(z,gt_y)

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,one],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zero],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,one],name='d_adv_bce')

    img_loss = lbann.MeanSquaredError([pred_y,gt_y])
    rec_error = lbann.L2Norm2(lbann.WeightedSum([pred_y,gt_y], scaling_factors="1 -1"))

    layers = list(lbann.traverse_layer_graph(input))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
      if(l.weights and "disc0" in l.name and "instance1" in l.name):
        src_layers.append(l.name)
      #freeze weights in disc2
      if(l.weights and "disc1" in l.name):
        dst_layers.append(l.name)
        for idx in range(len(l.weights)):
          l.weights[idx].optimizer = lbann.NoOptimizer()
      weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    d_adv_bce = lbann.LayerTerm(d_adv_bce,scale=0.01)
    obj = lbann.ObjectiveFunction([d1_real_bce,d1_fake_bce,d_adv_bce,img_loss,rec_error,l2_reg])
    # Initialize check metric callback
    metrics = [lbann.Metric(img_loss, name='recon_error')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer(),
                 lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                      destination_layers=list2str(dst_layers),
                                      batch_interval=2)]

    # Construct model
    num_epochs = 100
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)