Ejemplo n.º 1
0
def train(args):

    audio_fraction = args.audio_fraction
    image_fraction = args.image_fraction
    preprocess = args.preprocess

    encode_decoder_model = Encode_Decoder_Model(preprocess)
    encode_decoder_model.summary()
    encode_decoder_model.compile(
                optimizer=Adam(lr=0.0001),
                loss={'decode_image_model': 'mse',
                      'decode_audio_model': 'mse'},
                loss_weights={'decode_image_model': image_fraction,
                              'decode_audio_model': audio_fraction})

    check_point = ModelCheckpoint('../models/model_' + preprocess
                                  + str(image_fraction) + '_'
                                  + str(audio_fraction) + '.hdf5',
                                  verbose=True, save_best_only=True)
    early_stop = EarlyStopping(patience=5, verbose=True)
    train_generator = image_generator(PATH_IMAGE_TRAIN, PATH_AUDIO_TRAIN,
                                      mode=preprocess, batch_size=8)
    valid_generator = image_generator(PATH_IMAGE_VAL, PATH_AUDIO_VAL,
                                      mode=preprocess)
    encode_decoder_model.fit_generator(train_generator,
                                       steps_per_epoch=500, epochs=100,
                                       validation_data=valid_generator,
                                       validation_steps=200,
                                       callbacks=[check_point, early_stop])
Ejemplo n.º 2
0
    def sample_model(self, path, epoch, idx, count):
        """
        保存样例
        :param path: 路径
        :param epoch: epoch数
        :param idx: batch数
        :param count: 总计数
        :return:
        """
        realAtest = image_generator(os.path.join(self.args.datadir, 'testA'),
                                    4,
                                    resize=(self.args.imsize,
                                            self.args.imsize),
                                    value_mode='tanh')
        realBtest = image_generator(os.path.join(self.args.datadir, 'testB'),
                                    4,
                                    resize=(self.args.imsize,
                                            self.args.imsize),
                                    value_mode='tanh')
        real_A = next(realAtest)
        real_B = next(realBtest)
        fake_A, fake_B, cyc_A, cyc_B = self.sess.run(
            [self.fakeA, self.fakeB, self.cycA, self.cycB],
            feed_dict={
                self.realA_ph: real_A,
                self.realB_ph: real_B
            })
        img = visual_grid(
            np.concatenate([real_A, fake_B, cyc_A, real_B, fake_A, cyc_B],
                           axis=0),
            (6, 4),
        )

        if self.args.sample_to_file:
            imsave(os.path.join(path, '%04d-%04d.png' % (epoch, idx)), img,
                   'png')

        img = np.array([img])

        s_img = self.sess.run(self.img_op, feed_dict={self.p_img: img})
        self.writer.add_summary(s_img, count)
