Ejemplo n.º 1
0
        elif layer_name == 'Dropout':
            dropout_id += 1
            if dropout_id == args.gather_dropout_id:
                break

        layer.parallel_strategy = parallel_strategy

    # Set up model
    metrics = [lbann.Metric(mse, name='MSE', unit='')]
    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackTimer(),
        lbann.CallbackGPUMemoryUsage(),
        lbann.CallbackDumpOutputs(directory='dump_acts/',
                                  layers=' '.join([preds.name, secrets.name]),
                                  execution_modes='test'),
        lbann.CallbackProfiler(skip_init=True)
    ]
    # # TODO: Use polynomial learning rate decay (https://github.com/LLNL/lbann/issues/1581)
    # callbacks.append(lbann.CallbackPolyLearningRate(
    #     power=1.0,
    #     num_epochs=100,
    #     end_lr=1e-7))
    model = lbann.Model(epochs=args.num_epochs,
                        layers=layers,
                        objective_function=obj,
                        metrics=metrics,
                        callbacks=callbacks)

    # Setup optimizer
Ejemplo n.º 2
0
def construct_model(num_epochs,mcr,spectral_loss,save_batch_interval):
    """Construct LBANN model.
    """
    import lbann

    # Layer graph
    input = lbann.Input(target_mode='N/A',name='inp_img')
    
    ### Create expected labels for real and fake data (with label flipping = 0.01)
    prob_flip=0.01
    label_flip_rand = lbann.Uniform(min=0,max=1, neuron_dims='1')
    label_flip_prob = lbann.Constant(value=prob_flip, num_neurons='1')
    ones = lbann.GreaterEqual(label_flip_rand,label_flip_prob, name='is_real')
    zeros = lbann.LogicalNot(ones,name='is_fake')
    gen_ones=lbann.Constant(value=1.0,num_neurons='1')## All ones: no flip. Input for training Generator.
    
    #==============================================
    ### Implement GAN
    ##Create the noise vector
    z = lbann.Reshape(lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="64", name='noise_vec'),dims='1 64')
    ## Creating the GAN object and implementing forward pass for both networks ###
    d1_real, d1_fake, d_adv, gen_img, img  = ExaGAN.CosmoGAN(mcr)(input,z,mcr) 
    
    #==============================================
    ### Compute quantities for adding to Loss and Metrics
    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real,ones],name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake,zeros],name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv,gen_ones],name='d_adv_bce')
    
    #img_loss = lbann.MeanSquaredError([gen_img,img])
    #l1_loss = lbann.L1Norm(lbann.WeightedSum([gen_img,img], scaling_factors="1 -1")) 
    
    #==============================================
    ### Set up source and destination layers
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    src_layers,dst_layers = [],[]
    for l in layers:
        if(l.weights and "disc1" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2, analogous to discrim.trainable=False in Keras
        if(l.weights and "disc2" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)
    
    #==============================================
    ### Define Loss and Metrics
    #Define loss (Objective function)
    loss_list=[d1_real_bce,d1_fake_bce,d_adv_bce] ## Usual GAN loss function
#     loss_list=[d1_real_bce,d1_fake_bce] ## skipping adversarial loss for G for testing spectral loss
    
    if spectral_loss:
        dft_gen_img = lbann.DFTAbs(gen_img)
        dft_img = lbann.StopGradient(lbann.DFTAbs(img))
        spec_loss = lbann.Log(lbann.MeanSquaredError(dft_gen_img, dft_img))
        
        loss_list.append(lbann.LayerTerm(spec_loss, scale=8.0))
        
    loss = lbann.ObjectiveFunction(loss_list)
    
    #Define metrics
    metrics = [lbann.Metric(d1_real_bce,name='d_real'),lbann.Metric(d1_fake_bce, name='d_fake'), lbann.Metric(d_adv_bce,name='gen_adv')]
    if spectral_loss: metrics.append(lbann.Metric(spec_loss,name='spec_loss'))
    
    #==============================================
    ### Define callbacks list
    callbacks_list=[]
    dump_outputs=True
    save_model=False
    print_model=False
    
    callbacks_list.append(lbann.CallbackPrint())
    callbacks_list.append(lbann.CallbackTimer())
    callbacks_list.append(lbann.CallbackReplaceWeights(source_layers=list2str(src_layers), destination_layers=list2str(dst_layers),batch_interval=1))
    if dump_outputs:
        #callbacks_list.append(lbann.CallbackDumpOutputs(layers='inp_img gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) 
        callbacks_list.append(lbann.CallbackDumpOutputs(layers='gen_img_instance1_activation', execution_modes='train validation', directory='dump_outs',batch_interval=save_batch_interval,format='npy')) 
    
    if save_model : callbacks_list.append(lbann.CallbackSaveModel(dir='models'))
    if print_model: callbacks_list.append(lbann.CallbackPrintModelDescription())
    
    ### Construct model
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       metrics=metrics,
                       objective_function=loss,
                       callbacks=callbacks_list)
