Exemple #1
0
def tarantella_framework():
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

    import tarantella
    tarantella.init()

    logging.getLogger().info("init tarantella")
    yield tarantella  # provide the fixture value
    logging.getLogger().info("teardown tarantella")
Exemple #2
0
import argparse
import tensorflow as tf
from tensorflow import keras

import tarantella as tnt
tnt.init()


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("-bs", "--batch_size", type=int, default=64)
    parser.add_argument("-e", "--number_epochs", type=int, default=1)
    parser.add_argument("-lr", "--learning_rate", type=float, default=0.01)
    parser.add_argument("-train", "--train_size", type=int, default=48000)
    parser.add_argument("-val", "--val_size", type=int, default=6400)
    parser.add_argument("-test", "--test_size", type=int, default=6400)
    args = parser.parse_args()
    return args


def mnist_as_np_arrays(training_samples, validation_samples, test_samples):
    mnist_train_size = 60000
    mnist_test_size = 10000
    assert (training_samples + validation_samples <= mnist_train_size)
    assert (test_samples <= mnist_test_size)

    # load given number of samples
    (x_train_all, y_train_all), (x_test_all, y_test_all) = \
          keras.datasets.mnist.load_data()
    x_train = x_train_all[:training_samples]
    y_train = y_train_all[:training_samples]
BATCH_SIZE = 32

parser = ArgumentParser()
parser.add_argument('--hpdlf',
                    dest='use_hpdlf',
                    action='store_true',
                    default=False)

args = parser.parse_args()

use_hpdlf = args.use_hpdlf
print(use_hpdlf)

if use_hpdlf:
    import tarantella
    tarantella.init(0)
    rank = tarantella.get_rank()
    comm_size = tarantella.get_size()
else:
    rank = 0
    comm_size = 1

print(('RANK:      {}\n' 'COMM_SIZE: {}').format(rank, comm_size))

model = ReversibleSequential(*IN_SHAPE)

for k in range(5):
    kwargs = {
        'affine_clamping': 1.0,
        'global_affine_init': 0.85,
        'global_affine_type': 'SOFTPLUS',
def train(args):

    use_tarantella = eval(args['training']['use_tarantella'])
    ndims_tot = np.prod(eval(args['data']['data_dimensions']))
    output_dir = args['checkpoints']['output_dir']
    sched_milestones = eval(args['training']['milestones_lr_decay'])
    n_epochs = eval(args['training']['N_epochs'])
    optimizer_kwargs = eval(args['training']['optimizer_kwargs'])
    optimizer_type = args['training']['optimizer']
    optimizer_lr = eval(args['training']['lr'])

    if use_tarantella:
        import tarantella
        # no argument (otherwise: ranks per node)
        tarantella.init()
        node_rank = tarantella.get_rank()
        nodes_number = tarantella.get_size()
    else:
        node_rank = 0
        nodes_number = 1
    is_primary_node = (node_rank == 0)

    args['training']['rank'] = repr(node_rank)
    args['training']['comm_size'] = repr(nodes_number)

    model = build_model(args)
    data = Dataset(args)

    print(f'NODE_RANK {node_rank}')
    print(f'N_NODES {nodes_number}')
    print(f'NODE_RANK {str(is_primary_node).upper()}', flush=True)

    def nll_loss_z_part(y, z):
        zz = tf.math.reduce_mean(z**2)
        return 0.5 * zz

    def nll_loss_jac_part(y, jac):
        return -tf.math.reduce_mean(jac) / ndims_tot

    def lr_sched(ep, lr):
        if ep in sched_milestones:
            return 0.1 * lr
        return lr

    # TODO: should this only be for one node, or for each?
    lr_scheduler_callback = kr.callbacks.LearningRateScheduler(
        lr_sched, verbose=is_primary_node)

    callbacks = [lr_scheduler_callback, kr.callbacks.TerminateOnNaN()]

    if is_primary_node:
        #checkpoint_callback = kr.callbacks.ModelCheckpoint(filepath=os.path.join(output_dir, 'checkpoint_best.hdf5'),
        #save_best_only=True,
        #save_weights_only=True,
        #mode='min',
        #verbose=is_primary_node)

        loss_log_callback = kr.callbacks.CSVLogger(os.path.join(
            output_dir, 'losses.dat'),
                                                   separator=' ')

        #callbacks.append(checkpoint_callback)
        callbacks.append(loss_log_callback)

    try:
        optimizer_type = {
            'ADAM': kr.optimizers.Adam,
            'SGD': kr.optimizers.SGD
        }[optimizer_type]
    except KeyError:
        optimizer_type = eval(optimizer_type)

    optimizer = optimizer_type(optimizer_lr, **optimizer_kwargs)

    if use_tarantella:
        model = tarantella.Model(model)

    model.compile(loss=[nll_loss_z_part, nll_loss_jac_part],
                  optimizer=optimizer,
                  run_eagerly=False)
    model.build((128, 32, 32, 3))

    try:
        history = model.fit(
            data.train_dataset,
            epochs=n_epochs,
            verbose=is_primary_node,
            callbacks=callbacks,
            validation_data=(data.test_dataset if is_primary_node else None))
    except:
        raise