예제 #1
0
def test_format_config(cfg):
    cfg_text = _format_config(cfg, ConfigSummary())
    lines = cfg_text.split('\n')
    assert lines[0].startswith('Configuration')
    assert lines[1].find(' a = 0') > -1
    assert lines[2].find(' b = {}') > -1
    assert lines[3].find(' c:') > -1
    assert lines[4].find(' cA = 3') > -1
    assert lines[5].find(' cB = 4') > -1
    assert lines[6].find(' cC:') > -1
    assert lines[7].find(' cC1 = 6') > -1
    assert lines[8].find(' d:') > -1
    assert lines[9].find(' dA = 8') > -1
예제 #2
0
def test_format_config(cfg):
    cfg_text = _format_config(cfg, ConfigSummary())
    lines = cfg_text.split('\n')
    assert lines[0].startswith('Configuration')
    assert ' a = 0' in lines[1]
    assert ' b = {}' in lines[2]
    assert ' c:' in lines[3]
    assert ' cA = 3' in lines[4]
    assert ' cB = 4' in lines[5]
    assert ' cC:' in lines[6]
    assert ' cC1 = 6' in lines[7]
    assert ' d:' in lines[8]
    assert ' dA = 8' in lines[9]
예제 #3
0
def test_format_config(cfg):
    cfg_text = _format_config(cfg, ConfigSummary())
    lines = cfg_text.split("\n")
    assert lines[0].startswith("Configuration")
    assert " a = 0" in lines[1]
    assert " b = {}" in lines[2]
    assert " c:" in lines[3]
    assert " cA = 3" in lines[4]
    assert " cB = 4" in lines[5]
    assert " cC:" in lines[6]
    assert " cC1 = 6" in lines[7]
    assert " d:" in lines[8]
    assert " dA = 8" in lines[9]
예제 #4
0
def test_format_config(cfg):
    cfg_text = _format_config(cfg, ConfigSummary())
    lines = cfg_text.split('\n')
    assert lines[0].startswith('Configuration')
    assert lines[1].find(' a = 0') > -1
    assert lines[2].find(' b = {}') > -1
    assert lines[3].find(' c:') > -1
    assert lines[4].find(' cA = 3') > -1
    assert lines[5].find(' cB = 4') > -1
    assert lines[6].find(' cC:') > -1
    assert lines[7].find(' cC1 = 6') > -1
    assert lines[8].find(' d:') > -1
    assert lines[9].find(' dA = 8') > -1
예제 #5
0
def test_format_config(cfg):
    cfg_text = _format_config(cfg, ConfigSummary())
    lines = cfg_text.split('\n')
    assert lines[0].startswith('Configuration')
    assert ' a = 0' in lines[1]
    assert ' b = {}' in lines[2]
    assert ' c:' in lines[3]
    assert ' cA = 3' in lines[4]
    assert ' cB = 4' in lines[5]
    assert ' cC:' in lines[6]
    assert ' cC1 = 6' in lines[7]
    assert ' d:' in lines[8]
    assert ' dA = 8' in lines[9]
예제 #6
0
def set_up_loging(exp_path, _config, _run, loglevel='INFO'):
    spath = os.path.join(exp_path, 'scources')
    lpath = os.path.join(exp_path, 'log.txt')
    cpath = os.path.join(exp_path, 'config.json')

    for src in (glob.glob('./*.py') + glob.glob('./*/*.py')):
        dst = os.path.join(spath, src[2:])
        mkdir(dst)
        shutil.copy(src, dst)

    mkdir(lpath)
    handler = logging.FileHandler(lpath)
    handler.setFormatter(
        logging.Formatter(fmt='%(asctime)s %(levelname)s: %(message)s',
                          datefmt='%m-%d %H:%M:%S'))
    _run.run_logger.setLevel(loglevel)
    _run.run_logger.addHandler(handler)

    mkdir(cpath)
    save_config(_run.config, _run.run_logger, cpath)
    _run.run_logger.info(_format_config(_run.config,
                                        _run.config_modifications))
