示例#1
0
文件: train_vae.py 项目: yyht/igr
def start_all_logging_instruments(hyper, results_path, test_images):
    writer = tf.summary.create_file_writer(logdir=results_path)
    logger = setup_logger(log_file_name=append_timestamp_to_file(file_name=results_path + '/loss.log',
                                                                 termination='.log'),
                          logger_name=append_timestamp_to_file('logger', termination=''))
    log_all_hyperparameters(hyper=hyper, logger=logger)
    plot_originals(test_images=test_images, results_path=results_path)
    return writer, logger
示例#2
0
def start_all_logging_instruments(hyper, test_images):
    results_path = determine_path_to_save_results(
        model_type=hyper['model_type'], dataset_name=hyper['dataset_name'])
    if not os.path.exists(results_path):
        os.mkdir(results_path)
    logger = setup_logger(log_file_name=append_timestamp_to_file(
        file_name=results_path + '/loss.log', termination='.log'),
                          logger_name=append_timestamp_to_file('logger',
                                                               termination=''))
    log_all_hyperparameters(hyper=hyper, logger=logger)
    plot_originals(test_images=test_images, results_path=results_path)
    return logger, results_path
示例#3
0
文件: train_vae.py 项目: yyht/igr
def train_vae_model(vae_opt, model, writer, hyper, train_dataset, test_dataset, logger, results_path,
                    test_images, monitor_gradients=False):

    (iteration_counter, results_file, hyper_file,
     cont_c_linspace, disc_c_linspace) = initialize_vae_variables(results_path=results_path, hyper=hyper)
    grad_monitor_dict = {}
    grad_norm = tf.constant(0., dtype=tf.float32)
    with open(file=hyper_file, mode='wb') as f:
        pickle.dump(obj=hyper, file=f)

    with writer.as_default():
        initial_time = time.time()
        for epoch in range(1, hyper['epochs'] + 1):
            t0 = time.time()
            train_loss_mean = tf.keras.metrics.Mean()
            for x_train in train_dataset.take(hyper['iter_per_epoch']):
                output = vae_opt.compute_gradients(x=x_train)
                gradients, loss, log_px_z, kl, kl_n, kl_d = output
                vae_opt.apply_gradients(gradients=gradients)
                iteration_counter += 1
                train_loss_mean(loss)
                append_train_summaries(tracking_losses=output, iteration_counter=iteration_counter)
                update_regularization_channels(vae_opt=vae_opt, iteration_counter=iteration_counter,
                                               disc_c_linspace=disc_c_linspace,
                                               cont_c_linspace=cont_c_linspace)

            if monitor_gradients:
                grad_norm = vae_opt.monitor_parameter_gradients_at_psi(x=x_train)
                grad_monitor_dict.update({iteration_counter: grad_norm.numpy()})
                with open(file='./Results/gradients_' + str(epoch) + '.pkl', mode='wb') as f:
                    pickle.dump(obj=grad_monitor_dict, file=f)

            t1 = time.time()

            evaluate_progress_in_test_set(epoch=epoch, test_dataset=test_dataset, vae_opt=vae_opt,
                                          hyper=hyper, logger=logger, iteration_counter=iteration_counter,
                                          train_loss_mean=train_loss_mean, time_taken=t1 - t0,
                                          grad_norm=grad_norm)

            if epoch % 10 == 0:
                model.save_weights(filepath=append_timestamp_to_file(results_file, '.h5'))
                plot_reconstructions_samples_and_traversals(model=model, hyper=hyper, epoch=epoch,
                                                            results_path=results_path,
                                                            test_images=test_images, vae_opt=vae_opt)
            writer.flush()

        final_time = time.time()
        logger.info(f'Total training time {final_time - initial_time: 4.1f} secs')
        logger.info(f'Final temp {vae_opt.temp.numpy(): 4.5f}')
        results_file = append_timestamp_to_file(file_name=results_file, termination='.h5')
        model.save_weights(filepath=results_file)
示例#4
0
def train_sop(sop_optimizer, hyper, train_dataset, test_dataset, logger):
    initial_time = time.time()
    iteration_counter = 0
    for epoch in range(1, hyper['epochs'] + 1):
        train_mean_loss = tf.keras.metrics.Mean()
        tic = time.time()
        for x_train in train_dataset.take(hyper['iter_per_epoch']):
            x_train_lower = x_train[:, 14:, :, :]
            x_train_upper = x_train[:, :14, :, :]

            gradients, loss = sop_optimizer.compute_gradients_and_loss(
                x_upper=x_train_upper, x_lower=x_train_lower)
            sop_optimizer.apply_gradients(gradients=gradients)
            train_mean_loss(loss)
            iteration_counter += 1
        toc = time.time()
        evaluate_progress_in_test_set(epoch=epoch,
                                      sop_optimizer=sop_optimizer,
                                      test_dataset=test_dataset,
                                      logger=logger,
                                      hyper=hyper,
                                      iteration_counter=iteration_counter,
                                      time_taken=toc - tic,
                                      train_mean_loss=train_mean_loss)

    final_time = time.time()
    logger.info(f'Total training time {final_time - initial_time: 4.1f} secs')
    results_file = f'./Log/model_weights_{sop_optimizer.model.model_type}.h5'
    results_file = append_timestamp_to_file(file_name=results_file,
                                            termination='.h5')
    sop_optimizer.model.save_weights(filepath=results_file)
