Exemple #1
0
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator,
                      data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass

    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {
        'adam': optim.adam,
        'adamax': optim.adamax,
        'adam2': optim.adam2
    }[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(all_params,
                                              gs,
                                              alpha=lr,
                                              hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:

        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {
                feeds['x']: _x,
                feeds['y']: _y,
                lr: _lr
            })[1]

        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:

        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x, feeds['y']: _y})

        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {
            feeds['x']: data_init['x'],
            feeds['y']: data_init['y']
        })
    # sess.run(hvd.broadcast_global_variables(0))

    return m
Exemple #2
0
def abstract_model_xy(sess, hps, feeds, train_iterators, test_iterators, data_inits, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    if hps.joint_train:
        (loss_train_A, stats_train_A, eps_flatten_A, loss_train_B, stats_train_B, eps_flatten_B) \
            = f_loss(train_iterators, is_training=True)
    else:
        (loss_train_A, stats_train_A, loss_train_B, stats_train_B) \
            = f_loss(train_iterators, is_training=True)

    all_params = tf.trainable_variables()

    # Get train data op
    def get_train_data():
        x_A, y_A = train_iterators['A']()
        x_B, y_B = train_iterators['B']()
        return x_A, y_A, x_B, y_B
    m.get_train_data = get_train_data

    # A
    with tf.variable_scope('optim_A'):
        params_A = [param for param in all_params if 'A/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_A = gradients(loss_train_A, params_A)
        else:
            gs_A = tf.gradients(loss_train_A, params_A)
        m.optimizer_A = optim.Optimizer()
        train_op_A, polyak_swap_op_A, ema_A = m.optimizer_A.adamax(
            params_A, gs_A, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_A = lambda _lr: sess.run(
                [train_op_A, stats_train_A], {lr: _lr})[1]
        else:
            def _train_A(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_A, stats_train_A], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_A = _train_A
        m.polyak_swap_A = lambda: sess.run(polyak_swap_op_A)
    # B
    with tf.variable_scope('optim_B'):
        params_B = [param for param in all_params if 'B/' in param.name]
        if hps.gradient_checkpointing == 1:
            from memory_saving_gradients import gradients
            gs_B = gradients(loss_train_B, params_B)
        else:
            gs_B = tf.gradients(loss_train_B, params_B)
        m.optimizer_B = optim.Optimizer()
        train_op_B, polyak_swap_op_B, ema_B = m.optimizer_B.adamax(
            params_B, gs_B, alpha=lr, hps=hps)
        if hps.direct_iterator:
            m.train_B = lambda _lr: sess.run(
                [train_op_B, stats_train_B], {lr: _lr})[1]
        else:
            def _train_B(_lr, _x_A, _y_A, _x_B, _y_B):
                return sess.run([train_op_B, stats_train_B], {feeds['x_A']: _x_A,
                                                              feeds['y_A']: _y_A,
                                                              feeds['x_B']: _x_B,
                                                              feeds['y_B']: _y_B,
                                                              lr: _lr})[1]
            m.train_B = _train_B
        m.polyak_swap_B = lambda: sess.run(polyak_swap_op_B)

    def _train(_lr, _x_A, _y_A, _x_B, _y_B):
        return sess.run([train_op_A, train_op_B, stats_train_A, stats_train_B],
                        {feeds['x_A']: _x_A, feeds['y_A']: _y_A,
                         feeds['x_B']: _x_B, feeds['y_B']: _y_B,
                         lr: _lr})[-2:]
    m.train = _train

    # === Testing
    loss_test_A, stats_test_A, loss_test_B, stats_test_B = f_loss(
        test_iterators, False, reuse=True)
    if hps.direct_iterator:
        m.test_A = lambda: sess.run(stats_test_A)
        m.test_B = lambda: sess.run(stats_test_B)
    else:
        # Get test data op
        def get_test_data():
            x_A, y_A = test_iterators['A']()
            x_B, y_B = test_iterators['B']()
            return x_A, y_A, x_B, y_B
        m.get_test_data = get_test_data

        def _test_A(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_A, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})

        def _test_B(_x_A, _y_A, _x_B, _y_B):
            return sess.run(stats_test_B, {feeds['x_A']: _x_A,
                                           feeds['y_A']: _y_A,
                                           feeds['x_B']: _x_B,
                                           feeds['y_B']: _y_B})
        m.test_A = _test_A
        m.test_B = _test_B

    # === Saving and restoring
    with tf.variable_scope('saver_A'):
        saver_A = tf.train.Saver()
        saver_ema_A = tf.train.Saver(ema_A.variables_to_restore())
        m.save_ema_A = lambda path_A: saver_ema_A.save(
            sess, path_A, write_meta_graph=False)
        m.save_A = lambda path_A: saver_A.save(
            sess, path_A, write_meta_graph=False)
        m.restore_A = lambda path_A: saver_A.restore(sess, path_A)

    with tf.variable_scope('saver_B'):
        saver_B = tf.train.Saver()
        saver_ema_B = tf.train.Saver(ema_B.variables_to_restore())
        m.save_ema_B = lambda path_B: saver_ema_B.save(
            sess, path_B, write_meta_graph=False)
        m.save_B = lambda path_B: saver_B.save(
            sess, path_B, write_meta_graph=False)
        m.restore_B = lambda path_B: saver_B.restore(sess, path_B)
        print("After saver")

    # === Initialize the parameters
    if hps.restore_path_A != '':
        m.restore_A(hps.restore_path_A)
    if hps.restore_path_B != '':
        m.restore_B(hps.restore_path_B)
    if hps.restore_path_A == '' and hps.restore_path_B == '':
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, False, reuse=True, init=True)

        all_params = tf.global_variables()
        params_A = [param for param in all_params if 'A/' in param.name]
        params_B = [param for param in all_params if 'B/' in param.name]
        sess.run(tf.variables_initializer(params_A))
        sess.run(tf.variables_initializer(params_B))
        feeds_dict = {feeds['x_A']: data_inits['A']['x'],
                      feeds['y_A']: data_inits['A']['y'],
                      feeds['x_B']: data_inits['B']['x'],
                      feeds['y_B']: data_inits['B']['y']}
        sess.run(results_init, feeds_dict)
    sess.run(hvd.broadcast_global_variables(0))

    return m
Exemple #3
0
def abstract_model_xy(sess, hps, feeds, train_iterator, test_iterator, data_init, lr, f_loss):

    # == Create class with static fields and methods
    class m(object):
        pass
    m.sess = sess
    m.feeds = feeds
    m.lr = lr

    # === Loss and optimizer
    loss_train, stats_train = f_loss(train_iterator, True)
    all_params = tf.trainable_variables()
    if hps.gradient_checkpointing == 1:
        from memory_saving_gradients import gradients
        gs = gradients(loss_train, all_params)
    else:
        gs = tf.gradients(loss_train, all_params)

    optimizer = {'adam': optim.adam, 'adamax': optim.adamax,
                 'adam2': optim.adam2}[hps.optimizer]

    train_op, polyak_swap_op, ema = optimizer(
        all_params, gs, alpha=lr, hps=hps)
    if hps.direct_iterator:
        m.train = lambda _lr: sess.run([train_op, stats_train], {lr: _lr})[1]
    else:
        def _train(_lr):
            _x, _y = train_iterator()
            return sess.run([train_op, stats_train], {feeds['x']: _x,
                                                      feeds['y']: _y, lr: _lr})[1]
        m.train = _train

    m.polyak_swap = lambda: sess.run(polyak_swap_op)

    # === Testing
    loss_test, stats_test = f_loss(test_iterator, False, reuse=True)
    if hps.direct_iterator:
        m.test = lambda: sess.run(stats_test)
    else:
        def _test():
            _x, _y = test_iterator()
            return sess.run(stats_test, {feeds['x']: _x,
                                         feeds['y']: _y})
        m.test = _test

    # === Saving and restoring
    saver = tf.train.Saver()
    saver_ema = tf.train.Saver(ema.variables_to_restore())
    m.save_ema = lambda path: saver_ema.save(
        sess, path, write_meta_graph=False)
    m.save = lambda path: saver.save(sess, path, write_meta_graph=False)
    m.restore = lambda path: saver.restore(sess, path)

    # === Initialize the parameters
    if hps.restore_path != '':
        m.restore(hps.restore_path)
    else:
        with Z.arg_scope([Z.get_variable_ddi, Z.actnorm], init=True):
            results_init = f_loss(None, True, reuse=True)
        sess.run(tf.global_variables_initializer())
        sess.run(results_init, {feeds['x']: data_init['x'],
                                feeds['y']: data_init['y']})
    sess.run(hvd.broadcast_global_variables(0))

    return m