Esempio n. 1
0
def setup(data_reader_file,
          name='classifier',
          num_labels=200,
          mini_batch_size=128,
          num_epochs=1000,
          learning_rate=0.1,
          bn_statistics_group_size=2,
          fc_data_layout='model_parallel',
          warmup_epochs=50,
          learning_rate_drop_interval=50,
          learning_rate_drop_factor=0.25,
          checkpoint_interval=None):

    # Setup input data
    input = lbann.Input(target_mode = 'classification')
    images = lbann.Identity(input)
    labels = lbann.Identity(input)

    # Classification network
    head_cnn = modules.ResNet(bn_statistics_group_size=bn_statistics_group_size)
    class_fc = lbann.modules.FullyConnectedModule(num_labels,
                                                  activation=lbann.Softmax,
                                                  name=f'{name}_fc',
                                                  data_layout=fc_data_layout)
    x = head_cnn(images)
    probs = class_fc(x)

    # Setup objective function
    cross_entropy = lbann.CrossEntropy([probs, labels])
    l2_reg_weights = set()
    for l in lbann.traverse_layer_graph(input):
        if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected:
            l2_reg_weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002)
    obj = lbann.ObjectiveFunction([cross_entropy, l2_reg])

    # Setup model
    metrics = [lbann.Metric(lbann.CategoricalAccuracy([probs, labels]),
                            name='accuracy', unit='%')]
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    if checkpoint_interval:
        callbacks.append(
            lbann.CallbackCheckpoint(
                checkpoint_dir='ckpt',
                checkpoint_epochs=5
            )
        )

    # Learning rate schedules
    if warmup_epochs:
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(
                target=learning_rate * mini_batch_size / 128,
                num_epochs=warmup_epochs
            )
        )
    if learning_rate_drop_factor:
        callbacks.append(
            lbann.CallbackDropFixedLearningRate(
                drop_epoch=list(range(0, num_epochs, learning_rate_drop_interval)),
                amt=learning_rate_drop_factor)
        )

    # Construct model
    model = lbann.Model(num_epochs,
                        layers=lbann.traverse_layer_graph(input),
                        objective_function=obj,
                        metrics=metrics,
                        callbacks=callbacks)

    # Setup optimizer
    # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8)
    opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9)

    # Load data reader from prototext
    data_reader_proto = lbann.lbann_pb2.LbannPB()
    with open(data_reader_file, 'r') as f:
        google.protobuf.text_format.Merge(f.read(), data_reader_proto)
    data_reader_proto = data_reader_proto.data_reader
    for reader_proto in data_reader_proto.reader:
        reader_proto.python.module_dir = os.path.dirname(os.path.realpath(__file__))

    # Return experiment objects
    return model, data_reader_proto, opt