Ejemplo n.º 3
0
        optimizer = tf.train.AdamOptimizer()
        with tf.control_dependencies(extra_ops):
            train_op = optimizer.minimize(total_loss, var_list=var_list)

    opt_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='optimizer')
    model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope='transformation_net')
    init = tf.variables_initializer(var_list=opt_vars + model_vars)
    sess.run(init)

    print('Training')
    shutil.rmtree('logs', ignore_errors=True)
    summaries = tf.summary.merge_all()
    img_gen = utils.image_generator(args.train,
                                    batch_size=batch_size,
                                    target_shape=(img_width, img_height))
    writer = tf.summary.FileWriter('logs', sess.graph)
    global_step = 0
    for epoch in range(epochs):
        step = 0
        while step * batch_size < num_images:
            images = next(img_gen)
            _, step_loss, summary = sess.run(
                [train_op, total_loss, summaries],
                feed_dict={
                    transformation_model.input: images,
                    tf.keras.backend.learning_phase(): 1
                })
            if step % 500 == 0:
                print('Epoch {}, step {}:   loss: {}'.format(
Ejemplo n.º 4
0
    model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                   scope='transformation_net')
    init = tf.variables_initializer(var_list=opt_vars + model_vars)
    saver2 = tf.train.Saver(var_list=var3)
    sess.run(init)
    saver2.restore(sess,
                   args.weights + '/' + 'model.ckpt')  #updated the checkpoint

    #transformation_model.load_weights(args.weights,var_list=var3)

    print('Training')
    shutil.rmtree('logs', ignore_errors=True)
    summaries = tf.summary.merge_all()

    img_gen = utils.image_generator(args.train,
                                    batch_size=batch_size,
                                    target_shape=(img_width, img_height))
    style_gen = utils.image_generator(
        images_path='/home/dl/Desktop/neural-style-transfer-master/sty',
        batch_size=batch_size,
        target_shape=(img_width, img_height))
    writer = tf.summary.FileWriter('logs', sess.graph)
    global_step = 0

    for epoch in range(1000):
        step = 0
        while step * batch_size < num_images:
            images = next(img_gen)
            st_images = next(style_gen)
            print(step)
            _, step_loss, summary = sess.run(
Ejemplo n.º 5
0
print("Loading data...")

# image, gender, age, _, image_size, _ = load_data(input_path)
# image, gender, age, _, image_size, _ = load_data(input_path)
# image, gender, age, _, image_size, _ = load_adience_data()

# image, gender, age, _, image_size, _ = load_data(input_path)

# X_data = image
# y_data_g = np_utils.to_categorical(gender, 2)
# y_data_a = np_utils.to_categorical(age, 101)

train_path = '/run/user/1000/gvfs/smb-share:server=192.168.43.124,share=project_phase2/data/train/'
test_path = '/run/user/1000/gvfs/smb-share:server=192.168.43.124,share=project_phase2/data/test/'

traindata = image_generator(train_path, batch_size=52)
testdata = image_generator(test_path, batch_size=52)
# y_data_a = np_utils.to_categorical(age, 81)
image_size = 62
model = WideResNet(image_size, depth=depth, k=k)()
opt = get_optimizer(opt_name, lr)
model.compile(optimizer=opt,
              loss=["categorical_crossentropy", "categorical_crossentropy"],
              metrics=['accuracy'])

print("Model summary...")
model.count_params()
model.summary()

callbacks = [
    LearningRateScheduler(schedule=Schedule(nb_epochs, lr)),
Ejemplo n.º 6
0
    def train(self):
        """
        训练
        :return:
        """
        self._create_opts()

        init_op = tf.global_variables_initializer()
        self.sess.run(init_op)
        self.writer = tf.summary.FileWriter(self.args.logdir, self.sess.graph)

        start_time = time.time()

        if self.load(self.args.checkpointdir):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        realAtrain = image_generator(os.path.join(self.args.datadir, 'trainA'),
                                     1,
                                     resize=(self.args.imsize,
                                             self.args.imsize),
                                     value_mode='tanh')
        realBtrain = image_generator(os.path.join(self.args.datadir, 'trainB'),
                                     1,
                                     resize=(self.args.imsize,
                                             self.args.imsize),
                                     value_mode='tanh')

        fakeApool = DataPool(self.args.pool_size)
        fakeBpool = DataPool(self.args.pool_size)
        realApool = DataPool(self.args.pool_size)
        realBpool = DataPool(self.args.pool_size)

        counter = 0

        for epoch in range(self.args.nb_epoch):

            for idx in range(0, self.args.nb_batch):

                realAimage = next(realAtrain)
                realBimage = next(realBtrain)

                # Forward G network
                fake_A, fake_B = self.sess.run([self.fakeA, self.fakeB],
                                               feed_dict={
                                                   self.realA_ph: realAimage,
                                                   self.realB_ph: realBimage
                                               })
                realApool.push(realAimage)
                realBpool.push(realBimage)
                fakeApool.push(fake_A)
                fakeBpool.push(fake_B)

                # Update D network
                _, summary_str = self.sess.run([self.db_optim, self.db_sum],
                                               feed_dict={
                                                   self.realB_ph:
                                                   realBpool.all(),
                                                   self.fakeB_ph:
                                                   fakeBpool.all()
                                               })
                self.writer.add_summary(summary_str, counter)
                # Update D network
                _, summary_str = self.sess.run([self.da_optim, self.da_sum],
                                               feed_dict={
                                                   self.realA_ph:
                                                   realApool.all(),
                                                   self.fakeA_ph:
                                                   fakeApool.all()
                                               })
                self.writer.add_summary(summary_str, counter)
                # Update G network
                _, summary_str = self.sess.run(
                    [self.g_a2b_optim, self.g_a2b_sum],
                    feed_dict={
                        self.realA_ph: realAimage,
                        self.realB_ph: realBimage
                    })
                self.writer.add_summary(summary_str, counter)
                # Update G network
                _, summary_str = self.sess.run(
                    [self.g_b2a_optim, self.g_b2a_sum],
                    feed_dict={
                        self.realA_ph: realAimage,
                        self.realB_ph: realBimage
                    })
                self.writer.add_summary(summary_str, counter)

                print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" %
                       (epoch, idx, self.args.nb_batch,
                        time.time() - start_time)))

                counter += 1

                if counter % self.args.sample_freq == 1:
                    self.sample_model(self.args.sampledir, epoch, idx, counter)

                if counter % 1000 == 2:
                    self.save(self.args.checkpointdir, counter)
Ejemplo n.º 7
0
    def train(self,
              ART_dir,
              COCO_dir,
              batch_size=10,
              TB=False,
              checkpoint='./checkpoint/'):
        # build content data generator
        content_files = [
            join(COCO_dir, f) for f in os.listdir(COCO_dir)
            if isfile(join(COCO_dir, f))
        ]
        content_data_generator = utils.image_generator(content_files,
                                                       batch_size=batch_size)

        # build style data generator
        artists_list = ['vincent', 'monet', 'cezanne', 'katsu']
        style_data_generator = []
        for art in artists_list:
            artist_dir = ART_dir + '/images_256_' + art
            art_files = [
                join(artist_dir, f) for f in os.listdir(artist_dir)
                if isfile(join(artist_dir, f))
            ]
            style_data_generator.append(
                utils.image_generator(art_files, batch_size=batch_size))

        nb_batch = int(math.ceil((len(content_files) + 0.) / batch_size))

        ### log settings
        if args.TB:
            if not os.path.exists(args.logdir):
                os.makedirs(args.logdir)
            writer = tf.summary.FileWriter(args.logdir,
                                           graph=tf.get_default_graph())

        ### saver settings
        saver = tf.train.Saver()
        if not os.path.exists(checkpoint):
            os.makedirs(checkpoint)

        init = tf.global_variables_initializer()
        print('start training')
        with tf.Session() as sess:
            sess.run(init)
            for ep in xrange(args.epoch):
                ep_time = time.time()
                ep_loss = 0.
                for bs in xrange(nb_batch):
                    for art in range(len(artists_list)):
                        for _ in range(10):  # n critic
                            # Read data
                            content_images = next(content_data_generator)
                            style_images = next(style_data_generator[art])

                            content_img = np.repeat(content_images[0:1],
                                                    batch_size,
                                                    axis=0)
                            style_img = np.repeat(style_images[0:1],
                                                  batch_size,
                                                  axis=0)

                            _, _, D_loss = sess.run(
                                [self.D_solver, self.clip_D, self.D_loss],
                                feed_dict={
                                    self.style_image_x: style_images,
                                    self.real_image_x: content_images,
                                    self.style_image: style_img,
                                    self.real_image: content_img
                                })
                        _, G_loss, Cycle_loss, Gx_out, Gy_out = sess.run(
                            [
                                self.G_solver, self.G_loss, self.cycle_loss,
                                self.Gx_out, self.Gy_out
                            ],
                            feed_dict={
                                self.style_image_x: style_images,
                                self.real_image_x: content_images,
                                self.style_image: style_img,
                                self.real_image: content_img
                            })

                        print(
                            'Epoch:{:5}  Step:{:5}  D_loss{:f} G_loss:{:f} Cycle_loss:{:f}'
                            .format(ep, bs, D_loss, G_loss, Cycle_loss))

                    # save log
                    if (bs + 1) % 10 == 0:
                        # save Tensorbroad log
                        if args.TB:
                            writer.add_summary(summary, ep * nb_batch + bs)
                        print 'save image'  # NOTE debug
                        utils.save_rgb('haha.jpg'.format(ep, bs),
                                       Gx_out[0][np.newaxis, :])
                        utils.save_rgb('haha2.jpg'.format(ep, bs),
                                       Gy_out[0][np.newaxis, :])
                        utils.save_rgb('haha_in.jpg'.format(ep, bs),
                                       content_img[0][np.newaxis, :])
                        utils.save_rgb('haha2_in.jpg'.format(ep, bs),
                                       style_img[0][np.newaxis, :])
                    if (bs + 1) % 1000 == 0:
                        saver.save(sess,
                                   join(args.checkpoint, 'model'),
                                   global_step=ep * nb_batch + bs)
Ejemplo n.º 8
0
def optimize(style_name, style_path, epochs, batch_size, learning_rate, style_w,
             content_w, tv_w, save_step, checkpoint_path,
             test_image_name, test_image_path, eval_step, debug=False):
    """
    input a list of file names to batch into the model
    """
    style_image = load_image(style_path, expand_dims=True)
    # style_input = tf.constant(style_image, tf.float32)

    # Compute the outputs for the style image that will be
    # used for the calculation of the loss function. This
    # includes the gram matricies of the activations used
    # for style loss and the activations of the layer used for
    # content loss.
    with tf.name_scope("style_comp"):
        style_image_norm = normalize(style_image)
        style_act_dict = vgg(style_image_norm)
        style_gram_dict = {}
        with tf.Session() as sess:
            style_content_layer = sess.run(style_act_dict[content_layer])
            for key, act in style_act_dict.items():
                style_act_dict[key] = sess.run(style_act_dict[key])
            for layer in style_layers:
                style_gram_dict[layer] = sess.run(gram(style_act_dict[layer]))

    if debug:
        for layer, act in style_act_dict.items():
            # Save the first 16 activations
            fig, axes = plt.subplots(4, 4)
            for i in range(16):
                j, k = i % 4, i // 4
                axes[j, k].imshow(act[0][:,:,i])
            fig.suptitle("Activations for Layer: {}".format(layer))
            plt.show()
        for layer, gram_ in style_gram_dict.items():
            fig, axes = plt.subplots(4, 4)
            for i in range(16):
                j, k = i % 4, i // 4
                axes[j, k].imshow(gram_[0][:,:])
            fig.suptitle("Visualize Gram for Layer: {}".format(layer))
            plt.show()

    # Compute the content image activations
    input_image = tf.placeholder(tf.float32, shape=[batch_size, 256, 256, 3], name='image_input')
    input_image_norm = normalize(input_image)
    input_content_layer = vgg(input_image_norm)[content_layer]

    output_image = transform_net(input_image/255.0)
    norm_output_image = normalize(output_image)
    output_act_dict = vgg(norm_output_image)
    output_gram_dict = {}
    for key, act in output_act_dict.items():
        output_gram_dict[key] = gram(act)

    # calculate the losses
    style_losses = []
    for l in style_layers:
        style_losses.append(layer_style_loss(output_gram_dict[l], style_gram_dict[l]))
    total_var_loss = tv_loss(output_image, batch_size)
    style_loss = tf.add_n(style_losses) / batch_size
    content_loss = layer_content_loss(input_content_layer, output_act_dict[content_layer])
    loss = style_w * style_loss + content_w * content_loss + tv_w * total_var_loss
    tf.summary.scalar('loss', loss)

    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step=global_step)

    with tf.Session() as sess:
        writer = tf.summary.FileWriter("summaries/")
        writer.add_graph(sess.graph)

        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        if os.path.isdir(checkpoint_path):
            ckpt = tf.train.get_checkpoint_state(checkpoint_path)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("No Checkpoint Found. Training From Scratch.")
        else:
            saver.restore(sess, checkpoint_path)
        for i in range(epochs):
            start_time = time.time()
            for batch in image_generator('../data/train2014', batch_size):
                feed_dict = {input_image:batch}
                optimizer.run(feed_dict=feed_dict)
                step = global_step.eval()
                if step % save_step == 0:
                    loss_list = [style_loss, content_loss, total_var_loss, loss]
                    losses = sess.run(loss_list, feed_dict=feed_dict)
                    seconds = time.time() - start_time
                    print('Step {}\n   Loss: {:5.1f}'.format(step, losses[3]))
                    print('   Style Loss: {:5.1f}'.format(losses[0]))
                    print('   Content Loss: {:5.1f}'.format(losses[1]))
                    print('   TV Loss: {:5.1f}'.format(losses[2]))
                    print('   Took: {} seconds'.format(seconds))
                    saver.save(sess, "checkpoints/model", step)
                    start_time = time.time()

                # As a test during the training time, record evals
                # of the various models to make sure training is
                # happening
                if step % eval_step == 0:
                    output_path = '../images/stylized/{}_{}.jpg'.format(test_image_name, step)
                    test_image = load_image(test_image_path)
                    test_batch = np.array([test_image for i in range(batch_size)])
                    test_image_out = sess.run(output_image,
                                              feed_dict={input_image:test_batch})
                    save_image(output_path, test_image_out[0, :, :, :])