예제 #7
0
def main(_run):

    config = get_config()
    log = get_logger()

    from sacred.commands import _format_config  # brittle: get a string of what ex.commands['print_config']() prints.
    config_str = _format_config(_run.config, _run.config_modifications)
    log.info(config_str)

    train_data, valid_data, test_data, _ = get_raw_data(
        config.data_path, config.dataset)

    log.info('Compiling (batched) model...')
    m = Model(config)
    log.info('Done. Number of parameters: %d' % m.num_params)

    trains, vals, tests, best_val, save_path = [np.inf], [np.inf
                                                          ], [np.inf
                                                              ], np.inf, None

    for i in range(config.max_max_epoch):
        lr_decay = config.lr_decay**max(i - config.max_epoch + 1, 0.0)
        m.assign_lr(config.learning_rate / lr_decay)

        log.info("Epoch: %d Learning rate: %.3f" % (i + 1, m.lr))

        train_perplexity = run_epoch(m,
                                     train_data,
                                     config,
                                     is_train=True,
                                     verbose=True,
                                     log=log)
        log.info("Epoch: %d Train Perplexity: %.3f, Bits: %.3f" %
                 (i + 1, train_perplexity, np.log2(train_perplexity)))

        valid_perplexity = run_epoch(m, valid_data, config, is_train=False)
        log.info("Epoch: %d Valid Perplexity (batched): %.3f, Bits: %.3f" %
                 (i + 1, valid_perplexity, np.log2(valid_perplexity)))

        test_perplexity = run_epoch(m, test_data, config, is_train=False)
        log.info("Epoch: %d Test Perplexity (batched): %.3f, Bits: %.3f" %
                 (i + 1, test_perplexity, np.log2(test_perplexity)))

        trains.append(train_perplexity)
        vals.append(valid_perplexity)
        tests.append(test_perplexity)

        if valid_perplexity < best_val:
            best_val = valid_perplexity
            log.info("Best Batched Valid Perplexity improved to %.03f" %
                     best_val)
            save_path = './theano_rhn_' + config.dataset + '_' + str(
                config.seed) + '_best_model.pkl'
            m.save(save_path)
            log.info("Saved to: %s" % save_path)

    log.info("Training is over.")
    best_val_epoch = np.argmin(vals)
    log.info(
        "Best Batched Validation Perplexity %.03f (Bits: %.3f) was at Epoch %d"
        %
        (vals[best_val_epoch], np.log2(vals[best_val_epoch]), best_val_epoch))
    log.info("Training Perplexity at this Epoch was %.03f, Bits: %.3f" %
             (trains[best_val_epoch], np.log2(trains[best_val_epoch])))
    log.info("Batched Test Perplexity at this Epoch was %.03f, Bits: %.3f" %
             (tests[best_val_epoch], np.log2(tests[best_val_epoch])))

    non_batched_config = deepcopy(config)
    non_batched_config.batch_size = 1
    non_batched_config.load_model = save_path

    log.info('Compiling (non-batched) model...')
    m_non_batched = Model(non_batched_config)
    log.info('Done. Number of parameters: %d' % m_non_batched.num_params)

    log.info("Testing on non-batched Valid ...")
    valid_perplexity = run_epoch(m_non_batched,
                                 valid_data,
                                 non_batched_config,
                                 is_train=False,
                                 verbose=True,
                                 log=log)
    log.info("Full Valid Perplexity: %.3f, Bits: %.3f" %
             (valid_perplexity, np.log2(valid_perplexity)))

    log.info("Testing on non-batched Test ...")
    test_perplexity = run_epoch(m_non_batched,
                                test_data,
                                non_batched_config,
                                is_train=False,
                                verbose=True,
                                log=log)
    log.info("Full Test Perplexity: %.3f, Bits: %.3f" %
             (test_perplexity, np.log2(test_perplexity)))

    return vals[best_val_epoch]