Esempio n. 2
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    vae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    kl, recon = molvae.MolVAE(input_feature_dims, dictionary_size,
                              embedding_size, pad_index)(input_)

    vae_loss.append(kl)
    vae_loss.append(recon)
    print("LEN vae loss ", len(vae_loss))

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    for l in layers:
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=5e-4)
    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(kl, name='kl_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]

    if (run_args.dump_weights_interval > 0):
        callbacks.append(
            lbann.CallbackDumpWeights(
                directory=run_args.dump_weights_dir,
                epoch_interval=run_args.dump_weights_interval))
    if (run_args.ltfb):
        send_name = ('' if run_args.weights_to_send == 'All' else
                     run_args.weights_to_send)  #hack for Merlin empty string
        weights_to_ex = [w.name for w in weights if send_name in w.name]
        print("LTFB Weights to exchange ", weights_to_ex)
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=run_args.ltfb_batch_interval,
                               metric='recon',
                               weights=list2str(weights_to_ex),
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    if (run_args.warmup):
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=run_args.lr / 512 *
                                                   run_args.batch_size,
                                                   num_epochs=5))

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)
Esempio n. 3
0
def setup(num_patches=3,
          mini_batch_size=512,
          num_epochs=75,
          learning_rate=0.005,
          bn_statistics_group_size=2,
          fc_data_layout='model_parallel',
          warmup=True,
          checkpoint_interval=None):

    # Data dimensions
    patch_dims = patch_generator.patch_dims
    num_labels = patch_generator.num_labels(num_patches)

    # Extract tensors from data sample
    input = lbann.Input()
    slice_points = [0]
    for _ in range(num_patches):
        patch_size = functools.reduce(operator.mul, patch_dims)
        slice_points.append(slice_points[-1] + patch_size)
    slice_points.append(slice_points[-1] + num_labels)
    sample = lbann.Slice(input, slice_points=str_list(slice_points))
    patches = [
        lbann.Reshape(sample, dims=str_list(patch_dims))
        for _ in range(num_patches)
    ]
    labels = lbann.Identity(sample)

    # Siamese network
    head_cnn = modules.ResNet(
        bn_statistics_group_size=bn_statistics_group_size)
    heads = [head_cnn(patch) for patch in patches]
    heads_concat = lbann.Concatenation(heads)

    # Classification network
    class_fc1 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc1',
        data_layout=fc_data_layout)
    class_fc2 = modules.FcBnRelu(
        4096,
        statistics_group_size=bn_statistics_group_size,
        name='siamese_class_fc2',
        data_layout=fc_data_layout)
    class_fc3 = lbann.modules.FullyConnectedModule(num_labels,
                                                   activation=lbann.Softmax,
                                                   name='siamese_class_fc3',
                                                   data_layout=fc_data_layout)
    x = class_fc1(heads_concat)
    x = class_fc2(x)
    probs = class_fc3(x)

    # Setup objective function
    cross_entropy = lbann.CrossEntropy([probs, labels])
    l2_reg_weights = set()
    for l in lbann.traverse_layer_graph(input):
        if type(l) == lbann.Convolution or type(l) == lbann.FullyConnected:
            l2_reg_weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=l2_reg_weights, scale=0.0002)
    obj = lbann.ObjectiveFunction([cross_entropy, l2_reg])

    # Setup model
    metrics = [
        lbann.Metric(lbann.CategoricalAccuracy([probs, labels]),
                     name='accuracy',
                     unit='%')
    ]
    callbacks = [lbann.CallbackPrint(), lbann.CallbackTimer()]
    if checkpoint_interval:
        callbacks.append(
            lbann.CallbackCheckpoint(checkpoint_dir='ckpt',
                                     checkpoint_epochs=5))

    # Learning rate schedules
    if warmup:
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=learning_rate *
                                                   mini_batch_size / 128,
                                                   num_epochs=5))
    callbacks.append(
        lbann.CallbackDropFixedLearningRate(drop_epoch=list(range(0, 100, 15)),
                                            amt=0.25))

    # Construct model
    model = lbann.Model(num_epochs,
                        layers=lbann.traverse_layer_graph(input),
                        objective_function=obj,
                        metrics=metrics,
                        callbacks=callbacks)

    # Setup optimizer
    opt = lbann.SGD(learn_rate=learning_rate, momentum=0.9)
    # opt = lbann.Adam(learn_rate=learning_rate, beta1=0.9, beta2=0.999, eps=1e-8)

    # Setup data reader
    data_reader = make_data_reader(num_patches)

    # Return experiment objects
    return model, data_reader, opt
Esempio n. 4
0
obj = lbann.ObjectiveFunction([cross_entropy, l2_reg])

# Setup model
metrics = [
    lbann.Metric(top1, name='top-1 accuracy', unit='%'),
    lbann.Metric(top5, name='top-5 accuracy', unit='%')
]
callbacks = [
    lbann.CallbackPrint(),
    lbann.CallbackTimer(),
    lbann.CallbackDropFixedLearningRate(drop_epoch=[30, 60, 80], amt=0.1)
]
if args.warmup:
    callbacks.append(
        lbann.CallbackLinearGrowthLearningRate(target=0.1 *
                                               args.mini_batch_size / 256,
                                               num_epochs=5))
model = lbann.Model(args.num_epochs,
                    layers=layers,
                    objective_function=obj,
                    metrics=metrics,
                    callbacks=callbacks)