Ejemplo n.º 3
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None, 'should be training seq len + bos + eos'

    print("sequence length is {}, which is training sequence len + bos + eos".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(data_field='samples',name='inp_data')
    #Note input assumes to come from encoder script concatenation of input smiles + z
    inp_slice = lbann.Slice(input_, axis=0,
                             slice_points=str_list([0, sequence_length, sequence_length+run_args.z_dim]),
                             name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z  = lbann.Identity(inp_slice, name='z')
    wae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #uncomment below for random sampling
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims=str(run_args.z_dim))
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,run_args.z_dim,save_output=save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )


    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval,
                       execution_modes='test',
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Ejemplo n.º 4
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', data_field='samples'),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims=run_args.z_dim)

    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, run_args.z_dim,
                             save_output)
    recon, d1_real, d1_fake, d_adv, arg_max = waemodel(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    wae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_weights = [
        w for w in weights if not isinstance(w.optimizer, lbann.NoOptimizer)
    ]
    l2_reg = lbann.L2WeightRegularization(weights=l2_weights, scale=1e-4)

    wae_loss.append(d1_real_bce)
    wae_loss.append(d_adv_bce)
    wae_loss.append(d1_fake_bce)
    wae_loss.append(l2_reg)
    print("LEN wae loss ", len(wae_loss))

    obj = lbann.ObjectiveFunction(wae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers=f'inp pred_tensor {waemodel.q_mu.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Ejemplo n.º 5
0
loss = lbann.MeanSquaredError([x, secrets])

# Metrics
metrics = [lbann.Metric(loss, name="MSE", unit="")]

# Callbacks
callbacks = [
    lbann.CallbackPrint(),
    lbann.CallbackTimer(),
    lbann.CallbackPolyLearningRate(
        power=1.0,
        num_epochs=100,  # TODO: Warn if args.epochs < 100
    ),
    lbann.CallbackGPUMemoryUsage(),
    lbann.CallbackDumpOutputs(directory="dump_acts/",
                              layers="cosmoflow_module1_fc3_instance1 layer3",
                              execution_modes="test"),
    lbann.CallbackProfiler(skip_init=True)
]

# ----------------------------------
# Setup experiment
# ----------------------------------

# Setup model
model = lbann.Model(args.mini_batch_size,
                    args.epochs,
                    layers=lbann.traverse_layer_graph(input),
                    objective_function=loss,
                    metrics=metrics,
                    callbacks=callbacks)
Ejemplo n.º 6
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    print("Dump model dir ", run_args.dump_model_dir)
    assert run_args.dump_model_dir, "evaluate script asssumes a pretrained WAE model"
    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    wae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")

    x = lbann.Slice(input_, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims, dictionary_size,
                             embedding_size, pad_index, save_output)
    x_emb = lbann.Embedding(x,
                            num_embeddings=waemodel.dictionary_size,
                            embedding_dim=waemodel.embedding_size,
                            name='emb',
                            weights=waemodel.emb_weights)

    latentz = waemodel.forward_encoder(x_emb)

    fake_loss = lbann.MeanAbsoluteError(latentz, z)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)

    obj = lbann.ObjectiveFunction(fake_loss)

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    #Dump output (activation) for post processing
    conc_out = lbann.Concatenation([input_, latentz], name='conc_out')
    callbacks.append(
        lbann.CallbackDumpOutputs(
            batch_interval=run_args.dump_outputs_interval,
            execution_modes='test',
            directory=run_args.dump_outputs_dir,
            layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 7
0
    # Radial profile
    x = lbann.WeightsLayer(
        weights=lbann.Weights(
            lbann.ValueInitializer(values=str_list(image.flatten())), ),
        dims=str_list(image.shape),
    )
    max_r = image.shape[-1] // 2
    rprof = RadialProfile()(x, image.shape, max_r)
    rprof_slice = lbann.Slice(rprof, slice_points=str_list([0, 1, 2, 3]))
    red = lbann.Identity(rprof_slice, name='red')
    green = lbann.Identity(rprof_slice, name='green')
    blue = lbann.Identity(rprof_slice, name='blue')

    # Construct model
    callbacks = [
        lbann.CallbackDumpOutputs(layers=str_list(['red', 'green', 'blue'])),
    ]
    model = lbann.Model(
        epochs=0,
        layers=lbann.traverse_layer_graph([input_, rprof]),
        callbacks=callbacks,
    )

    # Run LBANN
    lbann.run(
        trainer=lbann.Trainer(mini_batch_size=1),
        model=model,
        data_reader=reader,
        optimizer=lbann.NoOptimizer(),
        job_name='lbann_radial_profile_test',
    )
Ejemplo n.º 8
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    #sequence_length = run_args.sequence_length
    sequence_length = 102
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Input(target_mode='N/A',name='inp_data')
    inp_slice = lbann.Slice(input_, axis=0, slice_points="0 102 230",name='inp_slice')
    inp_smile = lbann.Identity(inp_slice,name='inp_smile')
    z = lbann.Identity(inp_slice, name='z') #param not used
    #input_ = lbann.Identity(lbann.Input(name='inp',target_mode="N/A"), name='inp1')
    vae_loss= []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    #print("Inp smile len ", len(inp_smile), "z len ",  len(z))
    print("save output? ", save_output, "out dir ",  run_args.dump_outputs_dir)
    #z = lbann.Gaussian(mean=0.0,stdev=1.0, neuron_dims="128")
    x = lbann.Slice(inp_smile, slice_points=str_list([0, input_feature_dims]))
    x = lbann.Identity(x)
    waemodel = molwae.MolWAE(input_feature_dims,
                           dictionary_size,
                           embedding_size,
                           pad_index,save_output)
    x_emb = lbann.Embedding(
            x,
            num_embeddings=waemodel.dictionary_size,
            embedding_dim=waemodel.embedding_size,
            name='emb',
            weights=waemodel.emb_weights
    )

    
    pred, arg_max = waemodel.forward_decoder(x_emb,z)

    recon = waemodel.compute_loss(x, pred)



    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
      weights.update(l.weights)
    #l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    #vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [lbann.Metric(recon, name='recon')]

    callbacks = [lbann.CallbackPrint(),
                 lbann.CallbackTimer()]


    #Dump output (activation) for post processing
    pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
    conc_out = lbann.Concatenation([input_,pred_tensor], name='conc_out')
    callbacks.append(lbann.CallbackDumpOutputs(batch_interval=run_args.dump_outputs_interval, 
                       execution_modes='test', 
                       directory=run_args.dump_outputs_dir,
                       layers=f'{conc_out.name}'))
    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Ejemplo n.º 9
0
def construct_model():
    """Construct MACC surrogate model.

    See https://arxiv.org/pdf/1912.08113.pdf model architecture and other details

    """
    import lbann

    # Layer graph
    input = lbann.Input(data_field='samples', name='inp_data')
    # data is 64*64*4 images + 15 scalar + 5 param
    inp_slice = lbann.Slice(input,
                            axis=0,
                            slice_points=str_list(
                                [0, args.ydim, args.ydim + args.xdim]),
                            name='inp_slice')
    gt_y = lbann.Identity(inp_slice, name='gt_y')
    gt_x = lbann.Identity(inp_slice, name='gt_x')  #param not used

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="20")
    wae = macc_models.MACCWAE(args.zdim,
                              args.ydim,
                              cf=args.wae_mcf,
                              use_CNN=args.useCNN)  #pretrained, freeze
    inv = macc_models.MACCInverse(args.xdim, cf=args.surrogate_mcf)
    fwd = macc_models.MACCForward(args.zdim, cf=args.surrogate_mcf)

    y_pred_fwd = wae.encoder(gt_y)

    param_pred_ = wae.encoder(gt_y)
    input_fake = inv(param_pred_)

    output_cyc = fwd(input_fake)
    y_image_re2 = wae.decoder(output_cyc)
    '''**** Train cycleGAN input params <--> latent space of (images, scalars) ****'''
    output_fake = fwd(gt_x)
    y_image_re = wae.decoder(output_fake)

    y_out = wae.decoder(y_pred_fwd)

    param_pred2_ = wae.encoder(y_image_re)
    input_cyc = inv(param_pred2_)

    L_l2_x = lbann.MeanSquaredError(
        input_fake, gt_x)  #(x,inv(enc(y)), (encoder+)inverse loss
    L_cyc_x = lbann.MeanSquaredError(
        input_cyc, gt_x)  #param, x cycle loss, from latent space

    L_l2_y = lbann.MeanSquaredError(
        output_fake, y_pred_fwd)  #pred error into latent space (enc(y),fw(x))
    L_cyc_y = lbann.MeanSquaredError(
        output_cyc,
        y_pred_fwd)  # pred error into latent space (enc(y), fw(inv(enc(y))))

    #@todo slice here to separate scalar from image
    img_sca_loss = lbann.MeanSquaredError(
        y_image_re,
        gt_y)  # (y,dec(fw(x))) #forward model to decoder, no latent space
    dec_fw_inv_enc_y = lbann.MeanSquaredError(
        y_image_re2, gt_y)  #(y, dec(fw(inv(enc(y))))) y->enc_z->x'->fw_z->y'
    wae_loss = lbann.MeanSquaredError(y_out, gt_y)  #(y, dec(enc(y)) '
    #L_cyc = L_cyc_y + L_cyc_x
    L_cyc = lbann.Add(L_cyc_y, L_cyc_x)

    #loss_gen0  = L_l2_y + lamda_cyc*L_cyc
    loss_gen0 = lbann.WeightedSum([L_l2_y, L_cyc],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    loss_gen1 = lbann.WeightedSum([L_l2_x, L_cyc_y],
                                  scaling_factors=f'1 {args.lamda_cyc}')
    #loss_gen1  =  L_l2_x + lamda_cyc*L_cyc_y

    conc_out = lbann.Concatenation(
        [gt_x, wae_loss, img_sca_loss, dec_fw_inv_enc_y, L_l2_x],
        name='x_errors')
    layers = list(lbann.traverse_layer_graph(input))
    weights = set()
    for l in layers:
        weights.update(l.weights)

    # Setup objective function
    obj = lbann.ObjectiveFunction([loss_gen0, loss_gen1])
    # Initialize check metric callback
    metrics = [
        lbann.Metric(img_sca_loss, name='img_re1'),
        lbann.Metric(dec_fw_inv_enc_y, name='img_re2'),
        lbann.Metric(wae_loss, name='wae_loss'),
        lbann.Metric(L_l2_x, name='inverse loss'),
        lbann.Metric(L_cyc_y, name='output cycle loss'),
        lbann.Metric(L_cyc_x, name='param cycle loss')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        lbann.CallbackDumpOutputs(layers=f'{conc_out.name}',
                                  execution_modes='test',
                                  directory=args.dump_outputs,
                                  batch_interval=1,
                                  format='npy'),
        lbann.CallbackTimer()
    ]

    # Construct model
    num_epochs = 1
    return lbann.Model(num_epochs,
                       weights=weights,
                       layers=layers,
                       serialize_io=True,
                       metrics=metrics,
                       objective_function=obj,
                       callbacks=callbacks)
Ejemplo n.º 10
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    vae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")
    recon, d1_real, d1_fake, d_adv, arg_max = molwae.MolWAE(
        input_feature_dims, dictionary_size, embedding_size, pad_index,
        save_output)(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    vae_loss.append(d1_real_bce)
    vae_loss.append(d_adv_bce)
    vae_loss.append(d1_fake_bce)
    vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    if (run_args.dump_weights_interval > 0):
        callbacks.append(
            lbann.CallbackDumpWeights(
                directory=run_args.dump_weights_dir,
                epoch_interval=run_args.dump_weights_interval))
    if (run_args.ltfb):
        send_name = ('' if run_args.weights_to_send == 'All' else
                     run_args.weights_to_send)  #hack for Merlin empty string
        weights_to_ex = [w.name for w in weights if send_name in w.name]
        print("LTFB Weights to exchange ", weights_to_ex)
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=run_args.ltfb_batch_interval,
                               metric='recon',
                               weights=list2str(weights_to_ex),
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump final weight for inference
    if (run_args.dump_model_dir):
        callbacks.append(lbann.CallbackSaveModel(dir=run_args.dump_model_dir))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers='inp pred_tensor'))

    if (run_args.warmup):
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=run_args.lr / 512 *
                                                   run_args.batch_size,
                                                   num_epochs=5))

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Ejemplo n.º 11
0
mini_batch_size = args.mini_batch_size
num_epochs = args.num_epochs

# Callbacks for Debug and Running Model

print_model = lbann.CallbackPrintModelDescription(
)  #Prints initial Model after Setup

training_output = lbann.CallbackPrint(
    interval=1, print_global_stat_only=False)  #Prints training progress

gpu_usage = lbann.CallbackGPUMemoryUsage()

encoded_output = lbann.CallbackDumpOutputs(layers="decoded",
                                           batch_interval=400,
                                           directory=os.path.dirname(
                                               os.path.realpath(__file__)),
                                           format="npy")

# ----------------------------------
# Set up Experiment
# ----------------------------------

#Generate Model
model = lbann.Model(
    num_epochs,
    layers=layers,
    objective_function=img_loss,
    metrics=metrics,
    callbacks=[print_model, training_output, gpu_usage, encoded_output])