Exemplo n.º 1
0
def write_stg2_files(write_dir, samples_to_write, shard_size, batch_size,
                     devices_to_use, model_dir, data_dir, results_dir,
                     model_config, schedule_config, use_mixed_precision,
                     use_xla, show_steps):
    if not os.path.exists(write_dir):
        os.mkdir(write_dir)
    elif os.path.isdir(write_dir):
        if len(os.listdir(write_dir)) != 0:
            print(
                "The stage 2 data directory is not empty, so training will use the files in here."
            )
            return
    show_mode = 'DDIM'
    manager = ModelManager(devices_to_use, model_dir, data_dir, results_dir,
                           model_config, schedule_config, use_mixed_precision,
                           use_xla, show_mode, show_steps)
    num = manager._get_last_ckpt_num()
    manager.load_models(num)
    n_shards = samples_to_write // shard_size
    remainder = samples_to_write % shard_size
    h, w = manager.ema_model.h, manager.ema_model.w
    if remainder != 0:
        n_shards += 1
    n_total_ex = 0
    for i in range(n_shards):
        data_x = np.zeros((0, h, w, 3)).astype('float16')
        data_y = np.zeros((0, h, w, 3)).astype('uint8')
        if i == n_shards - 1 and remainder != 0:
            ss = remainder
        else:
            ss = shard_size
        shard_rem = ss % batch_size
        for j in range(ss // batch_size):
            inps, outs = manager.generate_samples(batch_size,
                                                  batch_size,
                                                  verbose=False)
            data_x = np.concatenate((data_x, inps))
            data_y = np.concatenate((data_y, outs))
        if shard_rem != 0:
            inps, outs = manager.generate_samples(shard_rem,
                                                  shard_rem,
                                                  verbose=False)
            data_x = np.concatenate((data_x, inps))
            data_y = np.concatenate((data_y, outs))

        assert data_x.shape[0] == ss and data_x.shape == data_y.shape
        x_savepath = os.path.join(write_dir, 'data_x_{}'.format(i))
        y_savepath = os.path.join(write_dir, 'data_y_{}'.format(i))
        np.save(x_savepath, data_x)
        np.save(y_savepath, data_y)
        n_total_ex += len(data_x)
        print("Finished writing {} examples to {}".format(
            n_total_ex, write_dir))
Exemplo n.º 2
0
        print("Training is complete.")

    elif args.option == 'eval':
        stg2 = not bool(args.no_stg2)
        if stg2:
            show_mode = 'ONE'
        print("SHOW MODE: {}".format(show_mode))
        from model_manager import ModelManager
        evaluator = ModelManager(devices_to_use, model_dir, data_dir, results_dir, model_config,
         schedule_config, use_mixed_precision, use_xla, show_mode, show_steps)
        
        num = evaluator._get_last_ckpt_num(stg2=stg2)
        print("Restoring model {}k".format(num))
        evaluator.load_models(num, stg2=stg2)
        

        n_ex = args.eval_examples

        if n_ex > 64:
            print("Evaluation is only supported for 64 or fewer examples. Reducing the number of examples to generate to 64...")
            n_ex = 64
        _, outputs = evaluator.generate_samples(n_ex, n_ex, verbose=False)

        from utils import save_samples
        save_path = args.figure_path
        save_samples(outputs, save_path)        
        print("Evaluation is complete. Samples have been saved to {}.".format(save_path))

    else:
        raise ValueError("When running from command line, the option argument should be either 'train' or 'eval'.")