示例#5
0
def save_final_results(nets, logger, results_file, initial_time, temp):
    final_time = time.time()
    logger.info(f'Total training time {final_time - initial_time: 4.1f} secs')
    logger.info(f'Final temp {temp: 4.5f}')
    results_file = append_timestamp_to_file(file_name=results_file,
                                            termination='.h5')
    nets.save_weights(filepath=results_file)
示例#6
0
def run_sop(hyper, results_path, data):
    train_dataset, test_dataset = data

    sop_optimizer = setup_sop_optimizer(hyper=hyper)

    logger = setup_logger(log_file_name=append_timestamp_to_file(
        file_name=results_path + f'/loss_{sop_optimizer.model.model_type}.log',
        termination='.log'))
    log_all_hyperparameters(hyper=hyper, logger=logger)
    train_sop(sop_optimizer=sop_optimizer,
              hyper=hyper,
              train_dataset=train_dataset,
              test_dataset=test_dataset,
              logger=logger)
示例#7
0
def train_sop(sop_optimizer, hyper, train_dataset, test_dataset, logger):
    initial_time = time.time()
    iteration_counter = 0
    sample_size = hyper['sample_size']
    for epoch in range(1, hyper['epochs'] + 1):
        train_mean_loss = tf.keras.metrics.Mean()
        tic = time.time()
        for x_train in train_dataset.take(hyper['iter_per_epoch']):
            x_train_lower = x_train[:, 14:, :, :]
            x_train_upper = x_train[:, :14, :, :]

            gradients, loss = sop_optimizer.compute_gradients_and_loss(
                x_upper=x_train_upper,
                x_lower=x_train_lower,
                sample_size=sample_size)
            sop_optimizer.apply_gradients(gradients=gradients)
            update_learning_rate(sop_optimizer, epoch, iteration_counter,
                                 hyper)
            train_mean_loss(loss)
            iteration_counter += 1
        time_taken = time.time() - tic
        if epoch % hyper['check_every'] == 0 or epoch == hyper['epochs']:
            evaluate_progress(epoch=epoch,
                              sop_optimizer=sop_optimizer,
                              test_dataset=test_dataset,
                              train_dataset=train_dataset,
                              logger=logger,
                              hyper=hyper,
                              train_mean_loss=train_mean_loss,
                              iteration_counter=iteration_counter,
                              tic=tic)
        else:
            logger.info(
                f'Epoch {epoch:4d} || Test_Recon 0.00000e+00 || '
                f'Train_Recon {train_mean_loss.result().numpy():2.3e} || '
                f'Temp {sop_optimizer.model.temp:2.1e} || '
                f'{sop_optimizer.model.model_type} || '
                f'{sop_optimizer.optimizer.learning_rate.numpy():1.1e} || '
                f'Time {time_taken:4.1f} sec')

    final_time = time.time()
    logger.info(f'Total training time {final_time - initial_time: 4.1f} secs')
    results_file = f'./Log/model_weights_{sop_optimizer.model.model_type}.h5'
    results_file = append_timestamp_to_file(file_name=results_file,
                                            termination='.h5')
    sop_optimizer.model.save_weights(filepath=results_file)
示例#8
0
def run_sop(hyper, results_path):
    tf.random.set_seed(seed=hyper['seed'])
    data = load_mnist_sop_data(batch_n=hyper['batch_size'])
    train_dataset, test_dataset = data

    sop_optimizer = setup_sop_optimizer(hyper=hyper)

    model_type = sop_optimizer.model.model_type
    log_path = results_path + f'/loss_{model_type}.log'
    logger = setup_logger(log_file_name=append_timestamp_to_file(
        file_name=log_path, termination='.log'),
                          logger_name=model_type + str(hyper['seed']))
    log_all_hyperparameters(hyper=hyper, logger=logger)
    save_hyper(hyper)
    train_sop(sop_optimizer=sop_optimizer,
              hyper=hyper,
              train_dataset=train_dataset,
              test_dataset=test_dataset,
              logger=logger)
示例#9
0
def determine_path_to_save_results(model_type, dataset_name):
    results_path = './Log/' + dataset_name + '_' + model_type + append_timestamp_to_file(
        '', termination='')
    return results_path
示例#10
0
def save_intermediate_results(epoch, vae_opt, test_images, hyper, results_file,
                              results_path):
    if epoch % hyper['save_every'] == 0:
        vae_opt.nets.save_weights(
            filepath=append_timestamp_to_file(results_file, '.h5'))