Ejemplo n.º 9
0
def train(args,
          model,
          Load_numpy=False,
          multi_gpu=False,
          load_augmented_data=False,
          load_mask=False,
          class_mode="categorical"):
    print('-' * 30 + 'Begin: training ' + '-' * 30)
    train_data_datagen = ImageDataGenerator(
    )  #horizontal_flip=True, zoom_range=0.2, rotation_range=90., shear_range=0.2)
    valid_data_datagen = ImageDataGenerator(
    )  #horizontal_flip=True, zoom_range=0.2, rotation_range=90., shear_range=0.2)
    seed = 1
    # Prepare generators..
    if (load_augmented_data):
        print("loading the previous augmented data...")
        valid_img_dataset = np.load('./test_augmented_data_img.npy')
        valid_mask_dataset = np.load('./test_augmented_data_mask.npy')
        train_img_dataset = np.load('./train_augmented_data_img.npy')
        train_mask_dataset = np.load('./train_augmented_data_mask.npy')
        print("training set: ", train_img_dataset.shape)
        print("validation set: ", valid_img_dataset.shape)
        train_input_generator = train_data_datagen.flow(
            train_img_dataset, train_mask_dataset, batch_size=args.batch_size)
        valid_input_generator = valid_data_datagen.flow(
            valid_img_dataset, valid_mask_dataset, batch_size=args.batch_size)
    elif (Load_numpy):
        print("Proc: Loading the image list...")
        train_img_dataset = np.load('./train_img_dataset_32.pkl.npy')
        train_mask_dataset = np.load('./train_mask_dataset_32.pkl.npy')
        train_img_dataset, train_mask_dataset = utils.PrepareData(
            train_img_dataset,
            train_mask_dataset,
            "train",
            args.input_shape,
            color_normalization=args.color)
        print("train image number:", train_img_dataset.shape[0])
        valid_img_dataset = np.load('./valid_img_dataset_32.pkl.npy')
        valid_mask_dataset = np.load('./valid_mask_dataset_32.pkl.npy')
        valid_img_dataset, valid_mask_dataset = utils.PrepareData(
            valid_img_dataset,
            valid_mask_dataset,
            "test",
            args.target_size,
            color_normalization=args.color)
        print("validation image number:", valid_img_dataset.shape[0])
        train_data_datagen = ImageDataGenerator()
        valid_data_datagen = ImageDataGenerator()
        print(train_img_dataset.shape)
        train_input_generator = train_data_datagen.flow(
            train_img_dataset, train_mask_dataset, batch_size=args.batch_size)
        valid_input_generator = valid_data_datagen.flow(
            valid_img_dataset, valid_mask_dataset, batch_size=args.batch_size)

    elif load_mask:
        print("Proc: Generating the image list...")
        train_input_generator = train_data_datagen.flow_from_directory(
            args.path_train + "/input",
            seed=seed,
            mask_directory=args.path_train + "/mask",
            # color_mode="binary",
            equalize_adaphist=False,
            rescale_intensity=False,
            set_random_clipping=True,
            generate_HE=True,
            generate_LAB=True,
            max_image_number=0,
            target_size=args.input_shape,
            batch_size=args.batch_size,
            class_mode='mask')

        validation_input_generator = valid_data_datagen.flow_from_directory(
            args.path_validation + "/input",
            seed=seed,
            mask_directory=args.path_validation + "/mask",
            max_image_number=0,
            equalize_adaphist=False,
            rescale_intensity=False,
            set_random_clipping=True,
            generate_HE=True,
            generate_LAB=True,
            target_size=args.input_shape,
            batch_size=args.batch_size,
            class_mode='mask')
    else:
        print("Proc: Generating the image list...")
        train_input_generator = train_data_datagen.flow_from_directory(
            args.path_train,
            seed=seed,
            mask_directory=None,
            equalize_adaphist=False,
            rescale_intensity=False,
            set_random_clipping=True,
            generate_HE=True,
            generate_LAB=True,
            max_image_number=0,
            target_size=args.input_shape,
            batch_size=args.batch_size,
            class_mode=class_mode)

        validation_input_generator = valid_data_datagen.flow_from_directory(
            args.path_validation,
            seed=seed,
            mask_directory=None,
            max_image_number=0,
            equalize_adaphist=False,
            rescale_intensity=False,
            set_random_clipping=True,
            generate_HE=True,
            generate_LAB=True,
            target_size=args.input_shape,
            batch_size=args.batch_size,
            class_mode=class_mode)

    print("Done: Image lists are created...")
    # callbacks
    print("Proc: Preprare the callbacks...")
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    #lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (0.9 ** epoch))
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        factor=args.lr_factor,
        patience=args.change_lr_threshold,
        min_lr=args.min_lr,
        verbose=1)
    history_register = keras.callbacks.History()

    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_loss',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1,
                                           multi_gpu_mode=multi_gpu,
                                           name_of_model="model_1")
    print("Done: callbacks are created...")
    # compile the modelg
    # , self.precision, self.recall, "acc",
    print("Proc: Compile the model...")
    model.compile(
        optimizer=optimizers.Adam(lr=args.args.lr),
        loss=[loss_functions.margin_loss, "mse"],
        metrics={
            'ucnet':
            ['acc', metrics.precision, metrics.recall, metrics.dice_coef]
        },
        loss_weights=[1., args.lam_recon])

    print("Done: the model was complied...")
    print("Proc: Training the model...")
    # Training with data augmentation
    if Load_numpy:
        train_steps_per_epoch = np.math.ceil(train_img_dataset.shape[0] /
                                             args.batch_size)
        valid_steps_per_epoch = np.math.ceil(valid_img_dataset.shape[0] /
                                             args.batch_size)

        model.fit_generator(
            generator=utils.image_generator_flow(train_input_generator,
                                                 reconstruction=False,
                                                 reshape=False,
                                                 generate_weight=False,
                                                 run_one_vs_all_mode=False),
            steps_per_epoch=train_steps_per_epoch,
            epochs=args.epochs,
            use_multiprocessing=True,
            validation_steps=valid_steps_per_epoch,
            validation_data=utils.image_generator_flow(
                valid_input_generator,
                reconstruction=False,
                reshape=False,
                generate_weight=False,
                run_one_vs_all_mode=False),
            callbacks=[log, tb, checkpoint, reduce_lr])  # lr_decay

    else:

        train_steps_per_epoch = np.math.ceil(
            (train_input_generator.samples) / args.batch_size)
        valid_steps_per_epoch = np.math.ceil(
            (validation_input_generator.samples) / args.batch_size)

        model.fit_generator(
            generator=utils.image_generator(train_input_generator,
                                            bool(args.use_cropping),
                                            args.input_shape,
                                            cropping_size=args.cropping_size),
            steps_per_epoch=train_steps_per_epoch,
            epochs=args.epochs,
            use_multiprocessing=True,
            validation_steps=valid_steps_per_epoch,
            validation_data=utils.image_generator(
                validation_input_generator,
                bool(args.use_cropping),
                args.input_shape,
                cropping_size=args.cropping_size),
            callbacks=[log, tb, checkpoint, reduce_lr])  #lr_decay

    # serialize weights to HDF5
    model.save(args.save_dir + '/trained_model.h5')
    # model.evaluate_generator
    # from utils import plot_log
    # plot_log(args.save_dir + '/log.csv', show=True)
    print('-' * 30 + 'End: training ' + '-' * 30)

    return model