def write_stg2_files(write_dir, samples_to_write, shard_size, batch_size, devices_to_use, model_dir, data_dir, results_dir, model_config, schedule_config, use_mixed_precision, use_xla, show_steps): if not os.path.exists(write_dir): os.mkdir(write_dir) elif os.path.isdir(write_dir): if len(os.listdir(write_dir)) != 0: print( "The stage 2 data directory is not empty, so training will use the files in here." ) return show_mode = 'DDIM' manager = ModelManager(devices_to_use, model_dir, data_dir, results_dir, model_config, schedule_config, use_mixed_precision, use_xla, show_mode, show_steps) num = manager._get_last_ckpt_num() manager.load_models(num) n_shards = samples_to_write // shard_size remainder = samples_to_write % shard_size h, w = manager.ema_model.h, manager.ema_model.w if remainder != 0: n_shards += 1 n_total_ex = 0 for i in range(n_shards): data_x = np.zeros((0, h, w, 3)).astype('float16') data_y = np.zeros((0, h, w, 3)).astype('uint8') if i == n_shards - 1 and remainder != 0: ss = remainder else: ss = shard_size shard_rem = ss % batch_size for j in range(ss // batch_size): inps, outs = manager.generate_samples(batch_size, batch_size, verbose=False) data_x = np.concatenate((data_x, inps)) data_y = np.concatenate((data_y, outs)) if shard_rem != 0: inps, outs = manager.generate_samples(shard_rem, shard_rem, verbose=False) data_x = np.concatenate((data_x, inps)) data_y = np.concatenate((data_y, outs)) assert data_x.shape[0] == ss and data_x.shape == data_y.shape x_savepath = os.path.join(write_dir, 'data_x_{}'.format(i)) y_savepath = os.path.join(write_dir, 'data_y_{}'.format(i)) np.save(x_savepath, data_x) np.save(y_savepath, data_y) n_total_ex += len(data_x) print("Finished writing {} examples to {}".format( n_total_ex, write_dir))
print("Training is complete.") elif args.option == 'eval': stg2 = not bool(args.no_stg2) if stg2: show_mode = 'ONE' print("SHOW MODE: {}".format(show_mode)) from model_manager import ModelManager evaluator = ModelManager(devices_to_use, model_dir, data_dir, results_dir, model_config, schedule_config, use_mixed_precision, use_xla, show_mode, show_steps) num = evaluator._get_last_ckpt_num(stg2=stg2) print("Restoring model {}k".format(num)) evaluator.load_models(num, stg2=stg2) n_ex = args.eval_examples if n_ex > 64: print("Evaluation is only supported for 64 or fewer examples. Reducing the number of examples to generate to 64...") n_ex = 64 _, outputs = evaluator.generate_samples(n_ex, n_ex, verbose=False) from utils import save_samples save_path = args.figure_path save_samples(outputs, save_path) print("Evaluation is complete. Samples have been saved to {}.".format(save_path)) else: raise ValueError("When running from command line, the option argument should be either 'train' or 'eval'.")