def stylize(contents_path, styles_path, output_dir, encoder_path, model_path, style_ratio=0.6, repeat_pipeline=1, autoencoder_levels=None): if isinstance(contents_path, str): contents_path = [contents_path] if isinstance(styles_path, str): styles_path = [styles_path] style_ratio = np.clip(style_ratio, 0, 1) with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content_input') style = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='style_input') stn = StyleTransferNet(encoder_path, autoencoder_levels) output_image = stn.transform(content, style, style_ratio, repeat_pipeline) sess.run(tf.global_variables_initializer()) # restore the trained model and run the style transferring saver = tf.train.Saver(var_list=tf.trainable_variables()) saver.restore(sess, model_path) outputs = [] for content_path in contents_path: content_img = get_images(content_path) for style_path in styles_path: style_img = get_images(style_path) result = sess.run(output_image, feed_dict={ content: content_img, style: style_img }) outputs.append(result[0]) save_images(outputs, contents_path, styles_path, output_dir) return outputs
def stylize(contents_path, styles_path, output_dir, encoder_path, model_path, resize_height=None, resize_width=None, suffix=None): if isinstance(contents_path, str): contents_path = [contents_path] if isinstance(styles_path, str): styles_path = [styles_path] with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content') style = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='style') stn = StyleTransferNet(encoder_path) output_image = stn.transform(content, style) sess.run(tf.global_variables_initializer()) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) outputs = [] for content_path in contents_path: content_img = get_images(content_path, height=resize_height, width=resize_width) for style_path in styles_path: style_img = get_images(style_path) result = sess.run(output_image, feed_dict={ content: content_img, style: style_img }) outputs.append(result[0]) save_images(outputs, contents_path, styles_path, output_dir, suffix=suffix) return outputs
def _handler1(content_path, style_path, encoder_path, model_path, resize_height=None, resize_width=None, output_path=None, prefix=None, suffix=None): # get the actual image data, output shape: # (num_images, height, width, color_channels) content_img = get_images(content_path, resize_height, resize_width) style_img = get_images(style_path) with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content = tf.placeholder(tf.float32, shape=content_img.shape, name='content') style = tf.placeholder(tf.float32, shape=style_img.shape, name='style') stn = StyleTransferNet(encoder_path) output_image = stn.transform(content, style) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) output = sess.run(output_image, feed_dict={ content: content_img, style: style_img }) if output_path is not None: save_images(content_path, output, output_path, prefix=prefix, suffix=suffix) return output
def _handler2(content_path, style_path, encoder_path, model_path, output_path=None, prefix=None, suffix=None): style_img = get_images(style_path) with tf.Graph().as_default(), tf.Session() as sess: # build the dataflow graph content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content') style = tf.placeholder(tf.float32, shape=style_img.shape, name='style') stn = StyleTransferNet(encoder_path) output_image = stn.transform(content, style) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) output = [] for path in content_path: content_img = get_images(path) result = sess.run(output_image, feed_dict={ content: content_img, style: style_img }) output.append(result[0]) if output_path is not None: save_images(content_path, output, output_path, prefix=prefix, suffix=suffix) return output
def train(ssim_weight, original_imgs_path_name, source_a_imgs_path, source_b_imgs_path_name, encoder_path, save_path, model_pre_path, debug=False, logging_period=100): if debug: from datetime import datetime start_time = datetime.now() # num_imgs = len(source_a_imgs_path) num_imgs = 10000 source_a_imgs_path = source_a_imgs_path[:num_imgs] mod = num_imgs % BATCH_SIZE print('Train images number %d.\n' % num_imgs) print('Train images samples %s.\n' % str(num_imgs / BATCH_SIZE)) if mod > 0: print('Train set has been trimmed %d samples...\n' % mod) source_a_imgs_path = source_a_imgs_path[:-mod] # get the traing image shape HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS) HEIGHT_OR, WIDTH_OR, CHANNELS_OR = TRAINING_IMAGE_SHAPE_OR INPUT_SHAPE_OR = (BATCH_SIZE, HEIGHT_OR, WIDTH_OR, CHANNELS_OR) # create the graph with tf.Graph().as_default(), tf.Session() as sess: original = tf.placeholder(tf.float32, shape=INPUT_SHAPE_OR, name='original') source_a = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='source_a') source_b = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='source_b') print('source:', source_a.shape) # create the style transfer net stn = StyleTransferNet(encoder_path, model_pre_path) # pass content and style to the stn, getting the generated_img, fused image generated_img = stn.transform(source_a, source_b) # # get the target feature maps which is the output of AdaIN # target_features = stn.target_features pixel_loss = tf.reduce_sum( tf.reduce_mean(tf.square(original - generated_img), axis=[1, 2])) pixel_loss = pixel_loss / (HEIGHT * WIDTH) # compute the SSIM loss ssim_loss = 1 - SSIM.tf_ssim(original, generated_img) # compute the total loss loss = pixel_loss + ssim_weight * ssim_loss # Training step train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) sess.run(tf.global_variables_initializer()) # saver = tf.train.Saver() saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) # ** Start Training ** step = 0 count_loss = 0 n_batches = int(len(source_a_imgs_path) // BATCH_SIZE) if debug: elapsed_time = datetime.now() - start_time print( '\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) print('Now begin to train the model...\n') start_time = datetime.now() Loss_all = [i for i in range(EPOCHS * n_batches)] for epoch in range(EPOCHS): np.random.shuffle(source_a_imgs_path) for batch in range(n_batches): # retrive a batch of content and style images source_a_path = source_a_imgs_path[batch * BATCH_SIZE:( batch * BATCH_SIZE + BATCH_SIZE)] source_a_str = source_a_path[0] name_f = source_a_str.find('\\') source_image_name = source_a_str[name_f + 1:] source_image_name_comm = source_image_name[2:] source_b_path = [source_b_imgs_path_name + source_image_name] original_path = [ original_imgs_path_name + source_image_name_comm ] original_batch = get_train_images(original_path, crop_height=HEIGHT, crop_width=WIDTH, flag=False) source_a_batch = get_train_images(source_a_path, crop_height=HEIGHT, crop_width=WIDTH) source_b_batch = get_train_images(source_b_path, crop_height=HEIGHT, crop_width=WIDTH) original_batch = original_batch.reshape([1, 256, 256, 1]) # run the training step sess.run(train_op, feed_dict={ original: original_batch, source_a: source_a_batch, source_b: source_b_batch }) step += 1 # if step % 1000 == 0: # saver.save(sess, save_path, global_step=step) if debug: is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) if is_last_step or step % logging_period == 0: elapsed_time = datetime.now() - start_time _pixel_loss, _ssim_loss, _loss = sess.run( [pixel_loss, ssim_loss, loss], feed_dict={ original: original_batch, source_a: source_a_batch, source_b: source_b_batch }) Loss_all[count_loss] = _loss count_loss += 1 print( 'step: %d, total loss: %.3f, elapsed time: %s' % (step, _loss, elapsed_time)) print('pixel loss: %.3f' % (_pixel_loss)) print('ssim loss : %.3f\n' % (_ssim_loss)) # print('pca or shape : ', _pca_or.shape) # print('pca gen shape : ', _pca_gen.shape) # ** Done Training & Save the model ** saver.save(sess, save_path) iter_index = [i for i in range(count_loss)] plt.plot(iter_index, Loss_all[:count_loss]) plt.show() if debug: elapsed_time = datetime.now() - start_time print('Done training! Elapsed time: %s' % elapsed_time) print('Model is saved to: %s' % save_path)
def train(style_weight, content_imgs_path, style_imgs_path, encoder_path, model_save_path, debug=False, logging_period=100): if debug: from datetime import datetime start_time = datetime.now() # guarantee the size of content and style images to be a multiple of BATCH_SIZE num_imgs = min(len(content_imgs_path), len(style_imgs_path)) content_imgs_path = content_imgs_path[:num_imgs] style_imgs_path = style_imgs_path[:num_imgs] mod = num_imgs % BATCH_SIZE if mod > 0: print('Train set has been trimmed %d samples...\n' % mod) content_imgs_path = content_imgs_path[:-mod] style_imgs_path = style_imgs_path[:-mod] # get the traing image shape HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS) # create the graph with tf.Graph().as_default(), tf.Session() as sess: content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content') style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') # create the style transfer net stn = StyleTransferNet(encoder_path) # pass content and style to the stn, getting the generated_img generated_img = stn.transform(content, style) # get the target feature maps which is the output of AdaIN target_features = stn.target_features # pass the generated_img to the encoder, and use the output compute loss generated_img = tf.reverse(generated_img, axis=[-1]) # switch RGB to BGR generated_img = stn.encoder.preprocess(generated_img) # preprocess image enc_gen, enc_gen_layers = stn.encoder.encode(generated_img) # compute the content loss # content_loss = fft_loss(enc_gen, target_features) content_loss = tf.reduce_sum(tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2])) # compute the style loss style_layer_loss = [] for layer in STYLE_LAYERS: enc_style_feat = stn.encoded_style_layers[layer] enc_gen_feat = enc_gen_layers[layer] meanS, varS = tf.nn.moments(enc_style_feat, [1, 2]) meanG, varG = tf.nn.moments(enc_gen_feat, [1, 2]) # fft_pred = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_gen_feat, (0, 3, 1, 2)))) # fft_true = tf.abs(tf.spectral.rfft2d(tf.transpose(enc_style_feat, (0, 3, 1, 2)))) # meanS, varS = tf.nn.moments(fft_pred, [2, 3]) # meanG, varG = tf.nn.moments(fft_true, [2, 3]) sigmaS = tf.sqrt(varS + EPSILON) sigmaG = tf.sqrt(varG + EPSILON) l2_mean = tf.reduce_sum(tf.square(meanG - meanS)) l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS)) style_layer_loss.append(l2_mean + l2_sigma) style_loss = tf.reduce_sum(style_layer_loss) # compute the total loss loss = content_loss + style_weight * style_loss # Training step global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.inverse_time_decay(LEARNING_RATE, global_step, DECAY_STEPS, LR_DECAY_RATE) train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step) sess.run(tf.global_variables_initializer()) # saver saver = tf.train.Saver(max_to_keep=10) ###### Start Training ###### step = 0 n_batches = int(len(content_imgs_path) // BATCH_SIZE) if debug: elapsed_time = datetime.now() - start_time start_time = datetime.now() print('\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) print('Now begin to train the model...\n') try: for epoch in range(EPOCHS): np.random.shuffle(content_imgs_path) np.random.shuffle(style_imgs_path) for batch in range(n_batches): # retrive a batch of content and style images content_batch_path = content_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)] style_batch_path = style_imgs_path[batch*BATCH_SIZE:(batch*BATCH_SIZE + BATCH_SIZE)] content_batch = get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH) style_batch = get_train_images(style_batch_path, crop_height=HEIGHT, crop_width=WIDTH) # run the training step sess.run(train_op, feed_dict={content: content_batch, style: style_batch}) step += 1 if step % 1000 == 0: saver.save(sess, model_save_path, global_step=step, write_meta_graph=False) if debug: is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) if is_last_step or step == 1 or step % logging_period == 0: elapsed_time = datetime.now() - start_time _content_loss, _style_loss, _loss = sess.run([content_loss, style_loss, loss], feed_dict={content: content_batch, style: style_batch}) print('step: %d, total loss: %.3f, elapsed time: %s' % (step, _loss, elapsed_time)) print('content loss: %.3f' % (_content_loss)) print('style loss : %.3f, weighted style loss: %.3f\n' % (_style_loss, style_weight * _style_loss)) except Exception as ex: saver.save(sess, model_save_path, global_step=step) print('\nSomething wrong happens! Current model is saved to <%s>' % tmp_save_path) print('Error message: %s' % str(ex)) ###### Done Training & Save the model ###### saver.save(sess, model_save_path) if debug: elapsed_time = datetime.now() - start_time print('Done training! Elapsed time: %s' % elapsed_time) print('Model is saved to: %s' % model_save_path)
def train(style_weight, content_imgs_path, style_imgs_path, encoder_path, save_path, debug=False, logging_period=100): if debug: from datetime import datetime start_time = datetime.now() # guarantee the size of content and style images to be a multiple of BATCH_SIZE num_imgs = min(len(content_imgs_path), len(style_imgs_path)) content_imgs_path = content_imgs_path[:num_imgs] style_imgs_path = style_imgs_path[:num_imgs] mod = num_imgs % BATCH_SIZE if mod > 0: print('Train set has been trimmed %d samples...\n' % mod) content_imgs_path = content_imgs_path[:-mod] style_imgs_path = style_imgs_path[:-mod] # get the traing image shape HEIGHT, WIDTH, CHANNELS = TRAINING_IMAGE_SHAPE INPUT_SHAPE = (BATCH_SIZE, HEIGHT, WIDTH, CHANNELS) # create the graph with tf.Graph().as_default(), tf.Session() as sess: content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content') style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') # create the style transfer net stn = StyleTransferNet(encoder_path) # pass content and style to the stn, getting the generated_img generated_img = stn.transform(content, style) # get the target feature maps which is the output of AdaIN target_features = stn.target_features # pass the generated_img to the encoder, and use the output compute loss generated_img = tf.reverse(generated_img, axis=[-1]) # switch RGB to BGR generated_img = stn.encoder.preprocess( generated_img) # preprocess image enc_gen, enc_gen_layers = stn.encoder.encode(generated_img) # compute the content loss content_loss = tf.reduce_sum( tf.reduce_mean(tf.square(enc_gen - target_features), axis=[1, 2])) # compute the style loss style_layer_loss = [] for layer in STYLE_LAYERS: enc_style_feat = stn.encoded_style_layers[layer] enc_gen_feat = enc_gen_layers[layer] meanS, varS = tf.nn.moments(enc_style_feat, [1, 2]) meanG, varG = tf.nn.moments(enc_gen_feat, [1, 2]) sigmaS = tf.sqrt(varS + EPSILON) sigmaG = tf.sqrt(varG + EPSILON) l2_mean = tf.reduce_sum(tf.square(meanG - meanS)) l2_sigma = tf.reduce_sum(tf.square(sigmaG - sigmaS)) style_layer_loss.append(l2_mean + l2_sigma) style_loss = tf.reduce_sum(style_layer_loss) # compute the total loss loss = content_loss + style_weight * style_loss # save loss to tensorboard tf.summary.scalar('content_loss', content_loss) tf.summary.scalar('style_loss', style_loss) tf.summary.scalar('total_loss', loss) merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter('runs', sess.graph) # Training step train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) sess.run(tf.global_variables_initializer()) # saver = tf.train.Saver() saver = tf.train.Saver(keep_checkpoint_every_n_hours=1) """Start Training""" step = 0 n_batches = int(len(content_imgs_path) // BATCH_SIZE) if debug: elapsed_time = datetime.now() - start_time print( '\nElapsed time for preprocessing before actually train the model: %s' % elapsed_time) print('Now begin to train the model...\n') start_time = datetime.now() for epoch in range(EPOCHS): np.random.shuffle(content_imgs_path) np.random.shuffle(style_imgs_path) for batch in range(n_batches): print("current step: {}/{}".format(batch, n_batches)) print( "select content images: {}~{}/{}, style images: {}~{}/{}". format(batch * BATCH_SIZE, batch * BATCH_SIZE + BATCH_SIZE, len(content_imgs_path), batch * BATCH_SIZE, batch * BATCH_SIZE + BATCH_SIZE, len(style_imgs_path))) # retrive a batch of content and style images content_batch_path = content_imgs_path[batch * BATCH_SIZE:( batch * BATCH_SIZE + BATCH_SIZE)] style_batch_path = style_imgs_path[batch * BATCH_SIZE:( batch * BATCH_SIZE + BATCH_SIZE)] content_batch = get_train_images(content_batch_path, crop_height=HEIGHT, crop_width=WIDTH) style_batch = get_train_images(style_batch_path, crop_height=HEIGHT, crop_width=WIDTH) # run the training step print("start training step") c_loss, s_loss, t_loss, summary, _ = sess.run( [content_loss, style_loss, loss, merged, train_op], feed_dict={ content: content_batch, style: style_batch }) train_writer.add_summary(summary, batch + epoch * n_batches) print(f'content: {c_loss}, style: {s_loss}, total: {t_loss}') print("stop trainnig step") step += 1 if step % 1000 == 0: saver.save(sess, save_path, global_step=step) if debug: is_last_step = (epoch == EPOCHS - 1) and (batch == n_batches - 1) if is_last_step or step % logging_period == 0: elapsed_time = datetime.now() - start_time _content_loss, _style_loss, _loss = sess.run( [content_loss, style_loss, loss], feed_dict={ content: content_batch, style: style_batch }) print( 'step: %d, total loss: %.3f, elapsed time: %s' % (step, _loss, elapsed_time)) print('content loss: %.3f' % (_content_loss)) print( 'style loss : %.3f, weighted style loss: %.3f\n' % (_style_loss, style_weight * _style_loss)) """ Done training. Save the model.""" saver.save(sess, save_path) if debug: elapsed_time = datetime.now() - start_time print('Done training! Elapsed time: %s' % elapsed_time) print('Model is saved to: %s' % save_path)
tf_config = tf.ConfigProto() #tf_config.gpu_options.per_process_gpu_memory_fraction=0.5 tf_config.gpu_options.allow_growth = True with tf.Graph().as_default(), tf.Session(config=tf_config) as sess: content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content') style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') label = tf.placeholder(tf.int64, shape=None, name="label") #style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') # create the style transfer net stn = StyleTransferNet(encoder_path) # pass content and style to the stn, getting the gen_img # decoded image from normal one, adversarial image, and input dec_img, adv_img = stn.transform(content, style) img = content print(adv_img.shape.as_list()) stn_vars = [] # get the target feature maps which is the output of AdaIN target_features = stn.target_features # pass the gen_img to the encoder, and use the output compute loss enc_gen_adv, enc_gen_layers_adv = stn.encode(adv_img) enc_gen, enc_gen_layers = stn.encode(dec_img) l2_embed = normalize(enc_gen)[0] - normalize(stn.norm_features)[0] l2_embed = tf.reduce_mean( tf.sqrt(tf.reduce_sum((l2_embed * l2_embed), axis=[1, 2, 3])))
# create the graph tf_config = tf.ConfigProto() #tf_config.gpu_options.per_process_gpu_memory_fraction=0.5 tf_config.gpu_options.allow_growth = True with tf.Graph().as_default(), tf.Session(config=tf_config) as sess: content = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='content') style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') label = tf.placeholder(tf.int64, shape=None, name="label") #style = tf.placeholder(tf.float32, shape=INPUT_SHAPE, name='style') # create the style transfer net stn = StyleTransferNet(encoder_path) # pass content and style to the stn, getting the generated_img generated_img, generated_img_adv = stn.transform(content, style) adv_img = generated_img_adv img = generated_img print(adv_img.shape.as_list()) stn_vars = [] #get_scope_var("transform") # get the target feature maps which is the output of AdaIN target_features = stn.target_features # pass the generated_img to the encoder, and use the output compute loss generated_img_adv = tf.reverse(generated_img_adv, axis=[-1]) # switch RGB to BGR adv_img_bgr = generated_img_adv generated_img_adv = stn.encoder.preprocess( generated_img_adv) # preprocess image enc_gen_adv, enc_gen_layers_adv = stn.encoder.encode(generated_img_adv)
def stylize(contents_path, styles_path, output_dir, encoder_path, model_path, resize_height=None, resize_width=None, suffix=None): if isinstance(contents_path, str): contents_path = [contents_path] if isinstance(styles_path, str): styles_path = [styles_path] with tf.Graph().as_default(), tf.Session(config=tf.ConfigProto( log_device_placement=True)) as sess: # 这段代码只是用来查看 tf 的运行设备信息,没啥其他用途 #a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a') #b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b') #c = tf.matmul(a, b) #print(sess.run(c)) # build the dataflow graph content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content') style = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='style') stn = StyleTransferNet(encoder_path) output_image = stn.transform(content, style) print(sess.run(tf.global_variables_initializer())) # restore the trained model and run the style transferring saver = tf.train.Saver() saver.restore(sess, model_path) outputs = [] for content_path in contents_path: content_img = get_images(content_path, height=resize_height, width=resize_width) for style_path in styles_path: style_img = get_images(style_path) print('--> processing %s with style %s' % (content_path, style_path)) result = sess.run(output_image, feed_dict={ content: content_img, style: style_img }) outputs.append(result[0]) save_image(result[0], content_path, style_path, output_dir, suffix=suffix) #save_images(outputs, contents_path, styles_path, output_dir, suffix=suffix) return outputs
OUTPUTS_DIR = 'outputs' ENCODER_WEIGHTS_PATH = 'vgg19_normalised.npz' MODEL_SAVE_PATH = 'models/style_weight_2e0.ckpt' content_img = imageio.imread('./images/content/karya.jpg') content_img = np.expand_dims(content_img, axis=0) style_img = imageio.imread('./images/style/mosaic.jpg') style_img = np.expand_dims(style_img, axis=0) sess = tf.InteractiveSession() content = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='content') style = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='style') model = StyleTransferNet(ENCODER_WEIGHTS_PATH) output = model.transform(content, style) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(sess, MODEL_SAVE_PATH) output_img = sess.run(output, feed_dict={ content: content_img, style: style_img }) output_img = output_img[0] print('output_img shape', output_img.shape) print('output img type', type(output_img))