# Setup optimizer
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size,
Esempio n. 5
0
def construct_model(run_args):
    """Construct LBANN model.

    Initial model for ATOM molecular VAE

    """
    import lbann

    pad_index = run_args.pad_index
    assert pad_index is not None

    sequence_length = run_args.sequence_length
    assert sequence_length is not None

    print("sequence length is {}".format(sequence_length))
    data_layout = "data_parallel"
    # Layer graph
    input_ = lbann.Identity(lbann.Input(name='inp', target_mode="N/A"),
                            name='inp1')
    vae_loss = []
    input_feature_dims = sequence_length

    embedding_size = run_args.embedding_dim
    dictionary_size = run_args.num_embeddings
    assert embedding_size is not None
    assert dictionary_size is not None

    save_output = True if run_args.dump_outputs_dir else False

    print("save output? ", save_output, "out dir ", run_args.dump_outputs_dir)
    z = lbann.Gaussian(mean=0.0, stdev=1.0, neuron_dims="128")
    recon, d1_real, d1_fake, d_adv, arg_max = molwae.MolWAE(
        input_feature_dims, dictionary_size, embedding_size, pad_index,
        save_output)(input_, z)

    zero = lbann.Constant(value=0.0, num_neurons='1', name='zero')
    one = lbann.Constant(value=1.0, num_neurons='1', name='one')

    d1_real_bce = lbann.SigmoidBinaryCrossEntropy([d1_real, one],
                                                  name='d1_real_bce')
    d1_fake_bce = lbann.SigmoidBinaryCrossEntropy([d1_fake, zero],
                                                  name='d1_fake_bce')
    d_adv_bce = lbann.SigmoidBinaryCrossEntropy([d_adv, one], name='d_adv_bce')

    vae_loss.append(recon)

    layers = list(lbann.traverse_layer_graph(input_))
    # Setup objective function
    weights = set()
    src_layers = []
    dst_layers = []
    for l in layers:
        if (l.weights and "disc0" in l.name and "instance1" in l.name):
            src_layers.append(l.name)
        #freeze weights in disc2
        if (l.weights and "disc1" in l.name):
            dst_layers.append(l.name)
            for idx in range(len(l.weights)):
                l.weights[idx].optimizer = lbann.NoOptimizer()
        weights.update(l.weights)
    l2_reg = lbann.L2WeightRegularization(weights=weights, scale=1e-4)

    vae_loss.append(d1_real_bce)
    vae_loss.append(d_adv_bce)
    vae_loss.append(d1_fake_bce)
    vae_loss.append(l2_reg)
    print("LEN vae loss ", len(vae_loss))

    obj = lbann.ObjectiveFunction(vae_loss)

    # Initialize check metric callback
    metrics = [
        lbann.Metric(d_adv_bce, name='adv_loss'),
        lbann.Metric(recon, name='recon')
    ]

    callbacks = [
        lbann.CallbackPrint(),
        #lbann.CallbackStepLearningRate(step=10, amt=0.5),
        lbann.CallbackTimer()
    ]

    if (run_args.dump_weights_interval > 0):
        callbacks.append(
            lbann.CallbackDumpWeights(
                directory=run_args.dump_weights_dir,
                epoch_interval=run_args.dump_weights_interval))
    if (run_args.ltfb):
        send_name = ('' if run_args.weights_to_send == 'All' else
                     run_args.weights_to_send)  #hack for Merlin empty string
        weights_to_ex = [w.name for w in weights if send_name in w.name]
        print("LTFB Weights to exchange ", weights_to_ex)
        callbacks.append(
            lbann.CallbackLTFB(batch_interval=run_args.ltfb_batch_interval,
                               metric='recon',
                               weights=list2str(weights_to_ex),
                               low_score_wins=True,
                               exchange_hyperparameters=True))

    callbacks.append(
        lbann.CallbackReplaceWeights(source_layers=list2str(src_layers),
                                     destination_layers=list2str(dst_layers),
                                     batch_interval=2))

    #Dump final weight for inference
    if (run_args.dump_model_dir):
        callbacks.append(lbann.CallbackSaveModel(dir=run_args.dump_model_dir))

    #Dump output (activation) for post processing
    if (run_args.dump_outputs_dir):
        pred_tensor = lbann.Concatenation(arg_max, name='pred_tensor')
        callbacks.append(
            lbann.CallbackDumpOutputs(
                batch_interval=run_args.dump_outputs_interval,
                execution_modes='test',
                directory=run_args.dump_outputs_dir,
                layers='inp pred_tensor'))

    if (run_args.warmup):
        callbacks.append(
            lbann.CallbackLinearGrowthLearningRate(target=run_args.lr / 512 *
                                                   run_args.batch_size,
                                                   num_epochs=5))

    # Construct model
    return lbann.Model(run_args.num_epochs,
                       weights=weights,
                       layers=layers,
                       objective_function=obj,
                       metrics=metrics,
                       callbacks=callbacks)