예제 #8
0
파일: train.py 프로젝트: qmiwang/deeprop
def main(seed, save_checkpoint_path, network_prefix, network, _run):
    net_config = get_config()
    log_path = '%s/%s-%s/%d' % (save_checkpoint_path, network_prefix, network,
                                seed)

    log = get_logger()
    from sacred.commands import _format_config
    config_str = _format_config(_run.config, _run.config_modifications)
    log.info(config_str)
    np.random.seed(seed)

    arg_scope = inception_arg_scope(batch_norm_decay=0.9)
    initializer = tf.random_uniform_initializer(-net_config.init_scale,
                                                net_config.init_scale)

    train_datas, train_labels, test_datas, test_labels, syntext = get_data(
        net_config)
    cpkl.dump((train_datas, train_labels, test_datas, test_labels),
              open(log_path + '/data.pkl', "wb"))
    img_aug = ImageAug(net_config.img_height, net_config.img_width,
                       net_config.img_channels)
    train_data = RetinaDataIter(train_datas,
                                net_config.imgs_per_sample,
                                train_labels,
                                net_config.batch_size,
                                net_config.img_height,
                                net_config.img_width,
                                net_config.img_channels,
                                image_aug=img_aug)
    test_data = RetinaDataIter(test_datas,
                               net_config.imgs_per_sample,
                               test_labels,
                               net_config.batch_size,
                               net_config.img_height,
                               net_config.img_width,
                               net_config.img_channels,
                               image_aug=img_aug)

    gpu_config = tf.ConfigProto()
    gpu_config.gpu_options.allow_growth = True
    gpu_config.gpu_options.visible_device_list = str(net_config.gpu)
    gpu_config.gpu_options.per_process_gpu_memory_fraction = 0.95
    with tf.Graph().as_default(), tf.Session(config=gpu_config) as sess:
        tf.logging.set_verbosity(tf.logging.INFO)
        with slim.arg_scope(arg_scope):
            with slim.arg_scope([slim.batch_norm], is_training=True):
                with tf.name_scope('train'):
                    with tf.variable_scope(net_config.network,
                                           reuse=None,
                                           initializer=initializer) as scope:
                        train_model = globals()[net_config.network](
                            is_training=True, config=net_config, scope=scope)
        with slim.arg_scope(arg_scope):
            with slim.arg_scope([slim.batch_norm], is_training=False):
                with tf.name_scope('test'):
                    with tf.variable_scope(net_config.network,
                                           reuse=True,
                                           initializer=initializer) as scope:
                        test_model = globals()[net_config.network](
                            is_training=False, config=net_config, scope=scope)

        log.info('[ Loading checkpoint ... ]')
        init_fn, restore_vars = get_init_fn(
            net_config.checkpoint_path, net_config.checkpoint_exclude_scopes)
        init_fn(sess)

        # init left variables in model
        log.info('init left...')
        uninitialized_vars = set(tf.global_variables()) - set(
            restore_vars)  #set(tf.trainable_variables()) - set(restore_vars)
        sess.run(tf.variables_initializer(uninitialized_vars))
        saver = tf.train.Saver(max_to_keep=net_config.num_epochs)

        for epoch in range(net_config.num_epochs):
            time_start = time.time()
            acc, cost = train_epoch(sess, train_model, train_data, net_config,
                                    epoch, log)
            time_end = time.time()
            log.info(
                'Epoch [%d] train acc = [%.4f], cost = [%.6f], time = %.2f' %
                (epoch, acc, cost, time_end - time_start))
            save_path = saver.save(sess,
                                   "%s/epoch-%d.ckpt" % (log_path, epoch))
            log.info("Saved to:%s" % save_path)
            #log_msg = test_epoch(sess, test_model, train_data, net_config, epoch)
            #log.info('Epoch [%d] test on test %s' % (epoch, log_msg))
            log_msg = test_epoch(sess, test_model, test_data, net_config,
                                 epoch)
            log.info('Epoch [%d] test on test %s' % (epoch, log_msg))